diff --git a/datajoint/table.py b/datajoint/table.py index 2dd226f0d..3966c40c3 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -84,7 +84,7 @@ def iter_insert(self, iter, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed it insert. - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ for row in iter: self.insert(row, **kwargs) @@ -113,15 +113,23 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress """ if isinstance(tup, tuple) or isinstance(tup, list) or isinstance(tup, np.ndarray): - value_list = ','.join([repr(q) for q in tup]) + value_list = ','.join([repr(val) if not name in self.heading.blobs else '%s' + for name, val in zip(self.heading.names, tup)]) + args = tuple(pack(val) for name, val in zip(self.heading.names, tup) if name in self.heading.blobs) attribute_list = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`' + elif isinstance(tup, dict): - value_list = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) - attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`' + value_list = ','.join([repr(tup[name]) if not name in self.heading.blobs else '%s' + for name in self.heading.names if name in tup]) + args = tuple(pack(tup[name]) for name in self.heading.names + if (name in tup and name in self.heading.blobs) ) + attribute_list = '`' + '`,`'.join([name for name in self.heading.names if name in tup]) + '`' elif isinstance(tup, np.void): - value_list = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup.dtype.fields]) + value_list = ','.join([repr(tup[name]) if not name in self.heading.blobs else '%s' + for name in self.heading.names if name in tup.dtype.fields]) + + args = tuple(pack(tup[name]) for name in self.heading.names + if (name in tup.dtype.fields and name in self.heading.blobs) ) attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup.dtype.fields]) + '`' else: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) @@ -133,9 +141,8 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress sql = 'INSERT' sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, attribute_list, value_list) - logger.info(sql) - self.conn.query(sql) + self.conn.query(sql, args=args) def delete(self): # TODO: (issues #14 and #15) pass diff --git a/tests/schemata/schema1/test4.py b/tests/schemata/schema1/test4.py new file mode 100644 index 000000000..9e3beace1 --- /dev/null +++ b/tests/schemata/schema1/test4.py @@ -0,0 +1,17 @@ +""" +Test 1 Schema definition - fully bound and has connection object +""" +__author__ = 'fabee' + +import datajoint as dj + + +class Matrix(dj.Base): + definition = """ + test4.Matrix (manual) # Some numpy array + + matrix_id : int # unique matrix id + --- + data : longblob # data + comment : varchar(1000) # comment + """ diff --git a/tests/test_blob.py b/tests/test_blob.py new file mode 100644 index 000000000..0e8441f7e --- /dev/null +++ b/tests/test_blob.py @@ -0,0 +1,10 @@ +__author__ = 'fabee' +import numpy as np +from datajoint.blob import pack, unpack +from numpy.testing import assert_array_equal + + +def test_pack(): + x = np.random.randn(10, 10) + + assert_array_equal(x, unpack(pack(x)), "Arrays do not match!") diff --git a/tests/test_table.py b/tests/test_table.py index 98a1ec3d0..a8cf35133 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1,13 +1,13 @@ __author__ = 'fabee' -from .schemata.schema1 import test1 +from .schemata.schema1 import test1, test4 from . import BASE_CONN, CONN_INFO, PREFIX, cleanup from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal from datajoint import DataJointError import numpy as np - +from numpy.testing import assert_array_equal def setup(): """ @@ -34,8 +34,11 @@ def setup(self): cleanup() # drop all databases with PREFIX self.conn = Connection(**CONN_INFO) test1.conn = self.conn + test4.conn = self.conn self.conn.bind(test1.__name__, PREFIX + '_test1') + self.conn.bind(test4.__name__, PREFIX + '_test4') self.relvar = test1.Subjects() + self.relvar_blob = test4.Matrix() def teardown(self): cleanup() @@ -47,6 +50,13 @@ def test_tuple_insert(self): testt2 = tuple((self.relvar & 'subject_id = 1').fetch()[0]) assert_equal(testt2, testt, "Inserted and fetched tuple do not match!") + def test_list_insert(self): + "Test whether tuple insert works" + testt = [1, 'Peter', 'mouse'] + self.relvar.insert(testt) + testt2 = list((self.relvar & 'subject_id = 1').fetch()[0]) + assert_equal(testt2, testt, "Inserted and fetched tuple do not match!") + def test_record_insert(self): "Test whether record insert works" tmp = np.array([(2, 'Klara', 'monkey')], @@ -104,4 +114,11 @@ def test_iter_insert(self): delivered = self.relvar.fetch() for e,d in zip(expected, delivered): - assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match') \ No newline at end of file + assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match') + + def test_blob_insert(self): + x = np.random.randn(10) + t = (0, x, 'this is a random image') + self.relvar_blob.insert(t) + x2 = self.relvar_blob.fetch()[0][1] + assert_array_equal(x,x2, 'inserted blob does not match') \ No newline at end of file