diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 78bb9b8f0d..864f251157 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,7 @@ env: # UTF-8 content may be interpreted as ascii and causes errors without this. LANG: C.UTF-8 PYTEST_ADDOPTS: "--verbose --color=yes" + SQLALCHEMY_WARN_20: "1" permissions: contents: read @@ -140,7 +141,7 @@ jobs: - name: Install Python dependencies run: | pip install --upgrade pip - pip install ".[test]" + pip install -e ".[test]" if [ "${{ matrix.oldest_dependencies }}" != "" ]; then # take any dependencies in requirements.txt such as tornado>=5.0 @@ -152,6 +153,7 @@ jobs: if [ "${{ matrix.main_dependencies }}" != "" ]; then pip install git+https://github.com/ipython/traitlets#egg=traitlets --force + pip install --upgrade --pre sqlalchemy fi if [ "${{ matrix.legacy_notebook }}" != "" ]; then pip uninstall jupyter_server --yes diff --git a/jupyterhub/apihandlers/groups.py b/jupyterhub/apihandlers/groups.py index c5799f1587..f8e776ea34 100644 --- a/jupyterhub/apihandlers/groups.py +++ b/jupyterhub/apihandlers/groups.py @@ -94,8 +94,9 @@ async def post(self): # create the group self.log.info("Creating new group %s with %i users", name, len(users)) self.log.debug("Users: %s", usernames) - group = orm.Group(name=name, users=users) + group = orm.Group(name=name) self.db.add(group) + group.users = users self.db.commit() created.append(group) self.write(json.dumps([self.group_model(group) for group in created])) @@ -131,8 +132,9 @@ async def post(self, group_name): # create the group self.log.info("Creating new group %s with %i users", group_name, len(users)) self.log.debug("Users: %s", usernames) - group = orm.Group(name=group_name, users=users) + group = orm.Group(name=group_name) self.db.add(group) + group.users = users self.db.commit() self.write(json.dumps(self.group_model(group))) self.set_status(201) diff --git a/jupyterhub/app.py b/jupyterhub/app.py index c967b84a8b..9c74fc0614 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -1948,9 +1948,9 @@ async def init_users(self): user = orm.User.find(db, name) if user is None: user = orm.User(name=name, admin=True) + db.add(user) roles.assign_default_roles(self.db, entity=user) new_users.append(user) - db.add(user) else: user.admin = True # the admin_users config variable will never be used after this point. @@ -2343,6 +2343,7 @@ def init_services(self): if orm_service is None: # not found, create a new one orm_service = orm.Service(name=name) + self.db.add(orm_service) if spec.get('admin', False): self.log.warning( f"Service {name} sets `admin: True`, which is deprecated in JupyterHub 2.0." @@ -2351,7 +2352,6 @@ def init_services(self): "the Service admin flag will be ignored." ) roles.update_roles(self.db, entity=orm_service, roles=['admin']) - self.db.add(orm_service) orm_service.admin = spec.get('admin', False) self.db.commit() service = Service( diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index bb72e3383c..9f45a07ec8 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -257,16 +257,16 @@ def save_authorization_code(self, client_id, code, request, *args, **kwargs): raise ValueError("No such client: %s" % client_id) orm_code = orm.OAuthCode( - client=orm_client, code=code['code'], # oauth has 5 minutes to complete expires_at=int(orm.OAuthCode.now() + 300), scopes=list(request.scopes), - user=request.user.orm_user, redirect_uri=orm_client.redirect_uri, session_id=request.session_id, ) self.db.add(orm_code) + orm_code.client = orm_client + orm_code.user = request.user.orm_user self.db.commit() def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 0aa1e1ca4d..02f1942a9b 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -8,7 +8,9 @@ import alembic.command import alembic.config +import sqlalchemy from alembic.script import ScriptDirectory +from packaging.version import parse as parse_version from sqlalchemy import ( Boolean, Column, @@ -24,8 +26,8 @@ inspect, or_, select, + text, ) -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import ( Session, backref, @@ -34,6 +36,13 @@ relationship, sessionmaker, ) + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + # sqlalchemy < 1.4 + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.pool import StaticPool from sqlalchemy.types import LargeBinary, Text, TypeDecorator from tornado.log import app_log @@ -749,6 +758,7 @@ def new( session_id=session_id, scopes=list(scopes), ) + db.add(orm_token) orm_token.token = token if user: assert user.id is not None @@ -759,7 +769,6 @@ def new( if expires_in is not None: orm_token.expires_at = cls.now() + timedelta(seconds=expires_in) - db.add(orm_token) db.commit() return token @@ -901,7 +910,7 @@ def register_ping_connection(engine): """ @event.listens_for(engine, "engine_connect") - def ping_connection(connection, branch): + def ping_connection(connection, branch=None): if branch: # "branch" refers to a sub-connection of a connection, # we don't want to bother pinging on these. @@ -912,11 +921,17 @@ def ping_connection(connection, branch): save_should_close_with_result = connection.should_close_with_result connection.should_close_with_result = False + if parse_version(sqlalchemy.__version__) < parse_version("1.4"): + one = [1] + else: + one = 1 + try: - # run a SELECT 1. use a core select() so that + # run a SELECT 1. use a core select() so that # the SELECT of a scalar value without a table is # appropriately formatted for the backend - connection.scalar(select([1])) + with connection.begin() as transaction: + connection.scalar(select(one)) except exc.DBAPIError as err: # catch SQLAlchemy's DBAPIError, which is a wrapper # for the DBAPI's exception. It includes a .connection_invalidated @@ -931,7 +946,8 @@ def ping_connection(connection, branch): # itself and establish a new connection. The disconnect detection # here also causes the whole connection pool to be invalidated # so that all stale connections are discarded. - connection.scalar(select([1])) + with connection.begin() as transaction: + connection.scalar(select(one)) else: raise finally: @@ -955,7 +971,13 @@ def check_db_revision(engine): from .dbutil import _temp_alembic_ini - with _temp_alembic_ini(engine.url) as ini: + if hasattr(engine.url, "render_as_string"): + # sqlalchemy >= 1.4 + engine_url = engine.url.render_as_string(hide_password=False) + else: + engine_url = str(engine.url) + + with _temp_alembic_ini(engine_url) as ini: cfg = alembic.config.Config(ini) scripts = ScriptDirectory.from_config(cfg) head = scripts.get_heads()[0] @@ -990,9 +1012,10 @@ def check_db_revision(engine): # check database schema version # it should always be defined at this point - alembic_revision = engine.execute( - 'SELECT version_num FROM alembic_version' - ).first()[0] + with engine.begin() as connection: + alembic_revision = connection.execute( + text('SELECT version_num FROM alembic_version') + ).first()[0] if alembic_revision == head: app_log.debug("database schema version found: %s", alembic_revision) else: @@ -1009,13 +1032,16 @@ def mysql_large_prefix_check(engine): """Check mysql has innodb_large_prefix set""" if not str(engine.url).startswith('mysql'): return False - variables = dict( - engine.execute( - 'show variables where variable_name like ' - '"innodb_large_prefix" or ' - 'variable_name like "innodb_file_format";' - ).fetchall() - ) + with engine.begin() as connection: + variables = dict( + connection.execute( + text( + 'show variables where variable_name like ' + '"innodb_large_prefix" or ' + 'variable_name like "innodb_file_format";' + ) + ).fetchall() + ) if ( variables.get('innodb_file_format', 'Barracuda') == 'Barracuda' and variables.get('innodb_large_prefix', 'ON') == 'ON' diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 2bd8c811a6..528edb1a45 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -521,11 +521,12 @@ async def test_get_self(app): db.add(oauth_client) db.commit() oauth_token = orm.APIToken( - user=u.orm_user, - oauth_client=oauth_client, token=token, ) db.add(oauth_token) + oauth_token.user = u.orm_user + oauth_token.oauth_client = oauth_client + db.commit() r = await api_request( app, @@ -2117,13 +2118,13 @@ async def shutdown(): def stop(): stop.called = True - loop.call_later(1, real_stop) + loop.call_later(2, real_stop) real_cleanup = app.cleanup def cleanup(): cleanup.called = True - return real_cleanup() + loop.call_later(1, real_cleanup) app.cleanup = cleanup diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 3009a11d61..74deff3bce 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -323,7 +323,9 @@ def test_spawner_delete_cascade(db): db.add(user) db.commit() - spawner = orm.Spawner(user=user) + spawner = orm.Spawner() + db.add(spawner) + spawner.user = user db.commit() spawner.server = server = orm.Server() db.commit() @@ -350,16 +352,19 @@ def test_user_delete_cascade(db): # these should all be deleted automatically when the user goes away user.new_api_token() api_token = user.api_tokens[0] - spawner = orm.Spawner(user=user) + spawner = orm.Spawner() + db.add(spawner) + spawner.user = user db.commit() spawner.server = server = orm.Server() - oauth_code = orm.OAuthCode(client=oauth_client, user=user) + oauth_code = orm.OAuthCode() db.add(oauth_code) - oauth_token = orm.APIToken( - oauth_client=oauth_client, - user=user, - ) + oauth_code.client = oauth_client + oauth_code.user = user + oauth_token = orm.APIToken() db.add(oauth_token) + oauth_token.oauth_client = oauth_client + oauth_token.user = user db.commit() # record all of the ids @@ -390,13 +395,14 @@ def test_oauth_client_delete_cascade(db): # create a bunch of objects that reference the User # these should all be deleted automatically when the user goes away - oauth_code = orm.OAuthCode(client=oauth_client, user=user) + oauth_code = orm.OAuthCode() db.add(oauth_code) - oauth_token = orm.APIToken( - oauth_client=oauth_client, - user=user, - ) + oauth_code.client = oauth_client + oauth_code.user = user + oauth_token = orm.APIToken() db.add(oauth_token) + oauth_token.oauth_client = oauth_client + oauth_token.user = user db.commit() assert user.api_tokens == [oauth_token] @@ -517,11 +523,11 @@ def test_expiring_oauth_token(app, user): db.add(client) orm_token = orm.APIToken( token=token, - oauth_client=client, - user=user, expires_at=now() + timedelta(seconds=30), ) db.add(orm_token) + orm_token.oauth_client = client + orm_token.user = user db.commit() found = orm.APIToken.find(db, token) diff --git a/jupyterhub/tests/test_pages.py b/jupyterhub/tests/test_pages.py index 424bb7743f..e8db7bca58 100644 --- a/jupyterhub/tests/test_pages.py +++ b/jupyterhub/tests/test_pages.py @@ -1045,11 +1045,10 @@ async def test_oauth_token_page(app): user = app.users[orm.User.find(app.db, name)] client = orm.OAuthClient(identifier='token') app.db.add(client) - oauth_token = orm.APIToken( - oauth_client=client, - user=user, - ) + oauth_token = orm.APIToken() app.db.add(oauth_token) + oauth_token.oauth_client = client + oauth_token.user = user app.db.commit() r = await get_page('token', app, cookies=cookies) r.raise_for_status() diff --git a/jupyterhub/tests/test_roles.py b/jupyterhub/tests/test_roles.py index 6dd5bdfeff..cb5f3f534c 100644 --- a/jupyterhub/tests/test_roles.py +++ b/jupyterhub/tests/test_roles.py @@ -3,9 +3,11 @@ # Distributed under the terms of the Modified BSD License. import json import os +import warnings import pytest from pytest import mark +from sqlalchemy.exc import SADeprecationWarning from tornado.log import app_log from .. import orm, roles @@ -343,7 +345,13 @@ async def test_creating_roles(app, role, role_def, response_type, response): # make sure no warnings/info logged when the role exists and its definition hasn't been changed elif response_type == 'no-log': with pytest.warns(response) as record: + # don't catch already-suppressed sqlalchemy warnings + warnings.simplefilter("ignore", SADeprecationWarning) roles.create_role(db, role_def) + + for warning in record.list: + # show warnings for debugging + print("Unexpected warning", warning) assert not record.list role = orm.Role.find(db, role_def['name']) assert role is not None diff --git a/jupyterhub/tests/utils.py b/jupyterhub/tests/utils.py index 2e535d4b32..365a07b593 100644 --- a/jupyterhub/tests/utils.py +++ b/jupyterhub/tests/utils.py @@ -6,12 +6,27 @@ import pytest import requests from certipy import Certipy +from sqlalchemy import text from jupyterhub import metrics, orm from jupyterhub.objects import Server from jupyterhub.roles import assign_default_roles, update_roles from jupyterhub.utils import url_path_join as ujoin +try: + from sqlalchemy.exc import RemovedIn20Warning +except ImportError: + + class RemovedIn20Warning(DeprecationWarning): + """ + I only exist so I can be used in warnings filters in pytest.ini + + I will never be displayed. + + sqlalchemy 1.4 introduces RemovedIn20Warning, + but we still test against older sqlalchemy. + """ + class _AsyncRequests: """Wrapper around requests to return a Future from request methods @@ -84,8 +99,8 @@ def new_func(app, *args, **kwargs): def _check(_=None): temp_session = app.session_factory() try: - temp_session.execute('CREATE TABLE dummy (foo INT)') - temp_session.execute('DROP TABLE dummy') + temp_session.execute(text('CREATE TABLE dummy (foo INT)')) + temp_session.execute(text('DROP TABLE dummy')) finally: temp_session.close() diff --git a/jupyterhub/user.py b/jupyterhub/user.py index 19e8011a0d..b50ba0384d 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -416,9 +416,10 @@ def all_spawners(self, include_default=True): yield orm_spawner def _new_orm_spawner(self, server_name): - """Creat the low-level orm Spawner object""" - orm_spawner = orm.Spawner(user=self.orm_user, name=server_name) + """Create the low-level orm Spawner object""" + orm_spawner = orm.Spawner(name=server_name) self.db.add(orm_spawner) + orm_spawner.user = self.orm_user self.db.commit() assert server_name in self.orm_spawners return orm_spawner diff --git a/pytest.ini b/pytest.ini index 82c19dcd68..1a57882e46 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,3 +18,7 @@ markers = slow: mark a test as slow role: mark as a test for roles selenium: web tests that run with selenium + +filterwarnings = + error:.*:jupyterhub.tests.utils.RemovedIn20Warning + ignore:.*event listener has changed as of version 2.0.*:sqlalchemy.exc.SADeprecationWarning