diff --git a/.gitignore b/.gitignore index ef4ea964a..ecd81049f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .ipynb_checkpoints/ +*.json */.*.swp */.*.swo */*.pyc diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 8e32ef43d..fa0e9300b 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -2,9 +2,11 @@ from . import DataJointError import pprint import abc +import logging #noinspection PyExceptionInherit,PyCallingNonCallable +logger = logging.getLogger(__name__) class AutoPopulate(metaclass=abc.ABCMeta): """ @@ -29,6 +31,10 @@ def make_tuples(self, key): """ pass + @property + def target(self): + return self + def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): """ rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel @@ -38,9 +44,9 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): raise DataJointError('') self.conn.cancel_transaction() - unpopulated = self.pop_rel - self + unpopulated = self.pop_rel - self.target if not unpopulated.count: - print('Nothing to populate', flush=True) # TODO: use logging? + logger.info('Nothing to populate', flush=True) if catch_errors: error_keys, errors = [], [] for key in unpopulated.fetch(): diff --git a/datajoint/base.py b/datajoint/base.py index 45370c182..eed8b2729 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -1,14 +1,10 @@ import importlib import abc from types import ModuleType -from enum import Enum from . import DataJointError from .table import Table import logging -import re -from .settings import Role, role_to_prefix -from .utils import from_camel_case -from .heading import Heading + logger = logging.getLogger(__name__) @@ -49,7 +45,7 @@ def full_class_name(self): return '{}.{}'.format(self.__module__, self.class_name) @property - def access_name(self): + def ref_name(self): """ :return: name by which this class should be accessible as """ @@ -60,12 +56,12 @@ def access_name(self): return parent + '.' + self.class_name - def __init__(self): #TODO: support taking in conn obj - self.class_name = self.__class__.__name__ - module = self.__module__ - mod_obj = importlib.import_module(module) + class_name = self.__class__.__name__ + module_name = self.__module__ + mod_obj = importlib.import_module(module_name) self._use_package = False + # first, find the conn object try: conn = mod_obj.conn except AttributeError: @@ -76,19 +72,20 @@ def __init__(self): #TODO: support taking in conn obj self._use_package = True except AttributeError: raise DataJointError( - "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) - self.conn = conn + "Please define object 'conn' in '{}' or in its containing package.".format(module_name)) + # now use the conn object to determine the dbname this belongs to try: if self._use_package: # the database is bound to the package - pkg_name = '.'.join(module.split('.')[:-1]) - dbname = self.conn.mod_to_db[pkg_name] + pkg_name = '.'.join(module_name.split('.')[:-1]) + dbname = conn.mod_to_db[pkg_name] else: - dbname = self.conn.mod_to_db[module] + dbname = conn.mod_to_db[module_name] except KeyError: raise DataJointError( - 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - self.dbname = dbname + 'Module {} is not bound to a database. See datajoint.connection.bind'.format(module_name)) + # initialize using super class's constructor + super().__init__(conn, dbname, class_name) def get_base(self, module_name, class_name): diff --git a/datajoint/heading.py b/datajoint/heading.py index 507fdd38c..5406cf271 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -8,13 +8,19 @@ class Heading: """ local class for relations' headings. + Heading contains the property attributes, which is an OrderedDict in which the keys are + the attribute names and the values are AttrTuples. """ AttrTuple = namedtuple('AttrTuple', - ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', - 'numeric', 'string', 'is_blob', 'computation', 'dtype')) + ('name', 'type', 'in_key', 'nullable', 'default', + 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', + 'computation', 'dtype')) + AttrTuple.as_dict = AttrTuple._asdict # rename the method into a nicer name def __init__(self, attributes): - # Input: attributes -list of dicts with attribute descriptions + """ + :param attributes: a list of dicts with the same keys as AttrTuple + """ self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) @property @@ -91,7 +97,7 @@ def init_from_database(cls, conn, dbname, table_name): """ cur = conn.query( 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( - table_name=table_name, dbname=dbname), asDict=True) + table_name=table_name, dbname=dbname), asDict=True) attributes = cur.fetchall() rename_map = { @@ -145,7 +151,7 @@ def init_from_database(cls, conn, dbname, table_name): field=attr['type'], dbname=dbname, table_name=table_name)) attr.pop('Extra') - # fill out the dtype. All floats and non-nullable integers are turned into specific dtypes + # fill out dtype. All floats and non-nullable integers are turned into specific dtypes attr['dtype'] = object if attr['numeric']: is_integer = bool(re.match(r'(tiny|small|medium|big)?int', attr['type'])) @@ -160,30 +166,29 @@ def init_from_database(cls, conn, dbname, table_name): return cls(attributes) - def pro(self, *attribute_list, **rename_dict): + def project(self, *attribute_list, **renamed_attributes): """ derive a new heading by selecting, renaming, or computing attributes. In relational algebra these operators are known as project, rename, and expand. The primary key is always included. """ - # include all if '*' is in attribute_set, always include primary key - attribute_set = set(self.names) if '*' in attribute_list \ - else set(attribute_list).union(self.primary_key) - - # report missing attributes - missing = attribute_set.difference(self.names) + # check missing attributes + missing = [a for a in attribute_list if a not in self.names] if missing: - raise DataJointError('Attributes %s are not found' % str(missing)) + raise DataJointError('Attributes `%s` are not found' % '`, `'.join(missing)) + + # always add primary key attributes + attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] - # make attribute_list a list of dicts for initializing a Heading - attribute_list = [v._asdict() for k, v in self.attributes.items() - if k in attribute_set and k not in rename_dict.values()] + # convert attribute_list into a list of dicts but exclude renamed attributes + attribute_list = [v.as_dict() for k, v in self.attributes.items() + if k in attribute_list and k not in renamed_attributes.values()] # add renamed and computed attributes - for new_name, computation in rename_dict.items(): + for new_name, computation in renamed_attributes.items(): if computation in self.names: # renamed attribute - new_attr = self.attributes[computation]._asdict() + new_attr = self.attributes[computation].as_dict() new_attr['name'] = new_name new_attr['computation'] = '`' + computation + '`' else: @@ -210,14 +215,14 @@ def join(self, other): join two headings """ assert isinstance(other, Heading) - attribute_list = [v._asdict() for v in self.attributes.values()] + attribute_list = [v.as_dict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attribute_list.append(other.attributes[name]._asdict()) + attribute_list.append(other.attributes[name].as_dict()) return Heading(attribute_list) def resolve_computations(self): """ Remove computations. To be done after computations have been resolved in a subquery """ - return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file + return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relational.py b/datajoint/relational.py index 3e8fc42bf..cdcfa5458 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -4,6 +4,7 @@ import numpy as np import abc +import re from copy import copy from datajoint import DataJointError from .blob import unpack @@ -25,11 +26,20 @@ class Relation(metaclass=abc.ABCMeta): @abc.abstractproperty def sql(self): - return NotImplemented - + """ + The sql property returns the tuple: (SQL command, Heading object) for its relation. + The SQL command does not include the attribute list or the WHERE clause. + :return: sql, heading + """ + pass + @abc.abstractproperty - def heading(self): - return NotImplemented + def conn(self): + """ + All relations must keep track of their connection object + :return: + """ + pass @property def restrictions(self): @@ -41,46 +51,50 @@ def __mul__(self, other): """ return Join(self, other) - def __mod__(self, attribute_list): + def __mod__(self, attributes=None): """ - relational projection operator. - :param attribute_list: list of attribute specifications. - The attribute specifications are strings in following forms: - 'name' - specific attribute - 'name->new_name' - rename attribute. The old attribute is kept only if specifically included. - 'sql_expression->new_name' - extend attribute, i.e. a new computed attribute. - :return: a new relation with specified heading + relational projection operator. See Relation.project """ - self.project(attribute_list) + return self.project(*attributes) - def project(self, *selection, **aliases): + def project(self, *attributes, **renamed_attributes): """ Relational projection operator. :param attributes: a list of attribute names to be included in the result. - :param renames: a dict of attributes to be renamed :return: a new relation with selected fields Primary key attributes are always selected and cannot be excluded. Therefore obj.project() produces a relation with only the primary key attributes. - If selection includes the string '*', all attributes are selected. - Each attribute can only be used once in attributes or renames. Therefore, the projected + If attributes includes the string '*', all attributes are selected. + Each attribute can only be used once in attributes or renamed_attributes. Therefore, the projected relation cannot have more attributes than the original relation. """ - group = selection.pop[0] if selection and isinstance(selection[0], Relation) else None - return self.aggregate(group, *selection, **aliases) + # if the first attribute is a relation, it will be aggregated + group = attributes.pop[0] \ + if attributes and isinstance(attributes[0], Relation) else None + return self.aggregate(group, *attributes, **renamed_attributes) - def aggregate(self, group, *selection, **aliases): + def aggregate(self, _group, *attributes, **renamed_attributes): """ Relational aggregation operator - :param grouped_relation: + :param group: relation whose tuples can be used in aggregation operators :param extensions: - :return: + :return: a relation representing the aggregation/projection operator result """ - if group is not None and not isinstance(group, Relation): - raise DataJointError('The second argument of aggregate must be a relation') - # convert the string notation for aliases to - # handling of the variable group is unclear here - # and thus ommitted - return Projection(self, *selection, **aliases) + if _group is not None and not isinstance(_group, Relation): + raise DataJointError('The second argument must be a relation or None') + alias_parser = re.compile( + '^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') + + # expand extended attributes in the form 'sql_expression -> new_attribute' + _attributes = [] + for attribute in attributes: + alias_match = alias_parser.match(attribute) + if alias_match: + d = alias_match.group_dict() + renamed_attributes.update({d['alias']: d['sql_expression']}) + else: + _attributes += attribute + return Projection(self, _group, *_attributes, **renamed_attributes) def __iand__(self, restriction): """ @@ -104,7 +118,7 @@ def __and__(self, restriction): def __isub__(self, restriction): """ - in-place inverted restriction aka antijoin + in-place antijoin (inverted restriction) """ self &= Not(restriction) return self @@ -117,14 +131,13 @@ def __sub__(self, restriction): @property def count(self): - sql = 'SELECT count(*) FROM ' + self.sql + self._where_clause - cur = self.conn.query(sql) + cur = self.conn.query('SELECT count(*) FROM ' + self.sql[0] + self._where) return cur.fetchone()[0] - def fetch(self, *args, **kwargs): + def __call__(self, *args, **kwargs): return self(*args, **kwargs) - def __call__(self, offset=0, limit=None, order_by=None, descending=False): + def fetch(self, offset=0, limit=None, order_by=None, descending=False): """ fetches the relation from the database table into an np.array and unpacks blob attributes. :param offset: the number of tuples to skip in the returned result @@ -150,7 +163,8 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): """ if offset and limit is None: raise DataJointError('') - sql = 'SELECT ' + self.heading.as_sql + ' FROM ' + self.sql + sql, heading = self.sql + sql = 'SELECT ' + heading.as_sql + ' FROM ' + sql if order_by is not None: sql += ' ORDER BY ' + ', '.join(self._orderBy) if descending: @@ -179,19 +193,22 @@ def __repr__(self): def __iter__(self): """ iterator yields primary key tuples + Example: + for key in relation: + (schema.Relation & key).fetch('field') """ - cur, h = self.project().cursor() + cur, h = self.project().cursor() # project q = cur.fetchone() while q: yield np.array([q, ], dtype=h.asdtype) q = cur.fetchone() @property - def _where_clause(self): + def _where(self): """ - make there WHERE clause based on the current restriction + convert the restriction into an SQL WHERE """ - if not self._restrictions: + if not self.restrictions: return '' def make_condition(arg): @@ -230,64 +247,97 @@ class Not: inverse of a restriction """ def __init__(self, restriction): - self._restriction = restriction + self.__restriction = restriction + + @property + def restriction(self): + return self.__restriction class Join(Relation): - alias_counter = 0 + subquery_counter = 0 def __init__(self, rel1, rel2): if not isinstance(rel2, Relation): - raise DataJointError('relvars can only be joined with other relvars') + raise DataJointError('a relation can only be joined with another relation') if rel1.conn is not rel2.conn: raise DataJointError('Cannot join relations with different database connections') self.conn = rel1.conn - self._rel1 = rel1 - self._rel2 = rel2 - + self._rel1 = Subquery(rel1) + self._rel2 = Subquery(rel2) + + @property + def conn(self): + return self._rel1.conn + @property def heading(self): return self._rel1.heading.join(self._rel2.heading) - - @property + + @property + def counter(self): + self.subquery_counter += 1 + return self.subquery_counter + + @property def sql(self): - Join.alias_counter += 1 - return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.alias_counter) + return '%s NATURAL JOIN %s as `_j%x`' % (self._rel1.sql, self._rel2.sql, self.counter) class Projection(Relation): - alias_counter = 0 + subquery_counter = 0 - def __init__(self, relation, *attributes, **renames): + def __init__(self, relation, group=None, *attributes, **renamed_attributes): """ See Relation.project() """ - self.conn = relation.conn - self._relation = relation + if group: + if relation.conn is not group.conn: + raise DataJointError('Cannot join relations with different database connections') + self._group = Subquery(group) + self._relation = Subquery(relation) + else: + self._group = None + self._relation = relation self._projection_attributes = attributes - self._renamed_attributes = renames + self._renamed_attributes = renamed_attributes - @property - def sql(self): - return self._relation.sql - @property - def heading(self): - return self._relation.heading.pro(*self._projection_attributes, **self._renamed_attributes) + def conn(self): + return self._relation.conn + + @property + def sql(self): + sql, heading = self._relation.sql + heading = heading.pro(self._projection_attributes, self._renamed_attributes) + if self._group is not None: + group_sql, group_heading = self._group.sql + sql = ("(%s) NATURAL LEFT JOIN (%s) GROUP BY `%s`" % + (sql, group_sql, '`,`'.join(heading.primary_key))) + return sql, heading class Subquery(Relation): - alias_counter = 0 - + """ + A Subquery encapsulates its argument in a SELECT statement, enabling its use as a subquery. + The attribute list and the WHERE clause are resolved. + """ + _counter = 0 + def __init__(self, rel): - self.conn = rel.conn self._rel = rel - + @property - def sql(self): - self.alias_counter += 1 - return '(SELECT ' + self._rel.heading.as_sql + ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter - + def conn(self): + return self._rel.conn + @property - def heading(self): - return self._rel.heading.resolve_computations() \ No newline at end of file + def counter(self): + Subquery._counter += 1 + return Subquery._counter + + @property + def sql(self): + return ('(SELECT ' + self._rel.heading.as_sql + + ' FROM ' + self._rel.sql + self._rel.where + ') as `_s%x`' % self.counter),\ + self._rel.heading.clear_aliases() \ No newline at end of file diff --git a/datajoint/table.py b/datajoint/table.py index b9a63f383..e275e0c5b 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -30,15 +30,22 @@ class Table(Relation): def __init__(self, conn=None, dbname=None, class_name=None, definition=None): self.class_name = class_name - self.conn = conn + self._conn = conn self.dbname = dbname - self.definition = definition + self._definition = definition if dbname not in self.conn.db_to_mod: # register with a fake module, enclosed in back quotes # necessary for loading mechanism self.conn.bind('`{0}`'.format(dbname), dbname) + @property + def definition(self): + return self._definition + + @property + def conn(self): + return self._conn @property def is_declared(self): @@ -86,7 +93,7 @@ def _field_to_sql(field): #TODO move this into Attribute Tuple @property def sql(self): - return self.full_table_name + return self.full_table_name, self.heading @property def heading(self): @@ -350,7 +357,12 @@ def _parse_attr_def(self, line, in_key=False): return Heading.AttrTuple(**attr_info) def get_base(self, module_name, class_name): - return None + m = re.match(r'`(\w+)`', module_name) + if m: + dbname = m.group(1) + return Table(self.conn, dbname, class_name) + else: + return None @property def ref_name(self): @@ -367,12 +379,8 @@ def _declare(self): raise DataJointError('Table definition is missing!') table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() defined_name = table_info['module'] + '.' + table_info['className'] - if self._use_package: - parent = self.__module__.split('.')[-2] - else: - parent = self.__module__.split('.')[-1] - expected_name = parent + '.' + self.class_name - if not defined_name == expected_name: + + if not defined_name == self.ref_name: raise DataJointError('Table name {} does not match the declared' 'name {}'.format(expected_name, defined_name)) @@ -466,7 +474,7 @@ def _parse_declaration(self): # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] ptrn = """ - ^(?P\w+)\.(?P\w+)\s* # module.className + ^(?P[\w\`]+)\.(?P\w+)\s* # module.className \(\s*(?P\w+)\s*\)\s* # (tier) \#\s*(?P.*)$ # comment """ diff --git a/datajoint/utils.py b/datajoint/utils.py index 47aacdeeb..af2e8e310 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -11,8 +11,8 @@ def to_camel_case(s): >>>to_camel_case("table_name") "TableName" """ - def to_upper(matchobj): - return matchobj.group(0)[-1].upper() + def to_upper(match): + return match.group(0)[-1].upper() return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) @@ -26,13 +26,13 @@ def from_camel_case(s): "table_name" """ if re.search(r'\s', s): - raise DataJointError('White space is not allowed') + raise DataJointError('Input cannot contain white space') if re.match(r'\d.*', s): - raise DataJointError('String cannot begin with a digit') + raise DataJointError('Input cannot begin with a digit') if not re.match(r'^[a-zA-Z0-9]*$', s): raise DataJointError('String can only contain alphanumeric characters') - def conv(matchobj): - return ('_' if matchobj.groups()[0] else '') + matchobj.group(0).lower() + def convert(match): + return ('_' if match.groups()[0] else '') + match.group(0).lower() - return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s) \ No newline at end of file + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) \ No newline at end of file diff --git a/demos/demo1.py b/demos/demo1.py index 689905730..a601a0622 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -61,15 +61,4 @@ class Scan(dj.Base): depth : float # depth from surface wavelength : smallint # (nm) laser wavelength mwatts: numeric(4,1) # (mW) laser power to brain - """ - - -class ScanInfo(dj.Base, dj.AutoPopulate): - definition = None - pop_rel = Session - - def make_tuples(self, key): - info = (Session()*Scan() & key).pro('experiment_folder').fetch() - filename = os.path.join(info.experiment_folder, 'scan_%03', ) - - + """ \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index 2fa1bea0a..26005a460 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -39,12 +39,17 @@ def cleanup(): and then later reset FOREIGN_KEY_CHECKS flag """ cur = BASE_CONN.cursor() + # cancel any unfinished transactions + cur.execute("ROLLBACK") + # start a transaction now + cur.execute("START TRANSACTION WITH CONSISTENT SNAPSHOT") cur.execute("SHOW DATABASES LIKE '{}\_%'".format(PREFIX)) dbs = [x[0] for x in cur.fetchall()] cur.execute('SET FOREIGN_KEY_CHECKS=0') # unset foreign key check while deleting for db in dbs: cur.execute('DROP DATABASE `{}`'.format(db)) cur.execute('SET FOREIGN_KEY_CHECKS=1') # set foreign key check back on + cur.execute("COMMIT") def setup_sample_db(): """ diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index d0d276e6d..de36c821a 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -27,9 +27,9 @@ class Experiments(dj.Base): """ # refers to a table in dj_test2 (bound to test2) but without a class -class Session(dj.Base): +class Sessions(dj.Base): definition = """ - test1.Session (manual) # Experiment sessions + test1.Sessions (manual) # Experiment sessions -> test1.Subjects -> test2.Experimenter session_id : int # unique session id @@ -45,6 +45,13 @@ class Match(dj.Base): dob : date # date of birth """ +# this tries to reference a table in database directly without ORM +class TrainingSession(dj.Base): + definition = """ + test1.TrainingSession (manual) # training sessions + -> `dj_test2`.Experimenter + session_id : int # training session id + """ class Empty(dj.Base): pass diff --git a/tests/test_base.py b/tests/test_base.py index 49a5b6810..d4addbbc0 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -178,7 +178,7 @@ def test_detection_of_existing_table(self): assert_true(s.is_declared) def test_definition_referring_to_existing_table_without_class(self): - s1 = test1.Session() + s1 = test1.Sessions() assert_true('experimenter_id' in s1.primary_key) s2 = test2.Session() @@ -189,6 +189,15 @@ def test_reference_to_package_level_table(self): s.declare() assert_true('pop_id' in s.primary_key) + def test_direct_reference_to_existing_table_should_fail(self): + """ + When deriving from Base, definition should not contain direct reference + to a database name + """ + s = test1.TrainingSession() + with assert_raises(DataJointError): + s.declare() + @raises(TypeError) def test_instantiation_of_base_derivative_without_definition_should_fail(): test1.Empty() diff --git a/tests/test_table.py b/tests/test_table.py index a8cf35133..bfa03a4c6 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -8,6 +8,7 @@ from datajoint import DataJointError import numpy as np from numpy.testing import assert_array_equal +from datajoint.table import Table def setup(): """ @@ -121,4 +122,53 @@ def test_blob_insert(self): t = (0, x, 'this is a random image') self.relvar_blob.insert(t) x2 = self.relvar_blob.fetch()[0][1] - assert_array_equal(x,x2, 'inserted blob does not match') \ No newline at end of file + assert_array_equal(x,x2, 'inserted blob does not match') + +class TestUnboundTables(object): + """ + Test usages of Table objects not connected to a module. + """ + def setup(self): + cleanup() + self.conn = Connection(**CONN_INFO) + + def test_creation_from_definition(self): + definition = """ + `dj_free`.Animals (manual) # my animal table + animal_id : int # unique id for the animal + --- + animal_name : varchar(128) # name of the animal + """ + table = Table(self.conn, 'dj_free', 'Animals', definition) + table.declare() + assert_true('animal_id' in table.primary_key) + + def test_reference_to_non_existant_table_should_fail(self): + definition = """ + `dj_free`.Recordings (manual) # recordings + -> `dj_free`.Animals + rec_session_id : int # recording session identifier + """ + table = Table(self.conn, 'dj_free', 'Recordings', definition) + assert_raises(DataJointError, table.declare) + + def test_reference_to_existing_table(self): + definition1 = """ + `dj_free`.Animals (manual) # my animal table + animal_id : int # unique id for the animal + --- + animal_name : varchar(128) # name of the animal + """ + table1 = Table(self.conn, 'dj_free', 'Animals', definition1) + table1.declare() + + definition2 = """ + `dj_free`.Recordings (manual) # recordings + -> `dj_free`.Animals + rec_session_id : int # recording session identifier + """ + table2 = Table(self.conn, 'dj_free', 'Recordings', definition2) + table2.declare() + assert_true('animal_id' in table2.primary_key) + +