Skip to content

Commit

Permalink
Merge pull request jupyterhub#4315 from minrk/3.x
Browse files Browse the repository at this point in the history
Backport PR jupyterhub#4302: sqlalchemy 2 compatibility
  • Loading branch information
minrk committed Jan 27, 2023
2 parents 43b0897 + 3dccb5d commit 193ebc9
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 50 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ env:
# UTF-8 content may be interpreted as ascii and causes errors without this.
LANG: C.UTF-8
PYTEST_ADDOPTS: "--verbose --color=yes"
SQLALCHEMY_WARN_20: "1"

permissions:
contents: read
Expand Down Expand Up @@ -140,7 +141,7 @@ jobs:
- name: Install Python dependencies
run: |
pip install --upgrade pip
pip install ".[test]"
pip install -e ".[test]"
if [ "${{ matrix.oldest_dependencies }}" != "" ]; then
# take any dependencies in requirements.txt such as tornado>=5.0
Expand All @@ -152,6 +153,7 @@ jobs:
if [ "${{ matrix.main_dependencies }}" != "" ]; then
pip install git+https://github.com/ipython/traitlets#egg=traitlets --force
pip install --upgrade --pre sqlalchemy
fi
if [ "${{ matrix.legacy_notebook }}" != "" ]; then
pip uninstall jupyter_server --yes
Expand Down
6 changes: 4 additions & 2 deletions jupyterhub/apihandlers/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ async def post(self):
# create the group
self.log.info("Creating new group %s with %i users", name, len(users))
self.log.debug("Users: %s", usernames)
group = orm.Group(name=name, users=users)
group = orm.Group(name=name)
self.db.add(group)
group.users = users
self.db.commit()
created.append(group)
self.write(json.dumps([self.group_model(group) for group in created]))
Expand Down Expand Up @@ -131,8 +132,9 @@ async def post(self, group_name):
# create the group
self.log.info("Creating new group %s with %i users", group_name, len(users))
self.log.debug("Users: %s", usernames)
group = orm.Group(name=group_name, users=users)
group = orm.Group(name=group_name)
self.db.add(group)
group.users = users
self.db.commit()
self.write(json.dumps(self.group_model(group)))
self.set_status(201)
Expand Down
4 changes: 2 additions & 2 deletions jupyterhub/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,9 +1948,9 @@ async def init_users(self):
user = orm.User.find(db, name)
if user is None:
user = orm.User(name=name, admin=True)
db.add(user)
roles.assign_default_roles(self.db, entity=user)
new_users.append(user)
db.add(user)
else:
user.admin = True
# the admin_users config variable will never be used after this point.
Expand Down Expand Up @@ -2343,6 +2343,7 @@ def init_services(self):
if orm_service is None:
# not found, create a new one
orm_service = orm.Service(name=name)
self.db.add(orm_service)
if spec.get('admin', False):
self.log.warning(
f"Service {name} sets `admin: True`, which is deprecated in JupyterHub 2.0."
Expand All @@ -2351,7 +2352,6 @@ def init_services(self):
"the Service admin flag will be ignored."
)
roles.update_roles(self.db, entity=orm_service, roles=['admin'])
self.db.add(orm_service)
orm_service.admin = spec.get('admin', False)
self.db.commit()
service = Service(
Expand Down
4 changes: 2 additions & 2 deletions jupyterhub/oauth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,16 @@ def save_authorization_code(self, client_id, code, request, *args, **kwargs):
raise ValueError("No such client: %s" % client_id)

orm_code = orm.OAuthCode(
client=orm_client,
code=code['code'],
# oauth has 5 minutes to complete
expires_at=int(orm.OAuthCode.now() + 300),
scopes=list(request.scopes),
user=request.user.orm_user,
redirect_uri=orm_client.redirect_uri,
session_id=request.session_id,
)
self.db.add(orm_code)
orm_code.client = orm_client
orm_code.user = request.user.orm_user
self.db.commit()

def get_authorization_code_scopes(self, client_id, code, redirect_uri, request):
Expand Down
60 changes: 43 additions & 17 deletions jupyterhub/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import alembic.command
import alembic.config
import sqlalchemy
from alembic.script import ScriptDirectory
from packaging.version import parse as parse_version
from sqlalchemy import (
Boolean,
Column,
Expand All @@ -24,8 +26,8 @@
inspect,
or_,
select,
text,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import (
Session,
backref,
Expand All @@ -34,6 +36,13 @@
relationship,
sessionmaker,
)

try:
from sqlalchemy.orm import declarative_base
except ImportError:
# sqlalchemy < 1.4
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy.pool import StaticPool
from sqlalchemy.types import LargeBinary, Text, TypeDecorator
from tornado.log import app_log
Expand Down Expand Up @@ -749,6 +758,7 @@ def new(
session_id=session_id,
scopes=list(scopes),
)
db.add(orm_token)
orm_token.token = token
if user:
assert user.id is not None
Expand All @@ -759,7 +769,6 @@ def new(
if expires_in is not None:
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)

db.add(orm_token)
db.commit()
return token

Expand Down Expand Up @@ -901,7 +910,7 @@ def register_ping_connection(engine):
"""

@event.listens_for(engine, "engine_connect")
def ping_connection(connection, branch):
def ping_connection(connection, branch=None):
if branch:
# "branch" refers to a sub-connection of a connection,
# we don't want to bother pinging on these.
Expand All @@ -912,11 +921,17 @@ def ping_connection(connection, branch):
save_should_close_with_result = connection.should_close_with_result
connection.should_close_with_result = False

if parse_version(sqlalchemy.__version__) < parse_version("1.4"):
one = [1]
else:
one = 1

try:
# run a SELECT 1. use a core select() so that
# run a SELECT 1. use a core select() so that
# the SELECT of a scalar value without a table is
# appropriately formatted for the backend
connection.scalar(select([1]))
with connection.begin() as transaction:
connection.scalar(select(one))
except exc.DBAPIError as err:
# catch SQLAlchemy's DBAPIError, which is a wrapper
# for the DBAPI's exception. It includes a .connection_invalidated
Expand All @@ -931,7 +946,8 @@ def ping_connection(connection, branch):
# itself and establish a new connection. The disconnect detection
# here also causes the whole connection pool to be invalidated
# so that all stale connections are discarded.
connection.scalar(select([1]))
with connection.begin() as transaction:
connection.scalar(select(one))
else:
raise
finally:
Expand All @@ -955,7 +971,13 @@ def check_db_revision(engine):

from .dbutil import _temp_alembic_ini

with _temp_alembic_ini(engine.url) as ini:
if hasattr(engine.url, "render_as_string"):
# sqlalchemy >= 1.4
engine_url = engine.url.render_as_string(hide_password=False)
else:
engine_url = str(engine.url)

with _temp_alembic_ini(engine_url) as ini:
cfg = alembic.config.Config(ini)
scripts = ScriptDirectory.from_config(cfg)
head = scripts.get_heads()[0]
Expand Down Expand Up @@ -990,9 +1012,10 @@ def check_db_revision(engine):

# check database schema version
# it should always be defined at this point
alembic_revision = engine.execute(
'SELECT version_num FROM alembic_version'
).first()[0]
with engine.begin() as connection:
alembic_revision = connection.execute(
text('SELECT version_num FROM alembic_version')
).first()[0]
if alembic_revision == head:
app_log.debug("database schema version found: %s", alembic_revision)
else:
Expand All @@ -1009,13 +1032,16 @@ def mysql_large_prefix_check(engine):
"""Check mysql has innodb_large_prefix set"""
if not str(engine.url).startswith('mysql'):
return False
variables = dict(
engine.execute(
'show variables where variable_name like '
'"innodb_large_prefix" or '
'variable_name like "innodb_file_format";'
).fetchall()
)
with engine.begin() as connection:
variables = dict(
connection.execute(
text(
'show variables where variable_name like '
'"innodb_large_prefix" or '
'variable_name like "innodb_file_format";'
)
).fetchall()
)
if (
variables.get('innodb_file_format', 'Barracuda') == 'Barracuda'
and variables.get('innodb_large_prefix', 'ON') == 'ON'
Expand Down
9 changes: 5 additions & 4 deletions jupyterhub/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,12 @@ async def test_get_self(app):
db.add(oauth_client)
db.commit()
oauth_token = orm.APIToken(
user=u.orm_user,
oauth_client=oauth_client,
token=token,
)
db.add(oauth_token)
oauth_token.user = u.orm_user
oauth_token.oauth_client = oauth_client

db.commit()
r = await api_request(
app,
Expand Down Expand Up @@ -2117,13 +2118,13 @@ async def shutdown():

def stop():
stop.called = True
loop.call_later(1, real_stop)
loop.call_later(2, real_stop)

real_cleanup = app.cleanup

def cleanup():
cleanup.called = True
return real_cleanup()
loop.call_later(1, real_cleanup)

app.cleanup = cleanup

Expand Down
34 changes: 20 additions & 14 deletions jupyterhub/tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ def test_spawner_delete_cascade(db):
db.add(user)
db.commit()

spawner = orm.Spawner(user=user)
spawner = orm.Spawner()
db.add(spawner)
spawner.user = user
db.commit()
spawner.server = server = orm.Server()
db.commit()
Expand All @@ -350,16 +352,19 @@ def test_user_delete_cascade(db):
# these should all be deleted automatically when the user goes away
user.new_api_token()
api_token = user.api_tokens[0]
spawner = orm.Spawner(user=user)
spawner = orm.Spawner()
db.add(spawner)
spawner.user = user
db.commit()
spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
oauth_code = orm.OAuthCode()
db.add(oauth_code)
oauth_token = orm.APIToken(
oauth_client=oauth_client,
user=user,
)
oauth_code.client = oauth_client
oauth_code.user = user
oauth_token = orm.APIToken()
db.add(oauth_token)
oauth_token.oauth_client = oauth_client
oauth_token.user = user
db.commit()

# record all of the ids
Expand Down Expand Up @@ -390,13 +395,14 @@ def test_oauth_client_delete_cascade(db):

# create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
oauth_code = orm.OAuthCode()
db.add(oauth_code)
oauth_token = orm.APIToken(
oauth_client=oauth_client,
user=user,
)
oauth_code.client = oauth_client
oauth_code.user = user
oauth_token = orm.APIToken()
db.add(oauth_token)
oauth_token.oauth_client = oauth_client
oauth_token.user = user
db.commit()
assert user.api_tokens == [oauth_token]

Expand Down Expand Up @@ -517,11 +523,11 @@ def test_expiring_oauth_token(app, user):
db.add(client)
orm_token = orm.APIToken(
token=token,
oauth_client=client,
user=user,
expires_at=now() + timedelta(seconds=30),
)
db.add(orm_token)
orm_token.oauth_client = client
orm_token.user = user
db.commit()

found = orm.APIToken.find(db, token)
Expand Down
7 changes: 3 additions & 4 deletions jupyterhub/tests/test_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,11 +1045,10 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token')
app.db.add(client)
oauth_token = orm.APIToken(
oauth_client=client,
user=user,
)
oauth_token = orm.APIToken()
app.db.add(oauth_token)
oauth_token.oauth_client = client
oauth_token.user = user
app.db.commit()
r = await get_page('token', app, cookies=cookies)
r.raise_for_status()
Expand Down
8 changes: 8 additions & 0 deletions jupyterhub/tests/test_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# Distributed under the terms of the Modified BSD License.
import json
import os
import warnings

import pytest
from pytest import mark
from sqlalchemy.exc import SADeprecationWarning
from tornado.log import app_log

from .. import orm, roles
Expand Down Expand Up @@ -343,7 +345,13 @@ async def test_creating_roles(app, role, role_def, response_type, response):
# make sure no warnings/info logged when the role exists and its definition hasn't been changed
elif response_type == 'no-log':
with pytest.warns(response) as record:
# don't catch already-suppressed sqlalchemy warnings
warnings.simplefilter("ignore", SADeprecationWarning)
roles.create_role(db, role_def)

for warning in record.list:
# show warnings for debugging
print("Unexpected warning", warning)
assert not record.list
role = orm.Role.find(db, role_def['name'])
assert role is not None
Expand Down

0 comments on commit 193ebc9

Please sign in to comment.