diff --git a/datajoint/base.py b/datajoint/base.py index 96a34859d..d9ca80581 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -5,6 +5,7 @@ from . import DataJointError from .table import Table import logging +from .declare import declare logger = logging.getLogger(__name__) @@ -42,16 +43,19 @@ def __init__(self): self.class_name = self.__class__.__name__ module = self.__module__ mod_obj = importlib.import_module(module) + use_package = False try: conn = mod_obj.conn except AttributeError: try: + # check if database bound at the package level instead pkg_obj = importlib.import_module(mod_obj.__package__) conn = pkg_obj.conn use_package = True except AttributeError: raise DataJointError( "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) + self.conn = conn try: if use_package: pkg_name = '.'.join(module.split('.')[:-1]) @@ -61,7 +65,8 @@ def __init__(self): except KeyError: raise DataJointError( 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - super().__init__(self, conn=conn, dbname=dbname, class_name=self.__class__.__name__) + declare(self.conn, self.definition, self.full_class_name) + super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__) @classmethod diff --git a/datajoint/declare.py b/datajoint/declare.py index af09334ae..d95468e8e 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -14,14 +14,15 @@ def declare(conn, definition, class_name): """ Declares the table in the data base if no table in the database matches this object. """ - table_info, parents, referenced, field_definitions, index_definitions = parse_declaration(definition) + table_info, parents, referenced, field_definitions, index_definitions = _parse_declaration(conn, definition) defined_name = table_info['module'] + '.' + table_info['className'] - if not defined_name == class_name: - raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(class_name, defined_name)) + # TODO: clean up this mess... currently just ignoring the name used to define the table + #if not defined_name == class_name: + # raise DataJointError('Table name {} does not match the declared' + # 'name {}'.format(class_name, defined_name)) # compile the CREATE TABLE statement - table_name = role_to_prefix[table_info['tier']] + from_camel_case(class_name) + table_name = role_to_prefix[table_info['tier']] + from_camel_case(defined_name) sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) # add inherited primary key fields @@ -88,7 +89,8 @@ def declare(conn, definition, class_name): sql[:-2], table_info['comment']) # make sure that the table does not alredy exist - self.conn.load_headings(self.dbname, force=True) + # TODO: there will be a problem with resolving the module here... + conn.load_headings(self.dbname, force=True) if not self.is_declared: # execute declaration logger.debug('\n\n' + sql + '\n\n') @@ -96,7 +98,7 @@ def declare(conn, definition, class_name): self.conn.load_headings(self.dbname, force=True) -def _parse_declaration(self): +def _parse_declaration(conn, definition): """ Parse declaration and create new SQL table accordingly. """ @@ -104,7 +106,7 @@ def _parse_declaration(self): referenced = [] index_defs = [] field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) + declaration = re.split(r'\s*\n\s*', definition.strip()) # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] @@ -205,11 +207,11 @@ def parse_attribute_definition(line, in_key=False): # todo add docu for in_key assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ 'BIGINT attributes cannot be nullable in "%s"' % line - attr_info['isKey'] = in_key - attr_info['isAutoincrement'] = None - attr_info['isNumeric'] = None - attr_info['isString'] = None - attr_info['isBlob'] = None + attr_info['in_key'] = in_key + attr_info['autoincrement'] = None + attr_info['numeric'] = None + attr_info['string'] = None + attr_info['is_blob'] = None attr_info['computation'] = None attr_info['dtype'] = None diff --git a/datajoint/heading.py b/datajoint/heading.py index 18e75df62..507fdd38c 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -81,6 +81,9 @@ def values(self): def items(self): return self.attributes.items() + def __iter__(self): + return iter(self.attributes) + @classmethod def init_from_database(cls, conn, dbname, table_name): """ diff --git a/datajoint/relational.py b/datajoint/relational.py index f72bd1369..3e8fc42bf 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -65,9 +65,8 @@ def project(self, *selection, **aliases): Each attribute can only be used once in attributes or renames. Therefore, the projected relation cannot have more attributes than the original relation. """ - return self.aggregate( - group=selection.pop[0] if selection and isinstance(selection[0], Relation) else None, - *selection, **aliases) + group = selection.pop[0] if selection and isinstance(selection[0], Relation) else None + return self.aggregate(group, *selection, **aliases) def aggregate(self, group, *selection, **aliases): """ @@ -79,8 +78,9 @@ def aggregate(self, group, *selection, **aliases): if group is not None and not isinstance(group, Relation): raise DataJointError('The second argument of aggregate must be a relation') # convert the string notation for aliases to - - return Projection(group=group, *attriutes, **aliases) + # handling of the variable group is unclear here + # and thus ommitted + return Projection(self, *selection, **aliases) def __iand__(self, restriction): """ @@ -121,6 +121,9 @@ def count(self): cur = self.conn.query(sql) return cur.fetchone()[0] + def fetch(self, *args, **kwargs): + return self(*args, **kwargs) + def __call__(self, offset=0, limit=None, order_by=None, descending=False): """ fetches the relation from the database table into an np.array and unpacks blob attributes. @@ -131,7 +134,7 @@ def __call__(self, offset=0, limit=None, order_by=None, descending=False): :return: the contents of the relation in the form of a structured numpy.array """ cur = self.cursor(offset, limit, order_by, descending) - ret = np.array(list(cur), dtype=self.heading.asdtype) + ret = np.array(list(cur), dtype=self.heading.as_dtype) for f in self.heading.blobs: for i in range(len(ret)): ret[i][f] = unpack(ret[i][f]) @@ -160,12 +163,12 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): return self.conn.query(sql) def __repr__(self): - limit = 13 + limit = 13 #TODO: move some of these display settings into the config width = 12 template = '%%-%d.%ds' % (width, width) - repr_string = ' '.join([template % column for column in header]) + '\n' - repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n' - tuples = self.pro(*self.heading.non_blobs).fetch(limit=limit) + repr_string = ' '.join([template % column for column in self.heading]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in self.heading]) + '\n' + tuples = self.project(*self.heading.non_blobs).fetch(limit=limit) for tup in tuples: repr_string += ' '.join([template % column for column in tup]) + '\n' if self.count > limit: @@ -177,7 +180,7 @@ def __iter__(self): """ iterator yields primary key tuples """ - cur, h = self.pro().cursor() + cur, h = self.project().cursor() q = cur.fetchone() while q: yield np.array([q, ], dtype=h.asdtype) @@ -266,11 +269,11 @@ def __init__(self, relation, *attributes, **renames): @property def sql(self): - return self._rel.sql + return self._relation.sql @property def heading(self): - return self._rel.heading.pro(*self._selection, **self._renames) + return self._relation.heading.pro(*self._projection_attributes, **self._renamed_attributes) class Subquery(Relation): diff --git a/datajoint/table.py b/datajoint/table.py index d1c6c3bb9..14971e868 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -2,7 +2,7 @@ import logging from . import DataJointError from .relational import Relation -from .declare import declare +from .declare import (declare, parse_attribute_definition) logger = logging.getLogger(__name__) diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index add1d39a4..fbc215864 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -7,7 +7,7 @@ class Subjects(dj.Base): - _table_def = """ + definition = """ test1.Subjects (manual) # Basic subject info subject_id : int # unique subject id diff --git a/tests/schemata/schema1/test2.py b/tests/schemata/schema1/test2.py index 775c716e4..f25b8d2a1 100644 --- a/tests/schemata/schema1/test2.py +++ b/tests/schemata/schema1/test2.py @@ -8,7 +8,7 @@ class Experiments(dj.Base): - _table_def = """ + definition = """ test2.Experiments (manual) # Basic subject info -> test1.Subjects experiment_id : int # unique experiment id diff --git a/tests/schemata/schema1/test3.py b/tests/schemata/schema1/test3.py index bb764c5cb..7551854c1 100644 --- a/tests/schemata/schema1/test3.py +++ b/tests/schemata/schema1/test3.py @@ -7,7 +7,7 @@ class Subjects(dj.Base): - _table_def = """ + definition = """ test3.Subjects (manual) # Basic subject info subject_id : int # unique subject id