From 86d2226719eab4421fdf0c62429269dc69172b97 Mon Sep 17 00:00:00 2001 From: Adrien Russo Date: Mon, 22 Nov 2021 15:37:44 -0500 Subject: [PATCH] Fix #663 - Fixed incompatibility of DatabaseSchemaIntrospection with django Fix search path issue typo Fix for psql support --- .../postgresql_backend/_constraints.py | 102 -------------- .../postgresql_backend/introspection.py | 132 +++++------------- 2 files changed, 35 insertions(+), 199 deletions(-) delete mode 100644 django_tenants/postgresql_backend/_constraints.py diff --git a/django_tenants/postgresql_backend/_constraints.py b/django_tenants/postgresql_backend/_constraints.py deleted file mode 100644 index d7105ebc..00000000 --- a/django_tenants/postgresql_backend/_constraints.py +++ /dev/null @@ -1,102 +0,0 @@ -from django.db.models.indexes import Index - - -def get_constraints(self, cursor, table_name): - """ - Retrieve any constraints or keys (unique, pk, fk, check, index) across - one or more columns. Also retrieve the definition of expression-based - indexes. - """ - constraints = {} - # Loop over the key table, collecting things as constraints. The column - # array must return column names in the same order in which they were - # created. - # The subquery containing generate_series can be replaced with - # "WITH ORDINALITY" when support for PostgreSQL 9.3 is dropped. - cursor.execute(""" - SELECT - c.conname, - array( - SELECT attname - FROM ( - SELECT unnest(c.conkey) AS colid, - generate_series(1, array_length(c.conkey, 1)) AS arridx - ) AS cols - JOIN pg_attribute AS ca ON cols.colid = ca.attnum - WHERE ca.attrelid = c.conrelid - ORDER BY cols.arridx - ), - c.contype, - (SELECT fkc.relname || '.' || fka.attname - FROM pg_attribute AS fka - JOIN pg_class AS fkc ON fka.attrelid = fkc.oid - WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]), - cl.reloptions - FROM pg_constraint AS c - JOIN pg_class AS cl ON c.conrelid = cl.oid - JOIN pg_namespace AS ns ON cl.relnamespace = ns.oid - WHERE ns.nspname = %s AND cl.relname = %s - """, [self.connection.schema_name, table_name]) - for constraint, columns, kind, used_cols, options in cursor.fetchall(): - constraints[constraint] = { - "columns": columns, - "primary_key": kind == "p", - "unique": kind in ["p", "u"], - "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None, - "check": kind == "c", - "index": False, - "definition": None, - "options": options, - } - # Now get indexes - # The row_number() function for ordering the index fields can be - # replaced by WITH ORDINALITY in the unnest() functions when support - # for PostgreSQL 9.3 is dropped. - cursor.execute(""" - SELECT - indexname, array_agg(attname ORDER BY rnum), indisunique, indisprimary, - array_agg(ordering ORDER BY rnum), amname, exprdef, s2.attoptions - FROM ( - SELECT - row_number() OVER () as rnum, c2.relname as indexname, - idx.*, attr.attname, am.amname, - CASE - WHEN idx.indexprs IS NOT NULL THEN - pg_get_indexdef(idx.indexrelid) - END AS exprdef, - CASE am.amname - WHEN 'btree' THEN - CASE (option & 1) - WHEN 1 THEN 'DESC' ELSE 'ASC' - END - END as ordering, - c2.reloptions as attoptions - FROM ( - SELECT - *, unnest(i.indkey) as key, unnest(i.indoption) as option - FROM pg_index i - ) idx - LEFT JOIN pg_class c ON idx.indrelid = c.oid - LEFT JOIN pg_namespace n ON n.oid = c.relnamespace - LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid - LEFT JOIN pg_am am ON c2.relam = am.oid - LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key - WHERE c.relname = %s and n.nspname = %s - ) s2 - GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions; - """, [table_name, self.connection.schema_name]) - for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall(): - if index not in constraints: - constraints[index] = { - "columns": columns if columns != [None] else [], - "orders": orders if orders != [None] else [], - "primary_key": primary, - "unique": unique, - "foreign_key": None, - "check": False, - "index": True, - "type": Index.suffix if type_ == 'btree' else type_, - "definition": definition, - "options": options, - } - return constraints diff --git a/django_tenants/postgresql_backend/introspection.py b/django_tenants/postgresql_backend/introspection.py index ed6624cb..281ff252 100644 --- a/django_tenants/postgresql_backend/introspection.py +++ b/django_tenants/postgresql_backend/introspection.py @@ -1,112 +1,50 @@ from django.db.backends.postgresql.introspection import DatabaseIntrospection -from django.db.backends.base.introspection import TableInfo, FieldInfo -from django.utils.encoding import force_text -from . import _constraints + +class DatabaseSchemaIntrospectionSearchPathContext: + """ + This context manager restores the original search path of the cursor + once the method of the introspection class has been called. + """ + def __init__(self, cursor, connection): + self.cursor = cursor + self.connection = connection + self.original_search_path = None + + def __enter__(self): + self.cursor.execute('SHOW search_path') + self.original_search_path = self.cursor.fetchone()[0].split(',') + self.cursor.execute(f"SET search_path = '{self.connection.schema_name}'") + + def __exit__(self, *args, **kwargs): + formatted_search_paths = ', '.join( + f"'{search_path.strip()}'" + for search_path in self.original_search_path + ) + self.cursor.execute(f'SET search_path = {formatted_search_paths}') class DatabaseSchemaIntrospection(DatabaseIntrospection): """ database schema introspection class """ - _get_indexes_query = """ - SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary - FROM pg_catalog.pg_class c, - INNER JOIN pg_catalog.pg_index idx ON c.oid = idx.indrelid - INNER JOIN pg_catalog.pg_class c2 ON idx.indexrelid = c2.oid - INNER JOIN pg_catalog.pg_attribute attr ON attr.attrelid = c.oid and attr.attnum = idx.indkey[0] - INNER JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE c.relname = %s AND n.nspname = %s - """ def get_table_list(self, cursor): - """ - Returns a list of table names in the current database and schema. - """ - - cursor.execute(""" - SELECT c.relname, c.relkind - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE c.relkind IN ('r', 'v', '') - AND n.nspname = '%s' - AND pg_catalog.pg_table_is_visible(c.oid)""" % self.connection.schema_name) - - return [TableInfo(row[0], {'r': 't', 'v': 'v'}.get(row[1])) - for row in cursor.fetchall() - if row[0] not in self.ignored_tables] + with DatabaseSchemaIntrospectionSearchPathContext(cursor=cursor, connection=self.connection): + return super().get_table_list(cursor) def get_table_description(self, cursor, table_name): - "Returns a description of the table, with the DB-API cursor.description interface." - # As cursor.description does not return reliably the nullable property, - # we have to query the information_schema (#7783) - cursor.execute(""" - SELECT column_name, is_nullable, column_default - FROM information_schema.columns - WHERE table_schema = %s and table_name = %s""", [self.connection.schema_name, table_name]) - field_map = {line[0]: line[1:] for line in cursor.fetchall()} - cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) - return [FieldInfo(*((force_text(line[0]),) + line[1:6] + - (field_map[force_text(line[0])][0] == 'YES', field_map[force_text(line[0])][1]))) - for line in cursor.description] - - def get_indexes(self, cursor, table_name): - # This query retrieves each index on the given table, including the - # first associated field name - cursor.execute(self._get_indexes_query, [table_name, self.connection.schema_name]) - indexes = {} - for row in cursor.fetchall(): - # row[1] (idx.indkey) is stored in the DB as an array. It comes out as - # a string of space-separated integers. This designates the field - # indexes (1-based) of the fields that have indexes on the table. - # Here, we skip any indexes across multiple fields. - if ' ' in row[1]: - continue - if row[0] not in indexes: - indexes[row[0]] = {'primary_key': False, 'unique': False} - # It's possible to have the unique and PK constraints in separate indexes. - if row[3]: - indexes[row[0]]['primary_key'] = True - if row[2]: - indexes[row[0]]['unique'] = True - return indexes + with DatabaseSchemaIntrospectionSearchPathContext(cursor=cursor, connection=self.connection): + return super().get_table_description(cursor, table_name) - def get_relations(self, cursor, table_name): - """ - Returns a dictionary of {field_name: (field_name_other_table, other_table)} - representing all relationships to the given table. - """ - cursor.execute(""" - SELECT c2.relname, a1.attname, a2.attname - FROM pg_constraint con - LEFT JOIN pg_class c1 ON con.conrelid = c1.oid - LEFT JOIN pg_namespace n ON n.oid = c1.relnamespace - LEFT JOIN pg_class c2 ON con.confrelid = c2.oid - LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1] - LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1] - WHERE c1.relname = %s and n.nspname = %s - AND con.contype = 'f'""", [table_name, self.connection.schema_name]) - relations = {} - for row in cursor.fetchall(): - relations[row[1]] = (row[2], row[0]) - return relations - - get_constraints = _constraints.get_constraints + def get_sequences(self, cursor, table_name, table_fields=()): + with DatabaseSchemaIntrospectionSearchPathContext(cursor=cursor, connection=self.connection): + return super().get_sequences(cursor, table_name, table_fields) def get_key_columns(self, cursor, table_name): - key_columns = [] - cursor.execute(""" - SELECT kcu.column_name, ccu.table_name AS referenced_table, ccu.column_name AS referenced_column - FROM information_schema.constraint_column_usage ccu - LEFT JOIN information_schema.key_column_usage kcu - ON ccu.constraint_catalog = kcu.constraint_catalog - AND ccu.constraint_schema = kcu.constraint_schema - AND ccu.constraint_name = kcu.constraint_name - LEFT JOIN information_schema.table_constraints tc - ON ccu.constraint_catalog = tc.constraint_catalog - AND ccu.constraint_schema = tc.constraint_schema - AND ccu.constraint_name = tc.constraint_name - WHERE kcu.table_name = %s AND tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = %s - """, [table_name, self.connection.schema_name]) - key_columns.extend(cursor.fetchall()) - return key_columns + with DatabaseSchemaIntrospectionSearchPathContext(cursor=cursor, connection=self.connection): + return super().get_key_columns(cursor, table_name) + + def get_constraints(self, cursor, table_name): + with DatabaseSchemaIntrospectionSearchPathContext(cursor=cursor, connection=self.connection): + return super().get_constraints(cursor, table_name)