diff --git a/datajoint/connection.py b/datajoint/connection.py index ab00e3ab7..6ece857e3 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -4,7 +4,7 @@ """ from contextlib import contextmanager -import pymysql as connector +import pymysql as client import logging from . import config from . import DataJointError @@ -59,7 +59,7 @@ def __init__(self, host, user, passwd, init_fun=None): else: port = config['database.port'] self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) - self._conn = connector.connect(init_command=init_fun, **self.conn_info) + self._conn = client.connect(init_command=init_fun, **self.conn_info) if self.is_connected: logger.info("Connected {user}@{host}:{port}".format(**self.conn_info)) else: @@ -96,15 +96,15 @@ def query(self, query, args=(), as_dict=False): Execute the specified query and return the tuple generator (cursor). :param query: mysql query - :param args: additional arguments for the connector.cursor + :param args: additional arguments for the client.cursor :param as_dict: If as_dict is set to True, the returned cursor objects returns query results as dictionary. """ - cursor = connector.cursors.DictCursor if as_dict else connector.cursors.Cursor + cursor = client.cursors.DictCursor if as_dict else client.cursors.Cursor cur = self._conn.cursor(cursor=cursor) # Log the query - logger.debug("Executing SQL:" + query) + logger.debug("Executing SQL:" + query[0:300]) cur.execute(query, args) return cur diff --git a/datajoint/relation.py b/datajoint/relation.py index 1129de7f0..25f354607 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -2,6 +2,7 @@ import numpy as np import logging import abc +import binascii from . import config from . import DataJointError @@ -181,24 +182,41 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False >>> relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01")) """ + heading = self.heading - if isinstance(tup, np.void): # np.array insert - for field in tup.dtype.fields: - if field not in heading: - raise KeyError(u'{0:s} is not in the attribute list'.format(field)) - values = ['%s' if heading[name].is_blob else tup[name] for name in heading if name in tup.dtype.fields] - attributes = [name for name in heading if name in tup.dtype.fields] - args = tuple(pack(tup[name]) for name in heading - if name in tup.dtype.fields and heading[name].is_blob) - elif isinstance(tup, Mapping): # dict-based insert - for field in tup.keys(): + def check_fields(fields): + for field in fields: if field not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(field)) - values = ['%s' if heading[name].is_blob else tup[name] for name in heading if name in tup] - attributes = [name for name in heading if name in tup] - args = tuple(pack(tup[name]) for name in heading - if name in tup and heading[name].is_blob) + + def make_attribute(name, value): + """ + For a given attribute, return its value or value placeholder as a string to be included + in the query and the value, if any to be submitted for processing by mysql API. + """ + if heading[name].is_blob: + value = pack(value) + # This is a temporary hack to address issue #131 (slow blob inserts). + # When this problem is fixed by pymysql or python, then pass blob as query argument. + placeholder = '0x' + binascii.b2a_hex(value).decode('ascii') + value = None + elif heading[name].numeric: + if np.isnan(value): + name = None # omit nans + placeholder = '%s' + value = repr(int(value) if isinstance(value, bool) else value) + else: + placeholder = '%s' + return name, placeholder, value + + if isinstance(tup, np.void): # np.array insert + check_fields(tup.dtype.fields) + attributes = [make_attribute(name, tup[name]) + for name in heading if name in tup.dtype.fields] + elif isinstance(tup, Mapping): # dict-based insert + check_fields(tup.keys()) + attributes = [make_attribute(name, tup[name]) for name in heading if name in tup] else: # positional insert try: if len(tup) != len(heading): @@ -209,14 +227,11 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False except TypeError: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) else: - values = ['%s' if heading[name].is_blob else value for name, value in zip(heading, tup)] - attributes = heading.names - args = tuple(pack(value) for name, value in zip(heading, tup) if heading[name].is_blob) - - value_list = ','.join(map(lambda elem: repr(elem) if elem != '%s' else elem , values)) - attribute_list = '`' + '`,`'.join(attributes) + '`' - - skip = skip_duplicates and (self & {a: v for a, v in zip(attributes, values) if heading[a].in_key}) + attributes = [make_attribute(name, value) for name, value in zip(heading, tup)] + if not attributes: + raise DataJointError('Empty tuple') + skip = skip_duplicates and ( + self & {name: value for name, _, value in attributes if heading[name].in_key}) if not skip: if replace: sql = 'REPLACE' @@ -224,10 +239,11 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list) - logger.info(sql) - self.connection.query(sql, args=args) - + attributes = (a for a in attributes if a[0]) # omit dropped attributes + names, placeholders, values = tuple(zip(*attributes)) + sql += " INTO %s (`%s`) VALUES (%s)" % ( + self.from_clause, '`,`'.join(names), ','.join(placeholders)) + self.connection.query(sql, args=tuple(v for v in values if v is not None)) def delete_quick(self): """ diff --git a/tests/schema.py b/tests/schema.py index f5f680338..587c69e1c 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -20,9 +20,7 @@ class Auto(dj.Lookup): contents = ( dict(name="Godel"), dict(name="Escher"), - dict(name="Bach") - ) - + dict(name="Bach")) @schema class User(dj.Lookup): diff --git a/tests/test_nan.py b/tests/test_nan.py new file mode 100644 index 000000000..8eb0a3d55 --- /dev/null +++ b/tests/test_nan.py @@ -0,0 +1,27 @@ +import numpy as np +from nose.tools import assert_true, assert_false, assert_equal, assert_list_equal +import datajoint as dj +from . import PREFIX, CONN_INFO + + +schema = dj.schema(PREFIX + '_nantest', locals(), connection=dj.conn(**CONN_INFO)) + + +@schema +class NanTest(dj.Manual): + definition = """ + id :int + --- + value=null :double + """ + + +def test_insert_nan(): + rel = NanTest() + a = np.array([1, 2, np.nan, np.pi, np.nan]) + rel.insert(((i, value) for i, value in enumerate(a))) + b = rel.fetch.order_by('id')['value'] + assert_true((np.isnan(a) == np.isnan(b)).all(), + 'incorrect handling of Nans') + assert_true(np.allclose(a[np.logical_not(np.isnan(a))], b[np.logical_not(np.isnan(b))]), + 'incorrect storage of floats')