Skip to content

Commit

Permalink
Fix automl to treat binary as categorical when missing values present (
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Sep 7, 2021
1 parent 914261e commit 6f3a720
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
8 changes: 5 additions & 3 deletions ludwig/automl/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def get_field_metadata(
metadata = []
for idx, field in enumerate(fields):
missing_value_percent = 1 - float(field.nonnull_values) / row_count
dtype = infer_type(field)
dtype = infer_type(field, missing_value_percent)
metadata.append(
FieldMetadata(
name=field.name,
Expand Down Expand Up @@ -278,19 +278,21 @@ def get_field_metadata(


def infer_type(
field: FieldInfo
field: FieldInfo,
missing_value_percent: float,
) -> str:
"""
Perform type inference on field
# Inputs
:param field: (FieldInfo) object describing field
:param missing_value_percent: (float) percent of missing values in the column
# Return
:return: (str) feature type
"""
distinct_values = field.distinct_values
if distinct_values == 2:
if distinct_values == 2 and missing_value_percent == 0:
return BINARY

if field.image_values >= 3:
Expand Down
17 changes: 9 additions & 8 deletions tests/ludwig/automl/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
TARGET_NAME = 'target'


@pytest.mark.parametrize("distinct_values,avg_words,img_values,expected", [
(ROW_COUNT, 0, 0, NUMERICAL),
(10, 0, 0, CATEGORY),
(2, 0, 0, BINARY),
(ROW_COUNT, 3, 0, TEXT),
(ROW_COUNT, 1, ROW_COUNT, IMAGE),
@pytest.mark.parametrize("distinct_values,avg_words,img_values,missing_vals,expected", [
(ROW_COUNT, 0, 0, 0.0, NUMERICAL),
(10, 0, 0, 0.0, CATEGORY),
(2, 0, 0, 0.0, BINARY),
(2, 0, 0, 0.1, CATEGORY),
(ROW_COUNT, 3, 0, 0.0, TEXT),
(ROW_COUNT, 1, ROW_COUNT, 0.0, IMAGE),
])
def test_infer_type(distinct_values, avg_words, img_values, expected):
def test_infer_type(distinct_values, avg_words, img_values, missing_vals, expected):
field = FieldInfo(
name='foo',
dtype='object',
distinct_values=distinct_values,
avg_words=avg_words,
image_values=img_values,
)
assert infer_type(field) == expected
assert infer_type(field, missing_vals) == expected


@pytest.mark.parametrize("idx,distinct_values,dtype,name,expected", [
Expand Down

0 comments on commit 6f3a720

Please sign in to comment.