diff --git a/ludwig/automl/base_config.py b/ludwig/automl/base_config.py index 1358490f27e..375d66c6480 100644 --- a/ludwig/automl/base_config.py +++ b/ludwig/automl/base_config.py @@ -344,6 +344,8 @@ def infer_type( :return: (str) feature type """ num_distinct_values = field.num_distinct_values + if num_distinct_values == 0: + return CATEGORY distinct_values = field.distinct_values if num_distinct_values <= 2 and missing_value_percent == 0: # Check that all distinct values are conventional bools. @@ -385,6 +387,9 @@ def should_exclude(idx: int, field: FieldInfo, dtype: str, row_count: int, targe if field.name in targets: return False + if field.num_distinct_values == 0: + return True + distinct_value_percent = float(field.num_distinct_values) / row_count if distinct_value_percent == 1.0: upper_name = field.name.upper() diff --git a/tests/ludwig/automl/test_base_config.py b/tests/ludwig/automl/test_base_config.py index 11cbe73d92f..4f68885ace7 100644 --- a/tests/ludwig/automl/test_base_config.py +++ b/tests/ludwig/automl/test_base_config.py @@ -16,6 +16,7 @@ (2, ['1.5', '3.7'], 0, 0, 0.1, NUMERICAL), (ROW_COUNT, [], 3, 0, 0.0, TEXT), (ROW_COUNT, [], 1, ROW_COUNT, 0.0, IMAGE), + (0, [], 0, 0, 0.0, CATEGORY), ]) def test_infer_type(num_distinct_values, distinct_values, avg_words, img_values, missing_vals, expected): field = FieldInfo( @@ -36,6 +37,7 @@ def test_infer_type(num_distinct_values, distinct_values, avg_words, img_values, (0, ROW_COUNT, TEXT, 'name', False), (0, ROW_COUNT, NUMERICAL, TARGET_NAME, False), (0, ROW_COUNT - 1, NUMERICAL, 'id', False), + (0, 0, CATEGORY, 'empty_col', True), ]) def test_should_exclude(idx, num_distinct_values, dtype, name, expected): field = FieldInfo(