Skip to content

Commit

Permalink
[#1871] Add function to extract table names from SQL statement
Browse files Browse the repository at this point in the history
The function performs an EXPLAIN query with the provided statement and
parses its output looking for table names.

For Postgres >= 9.x it uses the FORMAT JSON option to get and parse a
JSON objects.

For older versions of Postgres the plain text option is used.
  • Loading branch information
amercader committed Aug 5, 2014
1 parent d9e6240 commit 8f026f5
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 4 deletions.
108 changes: 108 additions & 0 deletions ckanext/datastore/helpers.py
@@ -1,7 +1,16 @@
import logging
import json
import re

import sqlalchemy
import sqlparse

import paste.deploy.converters as converters


log = logging.getLogger(__name__)


def get_list(input, strip_values=True):
'''Transforms a string or list to a list'''
if input is None:
Expand Down Expand Up @@ -33,3 +42,102 @@ def _strip(input):
if isinstance(input, basestring) and len(input) and input[0] == input[-1]:
return input.strip().strip('"')
return input


def get_table_names_from_sql(context, sql):
'''Parses the output of EXPLAIN (FORMAT JSON) looking for table names
It performs an EXPLAIN query against the provided SQL, and parses
the output recusively looking for "Relation Name".
This requires Postgres 9.x. If an older version of Postgres is being run,
it falls back to parse the text output of EXPLAIN. This is harder to parse
and maintain, so it will be deprecated once support for Postgres < 9.x is
dropped.
:param context: a CKAN context dict. It must contain a 'connection' key
with the current DB connection.
:type context: dict
:param sql: the SQL statement to parse for table names
:type sql: string
:rtype: list of strings
'''

def _get_table_names_from_plan(plan):

table_names = []

if plan.get('Relation Name'):
table_names.append(plan['Relation Name'])

if 'Plans' in plan:
for child_plan in plan['Plans']:
table_name = _get_table_names_from_plan(child_plan)
if table_name:
table_names.extend(table_name)

return table_names

try:
result = context['connection'].execute(
'EXPLAIN (FORMAT JSON) {0}'.format(sql)).fetchone()
except sqlalchemy.exc.ProgrammingError, e:
if 'syntax error at or near "format"' in str(e).lower():
# Old version of Postgres, parse the text output instead
return _get_table_names_from_sql_text(context, sql)
raise

table_names = []

try:
query_plan = json.loads(result['QUERY PLAN'])
plan = query_plan[0]['Plan']

table_names.extend(_get_table_names_from_plan(plan))

except ValueError:
log.error('Could not parse query plan')

return table_names


def _get_table_names_from_sql_text(context, sql):
'''Parses the output of EXPLAIN looking for table names
It performs an EXPLAIN query against the provided SQL, and parses
the output looking for "Scan on".
Note that double quotes are removed from table names.
This is to be used only on Postgres 8.x.
This function should not be called directly, use
`get_table_names_from_sql`.
:param context: a CKAN context dict. It must contain a 'connection' key
with the current DB connection.
:type context: dict
:param sql: the SQL statement to parse for table names
:type sql: string
:rtype: list of strings
'''

results = context['connection'].execute(
'EXPLAIN {0}'.format(sql))

pattern = re.compile('Scan on (.*) ')

table_names = []
for result in results:
query_plan = result['QUERY PLAN']

match = pattern.search(query_plan)
if match:
table_names.append(match.group(1).strip('"'))

return table_names
72 changes: 68 additions & 4 deletions ckanext/datastore/tests/test_helpers.py
@@ -1,9 +1,17 @@
import ckanext.datastore.helpers as helpers
import pylons
import sqlalchemy.orm as orm
import nose

import ckanext.datastore.helpers as datastore_helpers
import ckanext.datastore.tests.helpers as datastore_test_helpers
import ckanext.datastore.db as db


eq_ = nose.tools.eq_

class TestTypeGetters(object):
def test_get_list(self):
get_list = helpers.get_list
get_list = datastore_helpers.get_list
assert get_list(None) is None
assert get_list([]) == []
assert get_list('') == []
Expand All @@ -28,7 +36,63 @@ def test_is_single_statement(self):
'SELECT * FROM "foo"; SELECT * FROM "abc"']

for single in singles:
assert helpers.is_single_statement(single) is True
assert datastore_helpers.is_single_statement(single) is True

for multiple in multiples:
assert helpers.is_single_statement(multiple) is False
assert datastore_helpers.is_single_statement(multiple) is False


class TestGetTables(object):

@classmethod
def setup_class(cls):
engine = db._get_engine(
{'connection_url': pylons.config['ckan.datastore.write_url']}
)
cls.Session = orm.scoped_session(orm.sessionmaker(bind=engine))

datastore_test_helpers.clear_db(cls.Session)

create_tables = [
'CREATE TABLE test_a (id_a text)',
'CREATE TABLE test_b (id_b text)',
'CREATE TABLE "TEST_C" (id_c text)',
]
for create_table_sql in create_tables:
cls.Session.execute(create_table_sql)

@classmethod
def teardown_class(cls):
datastore_test_helpers.clear_db(cls.Session)

def test_get_table_names(self):

test_cases = [
('SELECT * FROM test_a', ['test_a']),
('SELECT * FROM public.test_a', ['test_a']),
('SELECT * FROM "TEST_C"', ['TEST_C']),
('SELECT * FROM public."TEST_C"', ['TEST_C']),
('SELECT * FROM pg_catalog.pg_database', ['pg_database']),
('SELECT rolpassword FROM pg_roles', ['pg_authid']),
('''SELECT p.rolpassword
FROM pg_roles p
JOIN test_b b
ON p.rolpassword = b.id_b''', ['pg_authid', 'test_b']),
('''SELECT id_a, id_b, id_c
FROM (
SELECT *
FROM (
SELECT *
FROM "TEST_C") AS c,
test_b) AS b,
test_a AS a''', ['test_a', 'test_b', 'TEST_C']),
('INSERT INTO test_a VALUES (\'a\')', ['test_a']),
]

context = {
'connection': self.Session.connection()
}
for case in test_cases:
eq_(sorted(datastore_helpers.get_table_names_from_sql(context,
case[0])),
sorted(case[1]))

0 comments on commit 8f026f5

Please sign in to comment.