Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database schema update and database migration logic #520

Merged
merged 12 commits into from
May 28, 2024
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker

from jupyter_scheduler.orm import Base
from jupyter_scheduler.scheduler import Scheduler
Expand Down
29 changes: 26 additions & 3 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import os
from sqlite3 import OperationalError
from uuid import uuid4

import sqlalchemy.types as types
from sqlalchemy import Boolean, Column, Integer, String, create_engine
from sqlalchemy import Boolean, Column, Integer, String, create_engine, inspect
from sqlalchemy.orm import declarative_base, declarative_mixin, registry, sessionmaker
from sqlalchemy.sql import text

from jupyter_scheduler.models import EmailNotifications, Status
from jupyter_scheduler.utils import get_utc_timestamp
Expand Down Expand Up @@ -91,6 +91,7 @@ class CommonColumns:

class Job(CommonColumns, Base):
__tablename__ = "jobs"
__table_args__ = {"extend_existing": True}
job_id = Column(String(36), primary_key=True, default=generate_uuid)
job_definition_id = Column(String(36))
status = Column(String(64), default=Status.STOPPED)
Expand All @@ -104,6 +105,7 @@ class Job(CommonColumns, Base):

class JobDefinition(CommonColumns, Base):
__tablename__ = "job_definitions"
__table_args__ = {"extend_existing": True}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do some testing to understand why this is necessary if we are defining update_db_schema(). Essentially, this argument is needed to extend the table metadata (e.g. column names & types) associated with the table. If this line is omitted in test_orm.py, the test fails with the error:

sqlalchemy.exc.InvalidRequestError: Table 'jobs' is already defined for this MetaData instance.  Specify 'extend_existing=True' to redefine options and columns on an existing Table object.

So both the table migration (defined in update_db_schema()) and table metadata migration (defined here by setting __table_args__ = {"extend_existing": True}) must be performed when a new column is added. 😵

No action needed. I just felt it was necessary to call this out to other readers in this review.

job_definition_id = Column(String(36), primary_key=True, default=generate_uuid)
schedule = Column(String(256))
timezone = Column(String(36))
Expand All @@ -112,8 +114,29 @@ class JobDefinition(CommonColumns, Base):
active = Column(Boolean, default=True)


def create_tables(db_url, drop_tables=False):
def update_db_schema(engine, Base):
inspector = inspect(engine)

with engine.connect() as connection:
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
for table_name, model in Base.metadata.tables.items():
if inspector.has_table(table_name):
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
columns_db = inspector.get_columns(table_name)
columns_db_names = {col["name"] for col in columns_db}

for column_model_name, column_model in model.c.items():
if column_model_name not in columns_db_names:
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
column_type = str(column_model.type.compile(dialect=engine.dialect))
nullable = "NULL" if column_model.nullable else "NOT NULL"
alter_statement = text(
f"ALTER TABLE {table_name} ADD COLUMN {column_model_name} {column_type} {nullable}"
)
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
connection.execute(alter_statement)


def create_tables(db_url, drop_tables=False, Base=Base):
engine = create_engine(db_url)
update_db_schema(engine, Base)

try:
if drop_tables:
Base.metadata.drop_all(engine)
Expand Down
73 changes: 73 additions & 0 deletions jupyter_scheduler/tests/test_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Type

import pytest
from sqlalchemy import Column, Integer, String, inspect
from sqlalchemy.orm import DeclarativeMeta, sessionmaker

from jupyter_scheduler.orm import (
create_session,
create_tables,
declarative_base,
generate_uuid,
)


@pytest.fixture
def initial_db(jp_scheduler_db_url) -> tuple[Type[DeclarativeMeta], sessionmaker, str]:
TestBase = declarative_base()

class InitialJob(TestBase):
__tablename__ = "jobs"
job_id = Column(String(36), primary_key=True, default=generate_uuid)
runtime_environment_name = Column(String(256), nullable=False)
input_filename = Column(String(256), nullable=False)

initial_job = InitialJob(runtime_environment_name="abc", input_filename="input.ipynb")

create_tables(db_url=jp_scheduler_db_url, Base=TestBase)

Session = create_session(jp_scheduler_db_url)
session = Session()

session.add(initial_job)
session.commit()
job_id = initial_job.job_id
session.close()

return (TestBase, Session, job_id)


@pytest.fixture
def updated_job_model(initial_db) -> Type[DeclarativeMeta]:
Base = initial_db[0]

class UpdatedJob(Base):
__tablename__ = "jobs"
__table_args__ = {"extend_existing": True}
job_id = Column(String(36), primary_key=True, default=generate_uuid)
runtime_environment_name = Column(String(256), nullable=False)
input_filename = Column(String(256), nullable=False)
new_column = Column("new_column", Integer)

return UpdatedJob


def test_create_tables_with_new_column(jp_scheduler_db_url, initial_db, updated_job_model):
Base, Session, initial_job_id = initial_db

session = Session()
initial_columns = {col["name"] for col in inspect(session.bind).get_columns("jobs")}
assert "new_column" not in initial_columns
session.close()

JobModel = updated_job_model
create_tables(db_url=jp_scheduler_db_url, Base=Base)

session = Session()
updated_columns = {col["name"] for col in inspect(session.bind).get_columns("jobs")}
assert "new_column" in updated_columns

updated_job = session.query(JobModel).filter(JobModel.job_id == initial_job_id).one()
assert hasattr(updated_job, "new_column")
assert updated_job.runtime_environment_name == "abc"
assert updated_job.input_filename == "input.ipynb"
Loading