diff --git a/datajoint/base.py b/datajoint/base.py index ce30adfe9..f0f476f2c 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -143,7 +143,7 @@ def _declare(self): field = p.heading[key] if field.name not in primary_key_fields: primary_key_fields.add(field.name) - sql += self._field_to_SQL(field) + sql += self._field_to_sql(field) else: logger.debug('Field definition of {} in {} ignored'.format( field.name, p.full_class_name)) @@ -355,6 +355,8 @@ def get_module(cls, module_name): check within `package.subpackage` but not inside `package`). 3. Globally accessible module with the same name. """ + # from IPython import embed + # embed() mod_obj = importlib.import_module(cls.__module__) attr = getattr(mod_obj, module_name, None) if isinstance(attr, ModuleType): @@ -363,7 +365,8 @@ def get_module(cls, module_name): try: return importlib.import_module('.' + module_name, mod_obj.__package__) except ImportError: - try: - return importlib.import_module(module_name) - except ImportError: - return None + pass + try: + return importlib.import_module(module_name) + except ImportError: + return None diff --git a/datajoint/settings.py b/datajoint/settings.py index d6ad20b72..97316c89a 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -83,9 +83,9 @@ def save(self, filename=None): """ if filename is None: import datajoint as dj - filename = dj.config['config.file'] + filename = LOCALCONFIG with open(filename, 'w') as fid: - json.dump(self._conf, fid) + json.dump(self._conf, fid, indent=4) def load(self, filename): """ diff --git a/datajoint/table.py b/datajoint/table.py index 77164e395..2dd226f0d 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -2,6 +2,7 @@ import logging from . import DataJointError from .relational import Relation +from .blob import pack logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def __init__(self, conn=None, dbname=None, class_name=None, definition=None): # register with a fake module, enclosed in back quotes self.conn.bind('`{0}`'.format(dbname), dbname) - #TODO: delay the loading until first use (move out of __init__) + # TODO: delay the loading until first use (move out of __init__) self.conn.load_headings() if self.class_name not in self.conn.table_names[self.dbname]: if definition is None: @@ -78,11 +79,29 @@ def primary_key(self): """ return self.heading.primary_key + + def iter_insert(self, iter, **kwargs): + """ + Inserts an entire batch of entries. Additional keyword arguments are passed it insert. + + :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + """ + for row in iter: + self.insert(row, **kwargs) + + def batch_insert(self, data, **kwargs): + """ + Inserts an entire batch of entries. Additional keyword arguments are passed it insert. + + :param data: must be iterable, each row must be a valid argument for insert + """ + self.iter_insert(data.__iter__()) + def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8) """ - Insert one data tuple. + Insert one data tuple, one data record, or one dictionary. - :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void. + :param tup: Data tuple, record, or dictionary. :param ignore_errors=False: Ignores errors if True. :param replace=False: Replaces data tuple if True. @@ -92,19 +111,18 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress b.insert(dict(subject_id = 7, species="mouse",\\ real_id = 1007, date_of_birth = "2014-09-01")) """ - # todo: do we support records and named tuples for tup? - if issubclass(type(tup), tuple) or issubclass(type(tup), list): + if isinstance(tup, tuple) or isinstance(tup, list) or isinstance(tup, np.ndarray): value_list = ','.join([repr(q) for q in tup]) - attribute_list = '`'+'`,`'.join(self.heading.names[0:len(tup)]) + '`' - elif issubclass(type(tup), dict): + attribute_list = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`' + elif isinstance(tup, dict): value_list = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) + for q in self.heading.names if q in tup]) attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`' - elif issubclass(type(tup), np.void): + elif isinstance(tup, np.void): value_list = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) - attribute_list = '`' + '`,`'.join(tup.dtype.fields) + '`' + for q in self.heading.names if q in tup.dtype.fields]) + attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup.dtype.fields]) + '`' else: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) if replace: @@ -115,10 +133,11 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress sql = 'INSERT' sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, attribute_list, value_list) + logger.info(sql) self.conn.query(sql) - def delete(self): # TODO: (issues #14 and #15) + def delete(self): # TODO: (issues #14 and #15) pass def drop(self): diff --git a/tests/__init__.py b/tests/__init__.py index 5ad8544d8..e6f7b5099 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,8 +14,8 @@ # Connection information for testing CONN_INFO = { 'host': environ.get('DJ_TEST_HOST', 'localhost'), - 'user': environ.get('DJ_TEST_USER', 'datajoint'), - 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') + 'user': environ.get('DJ_TEST_USER', 'travis'), + 'passwd': environ.get('DJ_TEST_PASSWORD', '') } # Prefix for all databases used during testing PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj') diff --git a/tests/test_table.py b/tests/test_table.py new file mode 100644 index 000000000..98a1ec3d0 --- /dev/null +++ b/tests/test_table.py @@ -0,0 +1,107 @@ +__author__ = 'fabee' + +from .schemata.schema1 import test1 + +from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +from datajoint.connection import Connection +from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal +from datajoint import DataJointError +import numpy as np + + +def setup(): + """ + Setup connections and bindings + """ + pass + + +class TestTableObject(object): + def __init__(self): + self.relvar = None + self.setup() + + """ + Test cases for Table objects + """ + + def setup(self): + """ + Create a connection object and prepare test modules + as follows: + test1 - has conn and bounded + """ + cleanup() # drop all databases with PREFIX + self.conn = Connection(**CONN_INFO) + test1.conn = self.conn + self.conn.bind(test1.__name__, PREFIX + '_test1') + self.relvar = test1.Subjects() + + def teardown(self): + cleanup() + + def test_tuple_insert(self): + "Test whether tuple insert works" + testt = (1, 'Peter', 'mouse') + self.relvar.insert(testt) + testt2 = tuple((self.relvar & 'subject_id = 1').fetch()[0]) + assert_equal(testt2, testt, "Inserted and fetched tuple do not match!") + + def test_record_insert(self): + "Test whether record insert works" + tmp = np.array([(2, 'Klara', 'monkey')], + dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + testt2 = (self.relvar & 'subject_id = 2').fetch()[0] + assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!") + + def test_record_insert_different_order(self): + "Test whether record insert works" + tmp = np.array([('Klara', 2, 'monkey')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + + self.relvar.insert(tmp[0]) + testt2 = (self.relvar & 'subject_id = 2').fetch()[0] + assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!") + + def test_dict_insert(self): + "Test whether record insert works" + tmp = {'real_id': 'Brunhilda', + 'subject_id': 3, + 'species': 'human'} + + self.relvar.insert(tmp) + testt2 = (self.relvar & 'subject_id = 3').fetch()[0] + assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!") + + + def test_batch_insert(self): + "Test whether record insert works" + tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')], + dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')]) + + self.relvar.batch_insert(tmp) + + expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), + (3, 'Brunhilda', 'mouse')], + dtype=[('subject_id', 'i4'), ('species', 'O')]) + + self.relvar.iter_insert(tmp.__iter__()) + + expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'), + (3, 'Brunhilda', 'mouse')], + dtype=[('subject_id', '