Skip to content

Commit

Permalink
Merge branch 'Edrolo-db-inspection'
Browse files Browse the repository at this point in the history
  • Loading branch information
shimizukawa committed Feb 14, 2022
2 parents ce0d578 + e2aa6cd commit 4578939
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 5 deletions.
160 changes: 159 additions & 1 deletion django_redshift_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.db.backends.base.introspection import FieldInfo
from django.db.backends.base.validation import BaseDatabaseValidation
from django.db.backends.postgresql.base import (
DatabaseFeatures as BasePGDatabaseFeatures,
Expand All @@ -20,8 +21,9 @@
DatabaseSchemaEditor as BasePGDatabaseSchemaEditor,
DatabaseClient,
DatabaseCreation as BasePGDatabaseCreation,
DatabaseIntrospection,
DatabaseIntrospection as BasePGDatabaseIntrospection,
)
from django.db.models import Index

from django.db.utils import NotSupportedError

Expand Down Expand Up @@ -563,6 +565,162 @@ class DatabaseCreation(BasePGDatabaseCreation):
pass


class DatabaseIntrospection(BasePGDatabaseIntrospection):
pass

def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.

# This function is based on the version from the Django postgres backend
# from before support for collations were introduced in Django 3.2
cursor.execute("""
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default
FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
JOIN pg_type t ON a.atttypid = t.oid
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""", [table_name])
field_map = {
column_name: (is_nullable, column_default)
for (column_name, is_nullable, column_default) in cursor.fetchall()
}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
return [
FieldInfo(
name=column.name,
type_code=column.type_code,
display_size=column.display_size,
internal_size=column.internal_size,
precision=column.precision,
scale=column.scale,
null_ok=field_map[column.name][0],
default=field_map[column.name][1],
collation=None, # Redshift doesn't support user-defined collation
# https://docs.aws.amazon.com/redshift/latest/dg/c_collation_sequences.html
)
for column in cursor.description
]

def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns. Also retrieve the definition of expression-based
indexes.
"""
# Based on code from Django 3.2
constraints = {}
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
cursor.execute("""
SELECT
c.conname,
c.conkey::int[],
c.conrelid,
c.contype,
(SELECT fkc.relname || '.' || fka.attname
FROM pg_attribute AS fka
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1])
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
""", [table_name])
constraint_records = [
(conname, conkey, conrelid, contype, used_cols) for
(conname, conkey, conrelid, contype, used_cols) in cursor.fetchall()
]
table_oid = list(constraint_records)[0][2] # Assuming at least one constraint
attribute_num_to_name_map = self._get_attribute_number_to_name_map_for_table(
cursor, table_oid)

for constraint, conkey, conrelid, kind, used_cols in constraint_records:
constraints[constraint] = {
"columns": [
attribute_num_to_name_map[column_id_int] for column_id_int in conkey
],
"primary_key": kind == "p",
"unique": kind in ["p", "u"],
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
"check": kind == "c",
"index": False,
"definition": None,
"options": None,
}

# Now get indexes
# Based on code from Django 1.7
cursor.execute("""
SELECT
c2.relname,
idx.indrelid,
idx.indkey, -- type "int2vector", returns space-separated string
idx.indisunique,
idx.indisprimary
FROM
pg_catalog.pg_class c,
pg_catalog.pg_class c2,
pg_catalog.pg_index idx
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND c.relname = %s
""", [table_name])
index_records = [
(index_name, indrelid, indkey, unique, primary) for
(index_name, indrelid, indkey, unique, primary) in cursor.fetchall()
]
for index_name, indrelid, indkey, unique, primary in index_records:
if index_name not in constraints:
constraints[index_name] = {
"columns": [
attribute_num_to_name_map[int(column_id_str)]
for column_id_str in indkey.split(' ')
],
"orders": [], # Not implemented
"primary_key": primary,
"unique": unique,
"foreign_key": None,
"check": False,
"index": True,
"type": Index.suffix, # Not implemented - assume default type
"definition": None, # Not implemented
"options": None, # Not implemented
}

return constraints

def _get_attribute_number_to_name_map_for_table(self, cursor, table_oid):
cursor.execute("""
SELECT
attrelid, -- table oid
attnum,
attname
FROM pg_attribute
WHERE pg_attribute.attrelid = %s
ORDER BY attrelid, attnum;
""", [table_oid])
return {
attnum: attname
for _, attnum, attname in cursor.fetchall()
}


class DatabaseWrapper(BasePGDatabaseWrapper):
vendor = 'redshift'

Expand Down
178 changes: 174 additions & 4 deletions tests/test_redshift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import unittest
from unittest import mock

import django
from django.db import connections
Expand All @@ -11,10 +12,7 @@


def norm_sql(sql):
r = sql
r = r.replace(' ', '')
r = r.replace('\n', '')
return r
return ' '.join(sql.split()).replace('( ', '(').replace(' )', ')').replace(' ;', ';')


class DatabaseWrapperTest(unittest.TestCase):
Expand Down Expand Up @@ -180,3 +178,175 @@ def test_sqlmigrate(self):
sql_statements = collect_sql(plan)
print('\n'.join(sql_statements))
assert sql_statements # It doesn't matter what SQL is generated.


class IntrospectionTest(unittest.TestCase):
expected_table_description_metadata = norm_sql(
u'''SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default
FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
JOIN pg_type t ON a.atttypid = t.oid
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
''')

expected_constraints_query = norm_sql(
u''' SELECT
c.conname,
c.conkey::int[],
c.conrelid,
c.contype,
(SELECT fkc.relname || '.' || fka.attname
FROM pg_attribute AS fka
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1])
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
''')

expected_attributes_query = norm_sql(
u'''SELECT
attrelid, -- table oid
attnum,
attname
FROM pg_attribute
WHERE pg_attribute.attrelid = %s
ORDER BY attrelid, attnum;
''')

expected_indexes_query = norm_sql(
u'''SELECT
c2.relname,
idx.indrelid,
idx.indkey, -- type "int2vector", returns space-separated string
idx.indisunique,
idx.indisprimary
FROM
pg_catalog.pg_class c,
pg_catalog.pg_class c2,
pg_catalog.pg_index idx
WHERE
c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND c.relname = %s
''')

def test_get_table_description_does_not_use_unsupported_functions(self):
conn = connections['default']
with mock.patch.object(conn, 'cursor') as mock_cursor_method:
mock_cursor = mock_cursor_method.return_value.__enter__.return_value
from testapp.models import TestModel
table_name = TestModel._meta.db_table

_ = conn.introspection.get_table_description(mock_cursor, table_name)

(
select_metadata_call,
fetchall_call,
select_row_call
) = mock_cursor.method_calls

call_method, call_args, call_kwargs = select_metadata_call
self.assertEqual('execute', call_method)
executed_sql = norm_sql(call_args[0])

self.assertEqual(self.expected_table_description_metadata, executed_sql)

self.assertNotIn('collation', executed_sql)
self.assertNotIn('unnest', executed_sql)

call_method, call_args, call_kwargs = select_row_call
self.assertEqual(
norm_sql('SELECT * FROM "testapp_testmodel" LIMIT 1'),
call_args[0],
)

def test_get_get_constraints_does_not_use_unsupported_functions(self):
conn = connections['default']
with mock.patch.object(conn, 'cursor') as mock_cursor_method:
mock_cursor = mock_cursor_method.return_value.__enter__.return_value
from testapp.models import TestModel
table_name = TestModel._meta.db_table

mock_cursor.fetchall.side_effect = [
# conname, conkey, conrelid, contype, used_cols)
[
(
'testapp_testmodel_testapp_testmodel_id_pkey',
[1],
12345678,
'p',
None,
),
],
[
# attrelid, attnum, attname
(12345678, 1, 'id'),
(12345678, 2, 'ctime'),
(12345678, 3, 'text'),
(12345678, 4, 'uuid'),
],
# index_name, indrelid, indkey, unique, primary
[
(
'testapp_testmodel_testapp_testmodel_id_pkey',
12345678,
'1',
True,
True,
),
],
]

table_constraints = conn.introspection.get_constraints(
mock_cursor, table_name)

expected_table_constraints = {
'testapp_testmodel_testapp_testmodel_id_pkey': {
'columns': ['id'],
'primary_key': True,
'unique': True,
'foreign_key': None,
'check': False,
'index': False,
'definition': None,
'options': None,
}
}
self.assertDictEqual(expected_table_constraints, table_constraints)

calls = mock_cursor.method_calls

# Should be a sequence of 3x execute and fetchall calls
expected_call_sequence = ['execute', 'fetchall'] * 3
actual_call_sequence = [name for (name, _args, _kwargs) in calls]
self.assertEqual(expected_call_sequence, actual_call_sequence)

# Constraints query
call_method, call_args, call_kwargs = calls[0]
executed_sql = norm_sql(call_args[0])
self.assertNotIn('collation', executed_sql)
self.assertNotIn('unnest', executed_sql)
self.assertEqual(self.expected_constraints_query, executed_sql)

# Attributes query
call_method, call_args, call_kwargs = calls[2]
executed_sql = norm_sql(call_args[0])
self.assertNotIn('collation', executed_sql)
self.assertNotIn('unnest', executed_sql)
self.assertEqual(self.expected_attributes_query, executed_sql)

# Indexes query
call_method, call_args, call_kwargs = calls[4]
executed_sql = norm_sql(call_args[0])
self.assertNotIn('collation', executed_sql)
self.assertNotIn('unnest', executed_sql)
self.assertEqual(self.expected_indexes_query, executed_sql)

0 comments on commit 4578939

Please sign in to comment.