Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tf-legacy] Employ a fallback str2bool mapping from the feature column's distinct values when the feature's values aren't boolean-like. #1471

Merged
merged 2 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions ludwig/features/binary_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class BinaryFeatureMixin:
},
"fill_value": fill_value_schema,
"computed_fill_value": fill_value_schema,
"fallback_true_label": {'type': 'string'},
}

@staticmethod
Expand All @@ -80,14 +81,25 @@ def get_feature_meta(column, preprocessing_parameters, backend):
f"found: {distinct_values.values.tolist()}"
)

str2bool = {v: strings_utils.str2bool(v) for v in distinct_values}
bool2str = [
k for k, v in sorted(str2bool.items(), key=lambda item: item[1])
]
if 'fallback_true_label' in preprocessing_parameters:
fallback_true_label = preprocessing_parameters['fallback_true_label']
else:
fallback_true_label = sorted(distinct_values)[0]
logger.warning(
f"In case binary feature {column.name} doesn't have conventional boolean values, "
f"we will interpret {fallback_true_label} as 1 and the other values as 0. "
f"If this is incorrect, please use the category feature type or "
f"manually specify the true value with `preprocessing.fallback_true_label`.")

str2bool = {v: strings_utils.str2bool(
v, fallback_true_label) for v in distinct_values}
bool2str = [k for k, v in sorted(
str2bool.items(), key=lambda item: item[1])]

return {
"str2bool": str2bool,
"bool2str": bool2str,
"fallback_true_label": fallback_true_label
}

@staticmethod
Expand Down
22 changes: 19 additions & 3 deletions ludwig/utils/strings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,23 @@ def strip_accents(s):
if unicodedata.category(c) != 'Mn')


def str2bool(v):
return str(v).lower() in BOOL_TRUE_STRS
def str2bool(v, fallback_true_label=None):
"""Returns bool representation of the given value v.
Check the value against global bool string lists.
Fallback to using fallback_true_label as True if the value isn't in the global bool string lists.
args:
v: Value to get the bool representation for.
fallback_true_label: (str) label to use as 'True'.
"""
v_str = str(v).lower()
if v_str in BOOL_TRUE_STRS:
return True
if v_str in BOOL_FALSE_STRS:
return False
if fallback_true_label is None:
raise ValueError(
f'Cannot automatically map value {v} to a boolean and no `fallback_true_label` specified.')
return v == fallback_true_label


def match_replace(string_to_match, list_regex):
Expand Down Expand Up @@ -149,7 +164,8 @@ def create_vocabulary(
elif vocab_file is not None:
vocab = load_vocabulary(vocab_file)

processed_lines = data.map(lambda line: tokenizer(line.lower() if lowercase else line))
processed_lines = data.map(lambda line: tokenizer(
line.lower() if lowercase else line))
processed_counts = processed_lines.explode().value_counts(sort=False)
processed_counts = processor.compute(processed_counts)
unit_counts = Counter(dict(processed_counts))
Expand Down
24 changes: 24 additions & 0 deletions tests/ludwig/utils/test_strings_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from ludwig.utils import strings_utils


def test_str_to_bool():
# Global bool mappings are used.
assert strings_utils.str2bool('True') == True
assert strings_utils.str2bool('true') == True
assert strings_utils.str2bool('0') == False

# Error raised if non-mapped value is encountered and no fallback is specified.
with pytest.raises(Exception):
strings_utils.str2bool('bot')

# Fallback label is used.
assert strings_utils.str2bool('bot', fallback_true_label='bot') == True
assert strings_utils.str2bool('human', fallback_true_label='bot') == False
assert strings_utils.str2bool('human', fallback_true_label='human') == True
assert strings_utils.str2bool(
'human', fallback_true_label='Human') == False

# Fallback label is used, strictly as a fallback.
assert strings_utils.str2bool('True', fallback_true_label='False') == True