diff --git a/datajoint/base.py b/datajoint/base.py index f0f476f2c..45370c182 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -33,7 +33,6 @@ class Subjects(dj.Base): ''' """ - @abc.abstractproperty def definition(self): """ @@ -42,7 +41,27 @@ def definition(self): """ pass - def __init__(self): + @property + def full_class_name(self): + """ + :return: full class name including the entire package hierarchy + """ + return '{}.{}'.format(self.__module__, self.class_name) + + @property + def access_name(self): + """ + :return: name by which this class should be accessible as + """ + if self._use_package: + parent = self.__module__.split('.')[-2] + else: + parent = self.__module__.split('.')[-1] + return parent + '.' + self.class_name + + + + def __init__(self): #TODO: support taking in conn obj self.class_name = self.__class__.__name__ module = self.__module__ mod_obj = importlib.import_module(module) @@ -61,6 +80,7 @@ def __init__(self): self.conn = conn try: if self._use_package: + # the database is bound to the package pkg_name = '.'.join(module.split('.')[:-1]) dbname = self.conn.mod_to_db[pkg_name] else: @@ -69,256 +89,7 @@ def __init__(self): raise DataJointError( 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) self.dbname = dbname - self.declare() - super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__) - - @property - def is_declared(self): - self.conn.load_headings(self.dbname) - return self.class_name in self.conn.table_names[self.dbname] - - def declare(self): - """ - Declare the table in database if it doesn't already exist. - - :raises: DataJointError if the table cannot be declared. - """ - if not self.is_declared: - self._declare() - if not self.is_declared: - raise DataJointError( - 'Table could not be declared for %s' % self.class_name) - - def _field_to_sql(self, field): - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - mysql_constants = ['CURRENT_TIMESTAMP'] - if field.nullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) - - # TODO: escape instead! - same goes for Matlab side implementation - assert not any((c in r'\"' for c in field.comment)), \ - 'Illegal characters in attribute comment "%s"' % field.comment - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - def _declare(self): - """ - Declares the table in the data base if no table in the database matches this object. - """ - if not self.definition: - raise DataJointError('Table declaration is missing!') - table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() - defined_name = table_info['module'] + '.' + table_info['className'] - expected_name = self.__module__.split('.')[-1] + '.' + self.class_name - if not defined_name == expected_name: - raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(expected_name, defined_name)) - - # compile the CREATE TABLE statement - # TODO: support prefix - table_name = role_to_prefix[ - table_info['tier']] + from_camel_case(self.class_name) - sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += self._field_to_sql(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in fieldDefs if f.in_key): - if field.nullable: - raise DataJointError('Primary key {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary key ' - '{key}. Check to make sure that the key ' - 'is not declared already in referenced ' - 'tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add secondary foreign key attributes - for r in referenced: - keys = (x for x in r.heading.attrs.values() if x.in_key) - for field in keys: - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += self._field_to_sql(field) - # add dependent attributes - for field in (f for f in fieldDefs if not f.in_key): - non_key_fields.add(field.name) - sql += self._field_to_sql(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: finish this up... - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # make sure that the table does not alredy exist - self.conn.load_headings(self.dbname, force=True) - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - - def _parse_declaration(self): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - ^(?P\w+)\.(?P\w+)\s* # module.className - \(\s*(?P\w+)\s*\)\s* # (tier) - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.match(declaration[0]).groupdict() - if table_info['tier'] not in Role.__members__: - raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ - {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], - cls=table_info['className'])) - table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - field_ptrn = """ - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - module_name, class_name = line[2:].strip().split('.') - rel = self.get_base(module_name, class_name) - (parents if in_key else referenced).append(rel) - elif re.match(r'^(unique\s+)?index[^:]*$', line): - index_defs.append(self._parse_index_def(line)) - elif fieldP.match(line): - field_defs.append(self._parse_attr_def(line, in_key)) - else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: - :returns: attribute tuple - """ - line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - attr_info['in_key'] = in_key - attr_info['autoincrement'] = None - attr_info['numeric'] = None - attr_info['string'] = None - attr_info['is_blob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) - - def _parse_index_def(self, line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_ptrn = """ - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info def get_base(self, module_name, class_name): """ @@ -330,12 +101,15 @@ def get_base(self, module_name, class_name): :returns: the base relation """ mod_obj = self.get_module(module_name) + if not mod_obj: + raise DataJointError('Module named {mod_name} was not found. Please make' + ' sure that it is in the path or you import the module.'.format(mod_name=module_name)) try: ret = getattr(mod_obj, class_name)() - except KeyError: - ret = self.__class__(conn=self.conn, - dbname=self.conn.schemas[module_name], - class_name=class_name) + except AttributeError: + ret = Table(conn=self.conn, + dbname=self.conn.mod_to_db[mod_obj.__name__], + class_name=class_name) return ret @classmethod @@ -358,6 +132,8 @@ def get_module(cls, module_name): # from IPython import embed # embed() mod_obj = importlib.import_module(cls.__module__) + if cls.__module__.split('.')[-1] == module_name: + return mod_obj attr = getattr(mod_obj, module_name, None) if isinstance(attr, ModuleType): return attr diff --git a/datajoint/connection.py b/datajoint/connection.py index 4655866f7..d94d404bd 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -66,7 +66,6 @@ def __init__(self, host, user, passwd, init_fun=None): print("Connected", user + '@' + host + ':' + str(port)) self._conn.autocommit(True) - self.mod_to_db2 = {} # database indexed by module names self.db_to_mod = {} # modules indexed by dbnames self.mod_to_db = {} # database names indexed by modules self.table_names = {} # tables names indexed by [dbname][class_name] diff --git a/datajoint/table.py b/datajoint/table.py index 3966c40c3..b9a63f383 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -3,6 +3,10 @@ from . import DataJointError from .relational import Relation from .blob import pack +from .heading import Heading +import re +from .settings import Role, role_to_prefix +from .utils import from_camel_case logger = logging.getLogger(__name__) @@ -28,19 +32,57 @@ def __init__(self, conn=None, dbname=None, class_name=None, definition=None): self.class_name = class_name self.conn = conn self.dbname = dbname - self.conn.load_headings(self.dbname) + self.definition = definition if dbname not in self.conn.db_to_mod: # register with a fake module, enclosed in back quotes + # necessary for loading mechanism self.conn.bind('`{0}`'.format(dbname), dbname) - # TODO: delay the loading until first use (move out of __init__) - self.conn.load_headings() - if self.class_name not in self.conn.table_names[self.dbname]: - if definition is None: - raise DataJointError('The table is not declared') - else: - declare(conn, definition, class_name) + + @property + def is_declared(self): + self.conn.load_headings(self.dbname) + return self.class_name in self.conn.table_names[self.dbname] + + def declare(self): + """ + Declare the table in database if it doesn't already exist. + + :raises: DataJointError if the table cannot be declared. + """ + if not self.is_declared: + self._declare() + if not self.is_declared: + raise DataJointError( + 'Table could not be declared for %s' % self.class_name) + + @staticmethod + def _field_to_sql(field): #TODO move this into Attribute Tuple + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + mysql_constants = ['CURRENT_TIMESTAMP'] + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes (even numeric), except special SQL values + # or values already enclosed by the user + if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: + default = '%s DEFAULT %s' % (default, field.default) + else: + default = '%s DEFAULT "%s"' % (default, field.default) + + if any((c in r'\"' for c in field.comment)): + raise DataJointError('Illegal characters in attribute comment "%s"' % field.comment) + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) @property def sql(self): @@ -48,6 +90,7 @@ def sql(self): @property def heading(self): + self.declare() return self.conn.headings[self.dbname][self.table_name] @property @@ -65,12 +108,6 @@ def table_name(self): return self.conn.table_names[self.dbname][self.class_name] - @property - def full_class_name(self): - """ - :return: full class name - """ - return '{}.{}'.format(self.__module__, self.class_name) @property def primary_key(self): @@ -82,7 +119,7 @@ def primary_key(self): def iter_insert(self, iter, **kwargs): """ - Inserts an entire batch of entries. Additional keyword arguments are passed it insert. + Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param iter: Must be an iterator that generates a sequence of valid arguments for insert. """ @@ -91,7 +128,7 @@ def iter_insert(self, iter, **kwargs): def batch_insert(self, data, **kwargs): """ - Inserts an entire batch of entries. Additional keyword arguments are passed it insert. + Inserts an entire batch of entries. Additional keyword arguments are passed to insert. :param data: must be iterable, each row must be a valid argument for insert """ @@ -249,4 +286,222 @@ def _alter(self, alter_statement): sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) self.conn.query(sql) self.conn.load_headings(self.dbname, force=True) - # TODO: place table definition sync mechanism \ No newline at end of file + # TODO: place table definition sync mechanism + + def _parse_index_def(self, line): + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_ptrn = """ + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """ + indexP = re.compile(index_ptrn, re.I + re.X) + m = indexP.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info + + def _parse_attr_def(self, line, in_key=False): + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute tuple + """ + line = line.strip() + attr_ptrn = """ + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """ + + attrP = re.compile(attr_ptrn, re.I + re.X) + m = attrP.match(line) + assert m, 'Invalid field declaration "%s"' % line + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + attr_info['in_key'] = in_key + attr_info['autoincrement'] = None + attr_info['numeric'] = None + attr_info['string'] = None + attr_info['is_blob'] = None + attr_info['computation'] = None + attr_info['dtype'] = None + + return Heading.AttrTuple(**attr_info) + + def get_base(self, module_name, class_name): + return None + + @property + def ref_name(self): + """ + :return: the name to refer to this class, taking form module.class or `database`.class + """ + return '`{0}`'.format(self.dbname) + '.' + self.class_name + + def _declare(self): + """ + Declares the table in the database if no table in the database matches this object. + """ + if not self.definition: + raise DataJointError('Table definition is missing!') + table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() + defined_name = table_info['module'] + '.' + table_info['className'] + if self._use_package: + parent = self.__module__.split('.')[-2] + else: + parent = self.__module__.split('.')[-1] + expected_name = parent + '.' + self.class_name + if not defined_name == expected_name: + raise DataJointError('Table name {} does not match the declared' + 'name {}'.format(expected_name, defined_name)) + + # compile the CREATE TABLE statement + # TODO: support prefix + table_name = role_to_prefix[table_info['tier']] + from_camel_case(self.class_name) + sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += self._field_to_sql(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in field_defs if f.in_key): + if field.nullable: + raise DataJointError('Primary key {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary key ' + '{key}. Check to make sure that the key ' + 'is not declared already in referenced ' + 'tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + keys = (x for x in r.heading.attrs.values() if x.in_key) + for field in keys: + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add dependent attributes + for field in (f for f in field_defs if not f.in_key): + non_key_fields.add(field.name) + sql += self._field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in indexDefs: + # TODO: finish this up... + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # make sure that the table does not alredy exist + self.conn.load_headings(self.dbname, force=True) + if not self.is_declared: + # execute declaration + logger.debug('\n\n' + sql + '\n\n') + self.conn.query(sql) + self.conn.load_headings(self.dbname, force=True) + + + def _parse_declaration(self): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', self.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + ^(?P\w+)\.(?P\w+)\s* # module.className + \(\s*(?P\w+)\s*\)\s* # (tier) + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.match(declaration[0]).groupdict() + if table_info['tier'] not in Role.__members__: + raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ + {module}.{cls}'.format(tier=table_info['tier'], + module=table_info[ + 'module'], + cls=table_info['className'])) + table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + field_ptrn = """ + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """ + fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + module_name, class_name = line[2:].strip().split('.') + rel = self.get_base(module_name, class_name) + (parents if in_key else referenced).append(rel) + elif re.match(r'^(unique\s+)?index[^:]*$', line): + index_defs.append(self._parse_index_def(line)) + elif fieldP.match(line): + field_defs.append(self._parse_attr_def(line, in_key)) + else: + raise DataJointError( + 'Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs diff --git a/tests/__init__.py b/tests/__init__.py index e6f7b5099..2fa1bea0a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,8 +14,8 @@ # Connection information for testing CONN_INFO = { 'host': environ.get('DJ_TEST_HOST', 'localhost'), - 'user': environ.get('DJ_TEST_USER', 'travis'), - 'passwd': environ.get('DJ_TEST_PASSWORD', '') + 'user': environ.get('DJ_TEST_USER', 'datajoint'), + 'passwd': environ.get('DJ_TEST_PASSWORD', 'datajoint') } # Prefix for all databases used during testing PREFIX = environ.get('DJ_TEST_DB_PREFIX', 'dj') @@ -46,6 +46,33 @@ def cleanup(): cur.execute('DROP DATABASE `{}`'.format(db)) cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on +def setup_sample_db(): + """ + Helper method to setup databases with tables to be used + during the test + """ + cur = BASE_CONN.cursor() + cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test1`".format(PREFIX)) + cur.execute("CREATE DATABASE IF NOT EXISTS `{}_test2`".format(PREFIX)) + query1 = """ + CREATE TABLE `{prefix}_test1`.`subjects` + ( + subject_id SMALLINT COMMENT 'Unique subject ID', + subject_name VARCHAR(255) COMMENT 'Subject name', + subject_email VARCHAR(255) COMMENT 'Subject email address', + PRIMARY KEY (subject_id) + ) + """.format(prefix=PREFIX) + cur.execute(query1) + query2 = """ + CREATE TABLE `{prefix}_test2`.`experimenter` + ( + experimenter_id SMALLINT COMMENT 'Unique experimenter ID', + experimenter_name VARCHAR(255) COMMENT 'Experimenter name', + PRIMARY KEY (experimenter_id) + )""".format(prefix=PREFIX) + cur.execute(query2) + diff --git a/tests/schemata/schema1/__init__.py b/tests/schemata/schema1/__init__.py index 281a53bba..6032e7bd6 100644 --- a/tests/schemata/schema1/__init__.py +++ b/tests/schemata/schema1/__init__.py @@ -2,6 +2,4 @@ import datajoint as dj print(__name__) -from .test1 import * -from .test2 import * from .test3 import * \ No newline at end of file diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index fbc215864..d0d276e6d 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -1,10 +1,10 @@ """ -Test 1 Schema definition - fully bound and has connection object +Test 1 Schema definition """ __author__ = 'eywalker' import datajoint as dj - +from .. import schema2 class Subjects(dj.Base): definition = """ @@ -15,3 +15,36 @@ class Subjects(dj.Base): real_id : varchar(40) # real-world name species = "mouse" : enum('mouse', 'monkey', 'human') # species """ + +# test reference to another table in same schema +class Experiments(dj.Base): + definition = """ + test1.Experiments (imported) # Experiment info + -> test1.Subjects + exp_id : int # unique id for experiment + --- + exp_data_file : varchar(255) # data file + """ + +# refers to a table in dj_test2 (bound to test2) but without a class +class Session(dj.Base): + definition = """ + test1.Session (manual) # Experiment sessions + -> test1.Subjects + -> test2.Experimenter + session_id : int # unique session id + --- + session_comment : varchar(255) # comment about the session + """ + +class Match(dj.Base): + definition = """ + test1.Match (manual) # Match between subject and color + -> schema2.Subjects + --- + dob : date # date of birth + """ + + +class Empty(dj.Base): + pass diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py index f25b8d2a1..920ef9b10 100644 --- a/tests/schemata/schema1/test2.py +++ b/tests/schemata/schema1/test2.py @@ -1,12 +1,14 @@ """ -Test 2 Schema definition - has conn but not bound +Test 2 Schema definition """ __author__ = 'eywalker' import datajoint as dj +from . import test1 as alias #from ..schema2 import test2 as test1 +# references to another schema class Experiments(dj.Base): definition = """ test2.Experiments (manual) # Basic subject info @@ -15,4 +17,29 @@ class Experiments(dj.Base): --- real_id : varchar(40) # real-world name species = "mouse" : enum('mouse', 'monkey', 'human') # species + """ + +# references to another schema +class Conditions(dj.Base): + definition = """ + test2.Conditions (manual) # Subject conditions + -> alias.Subjects + condition_name : varchar(255) # description of the condition + """ + +class FoodPreference(dj.Base): + definition = """ + test2.FoodPreference (manual) # Food preference of each subject + -> animals.Subjects + preferred_food : enum('banana', 'apple', 'oranges') + """ + +class Session(dj.Base): + definition = """ + test2.Session (manual) # Experiment sessions + -> test1.Subjects + -> test2.Experimenter + session_id : int # unique session id + --- + session_comment : varchar(255) # comment about the session """ \ No newline at end of file diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py index 7551854c1..59e1c84fb 100644 --- a/tests/schemata/schema1/test3.py +++ b/tests/schemata/schema1/test3.py @@ -1,5 +1,7 @@ """ Test 3 Schema definition - no binding, no conn + +To be bound at the package level """ __author__ = 'eywalker' @@ -8,10 +10,12 @@ class Subjects(dj.Base): definition = """ - test3.Subjects (manual) # Basic subject info + schema1.Subjects (manual) # Basic subject info subject_id : int # unique subject id + dob : date # date of birth --- real_id : varchar(40) # real-world name species = "mouse" : enum('mouse', 'monkey', 'human') # species - """ \ No newline at end of file + """ + diff --git a/tests/schemata/schema2/__init__.py b/tests/schemata/schema2/__init__.py index 25d6f77b7..e6b482590 100644 --- a/tests/schemata/schema2/__init__.py +++ b/tests/schemata/schema2/__init__.py @@ -1 +1,2 @@ __author__ = 'eywalker' +from .test1 import * \ No newline at end of file diff --git a/tests/schemata/schema2/test2.py b/tests/schemata/schema2/test1.py similarity index 80% rename from tests/schemata/schema2/test2.py rename to tests/schemata/schema2/test1.py index 05ed84a82..efcd71385 100644 --- a/tests/schemata/schema2/test2.py +++ b/tests/schemata/schema2/test1.py @@ -8,9 +8,8 @@ class Subjects(dj.Base): - _table_def = """ - test2.Subjects (manual) # Basic subject info - + definition = """ + schema2.Subjects (manual) # Basic subject info pop_id : int # unique experiment id --- real_id : varchar(40) # real-world name diff --git a/tests/test_base.py b/tests/test_base.py index 61ad99e8d..49a5b6810 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -2,14 +2,15 @@ Collection of test cases for base module. Tests functionalities such as creating tables using docstring table declarations """ +from .schemata import schema1, schema2 from .schemata.schema1 import test1, test2, test3 __author__ = 'eywalker' -from . import BASE_CONN, CONN_INFO, PREFIX, cleanup +from . import BASE_CONN, CONN_INFO, PREFIX, cleanup, setup_sample_db from datajoint.connection import Connection -from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true +from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true, raises from datajoint import DataJointError @@ -20,43 +21,43 @@ def setup(): pass -class TestBaseObject(object): +class TestBaseInstantiations(object): """ - Test cases for Base objects + Test cases for instantiating Base objects """ + def __init__(self): + self.conn = None + def setup(self): """ Create a connection object and prepare test modules as follows: test1 - has conn and bounded - test2 - has conn but not bound - test3 - no conn and not bound """ - cleanup() # drop all databases with PREFIX - self.conn = Connection(**CONN_INFO) - test1.conn = self.conn - self.conn.bind(test1.__name__, PREFIX+'_test1') - - test2.conn = self.conn + cleanup() # drop all databases with PREFIX + #test1.conn = self.conn + #self.conn.bind(test1.__name__, PREFIX+'_test1') - from .schemata.schema2 import test2 as test2_2 - test2_2.conn = self.conn - self.conn.bind(test2_2.__name__, PREFIX+'_test2_2') + #test2.conn = self.conn - test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 + #test3.__dict__.pop('conn', None) # make sure conn is not defined in test3 + test1.__dict__.pop('conn', None) + schema1.__dict__.pop('conn', None) # make sure conn is not defined at schema level def teardown(self): cleanup() + def test_instantiation_from_unbound_module_should_fail(self): """ Attempting to instantiate a Base derivative from a module with connection defined but not bound to a database should raise error """ + test1.conn = self.conn with assert_raises(DataJointError) as e: - test2.Experiments() + test1.Subjects() assert_regexp_matches(e.exception.args[0], r".*not bound.*") def test_instantiation_from_module_without_conn_should_fail(self): @@ -65,7 +66,7 @@ def test_instantiation_from_module_without_conn_should_fail(self): `conn` object should raise error """ with assert_raises(DataJointError) as e: - test3.Subjects() + test1.Subjects() assert_regexp_matches(e.exception.args[0], r".*define.*conn.*") def test_instantiation_of_base_derivatives(self): @@ -73,20 +74,133 @@ def test_instantiation_of_base_derivatives(self): Test instantiation and initialization of objects derived from Base class """ + test1.conn = self.conn + self.conn.bind(test1.__name__, PREFIX + '_test1') s = test1.Subjects() assert_equal(s.dbname, PREFIX + '_test1') assert_equal(s.conn, self.conn) assert_equal(s.definition, test1.Subjects.definition) - def test_declaration_status(self): - b = test1.Subjects() - assert_true(b.is_declared) - def test_declaration_from_doc_string(self): - cur = BASE_CONN.cursor() - assert_equal(cur.execute("SHOW TABLES IN `{}` LIKE 'subjects'".format(PREFIX + '_test1')), 0) - test1.Subjects().declare() - assert_equal(cur.execute("SHOW TABLES IN `{}` LIKE 'subjects'".format(PREFIX + '_test1')), 1) + + def test_packagelevel_binding(self): + schema2.conn = self.conn + self.conn.bind(schema2.__name__, PREFIX + '_test1') + s = schema2.test1.Subjects() + + +class TestBaseDeclaration(object): + """ + Test declaration (creation of table) from + definition in Base under various circumstances + """ + + def setup(self): + cleanup() + + self.conn = Connection(**CONN_INFO) + test1.conn = self.conn + self.conn.bind(test1.__name__, PREFIX + '_test1') + test2.conn = self.conn + self.conn.bind(test2.__name__, PREFIX + '_test2') + + def test_is_declared(self): + """ + The table should not be created immediately after instantiation, + but should be created when declare method is called + :return: + """ + s = test1.Subjects() + assert_false(s.is_declared) + s.declare() + assert_true(s.is_declared) + + def test_calling_heading_should_trigger_declaration(self): + s = test1.Subjects() + assert_false(s.is_declared) + a = s.heading + assert_true(s.is_declared) + + def test_foreign_key_ref_in_same_schema(self): + s = test1.Experiments() + assert_true('subject_id' in s.heading.primary_key) + + def test_foreign_key_ref_in_another_schema(self): + s = test2.Experiments() + assert_true('subject_id' in s.heading.primary_key) + + def test_aliased_module_name_should_resolve(self): + """ + Module names that were aliased in the definition should + be properly resolved. + """ + s = test2.Conditions() + assert_true('subject_id' in s.heading.primary_key) + + def test_reference_to_unknown_module_in_definition_should_fail(self): + """ + Module names in table definition that is not aliased via import + results in error + """ + s = test2.FoodPreference() + with assert_raises(DataJointError) as e: + s.declare() + + +class TestBaseWithExistingTables(object): + """ + Test base derivatives behaviors when some of the tables + already exists in the database + """ + def setup(self): + cleanup() + self.conn = Connection(**CONN_INFO) + setup_sample_db() + test1.conn = self.conn + self.conn.bind(test1.__name__, PREFIX + '_test1') + test2.conn = self.conn + self.conn.bind(test2.__name__, PREFIX + '_test2') + self.conn.load_headings(force=True) + + schema2.conn = self.conn + self.conn.bind(schema2.__name__, PREFIX + '_package') + + def teardown(selfself): + schema1.__dict__.pop('conn', None) + cleanup() + + def test_detection_of_existing_table(self): + """ + The Base instance should be able to detect if the + corresponding table already exists in the database + """ + s = test1.Subjects() + assert_true(s.is_declared) + + def test_definition_referring_to_existing_table_without_class(self): + s1 = test1.Session() + assert_true('experimenter_id' in s1.primary_key) + + s2 = test2.Session() + assert_true('experimenter_id' in s2.primary_key) + + def test_reference_to_package_level_table(self): + s = test1.Match() + s.declare() + assert_true('pop_id' in s.primary_key) + +@raises(TypeError) +def test_instantiation_of_base_derivative_without_definition_should_fail(): + test1.Empty() + + + + + + + + + diff --git a/tests/test_connection.py b/tests/test_connection.py index b54ad82a8..cd8af40fd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,8 +1,10 @@ """ Collection of test cases to test connection module. """ +from .schemata import schema1 from .schemata.schema1 import test1 + __author__ = 'eywalker' from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) from nose.tools import assert_true, assert_raises, assert_equal @@ -81,6 +83,8 @@ class TestConnectionWithoutBindings(object): """ def setup(self): self.conn = dj.Connection(**CONN_INFO) + test1.__dict__.pop('conn', None) + schema1.__dict__.pop('conn', None) setup_sample_db() def teardown(self): @@ -103,7 +107,10 @@ def test_bind_to_existing_database(self): self.check_binding(db_name, module) def test_bind_at_package_level(self): - pass + db_name = PREFIX + '_test1' + package = schema1.__name__ + self.conn.bind(package, db_name) + self.check_binding(db_name, package) def test_bind_to_non_existing_database(self): """ diff --git a/tests/test_core.py b/tests/test_utils.py similarity index 93% rename from tests/test_core.py rename to tests/test_utils.py index bfbfb0dd2..9bce1de8a 100644 --- a/tests/test_core.py +++ b/tests/test_utils.py @@ -3,7 +3,6 @@ """ __author__ = 'eywalker' -from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) from nose.tools import assert_true, assert_raises, assert_equal from datajoint.utils import to_camel_case, from_camel_case from datajoint import DataJointError