diff --git a/lib/tool_shed/util/shed_util_common.py b/lib/tool_shed/util/shed_util_common.py index 548f8ede40e4..3f107c652ab3 100644 --- a/lib/tool_shed/util/shed_util_common.py +++ b/lib/tool_shed/util/shed_util_common.py @@ -5,10 +5,10 @@ import string from typing import TYPE_CHECKING -import sqlalchemy.orm.exc from sqlalchemy import ( - and_, false, + func, + select, true, ) @@ -89,46 +89,41 @@ def count_repositories_in_category(app: "ToolShedApp", category_id: str) -> int: - sa_session = app.model.session - return ( - sa_session.query(app.model.RepositoryCategoryAssociation) - .filter(app.model.RepositoryCategoryAssociation.table.c.category_id == app.security.decode_id(category_id)) - .count() + stmt = ( + select(func.count()) + .select_from(app.model.RepositoryCategoryAssociation) + .where(app.model.RepositoryCategoryAssociation.category_id == app.security.decode_id(category_id)) ) + return app.model.session.scalar(stmt) def get_categories(app: "ToolShedApp"): """Get all categories from the database.""" sa_session = app.model.session - return ( - sa_session.query(app.model.Category) - .filter(app.model.Category.table.c.deleted == false()) - .order_by(app.model.Category.table.c.name) - .all() - ) + stmt = select(app.model.Category).where(app.model.Category.deleted == false()).order_by(app.model.Category.name) + return sa_session.scalars(stmt).all() def get_category(app: "ToolShedApp", id: str): """Get a category from the database.""" sa_session = app.model.session - return sa_session.query(app.model.Category).get(app.security.decode_id(id)) + return sa_session.get(app.model.Category, app.security.decode_id(id)) def get_category_by_name(app: "ToolShedApp", name: str): """Get a category from the database via name.""" sa_session = app.model.session - try: - return sa_session.query(app.model.Category).filter_by(name=name).one() - except sqlalchemy.orm.exc.NoResultFound: - return None + stmt = select(app.model.Category).filter_by(name=name).limit(1) + return sa_session.scalars(stmt).first() def get_repository_categories(app, id): """Get categories of a repository on the tool shed side from the database via id""" sa_session = app.model.session - return sa_session.query(app.model.RepositoryCategoryAssociation).filter( - app.model.RepositoryCategoryAssociation.table.c.repository_id == app.security.decode_id(id) + stmt = select(app.model.RepositoryCategoryAssociation).where( + app.model.RepositoryCategoryAssociation.repository_id == app.security.decode_id(id) ) + return sa_session.scalars(stmt).all() def get_repository_file_contents(app, file_path, repository_id, is_admin=False): @@ -336,9 +331,7 @@ def handle_email_alerts(app, host, repository, content_alert_str="", new_repo_al subject = f"Galaxy tool shed alert for new repository named {str(repository.name)}" subject = subject[:80] email_alerts = [] - for user in sa_session.query(app.model.User).filter( - and_(app.model.User.table.c.deleted == false(), app.model.User.table.c.new_repo_alert == true()) - ): + for user in get_users_with_repo_alert(sa_session.query, app.model.User): if admin_only: if user.email in admin_users: email_alerts.append(user.email) @@ -468,3 +461,8 @@ def tool_shed_is_this_tool_shed(toolshed_base_url, trans=None): "set_image_paths", "tool_shed_is_this_tool_shed", ) + + +def get_users_with_repo_alert(session, user_model): + stmt = select(user_model).where(user_model.deleted == false()).where(user_model.new_repo_alert == true()) + return session.scalars(stmt)