Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed schema validation to handle null preprocessing values for strings #1344

Merged
merged 2 commits into from
Oct 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ludwig/features/date_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DateFeatureMixin:
'missing_value_strategy': {'type': 'string', 'enum': MISSING_VALUE_STRATEGY_OPTIONS},
'fill_value': {'type': 'string'},
'computed_fill_value': {'type': 'string'},
'datetime_format': {'type': 'string'},
'datetime_format': {'type': ['string', 'null']},
}

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class SequenceFeatureMixin:
'padding': {'type': 'string', 'enum': ['right', 'left']},
'tokenizer': {'type': 'string', 'enum': sorted(list(tokenizer_registry.keys()))},
'lowercase': {'type': 'boolean'},
'vocab_file': {'type': ['string', 'null']},
'missing_value_strategy': {'type': 'string', 'enum': MISSING_VALUE_STRATEGY_OPTIONS},
'fill_value': {'type': 'string'},
'computed_fill_value': {'type': 'string'},
Expand Down
6 changes: 3 additions & 3 deletions ludwig/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ class TextFeatureMixin:

preprocessing_schema = {
'char_tokenizer': {'type': 'string', 'enum': sorted(list(tokenizer_registry.keys()))},
'char_vocab_file': {'type': 'string'},
'char_vocab_file': {'type': ['string', 'null']},
'char_sequence_length_limit': {'type': 'integer', 'minimum': 0},
'char_most_common': {'type': 'integer', 'minimum': 0},
'word_tokenizer': {'type': 'string', 'enum': sorted(list(tokenizer_registry.keys()))},
'pretrained_model_name_or_path': {'type': 'string'},
'word_vocab_file': {'type': 'string'},
'pretrained_model_name_or_path': {'type': ['string', 'null']},
'word_vocab_file': {'type': ['string', 'null']},
'word_sequence_length_limit': {'type': 'integer', 'minimum': 0},
'word_most_common': {'type': 'integer', 'minimum': 0},
'padding_symbol': {'type': 'string'},
Expand Down
47 changes: 47 additions & 0 deletions tests/ludwig/utils/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@

import pytest
from jsonschema.exceptions import ValidationError

from ludwig.features.audio_feature import AudioFeatureMixin
from ludwig.features.bag_feature import BagFeatureMixin
from ludwig.features.binary_feature import BinaryFeatureMixin
from ludwig.features.category_feature import CategoryFeatureMixin
from ludwig.features.date_feature import DateFeatureMixin
from ludwig.features.h3_feature import H3FeatureMixin
from ludwig.features.image_feature import ImageFeatureMixin
from ludwig.features.numerical_feature import NumericalFeatureMixin
from ludwig.features.sequence_feature import SequenceFeatureMixin
from ludwig.features.set_feature import SetFeatureMixin
from ludwig.features.text_feature import TextFeatureMixin
from ludwig.features.timeseries_feature import TimeseriesFeatureMixin
from ludwig.features.vector_feature import VectorFeatureMixin
from ludwig.utils.defaults import merge_with_defaults

from ludwig.utils.schema import validate_config, OUTPUT_FEATURE_TYPES
Expand Down Expand Up @@ -249,3 +263,36 @@ def test_config_fill_values():
}
with pytest.raises(ValidationError):
validate_config(config)


def test_validate_with_preprocessing_defaults():
config = {
"input_features": [
audio_feature('/tmp/destination_folder',
preprocessing=AudioFeatureMixin.preprocessing_defaults),
bag_feature(preprocessing=BagFeatureMixin.preprocessing_defaults),
binary_feature(preprocessing=BinaryFeatureMixin.preprocessing_defaults),
category_feature(preprocessing=CategoryFeatureMixin.preprocessing_defaults),
date_feature(preprocessing=DateFeatureMixin.preprocessing_defaults),
h3_feature(preprocessing=H3FeatureMixin.preprocessing_defaults),
image_feature('/tmp/destination_folder',
preprocessing=ImageFeatureMixin.preprocessing_defaults),
numerical_feature(preprocessing=NumericalFeatureMixin.preprocessing_defaults),
sequence_feature(preprocessing=SequenceFeatureMixin.preprocessing_defaults),
set_feature(preprocessing=SetFeatureMixin.preprocessing_defaults),
text_feature(preprocessing=TextFeatureMixin.preprocessing_defaults),
timeseries_feature(preprocessing=TimeseriesFeatureMixin.preprocessing_defaults),
vector_feature(preprocessing=VectorFeatureMixin.preprocessing_defaults),
],
"output_features": [{"name": "target", "type": "category"}],
"training": {
"decay": True,
"learning_rate": 0.001,
"validation_field": "target",
"validation_metric": "accuracy"
},
}

validate_config(config)
config = merge_with_defaults(config)
validate_config(config)