Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions datajoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _declare(self):
field = p.heading[key]
if field.name not in primary_key_fields:
primary_key_fields.add(field.name)
sql += self._field_to_SQL(field)
sql += self._field_to_sql(field)
else:
logger.debug('Field definition of {} in {} ignored'.format(
field.name, p.full_class_name))
Expand Down Expand Up @@ -355,6 +355,8 @@ def get_module(cls, module_name):
check within `package.subpackage` but not inside `package`).
3. Globally accessible module with the same name.
"""
# from IPython import embed
# embed()
mod_obj = importlib.import_module(cls.__module__)
attr = getattr(mod_obj, module_name, None)
if isinstance(attr, ModuleType):
Expand All @@ -363,7 +365,8 @@ def get_module(cls, module_name):
try:
return importlib.import_module('.' + module_name, mod_obj.__package__)
except ImportError:
try:
return importlib.import_module(module_name)
except ImportError:
return None
pass
try:
return importlib.import_module(module_name)
except ImportError:
return None
4 changes: 2 additions & 2 deletions datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def save(self, filename=None):
"""
if filename is None:
import datajoint as dj
filename = dj.config['config.file']
filename = LOCALCONFIG
with open(filename, 'w') as fid:
json.dump(self._conf, fid)
json.dump(self._conf, fid, indent=4)

def load(self, filename):
"""
Expand Down
43 changes: 31 additions & 12 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from . import DataJointError
from .relational import Relation
from .blob import pack

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(self, conn=None, dbname=None, class_name=None, definition=None):
# register with a fake module, enclosed in back quotes
self.conn.bind('`{0}`'.format(dbname), dbname)

#TODO: delay the loading until first use (move out of __init__)
# TODO: delay the loading until first use (move out of __init__)
self.conn.load_headings()
if self.class_name not in self.conn.table_names[self.dbname]:
if definition is None:
Expand Down Expand Up @@ -78,11 +79,29 @@ def primary_key(self):
"""
return self.heading.primary_key


def iter_insert(self, iter, **kwargs):
"""
Inserts an entire batch of entries. Additional keyword arguments are passed it insert.

:param iter: Must be an iterator that generates a sequence of valid arguments for insert.
"""
for row in iter:
self.insert(row, **kwargs)

def batch_insert(self, data, **kwargs):
"""
Inserts an entire batch of entries. Additional keyword arguments are passed it insert.

:param data: must be iterable, each row must be a valid argument for insert
"""
self.iter_insert(data.__iter__())

def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8)
"""
Insert one data tuple.
Insert one data tuple, one data record, or one dictionary.

:param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void.
:param tup: Data tuple, record, or dictionary.
:param ignore_errors=False: Ignores errors if True.
:param replace=False: Replaces data tuple if True.

Expand All @@ -92,19 +111,18 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress
b.insert(dict(subject_id = 7, species="mouse",\\
real_id = 1007, date_of_birth = "2014-09-01"))
"""
# todo: do we support records and named tuples for tup?

if issubclass(type(tup), tuple) or issubclass(type(tup), list):
if isinstance(tup, tuple) or isinstance(tup, list) or isinstance(tup, np.ndarray):
value_list = ','.join([repr(q) for q in tup])
attribute_list = '`'+'`,`'.join(self.heading.names[0:len(tup)]) + '`'
elif issubclass(type(tup), dict):
attribute_list = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`'
elif isinstance(tup, dict):
value_list = ','.join([repr(tup[q])
for q in self.heading.names if q in tup])
for q in self.heading.names if q in tup])
attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`'
elif issubclass(type(tup), np.void):
elif isinstance(tup, np.void):
value_list = ','.join([repr(tup[q])
for q in self.heading.names if q in tup])
attribute_list = '`' + '`,`'.join(tup.dtype.fields) + '`'
for q in self.heading.names if q in tup.dtype.fields])
attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup.dtype.fields]) + '`'
else:
raise DataJointError('Datatype %s cannot be inserted' % type(tup))
if replace:
Expand All @@ -115,10 +133,11 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress
sql = 'INSERT'
sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name,
attribute_list, value_list)

logger.info(sql)
self.conn.query(sql)

def delete(self): # TODO: (issues #14 and #15)
def delete(self): # TODO: (issues #14 and #15)
pass

def drop(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# Connection information for testing
CONN_INFO = {
'host': environ.get('DJ_TEST_HOST', 'localhost'),
'user': environ.get('DJ_TEST_USER', 'datajoint'),
'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint')
'user': environ.get('DJ_TEST_USER', 'travis'),
'passwd': environ.get('DJ_TEST_PASSWORD', '')
}
# Prefix for all databases used during testing
PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj')
Expand Down
107 changes: 107 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
__author__ = 'fabee'

from .schemata.schema1 import test1

from . import BASE_CONN, CONN_INFO, PREFIX, cleanup
from datajoint.connection import Connection
from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, assert_list_equal
from datajoint import DataJointError
import numpy as np


def setup():
"""
Setup connections and bindings
"""
pass


class TestTableObject(object):
def __init__(self):
self.relvar = None
self.setup()

"""
Test cases for Table objects
"""

def setup(self):
"""
Create a connection object and prepare test modules
as follows:
test1 - has conn and bounded
"""
cleanup() # drop all databases with PREFIX
self.conn = Connection(**CONN_INFO)
test1.conn = self.conn
self.conn.bind(test1.__name__, PREFIX + '_test1')
self.relvar = test1.Subjects()

def teardown(self):
cleanup()

def test_tuple_insert(self):
"Test whether tuple insert works"
testt = (1, 'Peter', 'mouse')
self.relvar.insert(testt)
testt2 = tuple((self.relvar & 'subject_id = 1').fetch()[0])
assert_equal(testt2, testt, "Inserted and fetched tuple do not match!")

def test_record_insert(self):
"Test whether record insert works"
tmp = np.array([(2, 'Klara', 'monkey')],
dtype=[('subject_id', '>i4'), ('real_id', 'O'), ('species', 'O')])

self.relvar.insert(tmp[0])
testt2 = (self.relvar & 'subject_id = 2').fetch()[0]
assert_equal(tuple(tmp[0]), tuple(testt2), "Inserted and fetched record do not match!")

def test_record_insert_different_order(self):
"Test whether record insert works"
tmp = np.array([('Klara', 2, 'monkey')],
dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.insert(tmp[0])
testt2 = (self.relvar & 'subject_id = 2').fetch()[0]
assert_equal((2, 'Klara', 'monkey'), tuple(testt2), "Inserted and fetched record do not match!")

def test_dict_insert(self):
"Test whether record insert works"
tmp = {'real_id': 'Brunhilda',
'subject_id': 3,
'species': 'human'}

self.relvar.insert(tmp)
testt2 = (self.relvar & 'subject_id = 3').fetch()[0]
assert_equal((3, 'Brunhilda', 'human'), tuple(testt2), "Inserted and fetched record do not match!")


def test_batch_insert(self):
"Test whether record insert works"
tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')],
dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.batch_insert(tmp)

expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'),
(3, 'Brunhilda', 'mouse')],
dtype=[('subject_id', '<i4'), ('real_id', 'O'), ('species', 'O')])
delivered = self.relvar.fetch()

for e,d in zip(expected, delivered):
assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match')

def test_iter_insert(self):
"Test whether record insert works"
tmp = np.array([('Klara', 2, 'monkey'), ('Brunhilda', 3, 'mouse'), ('Mickey', 1, 'human')],
dtype=[('real_id', 'O'), ('subject_id', '>i4'), ('species', 'O')])

self.relvar.iter_insert(tmp.__iter__())

expected = np.array([(1, 'Mickey', 'human'), (2, 'Klara', 'monkey'),
(3, 'Brunhilda', 'mouse')],
dtype=[('subject_id', '<i4'), ('real_id', 'O'), ('species', 'O')])
delivered = self.relvar.fetch()

for e,d in zip(expected, delivered):
assert_equal(tuple(e), tuple(d),'Inserted and fetched records do not match')