Skip to content

Commit

Permalink
create multiple aliases, replace previous, unicode safe, avoid sql in…
Browse files Browse the repository at this point in the history
…jection
  • Loading branch information
domoritz committed Sep 7, 2012
1 parent 38a5101 commit d15bd8d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 39 deletions.
38 changes: 29 additions & 9 deletions ckanext/datastore/db.py
Expand Up @@ -60,6 +60,9 @@ def _is_valid_field_name(name):
return True


_is_valid_table_name = _is_valid_field_name


def _validate_int(i, field_name):
try:
int(i)
Expand Down Expand Up @@ -236,13 +239,29 @@ def create_table(context, data_dict):

context['connection'].execute(sql_string)

# create alias view
alias = data_dict.get('alias', None)
if alias:
sql_alias_string = u'create view "{alias}" as select * from "{main}"'.format(
main=data_dict['resource_id'], alias=alias

def _get_aliases(context, data_dict):
res_id = data_dict['resource_id']
alias_sql = text(u'select name from "_table_metadata" where alias_of = :id')
results = context['connection'].execute(alias_sql, id=res_id).fetchall()
return [x[0] for x in results]


def create_alias(context, data_dict):
aliases = _get_list(data_dict.get('aliases', None))
if aliases:
# delete previous aliases
previous_aliases = _get_aliases(context, data_dict)
for alias in previous_aliases:
sql_alias_drop_string = u'drop view "{0}"'.format(alias)
context['connection'].execute(sql_alias_drop_string)

for alias in aliases:
sql_alias_string = u'create view "{alias}" as select * from "{main}"'.format(
main=data_dict['resource_id'],
alias=alias
)
context['connection'].execute(sql_alias_string)
context['connection'].execute(sql_alias_string)


def create_indexes(context, data_dict):
Expand Down Expand Up @@ -694,7 +713,7 @@ def format_results(context, results, data_dict):
return data_dict


def is_single_statement(sql):
def _is_single_statement(sql):
return not ';' in sql.strip(';')


Expand Down Expand Up @@ -740,6 +759,7 @@ def create(context, data_dict):
alter_table(context, data_dict)
insert_data(context, data_dict)
create_indexes(context, data_dict)
create_alias(context, data_dict)
trans.commit()
return data_dict
except IntegrityError, e:
Expand Down Expand Up @@ -787,10 +807,10 @@ def delete(context, data_dict):
_cache_types(context)

try:
# check if table existes
# check if table exists
trans = context['connection'].begin()
result = context['connection'].execute(
u'select * from pg_tables where tablename = %s',
u'select 1 from pg_tables where tablename = %s',
data_dict['resource_id']
).fetchone()
if not result:
Expand Down
28 changes: 18 additions & 10 deletions ckanext/datastore/logic/action.py
Expand Up @@ -14,8 +14,8 @@ def datastore_create(context, data_dict):
:param resource_id: resource id that the data is going to be stored under.
:type resource_id: string
:param alias: a name for a read only alias to the resource.
:type alias: string
:param aliases: name for read only aliases to the resource.
:type aliases: list or comma separated string
:param fields: fields/columns and their extra metadata.
:type fields: list of dictionaries
:param records: the data, eg: [{"dob": "2005", "some_stuff": ['a', b']}]
Expand All @@ -41,6 +41,14 @@ def datastore_create(context, data_dict):

data_dict['connection_url'] = pylons.config['ckan.datastore_write_url']

# validate aliases
aliases = db._get_list(data_dict.get('aliases', []))
for alias in aliases:
if not db._is_valid_table_name(alias):
raise p.toolkit.ValidationError({
'alias': ['{0} is not a valid alias name'.format(alias)]
})

result = db.create(context, data_dict)
result.pop('id')
result.pop('connection_url')
Expand Down Expand Up @@ -90,7 +98,7 @@ def datastore_delete(context, data_dict):
:param resource_id: resource id that the data will be deleted from.
:type resource_id: string
:param filter: filter to do deleting on over (eg {'name': 'fred'}).
If missing delete whole table.
If missing delete whole table and all dependent views.
:returns: original filters sent.
:rtype: dictionary
Expand Down Expand Up @@ -118,7 +126,7 @@ def datastore_delete(context, data_dict):
def datastore_search(context, data_dict):
'''Search a datastore table.
:param resource_id: id of the data that is going to be selected.
:param resource_id: id or alias of the data that is going to be selected.
:type resource_id: string
:param filters: matching conditions to select.
:type filters: dictionary
Expand Down Expand Up @@ -151,23 +159,23 @@ def datastore_search(context, data_dict):
'''
model = _get_or_bust(context, 'model')
id = _get_or_bust(data_dict, 'resource_id')
res_id = _get_or_bust(data_dict, 'resource_id')

data_dict['connection_url'] = pylons.config['ckan.datastore_read_url']

res_exists = model.Resource.get(id)
res_exists = model.Resource.get(res_id)

alias_exists = False
if not res_exists:
# assume id is an alias
alias_sql = text('select alias_of from "_table_metadata" where name = :id')
result = db._get_engine(None, data_dict).execute(alias_sql, id=id).fetchone()
alias_sql = text(u'select alias_of from "_table_metadata" where name = :id')
result = db._get_engine(None, data_dict).execute(alias_sql, id=res_id).fetchone()
if result:
alias_exists = model.Resource.get(result[0].strip('"'))

if not (res_exists or alias_exists):
raise p.toolkit.ObjectNotFound(p.toolkit._(
'Resource "{}" was not found.'.format(id)
'Resource "{}" was not found.'.format(res_id)
))

p.toolkit.check_access('datastore_search', context, data_dict)
Expand All @@ -192,7 +200,7 @@ def datastore_search_sql(context, data_dict):
'''
sql = _get_or_bust(data_dict, 'sql')

if not db.is_single_statement(sql):
if not db._is_single_statement(sql):
raise p.toolkit.ValidationError({
'query': ['Query is not a single statement or contains semicolons.'],
'hint': [('If you want to use semicolons, use character encoding'
Expand Down
74 changes: 54 additions & 20 deletions ckanext/datastore/tests/test_datastore.py
Expand Up @@ -21,6 +21,7 @@ def test_list(self):
assert db._get_list('') == []
assert db._get_list('foo') == ['foo']
assert db._get_list('foo, bar') == ['foo', 'bar']
assert db._get_list('foo_"bar, baz') == ['foo_"bar', 'baz']
assert db._get_list('"foo", "bar"') == ['foo', 'bar']
assert db._get_list(u'foo, bar') == ['foo', 'bar']
assert db._get_list(['foo', 'bar']) == ['foo', 'bar']
Expand Down Expand Up @@ -75,6 +76,21 @@ def test_create_empty_fails(self):
res_dict = json.loads(res.body)
assert res_dict['success'] is False

def test_create_invalid_alias_name(self):
resource = model.Package.get('annakarenina').resources[0]
data = {
'resource_id': resource.id,
'aliases': 'foo"bar',
'fields': [{'id': 'book', 'type': 'text'},
{'id': 'author', 'type': 'text'}]
}
postparams = '%s=1' % json.dumps(data)
auth = {'Authorization': str(self.sysadmin_user.apikey)}
res = self.app.post('/api/action/datastore_create', params=postparams,
extra_environ=auth, status=409)
res_dict = json.loads(res.body)
assert res_dict['success'] is False

def test_create_invalid_field_type(self):
resource = model.Package.get('annakarenina').resources[0]
data = {
Expand Down Expand Up @@ -165,10 +181,10 @@ def test_bad_records(self):

def test_create_basic(self):
resource = model.Package.get('annakarenina').resources[0]
alias = u'books1'
aliases = [u'great_list_of_books', u'another_list_of_b\xfcks']
data = {
'resource_id': resource.id,
'alias': alias,
'aliases': aliases,
'fields': [{'id': 'book', 'type': 'text'},
{'id': 'author', 'type': '_json'}],
'primary_key': 'book, author',
Expand Down Expand Up @@ -217,18 +233,19 @@ def test_create_basic(self):
assert results.rowcount == 2
model.Session.remove()

# check alias for resource
c = model.Session.connection()
# check aliases for resource
for alias in aliases:
c = model.Session.connection()

results = [row for row in c.execute('select * from "{0}"'.format(resource.id))]
results_alias = [row for row in c.execute('select * from "{0}"'.format(alias))]
results = [row for row in c.execute(u'select * from "{0}"'.format(resource.id))]
results_alias = [row for row in c.execute(u'select * from "{0}"'.format(alias))]

assert results == results_alias
assert results == results_alias

sql = ("select * from _table_metadata "
"where alias_of='{}' and name='{}'").format(resource.id, alias)
results = c.execute(sql)
assert results.rowcount == 1
sql = (u"select * from _table_metadata "
"where alias_of='{0}' and name='{1}'").format(resource.id, alias)
results = c.execute(sql)
assert results.rowcount == 1

# check to test to see if resource now has a datastore table
postparams = '%s=1' % json.dumps({'id': resource.id})
Expand Down Expand Up @@ -319,6 +336,7 @@ def test_create_basic(self):
####### insert again which should not fail because constraint is removed
data5 = {
'resource_id': resource.id,
'aliases': 'another_alias', # replaces aliases
'records': [{'book': 'warandpeace'}],
'primary_key': ''
}
Expand All @@ -331,6 +349,19 @@ def test_create_basic(self):

assert res_dict['success'] is True

# new aliases should replace old aliases
c = model.Session.connection()
for alias in aliases:
sql = (u"select * from _table_metadata "
"where alias_of='{0}' and name='{1}'").format(resource.id, alias)
results = c.execute(sql)
assert results.rowcount == 0

sql = (u"select * from _table_metadata "
"where alias_of='{0}' and name='{1}'").format(resource.id, 'another_alias')
results = c.execute(sql)
assert results.rowcount == 1

def test_guess_types(self):
resource = model.Package.get('annakarenina').resources[1]
data = {
Expand Down Expand Up @@ -609,7 +640,7 @@ def setup_class(cls):
resource = model.Package.get('annakarenina').resources[0]
cls.data = {
'resource_id': resource.id,
'alias': 'books2',
'aliases': 'books2',
'fields': [{'id': 'book', 'type': 'text'},
{'id': 'author', 'type': 'text'}],
'records': [{'book': 'annakarenina', 'author': 'tolstoy'},
Expand Down Expand Up @@ -647,7 +678,7 @@ def test_delete_basic(self):
c = model.Session.connection()

# alias should be deleted
results = c.execute("select 1 from pg_views where viewname = '{}'".format(self.data['alias']))
results = c.execute("select 1 from pg_views where viewname = '{}'".format(self.data['aliases']))
assert results.rowcount == 0

try:
Expand Down Expand Up @@ -739,7 +770,7 @@ def setup_class(cls):
resource = model.Package.get('annakarenina').resources[0]
cls.data = {
'resource_id': resource.id,
'alias': 'books3',
'aliases': 'books3',
'fields': [{'id': u'b\xfck', 'type': 'text'},
{'id': 'author', 'type': 'text'},
{'id': 'published'}],
Expand Down Expand Up @@ -785,13 +816,16 @@ def test_search_basic(self):
assert result['total'] == len(self.data['records'])
assert result['records'] == self.expected_records

data = {'resource_id': self.data['alias']}
def test_search_alias(self):
data = {'resource_id': self.data['aliases']}
postparams = '%s=1' % json.dumps(data)
auth = {'Authorization': str(self.sysadmin_user.apikey)}
res = self.app.post('/api/action/datastore_search', params=postparams,
extra_environ=auth)
res_dict_alias = json.loads(res.body)
assert res_dict_alias['result']['records'] == res_dict['result']['records']
result = res_dict_alias['result']
assert result['total'] == len(self.data['records'])
assert result['records'] == self.expected_records

def test_search_invalid_field(self):
data = {'resource_id': self.data['resource_id'],
Expand Down Expand Up @@ -1054,7 +1088,7 @@ def setup_class(cls):
resource = model.Package.get('annakarenina').resources[0]
cls.data = {
'resource_id': resource.id,
'alias': 'books4',
'aliases': 'books4',
'fields': [{'id': u'b\xfck', 'type': 'text'},
{'id': 'author', 'type': 'text'},
{'id': 'published'}],
Expand Down Expand Up @@ -1099,14 +1133,14 @@ def test_is_single_statement(self):
"select 'foo'||chr(59)||'bar'"]

for single in singles:
assert db.is_single_statement(single) is True
assert db._is_single_statement(single) is True

multiples = ['SELECT * FROM abc; SET LOCAL statement_timeout to'
'SET LOCAL statement_timeout to; SELECT * FROM abc',
'SELECT * FROM "foo"; SELECT * FROM "abc"']

for multiple in multiples:
assert db.is_single_statement(multiple) is False
assert db._is_single_statement(multiple) is False

def test_select_basic(self):
query = 'SELECT * FROM public."{}"'.format(self.data['resource_id'])
Expand All @@ -1121,7 +1155,7 @@ def test_select_basic(self):
assert result['records'] == self.expected_records

# test alias search
query = 'SELECT * FROM public."{}"'.format(self.data['alias'])
query = 'SELECT * FROM public."{}"'.format(self.data['aliases'])
data = {'sql': query}
postparams = json.dumps(data)
res = self.app.post('/api/action/datastore_search_sql', params=postparams,
Expand Down

0 comments on commit d15bd8d

Please sign in to comment.