diff --git a/ckanext/datastore/db.py b/ckanext/datastore/db.py index 691f01b3d07..5a70e626f91 100644 --- a/ckanext/datastore/db.py +++ b/ckanext/datastore/db.py @@ -1,5 +1,6 @@ import sqlalchemy from sqlalchemy.exc import ProgrammingError, IntegrityError +from sqlalchemy import text import ckan.plugins as p import psycopg2.extras import json @@ -299,7 +300,7 @@ def create_indexes(context, data_dict): def _drop_indexes(context, data_dict, unique=False): sql_drop_index = u'drop index "{0}" cascade' - sql_get_index_string = """ + sql_get_index_string = u""" select i.relname as index_name from @@ -312,12 +313,12 @@ def _drop_indexes(context, data_dict, unique=False): and t.relkind = 'r' and idx.indisunique = {unique} and idx.indisprimary = false - and t.relname = '{res_id}' + and t.relname = %s """ sql_stmt = sql_get_index_string.format( - res_id=data_dict['resource_id'], unique='true' if unique else 'false') - indexes_to_drop = context['connection'].execute(sql_stmt).fetchall() + indexes_to_drop = context['connection'].execute( + sql_stmt, data_dict['resource_id']).fetchall() for index in indexes_to_drop: context['connection'].execute(sql_drop_index.format(index[0])) @@ -397,65 +398,72 @@ def upsert_data(context, data_dict): field_names = _pluck('id', fields) records = data_dict['records'] sql_columns = ", ".join(['"%s"' % name for name in field_names] - + ['_full_text']) + + ['"_full_text"']) if method in [UPDATE, UPSERT]: - key_parts = _get_unique_key(context, data_dict) - - rows = [] - ## clean up and validate data - - for num, record in enumerate(records): - _validate_record(record, num, field_names) - - full_text = [] - row = [] - for field in fields: - value = record.get(field['id']) - if field['type'].lower() == '_json' and value: - full_text.extend(json_get_values(value)) - ## a tuple with an empty second value - value = (json.dumps(value), '') - elif field['type'].lower() == 'text' and value: - full_text.append(value) - row.append(value) - - row.append(u' '.join(full_text)) + unique_keys = _get_unique_key(context, data_dict) + if len(unique_keys) < 1: + raise p.toolkit.ValidationError({ + 'table': [u'table does not have a key defined'] + }) - if method == UPDATE: - # all key columns have to be defined - missing_columns = [field for field in fields - if record.get(field['id']) == None] - if missing_columns: - raise p.toolkit.ValidationError({ - 'key': [u'rows {0} are missing but needed as key'.format( - ','.join(missing_columns))] - }) - keys = [records[key] for key in key_parts] - row += keys + if method == INSERT: + rows = [] + for num, record in enumerate(records): + _validate_record(record, num, field_names) - rows.append(row) + row = [] + for field in fields: + value = record.get(field['id']) + if field['type'].lower() == '_json' and value: + ## a tuple with an empty second value + value = (json.dumps(value), '') + row.append(value) + row.append(_to_full_text(fields, record)) + rows.append(row) - if method == INSERT: sql_string = u'insert into "{res_id}" ({columns}) values ({values}, to_tsvector(%s));'.format( res_id=data_dict['resource_id'], columns=sql_columns, values=', '.join(['%s' for field in field_names]) ) + context['connection'].execute(sql_string, rows) elif method == UPDATE: - sql_string = u''' - update {table} - set ({columns}) = ({values}, to_tsvector(%s)) - where {primary_key} = {primary_value}; - '''.format( - res_id=data_dict['resource_id'], - columns=sql_columns, - values=', '.join(['%s' for field in field_names]), - primary_key='({0})'.format(','.join(['"%s"' % part for part in key_parts])), - primary_value='({0})'.format(','.join(["'%s'"] * len(key_parts))) - ) - context['connection'].execute(sql_string, rows) + for num, record in enumerate(records): + # all key columns have to be defined + missing_fields = [field for field in unique_keys + if field not in record] + if missing_fields: + raise p.toolkit.ValidationError({ + 'key': [u'rows "{0}" are missing but needed as key'.format( + ', '.join(missing_fields))] + }) + unique_values = [record[key] for key in unique_keys] + + used_field_names = record.keys() + used_values = [record[field] for field in used_field_names] + full_text = _to_full_text(fields, record) + + sql_string = u''' + update "{res_id}" + set ({columns}, "_full_text") = ({values}, to_tsvector(%s)) + where ({primary_key}) = ({primary_value}); + '''.format( + res_id=data_dict['resource_id'], + columns=u', '.join([u'"{0}"'.format(field) for field in used_field_names]), + values=u', '.join(['%s' for _ in used_field_names]), + primary_key=u','.join([u'"{}"'.format(part) for part in unique_keys]), + primary_value=u','.join(["%s"] * len(unique_keys)) + ) + results = context['connection'].execute( + sql_string, used_values + [full_text] + unique_values) + + # validate that exactly one row has been updated + if results.rowcount != 1: + raise p.toolkit.ValidationError({ + 'key': [u'key "{0}" not found'.format(unique_values)] + }) elif method == UPSERT: # TODO @@ -477,7 +485,7 @@ def _get_unique_key(context, data_dict): and t.relkind = 'r' and idx.indisunique = true and idx.indisprimary = false - and t.relname = '%s' + and t.relname = %s ''' key_parts = context['connection'].execute(sql_get_unique_key, data_dict['resource_id']) return [x[0] for x in key_parts] @@ -501,6 +509,17 @@ def _validate_record(record, num, field_names): }) +def _to_full_text(fields, record): + full_text = [] + for field in fields: + value = record.get(field['id']) + if field['type'].lower() == '_json' and value: + full_text.extend(json_get_values(value)) + elif field['type'].lower() == 'text' and value: + full_text.append(value) + return ' '.join(full_text) + + def _where(field_ids, data_dict): 'Return a SQL WHERE clause from data_dict filters and q' filters = data_dict.get('filters', {}) diff --git a/ckanext/datastore/tests/test_datastore.py b/ckanext/datastore/tests/test_datastore.py index a579c37a943..adff54b4129 100644 --- a/ckanext/datastore/tests/test_datastore.py +++ b/ckanext/datastore/tests/test_datastore.py @@ -7,6 +7,7 @@ import ckan.tests as tests import ckanext.datastore.db as db import pprint +import datetime def extract(d, keys): @@ -447,6 +448,7 @@ def setup_class(cls): 'fields': [{'id': u'b\xfck', 'type': 'text'}, {'id': 'author', 'type': 'text'}, {'id': 'published'}], + 'primary_key': u'b\xfck', 'records': [{u'b\xfck': 'annakarenina', 'author': 'tolstoy', 'published': '2005-03-01', 'nested': ['b', {'moo': 'moo'}]}, {u'b\xfck': 'warandpeace', 'author': 'tolstoy', @@ -486,7 +488,7 @@ def test_insert(self): data = { 'resource_id': self.data['resource_id'], 'method': 'insert', - 'records': [{u'b\xfck': 'hagji murat', 'author': 'tolstoy'}] + 'records': [{u'b\xfck': 'hitchhikers guide to the galaxy', 'author': 'tolstoy'}] } postparams = '%s=1' % json.dumps(data) @@ -502,6 +504,97 @@ def test_insert(self): assert results.rowcount == 3 + def test_update(self): + c = model.Session.connection() + results = c.execute('select 1 from "{0}"'.format(self.data['resource_id'])) + assert results.rowcount == 3 + model.Session.remove() + + hhguide = u"hitchhikers guide to the galaxy" + + data = { + 'resource_id': self.data['resource_id'], + 'method': 'update', + 'records': [{u'b\xfck': hhguide, 'author': 'adams'}] + } + + postparams = '%s=1' % json.dumps(data) + auth = {'Authorization': str(self.sysadmin_user.apikey)} + res = self.app.post('/api/action/datastore_upsert', params=postparams, + extra_environ=auth) + res_dict = json.loads(res.body) + + assert res_dict['success'] is True + + c = model.Session.connection() + results = c.execute('select * from "{0}"'.format(self.data['resource_id'])) + assert results.rowcount == 3 + + records = results.fetchall() + assert records[2][u'b\xfck'] == hhguide + assert records[2].author == 'adams' + model.Session.remove() + + c = model.Session.connection() + results = c.execute("select * from \"{0}\" where author='{1}'".format(self.data['resource_id'], 'adams')) + assert results.rowcount == 1 + model.Session.remove() + + #update only the publish date + data = { + 'resource_id': self.data['resource_id'], + 'method': 'update', + 'records': [{u'b\xfck': hhguide, 'published': '1979-1-1'}] + } + + postparams = '%s=1' % json.dumps(data) + auth = {'Authorization': str(self.sysadmin_user.apikey)} + res = self.app.post('/api/action/datastore_upsert', params=postparams, + extra_environ=auth) + res_dict = json.loads(res.body) + + assert res_dict['success'] is True + + c = model.Session.connection() + results = c.execute('select * from "{0}"'.format(self.data['resource_id'])) + assert results.rowcount == 3 + + records = results.fetchall() + assert records[2][u'b\xfck'] == hhguide + assert records[2].author == 'adams' + assert records[2].published == datetime.datetime(1979, 1, 1) + model.Session.remove() + + def test_update_missing_key(self): + data = { + 'resource_id': self.data['resource_id'], + 'method': 'update', + 'records': [{'author': 'tolkien'}] + } + + postparams = '%s=1' % json.dumps(data) + auth = {'Authorization': str(self.sysadmin_user.apikey)} + res = self.app.post('/api/action/datastore_upsert', params=postparams, + extra_environ=auth, status=409) + res_dict = json.loads(res.body) + + assert res_dict['success'] is False + + def test_update_non_existing_key(self): + data = { + 'resource_id': self.data['resource_id'], + 'method': 'update', + 'records': [{u'b\xfck': '', 'author': 'tolkien'}] + } + + postparams = '%s=1' % json.dumps(data) + auth = {'Authorization': str(self.sysadmin_user.apikey)} + res = self.app.post('/api/action/datastore_upsert', params=postparams, + extra_environ=auth, status=409) + res_dict = json.loads(res.body) + + assert res_dict['success'] is False + class TestDatastoreDelete(tests.WsgiAppCase): sysadmin_user = None