From a5a7d0675dd669470ec6ebf115da807ebe61f9cb Mon Sep 17 00:00:00 2001 From: John Davis Date: Wed, 11 Oct 2023 14:46:04 -0400 Subject: [PATCH] Fix SA2.0 usage in tool_shed.util.metadata_util --- lib/tool_shed/util/metadata_util.py | 32 ++++++++++++++++------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/lib/tool_shed/util/metadata_util.py b/lib/tool_shed/util/metadata_util.py index c41b62b7a7c0..0ca15a8aa919 100644 --- a/lib/tool_shed/util/metadata_util.py +++ b/lib/tool_shed/util/metadata_util.py @@ -5,7 +5,7 @@ TYPE_CHECKING, ) -from sqlalchemy import and_ +from sqlalchemy import select from galaxy.model.base import transaction from galaxy.tool_shed.util.hg_util import ( @@ -30,7 +30,6 @@ def get_all_dependencies(app, metadata_entry, processed_dependency_links=None): encoder = app.security.encode_id value_mapper = {"repository_id": encoder, "id": encoder, "user_id": encoder} metadata = metadata_entry.to_dict(value_mapper=value_mapper, view="element") - db = app.model.session returned_dependencies = [] required_metadata = get_dependencies_for_metadata_revision(app, metadata) if required_metadata is None: @@ -41,7 +40,9 @@ def get_all_dependencies(app, metadata_entry, processed_dependency_links=None): if dependency_link in processed_dependency_links: continue processed_dependency_links.append(dependency_link) - repository = db.query(app.model.Repository).get(app.security.decode_id(dependency_dict["repository_id"])) + repository = app.model.session.get( + app.model.Repository, app.security.decode_id(dependency_dict["repository_id"]) + ) dependency_dict["repository"] = repository.to_dict(value_mapper=value_mapper) if dependency_metadata.includes_tools: dependency_dict["tools"] = dependency_metadata.metadata["tools"] @@ -123,7 +124,7 @@ def get_latest_downloadable_changeset_revision(app, repository): def get_latest_repository_metadata(app, decoded_repository_id, downloadable=False): """Get last metadata defined for a specified repository from the database.""" sa_session = app.model.session - repository = sa_session.query(app.model.Repository).get(decoded_repository_id) + repository = sa_session.get(app.model.Repository, decoded_repository_id) if downloadable: changeset_revision = get_latest_downloadable_changeset_revision(app, repository) else: @@ -258,16 +259,10 @@ def repository_metadata_by_changeset_revision( # Make sure there are no duplicate records, and return the single unique record for the changeset_revision. # Duplicate records were somehow created in the past. The cause of this issue has been resolved, but we'll # leave this method as is for a while longer to ensure all duplicate records are removed. + sa_session = model_mapping.context - all_metadata_records = ( - sa_session.query(model_mapping.RepositoryMetadata) - .filter( - and_( - model_mapping.RepositoryMetadata.table.c.repository_id == id, - model_mapping.RepositoryMetadata.table.c.changeset_revision == changeset_revision, - ) - ) - .all() + all_metadata_records = get_metadata_by_changeset( + sa_session, id, changeset_revision, model_mapping.RepositoryMetadata ) if len(all_metadata_records) > 1: # Delete all records older than the last one updated. @@ -285,7 +280,7 @@ def repository_metadata_by_changeset_revision( def get_repository_metadata_by_id(app, id): """Get repository metadata from the database""" sa_session = app.model.session - return sa_session.query(app.model.RepositoryMetadata).get(app.security.decode_id(id)) + return sa_session.get(app.model.RepositoryMetadata, app.security.decode_id(id)) def get_repository_metadata_by_repository_id_changeset_revision(app, id, changeset_revision, metadata_only=False): @@ -348,3 +343,12 @@ def is_malicious(app, id, changeset_revision, **kwd): if repository_metadata: return repository_metadata.malicious return False + + +def get_metadata_by_changeset(session, repository_id, changeset_revision, repository_metadata_model): + stmt = ( + select(repository_metadata_model) + .where(repository_metadata_model.repository_id == repository_id) + .where(repository_metadata_model.changeset_revision == changeset_revision) + ) + return session.scalars(stmt).all()