Skip to content

Commit

Permalink
fix: Fix bug where SQL registry was incorrectly writing infra config …
Browse files Browse the repository at this point in the history
…around online stores (#3394)

fix: Fix bug where SQL registry was incorrectly writing info around sqlite online store

Signed-off-by: Danny Chiao <danny@tecton.ai>

Signed-off-by: Danny Chiao <danny@tecton.ai>
  • Loading branch information
adchia committed Dec 15, 2022
1 parent fd97254 commit 6bcf77c
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 84 deletions.
178 changes: 94 additions & 84 deletions sdk/python/feast/infra/registry/sql.py
Expand Up @@ -207,14 +207,14 @@ def get_stream_feature_view(
self, name: str, project: str, allow_cache: bool = False
):
return self._get_object(
stream_feature_views,
name,
project,
StreamFeatureViewProto,
StreamFeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
table=stream_feature_views,
name=name,
project=project,
proto_class=StreamFeatureViewProto,
python_class=StreamFeatureView,
id_field_name="feature_view_name",
proto_field_name="feature_view_proto",
not_found_exception=FeatureViewNotFoundException,
)

def list_stream_feature_views(
Expand All @@ -230,101 +230,105 @@ def list_stream_feature_views(

def apply_entity(self, entity: Entity, project: str, commit: bool = True):
return self._apply_object(
entities, project, "entity_name", entity, "entity_proto"
table=entities,
project=project,
id_field_name="entity_name",
obj=entity,
proto_field_name="entity_proto",
)

def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
return self._get_object(
entities,
name,
project,
EntityProto,
Entity,
"entity_name",
"entity_proto",
EntityNotFoundException,
table=entities,
name=name,
project=project,
proto_class=EntityProto,
python_class=Entity,
id_field_name="entity_name",
proto_field_name="entity_proto",
not_found_exception=EntityNotFoundException,
)

def get_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureView:
return self._get_object(
feature_views,
name,
project,
FeatureViewProto,
FeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
table=feature_views,
name=name,
project=project,
proto_class=FeatureViewProto,
python_class=FeatureView,
id_field_name="feature_view_name",
proto_field_name="feature_view_proto",
not_found_exception=FeatureViewNotFoundException,
)

def get_on_demand_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> OnDemandFeatureView:
return self._get_object(
on_demand_feature_views,
name,
project,
OnDemandFeatureViewProto,
OnDemandFeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
table=on_demand_feature_views,
name=name,
project=project,
proto_class=OnDemandFeatureViewProto,
python_class=OnDemandFeatureView,
id_field_name="feature_view_name",
proto_field_name="feature_view_proto",
not_found_exception=FeatureViewNotFoundException,
)

def get_request_feature_view(self, name: str, project: str):
return self._get_object(
request_feature_views,
name,
project,
RequestFeatureViewProto,
RequestFeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
table=request_feature_views,
name=name,
project=project,
proto_class=RequestFeatureViewProto,
python_class=RequestFeatureView,
id_field_name="feature_view_name",
proto_field_name="feature_view_proto",
not_found_exception=FeatureViewNotFoundException,
)

def get_feature_service(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureService:
return self._get_object(
feature_services,
name,
project,
FeatureServiceProto,
FeatureService,
"feature_service_name",
"feature_service_proto",
FeatureServiceNotFoundException,
table=feature_services,
name=name,
project=project,
proto_class=FeatureServiceProto,
python_class=FeatureService,
id_field_name="feature_service_name",
proto_field_name="feature_service_proto",
not_found_exception=FeatureServiceNotFoundException,
)

def get_saved_dataset(
self, name: str, project: str, allow_cache: bool = False
) -> SavedDataset:
return self._get_object(
saved_datasets,
name,
project,
SavedDatasetProto,
SavedDataset,
"saved_dataset_name",
"saved_dataset_proto",
SavedDatasetNotFound,
table=saved_datasets,
name=name,
project=project,
proto_class=SavedDatasetProto,
python_class=SavedDataset,
id_field_name="saved_dataset_name",
proto_field_name="saved_dataset_proto",
not_found_exception=SavedDatasetNotFound,
)

def get_validation_reference(
self, name: str, project: str, allow_cache: bool = False
) -> ValidationReference:
return self._get_object(
validation_references,
name,
project,
ValidationReferenceProto,
ValidationReference,
"validation_reference_name",
"validation_reference_proto",
ValidationReferenceNotFound,
table=validation_references,
name=name,
project=project,
proto_class=ValidationReferenceProto,
python_class=ValidationReference,
id_field_name="validation_reference_name",
proto_field_name="validation_reference_proto",
not_found_exception=ValidationReferenceNotFound,
)

def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]:
Expand Down Expand Up @@ -364,14 +368,14 @@ def get_data_source(
self, name: str, project: str, allow_cache: bool = False
) -> DataSource:
return self._get_object(
data_sources,
name,
project,
DataSourceProto,
DataSource,
"data_source_name",
"data_source_proto",
DataSourceObjectNotFoundException,
table=data_sources,
name=name,
project=project,
proto_class=DataSourceProto,
python_class=DataSource,
id_field_name="data_source_name",
proto_field_name="data_source_proto",
not_found_exception=DataSourceObjectNotFoundException,
)

def list_data_sources(
Expand Down Expand Up @@ -556,22 +560,28 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr

def update_infra(self, infra: Infra, project: str, commit: bool = True):
self._apply_object(
managed_infra, project, "infra_name", infra, "infra_proto", name="infra_obj"
table=managed_infra,
project=project,
id_field_name="infra_name",
obj=infra,
proto_field_name="infra_proto",
name="infra_obj",
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
infra_object = self._get_object(
managed_infra,
"infra_obj",
project,
InfraProto,
Infra,
"infra_name",
"infra_proto",
None,
table=managed_infra,
name="infra_obj",
project=project,
proto_class=InfraProto,
python_class=Infra,
id_field_name="infra_name",
proto_field_name="infra_proto",
not_found_exception=None,
)
infra_object = infra_object or InfraProto()
return Infra.from_proto(infra_object)
if infra_object:
return infra_object
return Infra()

def apply_user_metadata(
self,
Expand Down
28 changes: 28 additions & 0 deletions sdk/python/tests/unit/test_sql_registry.py
Expand Up @@ -28,6 +28,8 @@
from feast.errors import FeatureViewNotFoundException
from feast.feature_view import FeatureView
from feast.field import Field
from feast.infra.infra_object import Infra
from feast.infra.online_stores.sqlite import SqliteTable
from feast.infra.registry.sql import SqlRegistry
from feast.on_demand_feature_view import on_demand_feature_view
from feast.repo_config import RegistryConfig
Expand Down Expand Up @@ -258,10 +260,20 @@ def test_apply_feature_view_success(sql_registry):
and feature_view.features[3].dtype == Array(Bytes)
and feature_view.entities[0] == "fs1_my_entity_1"
)
assert feature_view.ttl == timedelta(minutes=5)

# After the first apply, the created_timestamp should be the same as the last_update_timestamp.
assert feature_view.created_timestamp == feature_view.last_updated_timestamp

# Modify the feature view and apply again to test if diffing the online store table works
fv1.ttl = timedelta(minutes=6)
sql_registry.apply_feature_view(fv1, project)
feature_views = sql_registry.list_feature_views(project)
assert len(feature_views) == 1
feature_view = sql_registry.get_feature_view("my_feature_view_1", project)
assert feature_view.ttl == timedelta(minutes=6)

# Delete feature view
sql_registry.delete_feature_view("my_feature_view_1", project)
feature_views = sql_registry.list_feature_views(project)
assert len(feature_views) == 0
Expand Down Expand Up @@ -570,6 +582,22 @@ def test_update_infra(sql_registry):
project = "project"
infra = sql_registry.get_infra(project=project)

assert len(infra.infra_objects) == 0

# Should run update infra successfully
sql_registry.update_infra(infra, project)

# Should run update infra successfully when adding
new_infra = Infra()
new_infra.infra_objects.append(
SqliteTable(
path="/tmp/my_path.db",
name="my_table",
)
)
sql_registry.update_infra(new_infra, project)
infra = sql_registry.get_infra(project=project)
assert len(infra.infra_objects) == 1

# Try again since second time, infra should be not-empty
sql_registry.teardown()

0 comments on commit 6bcf77c

Please sign in to comment.