Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from contextlib import contextmanager
import pymysql as connector
import pymysql as client
import logging
from . import config
from . import DataJointError
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, host, user, passwd, init_fun=None):
else:
port = config['database.port']
self.conn_info = dict(host=host, port=port, user=user, passwd=passwd)
self._conn = connector.connect(init_command=init_fun, **self.conn_info)
self._conn = client.connect(init_command=init_fun, **self.conn_info)
if self.is_connected:
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
else:
Expand Down Expand Up @@ -96,15 +96,15 @@ def query(self, query, args=(), as_dict=False):
Execute the specified query and return the tuple generator (cursor).

:param query: mysql query
:param args: additional arguments for the connector.cursor
:param args: additional arguments for the client.cursor
:param as_dict: If as_dict is set to True, the returned cursor objects returns
query results as dictionary.
"""
cursor = connector.cursors.DictCursor if as_dict else connector.cursors.Cursor
cursor = client.cursors.DictCursor if as_dict else client.cursors.Cursor
cur = self._conn.cursor(cursor=cursor)

# Log the query
logger.debug("Executing SQL:" + query)
logger.debug("Executing SQL:" + query[0:300])
cur.execute(query, args)
return cur

Expand Down
68 changes: 42 additions & 26 deletions datajoint/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import logging
import abc
import binascii

from . import config
from . import DataJointError
Expand Down Expand Up @@ -181,24 +182,41 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False
>>> relation.insert1(dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"))

"""

heading = self.heading

if isinstance(tup, np.void): # np.array insert
for field in tup.dtype.fields:
if field not in heading:
raise KeyError(u'{0:s} is not in the attribute list'.format(field))
values = ['%s' if heading[name].is_blob else tup[name] for name in heading if name in tup.dtype.fields]
attributes = [name for name in heading if name in tup.dtype.fields]
args = tuple(pack(tup[name]) for name in heading
if name in tup.dtype.fields and heading[name].is_blob)
elif isinstance(tup, Mapping): # dict-based insert
for field in tup.keys():
def check_fields(fields):
for field in fields:
if field not in heading:
raise KeyError(u'{0:s} is not in the attribute list'.format(field))
values = ['%s' if heading[name].is_blob else tup[name] for name in heading if name in tup]
attributes = [name for name in heading if name in tup]
args = tuple(pack(tup[name]) for name in heading
if name in tup and heading[name].is_blob)

def make_attribute(name, value):
"""
For a given attribute, return its value or value placeholder as a string to be included
in the query and the value, if any to be submitted for processing by mysql API.
"""
if heading[name].is_blob:
value = pack(value)
# This is a temporary hack to address issue #131 (slow blob inserts).
# When this problem is fixed by pymysql or python, then pass blob as query argument.
placeholder = '0x' + binascii.b2a_hex(value).decode('ascii')
value = None
elif heading[name].numeric:
if np.isnan(value):
name = None # omit nans
placeholder = '%s'
value = repr(int(value) if isinstance(value, bool) else value)
else:
placeholder = '%s'
return name, placeholder, value

if isinstance(tup, np.void): # np.array insert
check_fields(tup.dtype.fields)
attributes = [make_attribute(name, tup[name])
for name in heading if name in tup.dtype.fields]
elif isinstance(tup, Mapping): # dict-based insert
check_fields(tup.keys())
attributes = [make_attribute(name, tup[name]) for name in heading if name in tup]
else: # positional insert
try:
if len(tup) != len(heading):
Expand All @@ -209,25 +227,23 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False
except TypeError:
raise DataJointError('Datatype %s cannot be inserted' % type(tup))
else:
values = ['%s' if heading[name].is_blob else value for name, value in zip(heading, tup)]
attributes = heading.names
args = tuple(pack(value) for name, value in zip(heading, tup) if heading[name].is_blob)

value_list = ','.join(map(lambda elem: repr(elem) if elem != '%s' else elem , values))
attribute_list = '`' + '`,`'.join(attributes) + '`'

skip = skip_duplicates and (self & {a: v for a, v in zip(attributes, values) if heading[a].in_key})
attributes = [make_attribute(name, value) for name, value in zip(heading, tup)]
if not attributes:
raise DataJointError('Empty tuple')
skip = skip_duplicates and (
self & {name: value for name, _, value in attributes if heading[name].in_key})
if not skip:
if replace:
sql = 'REPLACE'
elif ignore_errors:
sql = 'INSERT IGNORE'
else:
sql = 'INSERT'
sql += " INTO %s (%s) VALUES (%s)" % (self.from_clause, attribute_list, value_list)
logger.info(sql)
self.connection.query(sql, args=args)

attributes = (a for a in attributes if a[0]) # omit dropped attributes
names, placeholders, values = tuple(zip(*attributes))
sql += " INTO %s (`%s`) VALUES (%s)" % (
self.from_clause, '`,`'.join(names), ','.join(placeholders))
self.connection.query(sql, args=tuple(v for v in values if v is not None))

def delete_quick(self):
"""
Expand Down
4 changes: 1 addition & 3 deletions tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ class Auto(dj.Lookup):
contents = (
dict(name="Godel"),
dict(name="Escher"),
dict(name="Bach")
)

dict(name="Bach"))

@schema
class User(dj.Lookup):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from nose.tools import assert_true, assert_false, assert_equal, assert_list_equal
import datajoint as dj
from . import PREFIX, CONN_INFO


schema = dj.schema(PREFIX + '_nantest', locals(), connection=dj.conn(**CONN_INFO))


@schema
class NanTest(dj.Manual):
definition = """
id :int
---
value=null :double
"""


def test_insert_nan():
rel = NanTest()
a = np.array([1, 2, np.nan, np.pi, np.nan])
rel.insert(((i, value) for i, value in enumerate(a)))
b = rel.fetch.order_by('id')['value']
assert_true((np.isnan(a) == np.isnan(b)).all(),
'incorrect handling of Nans')
assert_true(np.allclose(a[np.logical_not(np.isnan(a))], b[np.logical_not(np.isnan(b))]),
'incorrect storage of floats')