Skip to content

Commit

Permalink
Fix SA2.0 usage in tool_shed.webapp.controllers.user
Browse files Browse the repository at this point in the history
Also, add case_sensitive parameter to shared function in galaxy user manager
  • Loading branch information
jdavcs committed Oct 12, 2023
1 parent 49cafd0 commit c4e1c9b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
7 changes: 5 additions & 2 deletions lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,11 @@ def get_users_by_ids(session: Session, user_ids):
# the tool_shed app, which has its own User model, which is different from
# galaxy.model.User. In that case, the tool_shed user model should be passed as
# the model_class argument.
def get_user_by_email(session, email: str, model_class=User):
stmt = select(model_class).filter(model_class.email == email).limit(1)
def get_user_by_email(session, email: str, model_class=User, case_sensitive=True):
filter_clause = model_class.email == email
if not case_sensitive:
filter_clause = func.lower(model_class.email) == func.lower(email)
stmt = select(model_class).where(filter_clause).limit(1)
return session.scalars(stmt).first()


Expand Down
24 changes: 8 additions & 16 deletions lib/tool_shed/webapp/controllers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import socket

from markupsafe import escape
from sqlalchemy import func

from galaxy import (
util,
web,
)
from galaxy.managers.api_keys import ApiKeyManager
from galaxy.managers.users import get_user_by_email
from galaxy.model.base import transaction
from galaxy.security.validate_user_input import (
validate_email,
Expand Down Expand Up @@ -247,18 +247,10 @@ def reset_password(self, trans, email=None, **kwd):
"Please check your email account for more instructions. "
"If you do not receive an email shortly, please contact an administrator." % (escape(email))
)
reset_user = (
trans.sa_session.query(trans.app.model.User)
.filter(trans.app.model.User.table.c.email == email)
.first()
)
reset_user = get_user_by_email(trans.sa_session, email, trans.app.model.User)
if not reset_user:
# Perform a case-insensitive check only if the user wasn't found
reset_user = (
trans.sa_session.query(trans.app.model.User)
.filter(func.lower(trans.app.model.User.table.c.email) == func.lower(email))
.first()
)
reset_user = get_user_by_email(trans.sa_session, email, trans.app.model.User, False)
if reset_user:
prt = trans.app.model.PasswordResetToken(reset_user)
trans.sa_session.add(prt)
Expand Down Expand Up @@ -291,7 +283,7 @@ def manage_user_info(self, trans, cntrller, **kwd):
params = util.Params(kwd)
user_id = params.get("id", None)
if user_id:
user = trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(user_id))
user = trans.sa_session.get(trans.app.model.User, trans.security.decode_id(user_id))
else:
user = trans.user
if not user:
Expand Down Expand Up @@ -336,7 +328,7 @@ def edit_username(self, trans, cntrller, **kwd):
status = params.get("status", "done")
user_id = params.get("user_id", None)
if user_id and is_admin:
user = trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(user_id))
user = trans.sa_session.get(trans.app.model.User, trans.security.decode_id(user_id))
else:
user = trans.user
if user and params.get("change_username_button", False):
Expand Down Expand Up @@ -371,7 +363,7 @@ def edit_info(self, trans, cntrller, **kwd):
status = params.get("status", "done")
user_id = params.get("user_id", None)
if user_id and is_admin:
user = trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(user_id))
user = trans.sa_session.get(trans.app.model.User, trans.security.decode_id(user_id))
elif user_id and (not trans.user or trans.user.id != trans.security.decode_id(user_id)):
message = "Invalid user id"
status = "error"
Expand Down Expand Up @@ -422,8 +414,8 @@ def edit_info(self, trans, cntrller, **kwd):
# Edit user information - webapp MUST BE 'galaxy'
user_type_fd_id = params.get("user_type_fd_id", "none")
if user_type_fd_id not in ["none"]:
user_type_form_definition = trans.sa_session.query(trans.app.model.FormDefinition).get(
trans.security.decode_id(user_type_fd_id)
user_type_form_definition = trans.sa_session.get(
trans.app.model.FormDefinition, trans.security.decode_id(user_type_fd_id)
)
elif user.values:
user_type_form_definition = user.values.form_definition
Expand Down

0 comments on commit c4e1c9b

Please sign in to comment.