diff --git a/feathr_project/feathr/anchor.py b/feathr_project/feathr/anchor.py index ec7162042..99ea2d975 100644 --- a/feathr_project/feathr/anchor.py +++ b/feathr_project/feathr/anchor.py @@ -30,7 +30,8 @@ def __init__(self, self.name = name self.features = features self.source = source - self.registry_tags=registry_tags + # adding typecheck for registry tags incase user could also do like this {"description", "this is a test feature"} as well as like this dictionary format {"description":"this is a test feature"} + self.registry_tags=registry_tags if isinstance(registry_tags,Dict) else dict([registry_tags]) self.validate_features() def validate_features(self): diff --git a/feathr_project/feathr/feature.py b/feathr_project/feathr/feature.py index b46585993..d052d2f3e 100644 --- a/feathr_project/feathr/feature.py +++ b/feathr_project/feathr/feature.py @@ -33,7 +33,8 @@ def __init__(self, FeatureBase.validate_feature_name(name) self.name = name self.feature_type = feature_type - self.registry_tags=registry_tags + # adding typecheck for registry tags incase user could also do like this {"description", "this is a test feature"} as well as like this dictionary format {"description":"this is a test feature"} + self.registry_tags=registry_tags if isinstance(registry_tags,Dict) else dict([registry_tags]) self.key = key if isinstance(key, List) else [key] # feature_alias: Rename the derived feature to `feature_alias`. Default to feature name. self.feature_alias = name diff --git a/feathr_project/feathr/source.py b/feathr_project/feathr/source.py index fb991612e..41c665679 100644 --- a/feathr_project/feathr/source.py +++ b/feathr_project/feathr/source.py @@ -48,7 +48,8 @@ def __init__(self, self.name = name self.event_timestamp_column = event_timestamp_column self.timestamp_format = timestamp_format - self.registry_tags = registry_tags + # adding typecheck for registry tags incase user could also do like this {"description", "this is a test feature"} as well as like this dictionary format {"description":"this is a test feature"} + self.registry_tags=registry_tags if isinstance(registry_tags,Dict) else dict([registry_tags]) def __eq__(self, other): """A source is equal to another if name is equal."""