Skip to content

Commit

Permalink
generic upsert method that supports insert and update but not yet upsert
Browse files Browse the repository at this point in the history
  • Loading branch information
domoritz committed Sep 6, 2012
1 parent 42a2994 commit 6ef1487
Showing 1 changed file with 77 additions and 21 deletions.
98 changes: 77 additions & 21 deletions ckanext/datastore/db.py
Expand Up @@ -20,6 +20,10 @@
'%m-%d-%Y',
]
_pluck = lambda field, arr: [x[field] for x in arr]
INSERT = 'insert'
UPSERT = 'upsert'
UPDATE = 'update'
_methods = [INSERT, UPSERT, UPDATE]


def _get_list(input):
Expand Down Expand Up @@ -370,10 +374,22 @@ def alter_table(context, data_dict):


def insert_data(context, data_dict):
data_dict['method'] = INSERT
return upsert_data(context, data_dict)


def upsert_data(context, data_dict):
'''insert all data from records'''
if not data_dict.get('records'):
return

method = data_dict.get('method', UPSERT)

if method not in _methods:
raise p.toolkit.ValidationError({
'method': [u'Method {} is not defined'.format(method)]
})

fields = _get_fields(context, data_dict)
field_names = _pluck('id', fields)
sql_columns = ", ".join(['"%s"' % name for name in field_names]
Expand All @@ -383,21 +399,7 @@ def insert_data(context, data_dict):
## clean up and validate data

for num, record in enumerate(data_dict['records']):
# check record for sanity
if not isinstance(record, dict):
raise p.toolkit.ValidationError({
'records': [u'row {} is not a json object'.format(num)]
})
## check for extra fields in data
extra_keys = set(record.keys()) - set(field_names)

if extra_keys:
raise p.toolkit.ValidationError({
'records': [u'row {} has extra keys "{}"'.format(
num + 1,
', '.join(list(extra_keys))
)]
})
_validate_record(record, num, field_names)

full_text = []
row = []
Expand All @@ -414,13 +416,49 @@ def insert_data(context, data_dict):
row.append(' '.join(full_text))
rows.append(row)

sql_string = u'insert into "{0}" ({1}) values ({2}, to_tsvector(%s));'.format(
data_dict['resource_id'],
sql_columns,
', '.join(['%s' for field in field_names])
)
if method == INSERT:
sql_string = u'insert into "{res_id}" ({columns}) values ({values}, to_tsvector(%s));'.format(
res_id=data_dict['resource_id'],
columns=sql_columns,
values=', '.join(['%s' for field in field_names])
)
context['connection'].execute(sql_string, rows)

elif method == UPDATE:
sql_string = u'''
update {table}
set ({columns}) = ({values}, to_tsvector(%s))
where {primary_key} = {primary_value}
'''.format(
res_id=data_dict['resource_id'],
columns=sql_columns,
values=', '.join(['%s' for field in field_names]),
primary_key='',
primary_value=''
)
context['connection'].execute(sql_string, rows)

elif method == UPSERT:
# TODO
pass

context['connection'].execute(sql_string, rows)

def _validate_record(record, num, field_names):
# check record for sanity
if not isinstance(record, dict):
raise p.toolkit.ValidationError({
'records': [u'row {} is not a json object'.format(num)]
})
## check for extra fields in data
extra_keys = set(record.keys()) - set(field_names)

if extra_keys:
raise p.toolkit.ValidationError({
'records': [u'row {} has extra keys "{}"'.format(
num + 1,
', '.join(list(extra_keys))
)]
})


def _where(field_ids, data_dict):
Expand Down Expand Up @@ -666,6 +704,24 @@ def create(context, data_dict):
context['connection'].close()


def upsert(context, data_dict):
'''
This method combines upsert insert and update on the datastore. The method
that will be used is defined in the mehtod variable.
Any error results in total failure! For now pass back the actual error.
Should be transactional.
'''
engine = _get_engine(context, data_dict)
context['connection'] = engine.connect()

# check if table already existes
trans = context['connection'].begin()
upsert_data(context, data_dict)
trans.commit()
return data_dict


def delete(context, data_dict):
engine = _get_engine(context, data_dict)
context['connection'] = engine.connect()
Expand Down

0 comments on commit 6ef1487

Please sign in to comment.