diff --git a/datajoint/base.py b/datajoint/base.py
index f0f476f2c..45370c182 100644
--- a/datajoint/base.py
+++ b/datajoint/base.py
@@ -33,7 +33,6 @@ class Subjects(dj.Base):
'''
"""
-
@abc.abstractproperty
def definition(self):
"""
@@ -42,7 +41,27 @@ def definition(self):
"""
pass
- def __init__(self):
+ @property
+ def full_class_name(self):
+ """
+ :return: full class name including the entire package hierarchy
+ """
+ return '{}.{}'.format(self.__module__, self.class_name)
+
+ @property
+ def access_name(self):
+ """
+ :return: name by which this class should be accessible as
+ """
+ if self._use_package:
+ parent = self.__module__.split('.')[-2]
+ else:
+ parent = self.__module__.split('.')[-1]
+ return parent + '.' + self.class_name
+
+
+
+ def __init__(self): #TODO: support taking in conn obj
self.class_name = self.__class__.__name__
module = self.__module__
mod_obj = importlib.import_module(module)
@@ -61,6 +80,7 @@ def __init__(self):
self.conn = conn
try:
if self._use_package:
+ # the database is bound to the package
pkg_name = '.'.join(module.split('.')[:-1])
dbname = self.conn.mod_to_db[pkg_name]
else:
@@ -69,256 +89,7 @@ def __init__(self):
raise DataJointError(
'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
self.dbname = dbname
- self.declare()
- super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__)
-
- @property
- def is_declared(self):
- self.conn.load_headings(self.dbname)
- return self.class_name in self.conn.table_names[self.dbname]
-
- def declare(self):
- """
- Declare the table in database if it doesn't already exist.
-
- :raises: DataJointError if the table cannot be declared.
- """
- if not self.is_declared:
- self._declare()
- if not self.is_declared:
- raise DataJointError(
- 'Table could not be declared for %s' % self.class_name)
-
- def _field_to_sql(self, field):
- """
- Converts an attribute definition tuple into SQL code.
- :param field: attribute definition
- :rtype : SQL code
- """
- mysql_constants = ['CURRENT_TIMESTAMP']
- if field.nullable:
- default = 'DEFAULT NULL'
- else:
- default = 'NOT NULL'
- # if some default specified
- if field.default:
- # enclose value in quotes (even numeric), except special SQL values
- # or values already enclosed by the user
- if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']:
- default = '%s DEFAULT %s' % (default, field.default)
- else:
- default = '%s DEFAULT "%s"' % (default, field.default)
-
- # TODO: escape instead! - same goes for Matlab side implementation
- assert not any((c in r'\"' for c in field.comment)), \
- 'Illegal characters in attribute comment "%s"' % field.comment
-
- return '`{name}` {type} {default} COMMENT "{comment}",\n'.format(
- name=field.name, type=field.type, default=default, comment=field.comment)
-
- def _declare(self):
- """
- Declares the table in the data base if no table in the database matches this object.
- """
- if not self.definition:
- raise DataJointError('Table declaration is missing!')
- table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration()
- defined_name = table_info['module'] + '.' + table_info['className']
- expected_name = self.__module__.split('.')[-1] + '.' + self.class_name
- if not defined_name == expected_name:
- raise DataJointError('Table name {} does not match the declared'
- 'name {}'.format(expected_name, defined_name))
-
- # compile the CREATE TABLE statement
- # TODO: support prefix
- table_name = role_to_prefix[
- table_info['tier']] + from_camel_case(self.class_name)
- sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)
-
- # add inherited primary key fields
- primary_key_fields = set()
- non_key_fields = set()
- for p in parents:
- for key in p.primary_key:
- field = p.heading[key]
- if field.name not in primary_key_fields:
- primary_key_fields.add(field.name)
- sql += self._field_to_sql(field)
- else:
- logger.debug('Field definition of {} in {} ignored'.format(
- field.name, p.full_class_name))
-
- # add newly defined primary key fields
- for field in (f for f in fieldDefs if f.in_key):
- if field.nullable:
- raise DataJointError('Primary key {} cannot be nullable'.format(
- field.name))
- if field.name in primary_key_fields:
- raise DataJointError('Duplicate declaration of the primary key '
- '{key}. Check to make sure that the key '
- 'is not declared already in referenced '
- 'tables'.format(key=field.name))
- primary_key_fields.add(field.name)
- sql += self._field_to_sql(field)
-
- # add secondary foreign key attributes
- for r in referenced:
- keys = (x for x in r.heading.attrs.values() if x.in_key)
- for field in keys:
- if field.name not in primary_key_fields | non_key_fields:
- non_key_fields.add(field.name)
- sql += self._field_to_sql(field)
- # add dependent attributes
- for field in (f for f in fieldDefs if not f.in_key):
- non_key_fields.add(field.name)
- sql += self._field_to_sql(field)
-
- # add primary key declaration
- assert len(primary_key_fields) > 0, 'table must have a primary key'
- keys = ', '.join(primary_key_fields)
- sql += 'PRIMARY KEY (%s),\n' % keys
-
- # add foreign key declarations
- for ref in parents + referenced:
- keys = ', '.join(ref.primary_key)
- sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \
- (keys, ref.full_table_name, keys)
-
- # add secondary index declarations
- # gather implicit indexes due to foreign keys first
- implicit_indices = []
- for fk_source in parents + referenced:
- implicit_indices.append(fk_source.primary_key)
-
- # for index in indexDefs:
- # TODO: finish this up...
-
- # close the declaration
- sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (
- sql[:-2], table_info['comment'])
-
- # make sure that the table does not alredy exist
- self.conn.load_headings(self.dbname, force=True)
- if not self.is_declared:
- # execute declaration
- logger.debug('\n\n' + sql + '\n\n')
- self.conn.query(sql)
- self.conn.load_headings(self.dbname, force=True)
-
- def _parse_declaration(self):
- """
- Parse declaration and create new SQL table accordingly.
- """
- parents = []
- referenced = []
- index_defs = []
- field_defs = []
- declaration = re.split(r'\s*\n\s*', self.definition.strip())
-
- # remove comment lines
- declaration = [x for x in declaration if not x.startswith('#')]
- ptrn = """
- ^(?P\w+)\.(?P\w+)\s* # module.className
- \(\s*(?P\w+)\s*\)\s* # (tier)
- \#\s*(?P.*)$ # comment
- """
- p = re.compile(ptrn, re.X)
- table_info = p.match(declaration[0]).groupdict()
- if table_info['tier'] not in Role.__members__:
- raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\
- {module}.{cls}'.format(tier=table_info['tier'],
- module=table_info[
- 'module'],
- cls=table_info['className']))
- table_info['tier'] = Role[table_info['tier']] # convert into enum
-
- in_key = True # parse primary keys
- field_ptrn = """
- ^[a-z][a-z\d_]*\s* # name
- (=\s*\S+(\s+\S+)*\s*)? # optional defaults
- :\s*\w.*$ # type, comment
- """
- fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose
-
- for line in declaration[1:]:
- if line.startswith('---'):
- in_key = False # start parsing non-PK fields
- elif line.startswith('->'):
- # foreign key
- module_name, class_name = line[2:].strip().split('.')
- rel = self.get_base(module_name, class_name)
- (parents if in_key else referenced).append(rel)
- elif re.match(r'^(unique\s+)?index[^:]*$', line):
- index_defs.append(self._parse_index_def(line))
- elif fieldP.match(line):
- field_defs.append(self._parse_attr_def(line, in_key))
- else:
- raise DataJointError(
- 'Invalid table declaration line "%s"' % line)
-
- return table_info, parents, referenced, field_defs, index_defs
-
- def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key
- """
- Parse attribute definition line in the declaration and returns
- an attribute tuple.
-
- :param line: attribution line
- :param in_key:
- :returns: attribute tuple
- """
- line = line.strip()
- attr_ptrn = """
- ^(?P[a-z][a-z\d_]*)\s* # field name
- (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value
- :\s*(?P\w[^\#]*[^\#\s])\s* # datatype
- (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment
- """
-
- attrP = re.compile(attr_ptrn, re.I + re.X)
- m = attrP.match(line)
- assert m, 'Invalid field declaration "%s"' % line
- attr_info = m.groupdict()
- if not attr_info['comment']:
- attr_info['comment'] = ''
- if not attr_info['default']:
- attr_info['default'] = ''
- attr_info['nullable'] = attr_info['default'].lower() == 'null'
- assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \
- 'BIGINT attributes cannot be nullable in "%s"' % line
-
- attr_info['in_key'] = in_key
- attr_info['autoincrement'] = None
- attr_info['numeric'] = None
- attr_info['string'] = None
- attr_info['is_blob'] = None
- attr_info['computation'] = None
- attr_info['dtype'] = None
-
- return Heading.AttrTuple(**attr_info)
-
- def _parse_index_def(self, line):
- """
- Parses index definition.
-
- :param line: definition line
- :return: groupdict with index info
- """
- line = line.strip()
- index_ptrn = """
- ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX
- \((?P[^\)]+)\)$ # (attr1, attr2)
- """
- indexP = re.compile(index_ptrn, re.I + re.X)
- m = indexP.match(line)
- assert m, 'Invalid index declaration "%s"' % line
- index_info = m.groupdict()
- attributes = re.split(r'\s*,\s*', index_info['attributes'].strip())
- index_info['attributes'] = attributes
- assert len(attributes) == len(set(attributes)), \
- 'Duplicate attributes in index declaration "%s"' % line
- return index_info
def get_base(self, module_name, class_name):
"""
@@ -330,12 +101,15 @@ def get_base(self, module_name, class_name):
:returns: the base relation
"""
mod_obj = self.get_module(module_name)
+ if not mod_obj:
+ raise DataJointError('Module named {mod_name} was not found. Please make'
+ ' sure that it is in the path or you import the module.'.format(mod_name=module_name))
try:
ret = getattr(mod_obj, class_name)()
- except KeyError:
- ret = self.__class__(conn=self.conn,
- dbname=self.conn.schemas[module_name],
- class_name=class_name)
+ except AttributeError:
+ ret = Table(conn=self.conn,
+ dbname=self.conn.mod_to_db[mod_obj.__name__],
+ class_name=class_name)
return ret
@classmethod
@@ -358,6 +132,8 @@ def get_module(cls, module_name):
# from IPython import embed
# embed()
mod_obj = importlib.import_module(cls.__module__)
+ if cls.__module__.split('.')[-1] == module_name:
+ return mod_obj
attr = getattr(mod_obj, module_name, None)
if isinstance(attr, ModuleType):
return attr
diff --git a/datajoint/connection.py b/datajoint/connection.py
index 4655866f7..d94d404bd 100644
--- a/datajoint/connection.py
+++ b/datajoint/connection.py
@@ -66,7 +66,6 @@ def __init__(self, host, user, passwd, init_fun=None):
print("Connected", user + '@' + host + ':' + str(port))
self._conn.autocommit(True)
- self.mod_to_db2 = {} # database indexed by module names
self.db_to_mod = {} # modules indexed by dbnames
self.mod_to_db = {} # database names indexed by modules
self.table_names = {} # tables names indexed by [dbname][class_name]
diff --git a/datajoint/table.py b/datajoint/table.py
index 3966c40c3..b9a63f383 100644
--- a/datajoint/table.py
+++ b/datajoint/table.py
@@ -3,6 +3,10 @@
from . import DataJointError
from .relational import Relation
from .blob import pack
+from .heading import Heading
+import re
+from .settings import Role, role_to_prefix
+from .utils import from_camel_case
logger = logging.getLogger(__name__)
@@ -28,19 +32,57 @@ def __init__(self, conn=None, dbname=None, class_name=None, definition=None):
self.class_name = class_name
self.conn = conn
self.dbname = dbname
- self.conn.load_headings(self.dbname)
+ self.definition = definition
if dbname not in self.conn.db_to_mod:
# register with a fake module, enclosed in back quotes
+ # necessary for loading mechanism
self.conn.bind('`{0}`'.format(dbname), dbname)
- # TODO: delay the loading until first use (move out of __init__)
- self.conn.load_headings()
- if self.class_name not in self.conn.table_names[self.dbname]:
- if definition is None:
- raise DataJointError('The table is not declared')
- else:
- declare(conn, definition, class_name)
+
+ @property
+ def is_declared(self):
+ self.conn.load_headings(self.dbname)
+ return self.class_name in self.conn.table_names[self.dbname]
+
+ def declare(self):
+ """
+ Declare the table in database if it doesn't already exist.
+
+ :raises: DataJointError if the table cannot be declared.
+ """
+ if not self.is_declared:
+ self._declare()
+ if not self.is_declared:
+ raise DataJointError(
+ 'Table could not be declared for %s' % self.class_name)
+
+ @staticmethod
+ def _field_to_sql(field): #TODO move this into Attribute Tuple
+ """
+ Converts an attribute definition tuple into SQL code.
+ :param field: attribute definition
+ :rtype : SQL code
+ """
+ mysql_constants = ['CURRENT_TIMESTAMP']
+ if field.nullable:
+ default = 'DEFAULT NULL'
+ else:
+ default = 'NOT NULL'
+ # if some default specified
+ if field.default:
+ # enclose value in quotes (even numeric), except special SQL values
+ # or values already enclosed by the user
+ if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']:
+ default = '%s DEFAULT %s' % (default, field.default)
+ else:
+ default = '%s DEFAULT "%s"' % (default, field.default)
+
+ if any((c in r'\"' for c in field.comment)):
+ raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment)
+
+ return '`{name}` {type} {default} COMMENT "{comment}",\n'.format(
+ name=field.name, type=field.type, default=default, comment=field.comment)
@property
def sql(self):
@@ -48,6 +90,7 @@ def sql(self):
@property
def heading(self):
+ self.declare()
return self.conn.headings[self.dbname][self.table_name]
@property
@@ -65,12 +108,6 @@ def table_name(self):
return self.conn.table_names[self.dbname][self.class_name]
- @property
- def full_class_name(self):
- """
- :return: full class name
- """
- return '{}.{}'.format(self.__module__, self.class_name)
@property
def primary_key(self):
@@ -82,7 +119,7 @@ def primary_key(self):
def iter_insert(self, iter, **kwargs):
"""
- Inserts an entire batch of entries. Additional keyword arguments are passed it insert.
+ Inserts an entire batch of entries. Additional keyword arguments are passed to insert.
:param iter: Must be an iterator that generates a sequence of valid arguments for insert.
"""
@@ -91,7 +128,7 @@ def iter_insert(self, iter, **kwargs):
def batch_insert(self, data, **kwargs):
"""
- Inserts an entire batch of entries. Additional keyword arguments are passed it insert.
+ Inserts an entire batch of entries. Additional keyword arguments are passed to insert.
:param data: must be iterable, each row must be a valid argument for insert
"""
@@ -249,4 +286,222 @@ def _alter(self, alter_statement):
sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement)
self.conn.query(sql)
self.conn.load_headings(self.dbname, force=True)
- # TODO: place table definition sync mechanism
\ No newline at end of file
+ # TODO: place table definition sync mechanism
+
+ def _parse_index_def(self, line):
+ """
+ Parses index definition.
+
+ :param line: definition line
+ :return: groupdict with index info
+ """
+ line = line.strip()
+ index_ptrn = """
+ ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX
+ \((?P[^\)]+)\)$ # (attr1, attr2)
+ """
+ indexP = re.compile(index_ptrn, re.I + re.X)
+ m = indexP.match(line)
+ assert m, 'Invalid index declaration "%s"' % line
+ index_info = m.groupdict()
+ attributes = re.split(r'\s*,\s*', index_info['attributes'].strip())
+ index_info['attributes'] = attributes
+ assert len(attributes) == len(set(attributes)), \
+ 'Duplicate attributes in index declaration "%s"' % line
+ return index_info
+
+ def _parse_attr_def(self, line, in_key=False):
+ """
+ Parse attribute definition line in the declaration and returns
+ an attribute tuple.
+
+ :param line: attribution line
+ :param in_key: set to True if attribute is in primary key set
+ :returns: attribute tuple
+ """
+ line = line.strip()
+ attr_ptrn = """
+ ^(?P[a-z][a-z\d_]*)\s* # field name
+ (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value
+ :\s*(?P\w[^\#]*[^\#\s])\s* # datatype
+ (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment
+ """
+
+ attrP = re.compile(attr_ptrn, re.I + re.X)
+ m = attrP.match(line)
+ assert m, 'Invalid field declaration "%s"' % line
+ attr_info = m.groupdict()
+ if not attr_info['comment']:
+ attr_info['comment'] = ''
+ if not attr_info['default']:
+ attr_info['default'] = ''
+ attr_info['nullable'] = attr_info['default'].lower() == 'null'
+ assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \
+ 'BIGINT attributes cannot be nullable in "%s"' % line
+
+ attr_info['in_key'] = in_key
+ attr_info['autoincrement'] = None
+ attr_info['numeric'] = None
+ attr_info['string'] = None
+ attr_info['is_blob'] = None
+ attr_info['computation'] = None
+ attr_info['dtype'] = None
+
+ return Heading.AttrTuple(**attr_info)
+
+ def get_base(self, module_name, class_name):
+ return None
+
+ @property
+ def ref_name(self):
+ """
+ :return: the name to refer to this class, taking form module.class or `database`.class
+ """
+ return '`{0}`'.format(self.dbname) + '.' + self.class_name
+
+ def _declare(self):
+ """
+ Declares the table in the database if no table in the database matches this object.
+ """
+ if not self.definition:
+ raise DataJointError('Table definition is missing!')
+ table_info, parents, referenced, field_defs, index_defs = self._parse_declaration()
+ defined_name = table_info['module'] + '.' + table_info['className']
+ if self._use_package:
+ parent = self.__module__.split('.')[-2]
+ else:
+ parent = self.__module__.split('.')[-1]
+ expected_name = parent + '.' + self.class_name
+ if not defined_name == expected_name:
+ raise DataJointError('Table name {} does not match the declared'
+ 'name {}'.format(expected_name, defined_name))
+
+ # compile the CREATE TABLE statement
+ # TODO: support prefix
+ table_name = role_to_prefix[table_info['tier']] + from_camel_case(self.class_name)
+ sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)
+
+ # add inherited primary key fields
+ primary_key_fields = set()
+ non_key_fields = set()
+ for p in parents:
+ for key in p.primary_key:
+ field = p.heading[key]
+ if field.name not in primary_key_fields:
+ primary_key_fields.add(field.name)
+ sql += self._field_to_sql(field)
+ else:
+ logger.debug('Field definition of {} in {} ignored'.format(
+ field.name, p.full_class_name))
+
+ # add newly defined primary key fields
+ for field in (f for f in field_defs if f.in_key):
+ if field.nullable:
+ raise DataJointError('Primary key {} cannot be nullable'.format(
+ field.name))
+ if field.name in primary_key_fields:
+ raise DataJointError('Duplicate declaration of the primary key '
+ '{key}. Check to make sure that the key '
+ 'is not declared already in referenced '
+ 'tables'.format(key=field.name))
+ primary_key_fields.add(field.name)
+ sql += self._field_to_sql(field)
+
+ # add secondary foreign key attributes
+ for r in referenced:
+ keys = (x for x in r.heading.attrs.values() if x.in_key)
+ for field in keys:
+ if field.name not in primary_key_fields | non_key_fields:
+ non_key_fields.add(field.name)
+ sql += self._field_to_sql(field)
+
+ # add dependent attributes
+ for field in (f for f in field_defs if not f.in_key):
+ non_key_fields.add(field.name)
+ sql += self._field_to_sql(field)
+
+ # add primary key declaration
+ assert len(primary_key_fields) > 0, 'table must have a primary key'
+ keys = ', '.join(primary_key_fields)
+ sql += 'PRIMARY KEY (%s),\n' % keys
+
+ # add foreign key declarations
+ for ref in parents + referenced:
+ keys = ', '.join(ref.primary_key)
+ sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \
+ (keys, ref.full_table_name, keys)
+
+ # add secondary index declarations
+ # gather implicit indexes due to foreign keys first
+ implicit_indices = []
+ for fk_source in parents + referenced:
+ implicit_indices.append(fk_source.primary_key)
+
+ # for index in indexDefs:
+ # TODO: finish this up...
+
+ # close the declaration
+ sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (
+ sql[:-2], table_info['comment'])
+
+ # make sure that the table does not alredy exist
+ self.conn.load_headings(self.dbname, force=True)
+ if not self.is_declared:
+ # execute declaration
+ logger.debug('\n\n' + sql + '\n\n')
+ self.conn.query(sql)
+ self.conn.load_headings(self.dbname, force=True)
+
+
+ def _parse_declaration(self):
+ """
+ Parse declaration and create new SQL table accordingly.
+ """
+ parents = []
+ referenced = []
+ index_defs = []
+ field_defs = []
+ declaration = re.split(r'\s*\n\s*', self.definition.strip())
+
+ # remove comment lines
+ declaration = [x for x in declaration if not x.startswith('#')]
+ ptrn = """
+ ^(?P\w+)\.(?P\w+)\s* # module.className
+ \(\s*(?P\w+)\s*\)\s* # (tier)
+ \#\s*(?P.*)$ # comment
+ """
+ p = re.compile(ptrn, re.X)
+ table_info = p.match(declaration[0]).groupdict()
+ if table_info['tier'] not in Role.__members__:
+ raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\
+ {module}.{cls}'.format(tier=table_info['tier'],
+ module=table_info[
+ 'module'],
+ cls=table_info['className']))
+ table_info['tier'] = Role[table_info['tier']] # convert into enum
+
+ in_key = True # parse primary keys
+ field_ptrn = """
+ ^[a-z][a-z\d_]*\s* # name
+ (=\s*\S+(\s+\S+)*\s*)? # optional defaults
+ :\s*\w.*$ # type, comment
+ """
+ fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose
+
+ for line in declaration[1:]:
+ if line.startswith('---'):
+ in_key = False # start parsing non-PK fields
+ elif line.startswith('->'):
+ # foreign key
+ module_name, class_name = line[2:].strip().split('.')
+ rel = self.get_base(module_name, class_name)
+ (parents if in_key else referenced).append(rel)
+ elif re.match(r'^(unique\s+)?index[^:]*$', line):
+ index_defs.append(self._parse_index_def(line))
+ elif fieldP.match(line):
+ field_defs.append(self._parse_attr_def(line, in_key))
+ else:
+ raise DataJointError(
+ 'Invalid table declaration line "%s"' % line)
+
+ return table_info, parents, referenced, field_defs, index_defs
diff --git a/tests/__init__.py b/tests/__init__.py
index e6f7b5099..2fa1bea0a 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -14,8 +14,8 @@
# Connection information for testing
CONN_INFO = {
'host': environ.get('DJ_TEST_HOST', 'localhost'),
- 'user': environ.get('DJ_TEST_USER', 'travis'),
- 'passwd': environ.get('DJ_TEST_PASSWORD', '')
+ 'user': environ.get('DJ_TEST_USER', 'datajoint'),
+ 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint')
}
# Prefix for all databases used during testing
PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj')
@@ -46,6 +46,33 @@ def cleanup():
cur.execute('DROP DATABASE `{}`'.format(db))
cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on
+def setup_sample_db():
+ """
+ Helper method to setup databases with tables to be used
+ during the test
+ """
+ cur = BASE_CONN.cursor()
+ cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX))
+ cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX))
+ query1 = """
+ CREATE TABLE `{prefix}_test1`.`subjects`
+ (
+ subject_id SMALLINT COMMENT 'Unique subject ID',
+ subject_name VARCHAR(255) COMMENT 'Subject name',
+ subject_email VARCHAR(255) COMMENT 'Subject email address',
+ PRIMARY KEY (subject_id)
+ )
+ """.format(prefix=PREFIX)
+ cur.execute(query1)
+ query2 = """
+ CREATE TABLE `{prefix}_test2`.`experimenter`
+ (
+ experimenter_id SMALLINT COMMENT 'Unique experimenter ID',
+ experimenter_name VARCHAR(255) COMMENT 'Experimenter name',
+ PRIMARY KEY (experimenter_id)
+ )""".format(prefix=PREFIX)
+ cur.execute(query2)
+
diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py
index 281a53bba..6032e7bd6 100644
--- a/tests/schemata/schema1/__init__.py
+++ b/tests/schemata/schema1/__init__.py
@@ -2,6 +2,4 @@
import datajoint as dj
print(__name__)
-from .test1 import *
-from .test2 import *
from .test3 import *
\ No newline at end of file
diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py
index fbc215864..d0d276e6d 100644
--- a/tests/schemata/schema1/test1.py
+++ b/tests/schemata/schema1/test1.py
@@ -1,10 +1,10 @@
"""
-Test 1 Schema definition - fully bound and has connection object
+Test 1 Schema definition
"""
__author__ = 'eywalker'
import datajoint as dj
-
+from .. import schema2
class Subjects(dj.Base):
definition = """
@@ -15,3 +15,36 @@ class Subjects(dj.Base):
real_id : varchar(40) # real-world name
species = "mouse" : enum('mouse', 'monkey', 'human') # species
"""
+
+# test reference to another table in same schema
+class Experiments(dj.Base):
+ definition = """
+ test1.Experiments (imported) # Experiment info
+ -> test1.Subjects
+ exp_id : int # unique id for experiment
+ ---
+ exp_data_file : varchar(255) # data file
+ """
+
+# refers to a table in dj_test2 (bound to test2) but without a class
+class Session(dj.Base):
+ definition = """
+ test1.Session (manual) # Experiment sessions
+ -> test1.Subjects
+ -> test2.Experimenter
+ session_id : int # unique session id
+ ---
+ session_comment : varchar(255) # comment about the session
+ """
+
+class Match(dj.Base):
+ definition = """
+ test1.Match (manual) # Match between subject and color
+ -> schema2.Subjects
+ ---
+ dob : date # date of birth
+ """
+
+
+class Empty(dj.Base):
+ pass
diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py
index f25b8d2a1..920ef9b10 100644
--- a/tests/schemata/schema1/test2.py
+++ b/tests/schemata/schema1/test2.py
@@ -1,12 +1,14 @@
"""
-Test 2 Schema definition - has conn but not bound
+Test 2 Schema definition
"""
__author__ = 'eywalker'
import datajoint as dj
+from . import test1 as alias
#from ..schema2 import test2 as test1
+# references to another schema
class Experiments(dj.Base):
definition = """
test2.Experiments (manual) # Basic subject info
@@ -15,4 +17,29 @@ class Experiments(dj.Base):
---
real_id : varchar(40) # real-world name
species = "mouse" : enum('mouse', 'monkey', 'human') # species
+ """
+
+# references to another schema
+class Conditions(dj.Base):
+ definition = """
+ test2.Conditions (manual) # Subject conditions
+ -> alias.Subjects
+ condition_name : varchar(255) # description of the condition
+ """
+
+class FoodPreference(dj.Base):
+ definition = """
+ test2.FoodPreference (manual) # Food preference of each subject
+ -> animals.Subjects
+ preferred_food : enum('banana', 'apple', 'oranges')
+ """
+
+class Session(dj.Base):
+ definition = """
+ test2.Session (manual) # Experiment sessions
+ -> test1.Subjects
+ -> test2.Experimenter
+ session_id : int # unique session id
+ ---
+ session_comment : varchar(255) # comment about the session
"""
\ No newline at end of file
diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py
index 7551854c1..59e1c84fb 100644
--- a/tests/schemata/schema1/test3.py
+++ b/tests/schemata/schema1/test3.py
@@ -1,5 +1,7 @@
"""
Test 3 Schema definition - no binding, no conn
+
+To be bound at the package level
"""
__author__ = 'eywalker'
@@ -8,10 +10,12 @@
class Subjects(dj.Base):
definition = """
- test3.Subjects (manual) # Basic subject info
+ schema1.Subjects (manual) # Basic subject info
subject_id : int # unique subject id
+ dob : date # date of birth
---
real_id : varchar(40) # real-world name
species = "mouse" : enum('mouse', 'monkey', 'human') # species
- """
\ No newline at end of file
+ """
+
diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py
index 25d6f77b7..e6b482590 100644
--- a/tests/schemata/schema2/__init__.py
+++ b/tests/schemata/schema2/__init__.py
@@ -1 +1,2 @@
__author__ = 'eywalker'
+from .test1 import *
\ No newline at end of file
diff --git a/tests/schemata/schema2/test2.py b/tests/schemata/schema2/test1.py
similarity index 80%
rename from tests/schemata/schema2/test2.py
rename to tests/schemata/schema2/test1.py
index 05ed84a82..efcd71385 100644
--- a/tests/schemata/schema2/test2.py
+++ b/tests/schemata/schema2/test1.py
@@ -8,9 +8,8 @@
class Subjects(dj.Base):
- _table_def = """
- test2.Subjects (manual) # Basic subject info
-
+ definition = """
+ schema2.Subjects (manual) # Basic subject info
pop_id : int # unique experiment id
---
real_id : varchar(40) # real-world name
diff --git a/tests/test_base.py b/tests/test_base.py
index 61ad99e8d..49a5b6810 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -2,14 +2,15 @@
Collection of test cases for base module. Tests functionalities such as
creating tables using docstring table declarations
"""
+from .schemata import schema1, schema2
from .schemata.schema1 import test1, test2, test3
__author__ = 'eywalker'
-from . import BASE_CONN, CONN_INFO, PREFIX, cleanup
+from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db
from datajoint.connection import Connection
-from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true
+from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises
from datajoint import DataJointError
@@ -20,43 +21,43 @@ def setup():
pass
-class TestBaseObject(object):
+class TestBaseInstantiations(object):
"""
- Test cases for Base objects
+ Test cases for instantiating Base objects
"""
+ def __init__(self):
+ self.conn = None
+
def setup(self):
"""
Create a connection object and prepare test modules
as follows:
test1 - has conn and bounded
- test2 - has conn but not bound
- test3 - no conn and not bound
"""
- cleanup() # drop all databases with PREFIX
-
self.conn = Connection(**CONN_INFO)
- test1.conn = self.conn
- self.conn.bind(test1.__name__, PREFIX+'_test1')
-
- test2.conn = self.conn
+ cleanup() # drop all databases with PREFIX
+ #test1.conn = self.conn
+ #self.conn.bind(test1.__name__, PREFIX+'_test1')
- from .schemata.schema2 import test2 as test2_2
- test2_2.conn = self.conn
- self.conn.bind(test2_2.__name__, PREFIX+'_test2_2')
+ #test2.conn = self.conn
- test3.__dict__.pop('conn', None) # make sure conn is not defined in test3
+ #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3
+ test1.__dict__.pop('conn', None)
+ schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level
def teardown(self):
cleanup()
+
def test_instantiation_from_unbound_module_should_fail(self):
"""
Attempting to instantiate a Base derivative from a module with
connection defined but not bound to a database should raise error
"""
+ test1.conn = self.conn
with assert_raises(DataJointError) as e:
- test2.Experiments()
+ test1.Subjects()
assert_regexp_matches(e.exception.args[0], r".*not bound.*")
def test_instantiation_from_module_without_conn_should_fail(self):
@@ -65,7 +66,7 @@ def test_instantiation_from_module_without_conn_should_fail(self):
`conn` object should raise error
"""
with assert_raises(DataJointError) as e:
- test3.Subjects()
+ test1.Subjects()
assert_regexp_matches(e.exception.args[0], r".*define.*conn.*")
def test_instantiation_of_base_derivatives(self):
@@ -73,20 +74,133 @@ def test_instantiation_of_base_derivatives(self):
Test instantiation and initialization of objects derived from
Base class
"""
+ test1.conn = self.conn
+ self.conn.bind(test1.__name__, PREFIX + '_test1')
s = test1.Subjects()
assert_equal(s.dbname, PREFIX + '_test1')
assert_equal(s.conn, self.conn)
assert_equal(s.definition, test1.Subjects.definition)
- def test_declaration_status(self):
- b = test1.Subjects()
- assert_true(b.is_declared)
- def test_declaration_from_doc_string(self):
- cur = BASE_CONN.cursor()
- assert_equal(cur.execute("SHOW TABLES IN `{}` LIKE 'subjects'".format(PREFIX + '_test1')), 0)
- test1.Subjects().declare()
- assert_equal(cur.execute("SHOW TABLES IN `{}` LIKE 'subjects'".format(PREFIX + '_test1')), 1)
+
+ def test_packagelevel_binding(self):
+ schema2.conn = self.conn
+ self.conn.bind(schema2.__name__, PREFIX + '_test1')
+ s = schema2.test1.Subjects()
+
+
+class TestBaseDeclaration(object):
+ """
+ Test declaration (creation of table) from
+ definition in Base under various circumstances
+ """
+
+ def setup(self):
+ cleanup()
+
+ self.conn = Connection(**CONN_INFO)
+ test1.conn = self.conn
+ self.conn.bind(test1.__name__, PREFIX + '_test1')
+ test2.conn = self.conn
+ self.conn.bind(test2.__name__, PREFIX + '_test2')
+
+ def test_is_declared(self):
+ """
+ The table should not be created immediately after instantiation,
+ but should be created when declare method is called
+ :return:
+ """
+ s = test1.Subjects()
+ assert_false(s.is_declared)
+ s.declare()
+ assert_true(s.is_declared)
+
+ def test_calling_heading_should_trigger_declaration(self):
+ s = test1.Subjects()
+ assert_false(s.is_declared)
+ a = s.heading
+ assert_true(s.is_declared)
+
+ def test_foreign_key_ref_in_same_schema(self):
+ s = test1.Experiments()
+ assert_true('subject_id' in s.heading.primary_key)
+
+ def test_foreign_key_ref_in_another_schema(self):
+ s = test2.Experiments()
+ assert_true('subject_id' in s.heading.primary_key)
+
+ def test_aliased_module_name_should_resolve(self):
+ """
+ Module names that were aliased in the definition should
+ be properly resolved.
+ """
+ s = test2.Conditions()
+ assert_true('subject_id' in s.heading.primary_key)
+
+ def test_reference_to_unknown_module_in_definition_should_fail(self):
+ """
+ Module names in table definition that is not aliased via import
+ results in error
+ """
+ s = test2.FoodPreference()
+ with assert_raises(DataJointError) as e:
+ s.declare()
+
+
+class TestBaseWithExistingTables(object):
+ """
+ Test base derivatives behaviors when some of the tables
+ already exists in the database
+ """
+ def setup(self):
+ cleanup()
+ self.conn = Connection(**CONN_INFO)
+ setup_sample_db()
+ test1.conn = self.conn
+ self.conn.bind(test1.__name__, PREFIX + '_test1')
+ test2.conn = self.conn
+ self.conn.bind(test2.__name__, PREFIX + '_test2')
+ self.conn.load_headings(force=True)
+
+ schema2.conn = self.conn
+ self.conn.bind(schema2.__name__, PREFIX + '_package')
+
+ def teardown(selfself):
+ schema1.__dict__.pop('conn', None)
+ cleanup()
+
+ def test_detection_of_existing_table(self):
+ """
+ The Base instance should be able to detect if the
+ corresponding table already exists in the database
+ """
+ s = test1.Subjects()
+ assert_true(s.is_declared)
+
+ def test_definition_referring_to_existing_table_without_class(self):
+ s1 = test1.Session()
+ assert_true('experimenter_id' in s1.primary_key)
+
+ s2 = test2.Session()
+ assert_true('experimenter_id' in s2.primary_key)
+
+ def test_reference_to_package_level_table(self):
+ s = test1.Match()
+ s.declare()
+ assert_true('pop_id' in s.primary_key)
+
+@raises(TypeError)
+def test_instantiation_of_base_derivative_without_definition_should_fail():
+ test1.Empty()
+
+
+
+
+
+
+
+
+
diff --git a/tests/test_connection.py b/tests/test_connection.py
index b54ad82a8..cd8af40fd 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -1,8 +1,10 @@
"""
Collection of test cases to test connection module.
"""
+from .schemata import schema1
from .schemata.schema1 import test1
+
__author__ = 'eywalker'
from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup)
from nose.tools import assert_true, assert_raises, assert_equal
@@ -81,6 +83,8 @@ class TestConnectionWithoutBindings(object):
"""
def setup(self):
self.conn = dj.Connection(**CONN_INFO)
+ test1.__dict__.pop('conn', None)
+ schema1.__dict__.pop('conn', None)
setup_sample_db()
def teardown(self):
@@ -103,7 +107,10 @@ def test_bind_to_existing_database(self):
self.check_binding(db_name, module)
def test_bind_at_package_level(self):
- pass
+ db_name = PREFIX + '_test1'
+ package = schema1.__name__
+ self.conn.bind(package, db_name)
+ self.check_binding(db_name, package)
def test_bind_to_non_existing_database(self):
"""
diff --git a/tests/test_core.py b/tests/test_utils.py
similarity index 93%
rename from tests/test_core.py
rename to tests/test_utils.py
index bfbfb0dd2..9bce1de8a 100644
--- a/tests/test_core.py
+++ b/tests/test_utils.py
@@ -3,7 +3,6 @@
"""
__author__ = 'eywalker'
-from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup)
from nose.tools import assert_true, assert_raises, assert_equal
from datajoint.utils import to_camel_case, from_camel_case
from datajoint import DataJointError