Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 35 additions & 19 deletions datajoint/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,14 @@ def insert(self, rows, **kwargs):
for row in rows:
self.insert1(row, **kwargs)

def insert1(self, tup, replace=False, ignore_errors=False):
def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False):
"""
Insert one data record or one Mapping (like a dict).

:param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values.
:param replace=False: Replaces data tuple if True.
:param ignore_errors=False: If True, ignore errors: e.g. constraint violations or duplicates
:param ignore_errors=False: If True, ignore errors: e.g. constraint violations.
:param skip_dublicates=False: If True, ignore duplicate inserts.

Example::
relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"))
Expand All @@ -157,22 +158,23 @@ def insert1(self, tup, replace=False, ignore_errors=False):
for fieldname in tup.dtype.fields:
if fieldname not in heading:
raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname))
value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s'
for name in heading if name in tup.dtype.fields])
values = [tup[name] if not heading[name].is_blob else '%s'
for name in heading if name in tup.dtype.fields]
attributes = [q for q in heading if q in tup.dtype.fields]

args = tuple(pack(tup[name]) for name in heading
if name in tup.dtype.fields and heading[name].is_blob)
attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`'

elif isinstance(tup, Mapping): # dict-based insert
for fieldname in tup.keys():
if fieldname not in heading:
raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname))
value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s'
for name in heading if name in tup)
values = [tup[name] if not heading[name].is_blob else '%s'
for name in heading if name in tup]
attributes = [name for name in heading if name in tup]

args = tuple(pack(tup[name]) for name in heading
if name in tup and heading[name].is_blob)
attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`'

else: # positional insert
try:
if len(tup) != len(self.heading):
Expand All @@ -182,18 +184,32 @@ def insert1(self, tup, replace=False, ignore_errors=False):
raise DataJointError('Datatype %s cannot be inserted' % type(tup))
else:
pairs = zip(heading, tup)
value_list = ','.join('%s' if heading[name].is_blob else repr(value) for name, value in pairs)
attribute_list = '`' + '`,`'.join(heading.names) + '`'
values = ['%s' if heading[name].is_blob else value for name, value in pairs]
attributes = heading.names

args = tuple(pack(value) for name, value in pairs if heading[name].is_blob)
if replace:
sql = 'REPLACE'
elif ignore_errors:
sql = 'INSERT IGNORE'

value_list = ','.join(map(lambda elem: repr(elem) if elem != '%s' else elem , values))
attribute_list = '`' + '`,`'.join(attributes) + '`'


if skip_duplicates:
key = {a:v for a,v in zip(attributes, values) if heading[a].in_key}
not_in_table = len(self & key) == 0
else:
sql = 'INSERT'
sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list)
logger.info(sql)
self.connection.query(sql, args=args)
not_in_table = True

if not_in_table:
if replace:
sql = 'REPLACE'
elif ignore_errors:
sql = 'INSERT IGNORE'
else:
sql = 'INSERT'

sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list)
logger.info(sql)
self.connection.query(sql, args=args)

def delete_quick(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion datajoint/user_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _prepare(self):
Checks whether the instance has a property called `contents` and inserts its elements.
"""
if hasattr(self, 'contents'):
self.insert(self.contents, ignore_errors=True)
self.insert(self.contents, ignore_errors=False, skip_duplicates=True)


class Imported(Relation, AutoPopulate):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
assert_tuple_equal, assert_dict_equal, raises

from . import schema
from pymysql import IntegrityError


class TestRelation:
Expand Down Expand Up @@ -38,6 +39,7 @@ def test_contents(self):
assert_list_equal(list(u['subject_id']), sorted([s[0] for s in self.subject.contents]))

def test_delete_quick(self):
"""Tests quick deletion"""
tmp = np.array([
(2, 'Klara', 'monkey', '2010-01-01', ''),
(1, 'Peter', 'mouse', '2015-01-01', '')],
Expand All @@ -47,3 +49,22 @@ def test_delete_quick(self):
assert_true(len(s) == 2, 'insert did not work.')
s.delete_quick()
assert_true(len(s) == 0, 'delete did not work.')

def test_skip_duplicate(self):
"""Tests if duplicates are properly skipped."""
tmp = np.array([
(2, 'Klara', 'monkey', '2010-01-01', ''),
(2, 'Klara', 'monkey', '2010-01-01', ''),
(1, 'Peter', 'mouse', '2015-01-01', '')],
dtype=self.subject.heading.as_dtype)
self.subject.insert(tmp, skip_duplicates=True)

@raises(IntegrityError)
def test_not_skip_duplicate(self):
"""Tests if duplicates are not skipped."""
tmp = np.array([
(2, 'Klara', 'monkey', '2010-01-01', ''),
(2, 'Klara', 'monkey', '2010-01-01', ''),
(1, 'Peter', 'mouse', '2015-01-01', '')],
dtype=self.subject.heading.as_dtype)
self.subject.insert(tmp, skip_duplicates=False)