From a13831ffb35c8095e9199494b92c2513623d3336 Mon Sep 17 00:00:00 2001 From: Fabian Sinz Date: Wed, 26 Aug 2015 22:48:57 -0500 Subject: [PATCH] skip_duplicates in insert1 --- datajoint/relation.py | 54 ++++++++++++++++++++++++------------- datajoint/user_relations.py | 2 +- tests/test_relation.py | 21 +++++++++++++++ 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index c93826c3a..1dfc67bf8 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -140,13 +140,14 @@ def insert(self, rows, **kwargs): for row in rows: self.insert1(row, **kwargs) - def insert1(self, tup, replace=False, ignore_errors=False): + def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False): """ Insert one data record or one Mapping (like a dict). :param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values. :param replace=False: Replaces data tuple if True. - :param ignore_errors=False: If True, ignore errors: e.g. constraint violations or duplicates + :param ignore_errors=False: If True, ignore errors: e.g. constraint violations. + :param skip_dublicates=False: If True, ignore duplicate inserts. Example:: relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01")) @@ -157,22 +158,23 @@ def insert1(self, tup, replace=False, ignore_errors=False): for fieldname in tup.dtype.fields: if fieldname not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) - value_list = ','.join([repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup.dtype.fields]) + values = [tup[name] if not heading[name].is_blob else '%s' + for name in heading if name in tup.dtype.fields] + attributes = [q for q in heading if q in tup.dtype.fields] + args = tuple(pack(tup[name]) for name in heading if name in tup.dtype.fields and heading[name].is_blob) - attribute_list = '`' + '`,`'.join(q for q in heading if q in tup.dtype.fields) + '`' elif isinstance(tup, Mapping): # dict-based insert for fieldname in tup.keys(): if fieldname not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(fieldname)) - value_list = ','.join(repr(tup[name]) if not heading[name].is_blob else '%s' - for name in heading if name in tup) + values = [tup[name] if not heading[name].is_blob else '%s' + for name in heading if name in tup] + attributes = [name for name in heading if name in tup] + args = tuple(pack(tup[name]) for name in heading if name in tup and heading[name].is_blob) - attribute_list = '`' + '`,`'.join(name for name in heading if name in tup) + '`' - else: # positional insert try: if len(tup) != len(self.heading): @@ -182,18 +184,32 @@ def insert1(self, tup, replace=False, ignore_errors=False): raise DataJointError('Datatype %s cannot be inserted' % type(tup)) else: pairs = zip(heading, tup) - value_list = ','.join('%s' if heading[name].is_blob else repr(value) for name, value in pairs) - attribute_list = '`' + '`,`'.join(heading.names) + '`' + values = ['%s' if heading[name].is_blob else value for name, value in pairs] + attributes = heading.names + args = tuple(pack(value) for name, value in pairs if heading[name].is_blob) - if replace: - sql = 'REPLACE' - elif ignore_errors: - sql = 'INSERT IGNORE' + + value_list = ','.join(map(lambda elem: repr(elem) if elem != '%s' else elem , values)) + attribute_list = '`' + '`,`'.join(attributes) + '`' + + + if skip_duplicates: + key = {a:v for a,v in zip(attributes, values) if heading[a].in_key} + not_in_table = len(self & key) == 0 else: - sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list) - logger.info(sql) - self.connection.query(sql, args=args) + not_in_table = True + + if not_in_table: + if replace: + sql = 'REPLACE' + elif ignore_errors: + sql = 'INSERT IGNORE' + else: + sql = 'INSERT' + + sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list) + logger.info(sql) + self.connection.query(sql, args=args) def delete_quick(self): """ diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 2e140f62f..49480fa53 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -54,7 +54,7 @@ def _prepare(self): Checks whether the instance has a property called `contents` and inserts its elements. """ if hasattr(self, 'contents'): - self.insert(self.contents, ignore_errors=True) + self.insert(self.contents, ignore_errors=False, skip_duplicates=True) class Imported(Relation, AutoPopulate): diff --git a/tests/test_relation.py b/tests/test_relation.py index a2de3749b..decac5e3d 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -5,6 +5,7 @@ assert_tuple_equal, assert_dict_equal, raises from . import schema +from pymysql import IntegrityError class TestRelation: @@ -38,6 +39,7 @@ def test_contents(self): assert_list_equal(list(u['subject_id']), sorted([s[0] for s in self.subject.contents])) def test_delete_quick(self): + """Tests quick deletion""" tmp = np.array([ (2, 'Klara', 'monkey', '2010-01-01', ''), (1, 'Peter', 'mouse', '2015-01-01', '')], @@ -47,3 +49,22 @@ def test_delete_quick(self): assert_true(len(s) == 2, 'insert did not work.') s.delete_quick() assert_true(len(s) == 0, 'delete did not work.') + + def test_skip_duplicate(self): + """Tests if duplicates are properly skipped.""" + tmp = np.array([ + (2, 'Klara', 'monkey', '2010-01-01', ''), + (2, 'Klara', 'monkey', '2010-01-01', ''), + (1, 'Peter', 'mouse', '2015-01-01', '')], + dtype=self.subject.heading.as_dtype) + self.subject.insert(tmp, skip_duplicates=True) + + @raises(IntegrityError) + def test_not_skip_duplicate(self): + """Tests if duplicates are not skipped.""" + tmp = np.array([ + (2, 'Klara', 'monkey', '2010-01-01', ''), + (2, 'Klara', 'monkey', '2010-01-01', ''), + (1, 'Peter', 'mouse', '2015-01-01', '')], + dtype=self.subject.heading.as_dtype) + self.subject.insert(tmp, skip_duplicates=False)