Skip to content

Commit

Permalink
fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668. (#4065
Browse files Browse the repository at this point in the history
)

* fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668.

Signed-off-by: Shuchu Han <shuchu.han@gmail.com>

* fix: fix typo.

Signed-off-by: Shuchu Han <shuchu.han@gmail.com>

---------

Signed-off-by: Shuchu Han <shuchu.han@gmail.com>
  • Loading branch information
shuchu committed Apr 3, 2024
1 parent 7f1557b commit ec4c15c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
45 changes: 24 additions & 21 deletions sdk/python/feast/infra/registry/sql.py
Expand Up @@ -205,7 +205,7 @@ def teardown(self):
saved_datasets,
validation_references,
}:
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = delete(t)
conn.execute(stmt)

Expand Down Expand Up @@ -399,7 +399,7 @@ def apply_feature_service(
)

def delete_data_source(self, name: str, project: str, commit: bool = True):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = delete(data_sources).where(
data_sources.c.data_source_name == name,
data_sources.c.project_id == project,
Expand Down Expand Up @@ -441,16 +441,19 @@ def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureVie
)

def _list_project_metadata(self, project: str) -> List[ProjectMetadata]:
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.project_id == project,
)
rows = conn.execute(stmt).all()
if rows:
project_metadata = ProjectMetadata(project_name=project)
for row in rows:
if row["metadata_key"] == FeastMetadataKeys.PROJECT_UUID.value:
project_metadata.project_uuid = row["metadata_value"]
if (
row._mapping["metadata_key"]
== FeastMetadataKeys.PROJECT_UUID.value
):
project_metadata.project_uuid = row._mapping["metadata_value"]
break
# TODO(adchia): Add other project metadata in a structured way
return [project_metadata]
Expand Down Expand Up @@ -557,7 +560,7 @@ def apply_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, "feature_view_name") == name,
table.c.project_id == project,
Expand Down Expand Up @@ -612,11 +615,11 @@ def get_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(getattr(table.c, "feature_view_name") == name)
row = conn.execute(stmt).first()
if row:
return row["user_metadata"]
return row._mapping["user_metadata"]
else:
raise FeatureViewNotFoundException(feature_view.name, project=project)

Expand Down Expand Up @@ -674,7 +677,7 @@ def _apply_object(
name = name or (obj.name if hasattr(obj, "name") else None)
assert name, f"name needs to be provided for {obj}"

with self.engine.connect() as conn:
with self.engine.begin() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
stmt = select(table).where(
Expand Down Expand Up @@ -723,7 +726,7 @@ def _apply_object(

def _maybe_init_project_metadata(self, project):
# Initialize project metadata if needed
with self.engine.connect() as conn:
with self.engine.begin() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
stmt = select(feast_metadata).where(
Expand All @@ -732,7 +735,7 @@ def _maybe_init_project_metadata(self, project):
)
row = conn.execute(stmt).first()
if row:
usage.set_current_project_uuid(row["metadata_value"])
usage.set_current_project_uuid(row._mapping["metadata_value"])
else:
new_project_uuid = f"{uuid.uuid4()}"
values = {
Expand All @@ -753,7 +756,7 @@ def _delete_object(
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
Expand All @@ -777,13 +780,13 @@ def _get_object(
):
self._maybe_init_project_metadata(project)

with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
row = conn.execute(stmt).first()
if row:
_proto = proto_class.FromString(row[proto_field_name])
_proto = proto_class.FromString(row._mapping[proto_field_name])
return python_class.from_proto(_proto)
if not_found_exception:
raise not_found_exception(name, project)
Expand All @@ -799,20 +802,20 @@ def _list_objects(
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(table).where(table.c.project_id == project)
rows = conn.execute(stmt).all()
if rows:
return [
python_class.from_proto(
proto_class.FromString(row[proto_field_name])
proto_class.FromString(row._mapping[proto_field_name])
)
for row in rows
]
return []

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand Down Expand Up @@ -846,7 +849,7 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str):
conn.execute(insert_stmt)

def _get_last_updated_metadata(self, project: str):
with self.engine.connect() as conn:
with self.engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand All @@ -855,13 +858,13 @@ def _get_last_updated_metadata(self, project: str):
row = conn.execute(stmt).first()
if not row:
return None
update_time = int(row["last_updated_timestamp"])
update_time = int(row._mapping["last_updated_timestamp"])

return datetime.utcfromtimestamp(update_time)

def _get_all_projects(self) -> Set[str]:
projects = set()
with self.engine.connect() as conn:
with self.engine.begin() as conn:
for table in {
entities,
data_sources,
Expand All @@ -872,6 +875,6 @@ def _get_all_projects(self) -> Set[str]:
stmt = select(table)
rows = conn.execute(stmt).all()
for row in rows:
projects.add(row["project_id"])
projects.add(row._mapping["project_id"])

return projects
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -57,7 +57,7 @@
"pygments>=2.12.0,<3",
"PyYAML>=5.4.0,<7",
"requests",
"SQLAlchemy[mypy]>1,<2",
"SQLAlchemy[mypy]>1",
"tabulate>=0.8.0,<1",
"tenacity>=7,<9",
"toml>=0.10.0,<1",
Expand Down

0 comments on commit ec4c15c

Please sign in to comment.