Skip to content

Commit

Permalink
Fix SA2.0 usage in tool_shed.util.shed_util_common
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Oct 12, 2023
1 parent a5a7d06 commit 1711e96
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions lib/tool_shed/util/shed_util_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import string
from typing import TYPE_CHECKING

import sqlalchemy.orm.exc
from sqlalchemy import (
and_,
false,
func,
select,
true,
)

Expand Down Expand Up @@ -89,46 +89,41 @@


def count_repositories_in_category(app: "ToolShedApp", category_id: str) -> int:
sa_session = app.model.session
return (
sa_session.query(app.model.RepositoryCategoryAssociation)
.filter(app.model.RepositoryCategoryAssociation.table.c.category_id == app.security.decode_id(category_id))
.count()
stmt = (
select(func.count())
.select_from(app.model.RepositoryCategoryAssociation)
.where(app.model.RepositoryCategoryAssociation.category_id == app.security.decode_id(category_id))
)
return app.model.session.scalar(stmt)


def get_categories(app: "ToolShedApp"):
"""Get all categories from the database."""
sa_session = app.model.session
return (
sa_session.query(app.model.Category)
.filter(app.model.Category.table.c.deleted == false())
.order_by(app.model.Category.table.c.name)
.all()
)
stmt = select(app.model.Category).where(app.model.Category.deleted == false()).order_by(app.model.Category.name)
return sa_session.scalars(stmt).all()


def get_category(app: "ToolShedApp", id: str):
"""Get a category from the database."""
sa_session = app.model.session
return sa_session.query(app.model.Category).get(app.security.decode_id(id))
return sa_session.get(app.model.Category, app.security.decode_id(id))


def get_category_by_name(app: "ToolShedApp", name: str):
"""Get a category from the database via name."""
sa_session = app.model.session
try:
return sa_session.query(app.model.Category).filter_by(name=name).one()
except sqlalchemy.orm.exc.NoResultFound:
return None
stmt = select(app.model.Category).filter_by(name=name).limit(1)
return sa_session.scalars(stmt).first()


def get_repository_categories(app, id):
"""Get categories of a repository on the tool shed side from the database via id"""
sa_session = app.model.session
return sa_session.query(app.model.RepositoryCategoryAssociation).filter(
app.model.RepositoryCategoryAssociation.table.c.repository_id == app.security.decode_id(id)
stmt = select(app.model.RepositoryCategoryAssociation).where(
app.model.RepositoryCategoryAssociation.repository_id == app.security.decode_id(id)
)
return sa_session.scalars(stmt).all()


def get_repository_file_contents(app, file_path, repository_id, is_admin=False):
Expand Down Expand Up @@ -336,9 +331,7 @@ def handle_email_alerts(app, host, repository, content_alert_str="", new_repo_al
subject = f"Galaxy tool shed alert for new repository named {str(repository.name)}"
subject = subject[:80]
email_alerts = []
for user in sa_session.query(app.model.User).filter(
and_(app.model.User.table.c.deleted == false(), app.model.User.table.c.new_repo_alert == true())
):
for user in get_users_with_repo_alert(sa_session.query, app.model.User):
if admin_only:
if user.email in admin_users:
email_alerts.append(user.email)
Expand Down Expand Up @@ -468,3 +461,8 @@ def tool_shed_is_this_tool_shed(toolshed_base_url, trans=None):
"set_image_paths",
"tool_shed_is_this_tool_shed",
)


def get_users_with_repo_alert(session, user_model):
stmt = select(user_model).where(user_model.deleted == false()).where(user_model.new_repo_alert == true())
return session.scalars(stmt)

0 comments on commit 1711e96

Please sign in to comment.