Skip to content

Commit

Permalink
Merge pull request #2952 from centerofci/constraint_create_sql
Browse files Browse the repository at this point in the history
Move constraint creation to SQL
  • Loading branch information
silentninja committed Jun 23, 2023
2 parents c5a7790 + cd73063 commit 9ccbb35
Show file tree
Hide file tree
Showing 9 changed files with 823 additions and 271 deletions.
90 changes: 43 additions & 47 deletions db/constraints/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
"""TODO: This needs to be consolidated with db.constraints.operations.create."""

from alembic.migration import MigrationContext
from alembic.operations import Operations
from sqlalchemy import MetaData
from abc import ABC, abstractmethod
import json

from db.columns.operations.select import get_column_names_from_attnums
from db.constraints.utils import naming_convention
from db.tables.operations.select import reflect_table_from_oid
from db.metadata import get_empty_metadata
from db.connection import execute_msar_func_with_engine
from db.constraints.utils import (
get_constraint_match_char_from_type, get_constraint_char_from_action
)


class Constraint(ABC):
Expand Down Expand Up @@ -41,35 +40,29 @@ def __init__(
self.options = options

def add_constraint(self, schema, engine, connection_to_use):
# TODO reuse metadata
metadata = get_empty_metadata()
table = reflect_table_from_oid(self.table_oid, engine, connection_to_use=connection_to_use, metadata=metadata)
referent_table = reflect_table_from_oid(self.referent_table_oid, engine, connection_to_use=connection_to_use, metadata=metadata)
columns_name = get_column_names_from_attnums(self.table_oid, self.columns_attnum, engine, connection_to_use=connection_to_use, metadata=metadata)
referent_columns_name = get_column_names_from_attnums(
[self.referent_table_oid],
self.referent_columns,
match_type = get_constraint_match_char_from_type(self.options.get('match'))
on_update = get_constraint_char_from_action(self.options.get('onupdate'))
on_delete = get_constraint_char_from_action(self.options.get('ondelete'))
return execute_msar_func_with_engine(
engine,
connection_to_use=connection_to_use,
metadata=metadata,
)
# TODO reuse metadata
metadata = MetaData(bind=engine, schema=schema, naming_convention=naming_convention)
opts = {
'target_metadata': metadata
}
ctx = MigrationContext.configure(connection_to_use, opts=opts)
op = Operations(ctx)
op.create_foreign_key(
self.name,
table.name,
referent_table.name,
columns_name,
referent_columns_name,
source_schema=table.schema,
referent_schema=referent_table.schema,
**self.options
)
'add_constraints',
self.table_oid,
json.dumps(
[
{
'name': self.name,
'type': 'f',
'columns': self.columns_attnum,
'deferrable': self.options.get('deferrable'),
'fkey_relation_id': self.referent_table_oid,
'fkey_columns': self.referent_columns,
'fkey_update_action': on_update,
'fkey_delete_action': on_delete,
'fkey_match_type': match_type,
}
]
)
).fetchone()[0]


class UniqueConstraint(Constraint):
Expand All @@ -80,17 +73,20 @@ def __init__(self, name, table_oid, columns_attnum):
self.columns_attnum = columns_attnum

def add_constraint(self, schema, engine, connection_to_use):
# TODO reuse metadata
metadata = get_empty_metadata()
table = reflect_table_from_oid(self.table_oid, engine, connection_to_use=connection_to_use, metadata=metadata)
columns = get_column_names_from_attnums(self.table_oid, self.columns_attnum, engine, connection_to_use=connection_to_use, metadata=metadata)
metadata = MetaData(bind=engine, schema=schema, naming_convention=naming_convention)
opts = {
'target_metadata': metadata
}
ctx = MigrationContext.configure(connection_to_use, opts=opts)
op = Operations(ctx)
op.create_unique_constraint(self.name, table.name, columns, table.schema)
return execute_msar_func_with_engine(
engine,
'add_constraints',
self.table_oid,
json.dumps(
[
{
'name': self.name,
'type': 'u',
'columns': self.columns_attnum
}
],
)
).fetchone()[0]

def constraint_type(self):
return "unique"
44 changes: 20 additions & 24 deletions db/constraints/operations/create.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,34 @@
from alembic.migration import MigrationContext
from alembic.operations import Operations
from sqlalchemy import MetaData
"""TODO This needs to be consolidated with db.constraints.base"""
import json

from db.columns.operations.select import get_column_names_from_attnums
from db.constraints.utils import get_constraint_type_from_char, ConstraintType, naming_convention
from db.tables.operations.select import reflect_table_from_oid
from db.metadata import get_empty_metadata
from db.connection import execute_msar_func_with_engine
from db.constraints.utils import get_constraint_type_from_char, ConstraintType


def create_unique_constraint(table_name, schema, engine, columns, constraint_name=None):
with engine.begin() as conn:
metadata = MetaData(bind=engine, schema=schema, naming_convention=naming_convention)
opts = {
'target_metadata': metadata
}
ctx = MigrationContext.configure(conn, opts=opts)
op = Operations(ctx)
op.create_unique_constraint(constraint_name, table_name, columns, schema)
return execute_msar_func_with_engine(
engine,
'add_constraints',
schema,
table_name,
json.dumps([{'name': constraint_name, 'type': 'u', 'columns': columns}])
).fetchone()[0]


def create_constraint(schema, engine, constraint_obj):
with engine.begin() as conn:
constraint_obj.add_constraint(schema, engine, conn)
return constraint_obj.add_constraint(schema, engine, conn)


def copy_constraint(table_oid, engine, constraint, from_column_attnum, to_column_attnum):
# TODO reuse metadata
metadata = get_empty_metadata()
table = reflect_table_from_oid(table_oid, engine, metadata=metadata)
def copy_constraint(_, engine, constraint, from_column_attnum, to_column_attnum):
constraint_type = get_constraint_type_from_char(constraint.contype)
if constraint_type == ConstraintType.UNIQUE.value:
column_attnums = constraint.conkey
changed_column_attnums = [to_column_attnum if attnum == from_column_attnum else attnum for attnum in column_attnums]
columns = get_column_names_from_attnums(table_oid, changed_column_attnums, engine, metadata=metadata)
create_unique_constraint(table.name, table.schema, engine, columns)
return execute_msar_func_with_engine(
engine,
'copy_constraint',
constraint.oid,
from_column_attnum,
to_column_attnum
).fetchone()[0]
else:
raise NotImplementedError
95 changes: 58 additions & 37 deletions db/constraints/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
"""Utilities for database constraints."""
from enum import Enum

from sqlalchemy import CheckConstraint, ForeignKeyConstraint, PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.dialects.postgresql import ExcludeConstraint

from db.columns.operations.select import get_column_name_from_attnum
from db.tables.operations.select import reflect_table_from_oid


class ConstraintType(Enum):
FOREIGN_KEY = 'foreignkey'
Expand All @@ -29,6 +24,7 @@ class ConstraintMatch(Enum):
SIMPLE = 'SIMPLE'


# TODO Remove this. It's incorrect, and not robust.
# Naming conventions for constraints follow standard Postgres conventions
# described in https://stackoverflow.com/a/4108266
naming_convention = {
Expand All @@ -40,21 +36,16 @@ class ConstraintMatch(Enum):
}


def get_constraint_type_from_class(constraint):
if type(constraint) == CheckConstraint:
return ConstraintType.CHECK.value
elif type(constraint) == ForeignKeyConstraint:
return ConstraintType.FOREIGN_KEY.value
elif type(constraint) == PrimaryKeyConstraint:
return ConstraintType.PRIMARY_KEY.value
elif type(constraint) == UniqueConstraint:
return ConstraintType.UNIQUE.value
elif type(constraint) == ExcludeConstraint:
return ConstraintType.EXCLUDE.value
return None
def get_constraint_type_from_char(constraint_char):
"""
Map the char for a constraint to the string used for creating it in SQL.
Args:
constraint_char: Single character, matching pg_constraints.
def get_constraint_type_from_char(constraint_char):
Returns:
The string used for creating the constraint in SQL.
"""
char_type_map = {
"c": ConstraintType.CHECK.value,
"f": ConstraintType.FOREIGN_KEY.value,
Expand All @@ -65,39 +56,69 @@ def get_constraint_type_from_char(constraint_char):
return char_type_map.get(constraint_char)


def get_constraint_action_from_char(action_char):
def _get_char_action_map(reverse=False):
action_map = {
"a": ConstraintAction.NO_ACTION.value,
"r": ConstraintAction.RESTRICT.value,
"c": ConstraintAction.CASCADE.value,
"n": ConstraintAction.SET_NULL.value,
"d": ConstraintAction.SET_DEFAULT.value,
}
if reverse:
action_map = {v: k for k, v in action_map.items()}
return action_map


def get_constraint_action_from_char(action_char):
"""
Map the action_char to a string giving the on update or on delecte action.
Args:
action_char: Single character, matching pg_constraints.
"""
action_map = _get_char_action_map()
return action_map.get(action_char)


def get_constraint_match_type_from_char(match_char):
def get_constraint_char_from_action(action):
"""
Map the on update or on delete action to a single character.
Args:
action: Single character, matching pg_constraints.
"""
action_map = _get_char_action_map(reverse=True)
return action_map.get(action)


def _get_char_match_map(reverse=False):
match_map = {
"f": ConstraintMatch.FULL.value,
"p": ConstraintMatch.PARTIAL.value,
"s": ConstraintMatch.SIMPLE.value,
}
if reverse:
match_map = {v: k for k, v in match_map.items()}
return match_map


def get_constraint_match_type_from_char(match_char):
"""
Map the match_char to a string giving the match type.
Args:
match_char: Single character, matching pg_constraints.
"""
match_map = _get_char_match_map()
return match_map.get(match_char)


def get_constraint_name(engine, constraint_type, table_oid, column_0_attnum, metadata, connection_to_use=None):
table_name = reflect_table_from_oid(table_oid, engine, connection_to_use=connection_to_use, metadata=metadata).name
column_0_name = get_column_name_from_attnum(table_oid, column_0_attnum, engine, metadata=metadata, connection_to_use=connection_to_use)
data = {
'table_name': table_name,
'column_0_name': column_0_name
}
if constraint_type == ConstraintType.UNIQUE.value:
return naming_convention['uq'] % data
if constraint_type == ConstraintType.FOREIGN_KEY.value:
return naming_convention['fk'] % data
if constraint_type == ConstraintType.PRIMARY_KEY.value:
return naming_convention['pk'] % data
if constraint_type == ConstraintType.CHECK.value:
return naming_convention['ck'] % data
return None
def get_constraint_match_char_from_type(match_type):
"""
Map the match_type to a single character.
Args:
match_type: Single character, matching pg_constraints.
"""
match_map = _get_char_match_map(reverse=True)
return match_map.get(match_type)

0 comments on commit 9ccbb35

Please sign in to comment.