diff --git a/augur/api/routes/user.py b/augur/api/routes/user.py index 1d7c689166..3362d61211 100644 --- a/augur/api/routes/user.py +++ b/augur/api/routes/user.py @@ -139,25 +139,26 @@ def generate_session(application): if not username: return jsonify({"status": "Invalid authorization code"}) - user = User.get_user(username) - if not user: - return jsonify({"status": "Invalid user"}) + with DatabaseSession(logger) as session: - seconds_to_expire = 86400 + user = User.get_user(session, username) + if not user: + return jsonify({"status": "Invalid user"}) - with DatabaseSession(logger) as session: + seconds_to_expire = 86400 existing_session = session.query(UserSessionToken).filter(UserSessionToken.user_id == user.user_id, UserSessionToken.application_id == application.id).first() if existing_session: existing_session.delete_refresh_tokens(session) - + session.delete(existing_session) + session.commit() - user_session_token = UserSessionToken.create(user.user_id, application.id, seconds_to_expire).token - refresh_token = RefreshToken.create(user_session_token) + user_session_token = UserSessionToken.create(session, user.user_id, application.id, seconds_to_expire).token + refresh_token = RefreshToken.create(session, user_session_token) - response = jsonify({"status": "Validated", "username": username, "access_token": user_session_token, "refresh_token" : refresh_token.id, "token_type": "Bearer", "expires": seconds_to_expire}) - response.headers["Cache-Control"] = "no-store" + response = jsonify({"status": "Validated", "username": username, "access_token": user_session_token, "refresh_token" : refresh_token.id, "token_type": "Bearer", "expires": seconds_to_expire}) + response.headers["Cache-Control"] = "no-store" return response @@ -172,25 +173,26 @@ def refresh_session(application): if request.args.get("grant_type") != "refresh_token": return jsonify({"status": "Invalid grant type"}) - session = Session() - refresh_token = session.query(RefreshToken).filter(RefreshToken.id == refresh_token_str).first() - if not refresh_token: - return jsonify({"status": "Invalid refresh token"}) + with DatabaseSession(logger) as session: - if refresh_token.user_session.application == application: - return jsonify({"status": "Applications do not match"}) + refresh_token = session.query(RefreshToken).filter(RefreshToken.id == refresh_token_str).first() + if not refresh_token: + return jsonify({"status": "Invalid refresh token"}) - user_session = refresh_token.user_session - user = user_session.user + if refresh_token.user_session.application == application: + return jsonify({"status": "Applications do not match"}) - new_user_session = UserSessionToken.create(user.user_id, user_session.application.id) - new_refresh_token = RefreshToken.create(new_user_session.token) - - session.delete(refresh_token) - session.delete(user_session) - session.commit() + user_session = refresh_token.user_session + user = user_session.user + + new_user_session_token = UserSessionToken.create(session, user.user_id, user_session.application.id).token + new_refresh_token_id = RefreshToken.create(session, new_user_session_token).id + + session.delete(refresh_token) + session.delete(user_session) + session.commit() - return jsonify({"status": "Validated", "refresh_token": new_refresh_token.id, "access_token": new_user_session.token, "expires": 86400}) + return jsonify({"status": "Validated", "refresh_token": new_refresh_token_id, "access_token": new_user_session_token, "expires": 86400}) @server.app.route(f"/{AUGUR_API_VERSION}/user/query", methods=['POST']) diff --git a/augur/api/view/api.py b/augur/api/view/api.py index 721c8164ef..f31fbd1057 100644 --- a/augur/api/view/api.py +++ b/augur/api/view/api.py @@ -1,6 +1,7 @@ from flask import Flask, render_template, render_template_string, request, abort, jsonify, redirect, url_for, session, flash from flask_login import current_user, login_required from augur.application.db.models import Repo +from augur.application.db.session import DatabaseSession # from augur.util.repo_load_controller import parse_org_url, parse_repo_url from .utils import * @@ -88,8 +89,10 @@ def user_remove_repo(): repo = int(repo) + with DatabaseSession(logger) as session: + result = current_user.remove_repo(session, group, repo)[0] - if current_user.remove_repo(group, repo)[0]: + if result: flash(f"Successfully removed repo {repo} from group {group}") else: flash("An error occurred removing repo from group") diff --git a/augur/api/view/augur_view.py b/augur/api/view/augur_view.py index a4ede35e30..3afa315a11 100644 --- a/augur/api/view/augur_view.py +++ b/augur/api/view/augur_view.py @@ -64,8 +64,15 @@ def load_user(user_id): user = User.get_user(db_session, user_id) groups = user.groups + tokens = user.tokens + applications = user.applications + for application in applications: + sessions = application.sessions for group in groups: repos = group.repos + for token in tokens: + application = token.application + db_session.expunge(user) if not user: return None diff --git a/augur/api/view/routes.py b/augur/api/view/routes.py index 0a8caa00cb..78a461a1d6 100644 --- a/augur/api/view/routes.py +++ b/augur/api/view/routes.py @@ -216,9 +216,11 @@ def authorize_user(): if not client_id or response_type != "code": return render_message("Invalid Request", "Something went wrong. You may need to return to the previous application and make the request again.") + + with DatabaseSession(logger) as session: - # TODO get application from client id - client = ClientApplication.get_by_id(client_id) + # TODO get application from client id + client = ClientApplication.get_by_id(session, client_id) return render_module("authorization", app = client, state = state) diff --git a/augur/application/db/models/augur_operations.py b/augur/application/db/models/augur_operations.py index c11b5ac66a..3f3b8566f7 100644 --- a/augur/application/db/models/augur_operations.py +++ b/augur/application/db/models/augur_operations.py @@ -839,9 +839,6 @@ def delete_refresh_tokens(self, session): session.delete(token) session.commit() - session.delete(self) - session.commit() - class ClientApplication(Base): __tablename__ = "client_applications" __table_args__ = (