Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def iter_insert(self, iter, **kwargs):
"""
Inserts an entire batch of entries. Additional keyword arguments are passed it insert.

:param iter: Must be an iterator that generates a sequence of valid arguments for insert.
:param iter: Must be an iterator that generates a sequence of valid arguments for insert.
"""
for row in iter:
self.insert(row, **kwargs)
Expand Down Expand Up @@ -113,15 +113,23 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress
"""

if isinstance(tup, tuple) or isinstance(tup, list) or isinstance(tup, np.ndarray):
value_list = ','.join([repr(q) for q in tup])
value_list = ','.join([repr(val) if not name in self.heading.blobs else '%s'
for name, val in zip(self.heading.names, tup)])
args = tuple(pack(val) for name, val in zip(self.heading.names, tup) if name in self.heading.blobs)
attribute_list = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`'

elif isinstance(tup, dict):
value_list = ','.join([repr(tup[q])
for q in self.heading.names if q in tup])
attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`'
value_list = ','.join([repr(tup[name]) if not name in self.heading.blobs else '%s'
for name in self.heading.names if name in tup])
args = tuple(pack(tup[name]) for name in self.heading.names
if (name in tup and name in self.heading.blobs) )
attribute_list = '`' + '`,`'.join([name for name in self.heading.names if name in tup]) + '`'
elif isinstance(tup, np.void):
value_list = ','.join([repr(tup[q])
for q in self.heading.names if q in tup.dtype.fields])
value_list = ','.join([repr(tup[name]) if not name in self.heading.blobs else '%s'
for name in self.heading.names if name in tup.dtype.fields])

args = tuple(pack(tup[name]) for name in self.heading.names
if (name in tup.dtype.fields and name in self.heading.blobs) )
attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup.dtype.fields]) + '`'
else:
raise DataJointError('Datatype %s cannot be inserted' % type(tup))
Expand All @@ -133,9 +141,8 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress
sql = 'INSERT'
sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name,
attribute_list, value_list)

logger.info(sql)
self.conn.query(sql)
self.conn.query(sql, args=args)

def delete(self): # TODO: (issues #14 and #15)
pass
Expand Down
17 changes: 17 additions & 0 deletions tests/schemata/schema1/test4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Test 1 Schema definition - fully bound and has connection object
"""
__author__ = 'fabee'

import datajoint as dj


class Matrix(dj.Base):
definition = """
test4.Matrix (manual) # Some numpy array

matrix_id : int # unique matrix id
---
data : longblob # data
comment : varchar(1000) # comment
"""
10 changes: 10 additions & 0 deletions tests/test_blob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__author__ = 'fabee'
import numpy as np
from datajoint.blob import pack, unpack
from numpy.testing import assert_array_equal


def test_pack():
x = np.random.randn(10, 10)

assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
23 changes: 20 additions & 3 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
__author__ = 'fabee'

from .schemata.schema1 import test1
from .schemata.schema1 import test1, test4

from . import BASE_CONN, CONN_INFO, PREFIX, cleanup
from datajoint.connection import Connection
from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal
from datajoint import DataJointError
import numpy as np

from numpy.testing import assert_array_equal

def setup():
"""
Expand All @@ -34,8 +34,11 @@ def setup(self):
cleanup() # drop all databases with PREFIX
self.conn = Connection(**CONN_INFO)
test1.conn = self.conn
test4.conn = self.conn
self.conn.bind(test1.__name__, PREFIX + '_test1')
self.conn.bind(test4.__name__, PREFIX + '_test4')
self.relvar = test1.Subjects()
self.relvar_blob = test4.Matrix()

def teardown(self):
cleanup()
Expand All @@ -47,6 +50,13 @@ def test_tuple_insert(self):
testt2 = tuple((self.relvar & 'subject_id = 1').fetch()[0])
assert_equal(testt2, testt, "Inserted and fetched tuple do not match!")

def test_list_insert(self):
"Test whether tuple insert works"
testt = [1, 'Peter', 'mouse']
self.relvar.insert(testt)
testt2 = list((self.relvar & 'subject_id = 1').fetch()[0])
assert_equal(testt2, testt, "Inserted and fetched tuple do not match!")

def test_record_insert(self):
"Test whether record insert works"
tmp = np.array([(2, 'Klara', 'monkey')],
Expand Down Expand Up @@ -104,4 +114,11 @@ def test_iter_insert(self):
delivered = self.relvar.fetch()

for e,d in zip(expected, delivered):
assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match')
assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match')

def test_blob_insert(self):
x = np.random.randn(10)
t = (0, x, 'this is a random image')
self.relvar_blob.insert(t)
x2 = self.relvar_blob.fetch()[0][1]
assert_array_equal(x,x2, 'inserted blob does not match')