diff --git a/datajoint/connection.py b/datajoint/connection.py index 6ece857e3..23b3db47a 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -14,6 +14,20 @@ logger = logging.getLogger(__name__) +# The following is a temporary hack to address issue # +# START HACK +hack = False +if hack: + import binascii + + def escape_bytes(bs, mapping=None): + assert isinstance(bs, (bytes, bytearray)) + return '0x' + binascii.b2a_hex(bs).decode('ascii') + + connector.connections.escape_bytes = escape_bytes +# END HACK + + def conn(host=None, user=None, passwd=None, init_fun=None, reset=False): """ Returns a persistent connection object to be shared by multiple modules. @@ -160,8 +174,6 @@ def transaction(self): >>> import datajoint as dj >>> with dj.conn().transaction as conn: >>> # transaction is open here - - """ try: self.start_transaction() diff --git a/datajoint/kill.py b/datajoint/kill.py index b4c0c526c..48b9cbd54 100644 --- a/datajoint/kill.py +++ b/datajoint/kill.py @@ -37,11 +37,11 @@ def kill(restriction=None, connection=None): break if response: try: - id = int(response) + pid = int(response) except ValueError: pass # ignore non-numeric input else: try: - connection.query('kill %d' % id) + connection.query('kill %d' % pid) except pymysql.err.InternalError: - print('Process not found') \ No newline at end of file + print('Process not found') diff --git a/datajoint/relation.py b/datajoint/relation.py index 25f354607..e26b2bac4 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -32,7 +32,6 @@ def table_name(self): """ :return: the name of the table in the database """ - raise NotImplementedError('Relation subclasses must define property table_name') @property @abc.abstractmethod @@ -40,7 +39,6 @@ def definition(self): """ :return: a string containing the table definition using the DataJoint DDL """ - pass # -------------- required by RelationalOperand ----------------- # @property diff --git a/datajoint/schema.py b/datajoint/schema.py index a96f1678e..002c90f21 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -36,7 +36,7 @@ def __init__(self, database, context, connection=None): logger.info('Created database `{database}`.'.format(database=database)) except pymysql.OperationalError: raise DataJointError("Database named `{database}` was not defined, and" - "an attempt to create has failed. Check" + " an attempt to create has failed. Check" " permissions.".format(database=database)) def __call__(self, cls): diff --git a/tests/test_nan.py b/tests/test_nan.py index 8eb0a3d55..298573cca 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -1,9 +1,8 @@ import numpy as np -from nose.tools import assert_true, assert_false, assert_equal, assert_list_equal +from nose.tools import assert_true import datajoint as dj from . import PREFIX, CONN_INFO - schema = dj.schema(PREFIX + '_nantest', locals(), connection=dj.conn(**CONN_INFO)) diff --git a/tests/test_relation.py b/tests/test_relation.py index c7c357e23..5f834f935 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,6 +1,6 @@ from numpy.testing import assert_array_equal import numpy as np -from nose.tools import assert_raises, assert_equal, \ +from nose.tools import assert_raises, assert_equal, assert_not_equal, \ assert_false, assert_true, assert_list_equal, \ assert_tuple_equal, assert_dict_equal, raises @@ -43,6 +43,40 @@ def test_contents(self): u = self.subject.fetch(order_by=['subject_id']) assert_list_equal(list(u['subject_id']), sorted([s[0] for s in self.subject.contents])) + @raises(KeyError) + def test_misnamed_attribute(self): + self.user.insert1(dict(user="Bob")) + + @raises(dj.DataJointError) + def test_empty_insert(self): + self.user.insert1(()) + + @raises(dj.DataJointError) + def test_wrong_arguments_insert(self): + self.user.insert1(('First', 'Second')) + + @raises(dj.DataJointError) + def test_wrong_insert_type(self): + self.user.insert1(3) + + def test_replace(self): + """ + Test replacing or ignoring duplicate entries + """ + key = dict(subject_id=7) + date = "2015-01-01" + self.subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes="")) + assert_equal(date, str((self.subject & key).fetch1['date_of_birth']), 'incorrect insert') + date = "2015-01-02" + self.subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes=""), skip_duplicates=True) + assert_not_equal(date, str((self.subject & key).fetch1['date_of_birth']), + 'inappropriate replace') + self.subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True) + assert_equal(date, str((self.subject & key).fetch1['date_of_birth']), "replace failed") + def test_delete_quick(self): """Tests quick deletion""" tmp = np.array([