Skip to content

Commit

Permalink
Merge pull request #578 from centerofci/table_patch
Browse files Browse the repository at this point in the history
Update table PATCH endpoint to work with columns
  • Loading branch information
pavish committed Aug 23, 2021
2 parents 41add0b + fe3cc30 commit 7b33cbd
Show file tree
Hide file tree
Showing 16 changed files with 752 additions and 98 deletions.
159 changes: 126 additions & 33 deletions db/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
DefaultClause, func
)
from sqlalchemy.ext import compiler
from sqlalchemy.exc import DataError
from sqlalchemy.exc import DataError, InternalError
from sqlalchemy.schema import DDLElement
from psycopg2.errors import InvalidTextRepresentation, InvalidParameterValue

from db import constants, tables, constraints
from db.types import alteration
from db.types.base import get_db_type_name
from db.utils import execute_statement

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,7 +201,7 @@ def init_mathesar_table_column_list_with_defaults(column_list):
return default_columns + given_columns


def get_column_index_from_name(table_oid, column_name, engine):
def get_column_index_from_name(table_oid, column_name, engine, connection_to_use=None):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Did not recognize type")
pg_attribute = Table("pg_attribute", MetaData(), autoload_with=engine)
Expand All @@ -209,8 +211,7 @@ def get_column_index_from_name(table_oid, column_name, engine):
pg_attribute.c.attname == column_name
)
)
with engine.begin() as conn:
result = conn.execute(sel).fetchone()[0]
result = execute_statement(engine, sel, connection_to_use).fetchone()[0]

# Account for dropped columns that don't appear in the SQLAlchemy tables
sel = (
Expand All @@ -220,8 +221,7 @@ def get_column_index_from_name(table_oid, column_name, engine):
pg_attribute.c.attnum < result,
))
)
with engine.begin() as conn:
dropped_count = conn.execute(sel).fetchone()[0]
dropped_count = execute_statement(engine, sel, connection_to_use).fetchone()[0]

return result - 1 - dropped_count

Expand Down Expand Up @@ -316,9 +316,18 @@ def rename_column(table_oid, column_index, new_column_name, engine, **kwargs):
)


def retype_column(table_oid, column_index, new_type, engine, **kwargs):
def _handle_retype_data_errors(e):
if (
type(e.orig) == InvalidParameterValue
or type(e.orig) == InvalidTextRepresentation
):
raise InvalidTypeOptionError
else:
raise e


def retype_column(table_oid, column_index, new_type, engine, type_options={}):
table = tables.reflect_table_from_oid(table_oid, engine)
type_options = kwargs.get("type_options", {})
try:
alteration.alter_column_type(
table.schema,
Expand All @@ -330,20 +339,48 @@ def retype_column(table_oid, column_index, new_type, engine, **kwargs):
type_options=type_options
)
except DataError as e:
if (
type(e.orig) == InvalidParameterValue
or type(e.orig) == InvalidTextRepresentation
):
raise InvalidTypeOptionError
else:
raise e
_handle_retype_data_errors(e)

return get_mathesar_column_with_engine(
tables.reflect_table_from_oid(table_oid, engine).columns[column_index],
engine
)


def _check_type_option_equivalence(type_options_1, type_options_2):
NULL_OPTIONS = [None, {}]
if type_options_1 in NULL_OPTIONS and type_options_2 in NULL_OPTIONS:
return True
elif type_options_1 == type_options_2:
return True
return False


def retype_column_in_connection(table, connection, engine, column_index, new_type, type_options={}):
column = table.columns[column_index]
column_db_type = get_db_type_name(column.type, engine)
column_type_options = MathesarColumn.from_column(column).type_options
# Don't try to retype the column if it's already the correct type.
if (new_type == column_db_type) and _check_type_option_equivalence(type_options, column_type_options):
return
try:
alteration.alter_column_type(
table.schema,
table.name,
column.name,
new_type,
engine,
friendly_names=False,
type_options=type_options,
connection_to_use=connection,
table_to_use=table
)
except DataError as e:
_handle_retype_data_errors(e)
except InternalError as e:
raise e.orig


def change_column_nullable(table_oid, column_index, nullable, engine, **kwargs):
table = tables.reflect_table_from_oid(table_oid, engine)
column = table.columns[column_index]
Expand All @@ -368,6 +405,50 @@ def get_mathesar_column_with_engine(col, engine):
return new_column


def _validate_columns_for_batch_update(table, column_data):
ALLOWED_KEYS = ['name', 'plain_type', 'type_options']
if len(column_data) != len(table.columns):
raise ValueError('Number of columns passed in must equal number of columns in table')
for single_column_data in column_data:
for key in single_column_data.keys():
if key not in ALLOWED_KEYS:
allowed_key_list = ', '.join(ALLOWED_KEYS)
raise ValueError(f'Key "{key}" found in columns. Keys allowed are: {allowed_key_list}')


def _batch_update_column_types(table, column_data_list, connection, engine):
for index, column_data in enumerate(column_data_list):
if 'plain_type' in column_data:
new_type = column_data['plain_type']
type_options = column_data.get('type_options', {})
if type_options is None:
type_options = {}
retype_column_in_connection(table, connection, engine, index, new_type, type_options)


def _batch_alter_table_columns(table, column_data_list, connection):
ctx = MigrationContext.configure(connection)
op = Operations(ctx)
with op.batch_alter_table(table.name, schema=table.schema) as batch_op:
for index, column_data in enumerate(column_data_list):
column = table.columns[index]
if 'name' in column_data and column.name != column_data['name']:
batch_op.alter_column(
column.name,
new_column_name=column_data['name']
)
elif len(column_data.keys()) == 0:
batch_op.drop_column(column.name)


def batch_update_columns(table_oid, engine, column_data_list):
table = tables.reflect_table_from_oid(table_oid, engine)
_validate_columns_for_batch_update(table, column_data_list)
with engine.begin() as conn:
_batch_update_column_types(table, column_data_list, conn, engine)
_batch_alter_table_columns(table, column_data_list, conn)


def drop_column(table_oid, column_index, engine):
column_index = int(column_index)
table = tables.reflect_table_from_oid(table_oid, engine)
Expand All @@ -378,8 +459,11 @@ def drop_column(table_oid, column_index, engine):
op.drop_column(table.name, column.name, schema=table.schema)


def get_column_default(table_oid, column_index, engine):
table = tables.reflect_table_from_oid(table_oid, engine)
def get_column_default(table_oid, column_index, engine, connection_to_use=None, table_to_use=None):
if table_to_use is None:
table = tables.reflect_table_from_oid(table_oid, engine)
else:
table = table_to_use
column = table.columns[column_index]
if column.server_default is None:
return None
Expand Down Expand Up @@ -409,8 +493,7 @@ def get_column_default(table_oid, column_index, engine):
))
)

with engine.begin() as conn:
result = conn.execute(query).first()[0]
result = execute_statement(engine, query, connection_to_use).first()[0]

# Here, we get the 'adbin' value for the current column, stored in the attrdef
# system table. The prefix of this value tells us whether the default is static
Expand All @@ -422,31 +505,41 @@ def get_column_default(table_oid, column_index, engine):
# Ex: "'test default string'::character varying" or "'2020-01-01'::date"
# Here, we execute the cast to get the proper python value
cast_sql_text = column.server_default.arg.text
with engine.begin() as conn:
return conn.execute(select(text(cast_sql_text))).first()[0]
return execute_statement(engine, select(text(cast_sql_text)), connection_to_use).first()[0]


def _alter_column_default(ctx, table_name, column_name, schema, default_clause):
op = Operations(ctx)
op.alter_column(table_name, column_name, schema=schema, server_default=default_clause)


def set_column_default(table_oid, column_index, default, engine, **kwargs):
def set_column_default(table_oid, column_index, default, engine, connection_to_use=None,
table_to_use=None, **_):
# Note: default should be textual SQL that produces the desired default
table = tables.reflect_table_from_oid(table_oid, engine)
if table_to_use is None:
table = tables.reflect_table_from_oid(table_oid, engine)
else:
table = table_to_use
column = table.columns[column_index]
default_clause = DefaultClause(str(default)) if default is not None else default
try:
with engine.begin() as conn:
ctx = MigrationContext.configure(conn)
op = Operations(ctx)
op.alter_column(
table.name, column.name, schema=table.schema, server_default=default_clause
)
if connection_to_use is None:
with engine.begin() as conn:
ctx = MigrationContext.configure(conn)
_alter_column_default(ctx, table.name, column.name, table.schema, default_clause)
else:
ctx = MigrationContext.configure(connection_to_use)
_alter_column_default(ctx, table.name, column.name, table.schema, default_clause)
except DataError as e:
if (type(e.orig) == InvalidTextRepresentation):
raise InvalidDefaultError
else:
raise e
return get_mathesar_column_with_engine(
tables.reflect_table_from_oid(table_oid, engine).columns[column_index],
engine
)
if connection_to_use is None:
return get_mathesar_column_with_engine(
tables.reflect_table_from_oid(table_oid, engine).columns[column_index],
engine
)


def _gen_col_name(table, column_name):
Expand Down
10 changes: 7 additions & 3 deletions db/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TEMP_SCHEMA = f"{constants.MATHESAR_PREFIX}temp_schema"
TEMP_TABLE = f"{constants.MATHESAR_PREFIX}temp_table_%s"

SUPPORTED_TABLE_UPDATE_ARGS = {'name'}
SUPPORTED_TABLE_UPDATE_ARGS = {'name', 'sa_columns'}


def create_string_column_table(name, schema, column_names, engine):
Expand Down Expand Up @@ -94,9 +94,13 @@ def rename_table(name, schema, engine, rename_to):
op.rename_table(table.name, rename_to, schema=table.schema)


def update_table(name, schema, engine, update_data):
def update_table(table_name, table_oid, schema, engine, update_data):
if 'name' in update_data and 'sa_columns' in update_data:
raise ValueError('Only name or columns can be passed in, not both.')
if 'name' in update_data:
rename_table(name, schema, engine, update_data['name'])
rename_table(table_name, schema, engine, update_data['name'])
if 'sa_columns' in update_data:
columns.batch_update_columns(table_oid, engine, update_data['sa_columns'])


def extract_columns_from_table(
Expand Down
3 changes: 2 additions & 1 deletion db/tests/resources/roster_create.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ CREATE TABLE "Roster" (
mathesar_id integer NOT NULL,
"Student Number" uuid,
"Student Name" character varying(100),
"Student Email" character varying(150), "Teacher" character varying(100),
"Student Email" character varying(150),
"Teacher" character varying(100),
"Teacher Email" character varying(150),
"Subject" character varying(20),
"Grade" integer
Expand Down
Loading

0 comments on commit 7b33cbd

Please sign in to comment.