From ee04d7f1c48567fdefabaf304499836e8824bad5 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Wed, 6 May 2015 12:58:46 -0500 Subject: [PATCH 1/8] intermediate: working on relational algebra --- datajoint/base.py | 5 +--- datajoint/declare.py | 12 ++++----- datajoint/relational.py | 54 ++++++++++++++++++++++------------------- datajoint/utils.py | 14 +++++------ 4 files changed, 42 insertions(+), 43 deletions(-) diff --git a/datajoint/base.py b/datajoint/base.py index d9ca80581..7374d660e 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -1,7 +1,6 @@ import importlib import abc from types import ModuleType -from enum import Enum from . import DataJointError from .table import Table import logging @@ -28,7 +27,6 @@ class Subjects(dj.Base): real_id : varchar(40) # real-world name species = "mouse" : enum('mouse', 'monkey', 'human') # species ''' - """ @abc.abstractproperty @@ -68,7 +66,6 @@ def __init__(self): declare(self.conn, self.definition, self.full_class_name) super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__) - @classmethod def get_module(cls, module_name): """ @@ -97,4 +94,4 @@ def get_module(cls, module_name): try: return importlib.import_module(module_name) except ImportError: - return None + return None \ No newline at end of file diff --git a/datajoint/declare.py b/datajoint/declare.py index d95468e8e..1eb4a2124 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -44,8 +44,8 @@ def declare(conn, definition, class_name): raise DataJointError('Primary key {} cannot be nullable'.format( field.name)) if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary key ' - '{key}. Check to make sure that the key ' + raise DataJointError('Duplicate attribute {key} in the primary key. ' + 'Check to make sure that the key ' 'is not declared already in referenced ' 'tables'.format(key=field.name)) primary_key_fields.add(field.name) @@ -165,11 +165,9 @@ def field_to_sql(field): # if some default specified if field.default: # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) + add_quotes = field.default.upper() not in mysql_constants \ + and field.default[:1] not in ["'", '"'] + default += ' DEFAULT ' + ('"%s"' if add_quotes else '%s') % field.default # TODO: escape instead! - same goes for Matlab side implementation assert not any((c in r'\"' for c in field.comment)), \ diff --git a/datajoint/relational.py b/datajoint/relational.py index 3e8fc42bf..4757a3293 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -4,6 +4,7 @@ import numpy as np import abc +import re from copy import copy from datajoint import DataJointError from .blob import unpack @@ -41,46 +42,49 @@ def __mul__(self, other): """ return Join(self, other) - def __mod__(self, attribute_list): + def __mod__(self, attributes=None): """ - relational projection operator. - :param attribute_list: list of attribute specifications. - The attribute specifications are strings in following forms: - 'name' - specific attribute - 'name->new_name' - rename attribute. The old attribute is kept only if specifically included. - 'sql_expression->new_name' - extend attribute, i.e. a new computed attribute. - :return: a new relation with specified heading + relational projection operator. See Relation.project """ - self.project(attribute_list) + return self.project(*attributes) - def project(self, *selection, **aliases): + def project(self, *attributes, **aliases): """ Relational projection operator. :param attributes: a list of attribute names to be included in the result. - :param renames: a dict of attributes to be renamed :return: a new relation with selected fields Primary key attributes are always selected and cannot be excluded. Therefore obj.project() produces a relation with only the primary key attributes. - If selection includes the string '*', all attributes are selected. + If attributes includes the string '*', all attributes are selected. Each attribute can only be used once in attributes or renames. Therefore, the projected relation cannot have more attributes than the original relation. """ - group = selection.pop[0] if selection and isinstance(selection[0], Relation) else None - return self.aggregate(group, *selection, **aliases) + # if the first attribute is a relation, it will be aggregated + group = attributes.pop[0] \ + if attributes and isinstance(attributes[0], Relation) else None + return self.aggregate(group, *attributes, **aliases) - def aggregate(self, group, *selection, **aliases): + def aggregate(self, group, *attributes, **aliases): """ Relational aggregation operator - :param grouped_relation: + :param group: relation whose tuples can be used in aggregation operators :param extensions: - :return: + :return: a relation representing the aggregation/projection operator result """ if group is not None and not isinstance(group, Relation): - raise DataJointError('The second argument of aggregate must be a relation') - # convert the string notation for aliases to - # handling of the variable group is unclear here - # and thus ommitted - return Projection(self, *selection, **aliases) + raise DataJointError('The second argument must be a relation or None') + alias_parser = re.compile( + '^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') + + # expand extended attributes in the form 'sql_expression -> new_attribute' + new_selection = [] + for attribute in attributes: + alias_match = alias_parser.match(attribute) + if alias_match: + aliases.update(alias_match.groupdict()) + else: + new_selection += attribute + return Projection(self, group, *new_selection, **aliases) def __iand__(self, restriction): """ @@ -121,10 +125,10 @@ def count(self): cur = self.conn.query(sql) return cur.fetchone()[0] - def fetch(self, *args, **kwargs): + def __call__(self, *args, **kwargs): return self(*args, **kwargs) - def __call__(self, offset=0, limit=None, order_by=None, descending=False): + def fetch(self, offset=0, limit=None, order_by=None, descending=False): """ fetches the relation from the database table into an np.array and unpacks blob attributes. :param offset: the number of tuples to skip in the returned result @@ -180,7 +184,7 @@ def __iter__(self): """ iterator yields primary key tuples """ - cur, h = self.project().cursor() + cur, h = self.project().cursor() # project q = cur.fetchone() while q: yield np.array([q, ], dtype=h.asdtype) diff --git a/datajoint/utils.py b/datajoint/utils.py index 47aacdeeb..af2e8e310 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -11,8 +11,8 @@ def to_camel_case(s): >>>to_camel_case("table_name") "TableName" """ - def to_upper(matchobj): - return matchobj.group(0)[-1].upper() + def to_upper(match): + return match.group(0)[-1].upper() return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) @@ -26,13 +26,13 @@ def from_camel_case(s): "table_name" """ if re.search(r'\s', s): - raise DataJointError('White space is not allowed') + raise DataJointError('Input cannot contain white space') if re.match(r'\d.*', s): - raise DataJointError('String cannot begin with a digit') + raise DataJointError('Input cannot begin with a digit') if not re.match(r'^[a-zA-Z0-9]*$', s): raise DataJointError('String can only contain alphanumeric characters') - def conv(matchobj): - return ('_' if matchobj.groups()[0] else '') + matchobj.group(0).lower() + def convert(match): + return ('_' if match.groups()[0] else '') + match.group(0).lower() - return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s) \ No newline at end of file + return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s) \ No newline at end of file From f0828174a15049bcc2c9835c8a0c913ebd696602 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 7 May 2015 09:39:12 -0500 Subject: [PATCH 2/8] minor --- datajoint/relational.py | 44 ++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/datajoint/relational.py b/datajoint/relational.py index 4757a3293..714959cb8 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -64,27 +64,28 @@ def project(self, *attributes, **aliases): if attributes and isinstance(attributes[0], Relation) else None return self.aggregate(group, *attributes, **aliases) - def aggregate(self, group, *attributes, **aliases): + def aggregate(self, _group, *attributes, **aliases): """ Relational aggregation operator :param group: relation whose tuples can be used in aggregation operators :param extensions: :return: a relation representing the aggregation/projection operator result """ - if group is not None and not isinstance(group, Relation): + if _group is not None and not isinstance(_group, Relation): raise DataJointError('The second argument must be a relation or None') alias_parser = re.compile( '^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') # expand extended attributes in the form 'sql_expression -> new_attribute' - new_selection = [] + _attributes = [] for attribute in attributes: alias_match = alias_parser.match(attribute) if alias_match: - aliases.update(alias_match.groupdict()) + d = alias_match.group_dict() + aliases.update({d['alias']: d['sql_expression']}) else: - new_selection += attribute - return Projection(self, group, *new_selection, **aliases) + _attributes += attribute + return Projection(self, _group, *_attributes, **aliases) def __iand__(self, restriction): """ @@ -108,7 +109,7 @@ def __and__(self, restriction): def __isub__(self, restriction): """ - in-place inverted restriction aka antijoin + in-place antijoin (inverted restriction) """ self &= Not(restriction) return self @@ -121,8 +122,7 @@ def __sub__(self, restriction): @property def count(self): - sql = 'SELECT count(*) FROM ' + self.sql + self._where_clause - cur = self.conn.query(sql) + cur = self.conn.query('SELECT count(*) FROM ' + self.sql[0] + self._where) return cur.fetchone()[0] def __call__(self, *args, **kwargs): @@ -154,7 +154,8 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): """ if offset and limit is None: raise DataJointError('') - sql = 'SELECT ' + self.heading.as_sql + ' FROM ' + self.sql + sql, heading = self.sql + sql = 'SELECT ' + heading.as_sql + ' FROM ' + sql if order_by is not None: sql += ' ORDER BY ' + ', '.join(self._orderBy) if descending: @@ -191,11 +192,11 @@ def __iter__(self): q = cur.fetchone() @property - def _where_clause(self): + def _where(self): """ - make there WHERE clause based on the current restriction + convert the restriction into an SQL WHERE """ - if not self._restrictions: + if not self.restrictions: return '' def make_condition(arg): @@ -234,7 +235,11 @@ class Not: inverse of a restriction """ def __init__(self, restriction): - self._restriction = restriction + self.__restriction = restriction + + @property + def restriction(self): + return self.__restriction class Join(Relation): @@ -242,7 +247,7 @@ class Join(Relation): def __init__(self, rel1, rel2): if not isinstance(rel2, Relation): - raise DataJointError('relvars can only be joined with other relvars') + raise DataJointError('a relation can only be joined with another relation') if rel1.conn is not rel2.conn: raise DataJointError('Cannot join relations with different database connections') self.conn = rel1.conn @@ -262,17 +267,19 @@ def sql(self): class Projection(Relation): alias_counter = 0 - def __init__(self, relation, *attributes, **renames): + def __init__(self, relation, group=None, *attributes, **renames): """ See Relation.project() """ self.conn = relation.conn + self._group = group self._relation = relation self._projection_attributes = attributes self._renamed_attributes = renames @property def sql(self): + if return self._relation.sql @property @@ -290,8 +297,9 @@ def __init__(self, rel): @property def sql(self): self.alias_counter += 1 - return '(SELECT ' + self._rel.heading.as_sql + ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter + return ('(SELECT ' + self._rel.heading.as_sql + + ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter) @property def heading(self): - return self._rel.heading.resolve_computations() \ No newline at end of file + return self._rel.heading.resolve_aliases() \ No newline at end of file From f2c952ef28b00ad3cc382e43ffd4ff909716c915 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 7 May 2015 11:38:16 -0500 Subject: [PATCH 3/8] cosmetic changes --- datajoint/base.py | 82 +++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/datajoint/base.py b/datajoint/base.py index ce1e31318..68b5cc3fb 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -17,18 +17,19 @@ class Base(Table, metaclass=abc.ABCMeta): Base is a Table that implements data definition functions. It is an abstract class with the abstract property 'definition'. - Example for a usage of Base:: + Example for a usage of Base: import datajoint as dj - class Subjects(dj.Base): + class Subject(dj.Base): definition = ''' - test1.Subjects (manual) # Basic subject info - subject_id : int # unique subject id + test1.Subjects (manual) # Basic info about experiment subject + subject_id : int # unique subject id --- - real_id : varchar(40) # real-world name - species = "mouse" : enum('mouse', 'monkey', 'human') # species + species = "mouse" : enum('mouse', 'monkey', 'human') + date_of_birth : date + subject_notes : varchar(1000) # notes about the subject ''' """ @@ -87,7 +88,8 @@ def declare(self): raise DataJointError( 'Table could not be declared for %s' % self.class_name) - def _field_to_sql(self, field): + @staticmethod + def _field_to_sql(field): """ Converts an attribute definition tuple into SQL code. :param field: attribute definition @@ -216,28 +218,26 @@ def _parse_declaration(self): # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ + pattern = """ ^(?P\w+)\.(?P\w+)\s* # module.className \(\s*(?P\w+)\s*\)\s* # (tier) \#\s*(?P.*)$ # comment """ - p = re.compile(ptrn, re.X) + p = re.compile(pattern, re.X) table_info = p.match(declaration[0]).groupdict() if table_info['tier'] not in Role.__members__: raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], + module=table_info['module'], cls=table_info['className'])) table_info['tier'] = Role[table_info['tier']] # convert into enum in_key = True # parse primary keys - field_ptrn = """ - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + field_regexp = re.compile(""" + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """, re.X) for line in declaration[1:]: if line.startswith('---'): @@ -249,15 +249,15 @@ def _parse_declaration(self): (parents if in_key else referenced).append(rel) elif re.match(r'^(unique\s+)?index[^:]*$', line): index_defs.append(self._parse_index_def(line)) - elif fieldP.match(line): + elif field_regexp.match(line): field_defs.append(self._parse_attr_def(line, in_key)) else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) + raise DataJointError('Invalid table declaration line "%s"' % line) return table_info, parents, referenced, field_defs, index_defs - def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key + @staticmethod + def _parse_attr_def(line, in_key=False): # todo add docu for in_key """ Parse attribute definition line in the declaration and returns an attribute tuple. @@ -267,17 +267,15 @@ def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key :returns: attribute tuple """ line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() + field_regexp = re.compile(""" + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """, re.X) + match = field_regexp.match(line) + assert match, 'Invalid field declaration "%s"' % line + attr_info = match.groupdict() if not attr_info['comment']: attr_info['comment'] = '' if not attr_info['default']: @@ -296,7 +294,8 @@ def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key return Heading.AttrTuple(**attr_info) - def _parse_index_def(self, line): + @staticmethod + def _parse_index_def(line): """ Parses index definition. @@ -304,14 +303,13 @@ def _parse_index_def(self, line): :return: groupdict with index info """ line = line.strip() - index_ptrn = """ - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() + index_regexp = re.compile(""" + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """, re.I + re.X) + match = index_regexp.match(line) + assert match, 'Invalid index declaration "%s"' % line + index_info = match.groupdict() attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) index_info['attributes'] = attributes assert len(attributes) == len(set(attributes)), \ @@ -364,4 +362,4 @@ def get_module(cls, module_name): try: return importlib.import_module(module_name) except ImportError: - return None + return None \ No newline at end of file From 6c47bc2016d524328d77086e6a3b18f1f43702b3 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 7 May 2015 16:47:17 -0500 Subject: [PATCH 4/8] intermediate commit before merge --- .gitignore | 1 + datajoint/relational.py | 75 ++++++++++++++++++++++++++--------------- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index ef4ea964a..ecd81049f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .ipynb_checkpoints/ +*.json */.*.swp */.*.swo */*.pyc diff --git a/datajoint/relational.py b/datajoint/relational.py index 714959cb8..2022bead7 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -26,11 +26,20 @@ class Relation(metaclass=abc.ABCMeta): @abc.abstractproperty def sql(self): - return NotImplemented - + """ + The sql property returns the tuple: (SQL command, Heading object) for its relation. + The SQL command does not include the attribute list or the WHERE clause. + :return: sql, heading + """ + pass + @abc.abstractproperty - def heading(self): - return NotImplemented + def conn(self): + """ + All relations must keep track of their connection object + :return: + """ + pass @property def restrictions(self): @@ -184,6 +193,9 @@ def __repr__(self): def __iter__(self): """ iterator yields primary key tuples + Example: + for key in relation: + (schema.Relation & key).fetch('field') """ cur, h = self.project().cursor() # project q = cur.fetchone() @@ -243,7 +255,7 @@ def restriction(self): class Join(Relation): - alias_counter = 0 + subquery_counter = 0 def __init__(self, rel1, rel2): if not isinstance(rel2, Relation): @@ -253,19 +265,19 @@ def __init__(self, rel1, rel2): self.conn = rel1.conn self._rel1 = rel1 self._rel2 = rel2 - + @property def heading(self): return self._rel1.heading.join(self._rel2.heading) - - @property + + @property def sql(self): - Join.alias_counter += 1 - return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.alias_counter) + Join.subquery_counter += 1 + return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.subquery_counter) class Projection(Relation): - alias_counter = 0 + subquery_counter = 0 def __init__(self, relation, group=None, *attributes, **renames): """ @@ -277,29 +289,36 @@ def __init__(self, relation, group=None, *attributes, **renames): self._projection_attributes = attributes self._renamed_attributes = renames - @property - def sql(self): - if - return self._relation.sql - @property - def heading(self): - return self._relation.heading.pro(*self._projection_attributes, **self._renamed_attributes) + def sql(self): + sql = self._relation.sql + if self._group is not None: + sql = "NATURAL LEFT JOIN " + return self._relation.sql, \ + self._relation.heading.pro(*self._projection_attributes, **self._renamed_attributes) class Subquery(Relation): - alias_counter = 0 - + """ + A Subquery encapsulates its argument in a SELECT statement, enabling its use as a subquery. + The attribute list and the WHERE clause are resolved. + """ + _counter = 0 + def __init__(self, rel): - self.conn = rel.conn self._rel = rel - + + @property + def conn(self): + return self._rel.conn + + @property + def counter(self): + Subquery._counter += 1 + return Subquery._counter + @property def sql(self): - self.alias_counter += 1 return ('(SELECT ' + self._rel.heading.as_sql + - ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter) - - @property - def heading(self): - return self._rel.heading.resolve_aliases() \ No newline at end of file + ' FROM ' + self._rel.sql + self._rel.where + ') as `s%x`' % self.counter),\ + self._rel.heading.clear_aliases() \ No newline at end of file From 0590b8e107f3c2f8b6be03e3f7e9d88f45544bae Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 7 May 2015 19:12:42 -0500 Subject: [PATCH 5/8] intermediate commit before merge --- datajoint/heading.py | 8 ++++---- datajoint/relational.py | 42 +++++++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/datajoint/heading.py b/datajoint/heading.py index 507fdd38c..d288e0e13 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -145,7 +145,7 @@ def init_from_database(cls, conn, dbname, table_name): field=attr['type'], dbname=dbname, table_name=table_name)) attr.pop('Extra') - # fill out the dtype. All floats and non-nullable integers are turned into specific dtypes + # fill out dtype. All floats and non-nullable integers are turned into specific dtypes attr['dtype'] = object if attr['numeric']: is_integer = bool(re.match(r'(tiny|small|medium|big)?int', attr['type'])) @@ -160,7 +160,7 @@ def init_from_database(cls, conn, dbname, table_name): return cls(attributes) - def pro(self, *attribute_list, **rename_dict): + def pro(self, *attribute_list, **renamed_attributes): """ derive a new heading by selecting, renaming, or computing attributes. In relational algebra these operators are known as project, rename, and expand. @@ -177,10 +177,10 @@ def pro(self, *attribute_list, **rename_dict): # make attribute_list a list of dicts for initializing a Heading attribute_list = [v._asdict() for k, v in self.attributes.items() - if k in attribute_set and k not in rename_dict.values()] + if k in attribute_set and k not in renamed_attributes.values()] # add renamed and computed attributes - for new_name, computation in rename_dict.items(): + for new_name, computation in renamed_attributes.items(): if computation in self.names: # renamed attribute new_attr = self.attributes[computation]._asdict() diff --git a/datajoint/relational.py b/datajoint/relational.py index 2022bead7..94e0d04af 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -57,7 +57,7 @@ def __mod__(self, attributes=None): """ return self.project(*attributes) - def project(self, *attributes, **aliases): + def project(self, *attributes, **renamed_attributes): """ Relational projection operator. :param attributes: a list of attribute names to be included in the result. @@ -65,15 +65,15 @@ def project(self, *attributes, **aliases): Primary key attributes are always selected and cannot be excluded. Therefore obj.project() produces a relation with only the primary key attributes. If attributes includes the string '*', all attributes are selected. - Each attribute can only be used once in attributes or renames. Therefore, the projected + Each attribute can only be used once in attributes or renamed_attributes. Therefore, the projected relation cannot have more attributes than the original relation. """ # if the first attribute is a relation, it will be aggregated group = attributes.pop[0] \ if attributes and isinstance(attributes[0], Relation) else None - return self.aggregate(group, *attributes, **aliases) + return self.aggregate(group, *attributes, **renamed_attributes) - def aggregate(self, _group, *attributes, **aliases): + def aggregate(self, _group, *attributes, **renamed_attributes): """ Relational aggregation operator :param group: relation whose tuples can be used in aggregation operators @@ -91,10 +91,10 @@ def aggregate(self, _group, *attributes, **aliases): alias_match = alias_parser.match(attribute) if alias_match: d = alias_match.group_dict() - aliases.update({d['alias']: d['sql_expression']}) + renamed_attributes.update({d['alias']: d['sql_expression']}) else: _attributes += attribute - return Projection(self, _group, *_attributes, **aliases) + return Projection(self, _group, *_attributes, **renamed_attributes) def __iand__(self, restriction): """ @@ -279,23 +279,33 @@ def sql(self): class Projection(Relation): subquery_counter = 0 - def __init__(self, relation, group=None, *attributes, **renames): + def __init__(self, relation, group=None, *attributes, **renamed_attributes): """ See Relation.project() """ - self.conn = relation.conn - self._group = group - self._relation = relation + if group: + # TODO: assert that group.conn is same as relation.conn + self._group = Subquery(group) + self._relation = Subquery(relation) + else: + self._group = None + self._relation = relation self._projection_attributes = attributes - self._renamed_attributes = renames + self._renamed_attributes = renamed_attributes + + @property + def conn(self): + return self._relation.conn @property def sql(self): - sql = self._relation.sql + sql, heading = self._relation.sql + heading = heading.pro(self._projection_attributes, self._renamed_attributes) if self._group is not None: - sql = "NATURAL LEFT JOIN " - return self._relation.sql, \ - self._relation.heading.pro(*self._projection_attributes, **self._renamed_attributes) + group_sql, group_heading = self._group.sql + sql = ("(%s) NATURAL LEFT JOIN (%s) GROUP BY `%s`" % + (sql, group_sql, '`,`'.join(heading.primary_key))) + return sql, heading class Subquery(Relation): @@ -320,5 +330,5 @@ def counter(self): @property def sql(self): return ('(SELECT ' + self._rel.heading.as_sql + - ' FROM ' + self._rel.sql + self._rel.where + ') as `s%x`' % self.counter),\ + ' FROM ' + self._rel.sql + self._rel.where + ') as `_s%x`' % self.counter),\ self._rel.heading.clear_aliases() \ No newline at end of file From 2590ed16651d12593c8ab1ff08b84942de2c3ca2 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 7 May 2015 20:41:49 -0500 Subject: [PATCH 6/8] further progress on relational algebra --- datajoint/heading.py | 43 +++++++++++++++++++++++------------------ datajoint/relational.py | 19 +++++++++++++----- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/datajoint/heading.py b/datajoint/heading.py index d288e0e13..5406cf271 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -8,13 +8,19 @@ class Heading: """ local class for relations' headings. + Heading contains the property attributes, which is an OrderedDict in which the keys are + the attribute names and the values are AttrTuples. """ AttrTuple = namedtuple('AttrTuple', - ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', - 'numeric', 'string', 'is_blob', 'computation', 'dtype')) + ('name', 'type', 'in_key', 'nullable', 'default', + 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', + 'computation', 'dtype')) + AttrTuple.as_dict = AttrTuple._asdict # rename the method into a nicer name def __init__(self, attributes): - # Input: attributes -list of dicts with attribute descriptions + """ + :param attributes: a list of dicts with the same keys as AttrTuple + """ self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) @property @@ -91,7 +97,7 @@ def init_from_database(cls, conn, dbname, table_name): """ cur = conn.query( 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( - table_name=table_name, dbname=dbname), asDict=True) + table_name=table_name, dbname=dbname), asDict=True) attributes = cur.fetchall() rename_map = { @@ -160,30 +166,29 @@ def init_from_database(cls, conn, dbname, table_name): return cls(attributes) - def pro(self, *attribute_list, **renamed_attributes): + def project(self, *attribute_list, **renamed_attributes): """ derive a new heading by selecting, renaming, or computing attributes. In relational algebra these operators are known as project, rename, and expand. The primary key is always included. """ - # include all if '*' is in attribute_set, always include primary key - attribute_set = set(self.names) if '*' in attribute_list \ - else set(attribute_list).union(self.primary_key) - - # report missing attributes - missing = attribute_set.difference(self.names) + # check missing attributes + missing = [a for a in attribute_list if a not in self.names] if missing: - raise DataJointError('Attributes %s are not found' % str(missing)) + raise DataJointError('Attributes `%s` are not found' % '`, `'.join(missing)) + + # always add primary key attributes + attribute_list = self.primary_key + [a for a in attribute_list if a not in self.primary_key] - # make attribute_list a list of dicts for initializing a Heading - attribute_list = [v._asdict() for k, v in self.attributes.items() - if k in attribute_set and k not in renamed_attributes.values()] + # convert attribute_list into a list of dicts but exclude renamed attributes + attribute_list = [v.as_dict() for k, v in self.attributes.items() + if k in attribute_list and k not in renamed_attributes.values()] # add renamed and computed attributes for new_name, computation in renamed_attributes.items(): if computation in self.names: # renamed attribute - new_attr = self.attributes[computation]._asdict() + new_attr = self.attributes[computation].as_dict() new_attr['name'] = new_name new_attr['computation'] = '`' + computation + '`' else: @@ -210,14 +215,14 @@ def join(self, other): join two headings """ assert isinstance(other, Heading) - attribute_list = [v._asdict() for v in self.attributes.values()] + attribute_list = [v.as_dict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attribute_list.append(other.attributes[name]._asdict()) + attribute_list.append(other.attributes[name].as_dict()) return Heading(attribute_list) def resolve_computations(self): """ Remove computations. To be done after computations have been resolved in a subquery """ - return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file + return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relational.py b/datajoint/relational.py index 94e0d04af..cdcfa5458 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -263,17 +263,25 @@ def __init__(self, rel1, rel2): if rel1.conn is not rel2.conn: raise DataJointError('Cannot join relations with different database connections') self.conn = rel1.conn - self._rel1 = rel1 - self._rel2 = rel2 + self._rel1 = Subquery(rel1) + self._rel2 = Subquery(rel2) + + @property + def conn(self): + return self._rel1.conn @property def heading(self): return self._rel1.heading.join(self._rel2.heading) + @property + def counter(self): + self.subquery_counter += 1 + return self.subquery_counter + @property def sql(self): - Join.subquery_counter += 1 - return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.subquery_counter) + return '%s NATURAL JOIN %s as `_j%x`' % (self._rel1.sql, self._rel2.sql, self.counter) class Projection(Relation): @@ -284,7 +292,8 @@ def __init__(self, relation, group=None, *attributes, **renamed_attributes): See Relation.project() """ if group: - # TODO: assert that group.conn is same as relation.conn + if relation.conn is not group.conn: + raise DataJointError('Cannot join relations with different database connections') self._group = Subquery(group) self._relation = Subquery(relation) else: From 7801a36ecfc3e663f46e72db3de3037c019eba5d Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 7 May 2015 20:46:00 -0500 Subject: [PATCH 7/8] fixed Table.sql property --- datajoint/table.py | 11 ++++++----- demos/demo1.py | 13 +------------ 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/datajoint/table.py b/datajoint/table.py index 3966c40c3..4d23deec8 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -26,7 +26,7 @@ class Table(Relation): def __init__(self, conn=None, dbname=None, class_name=None, definition=None): self.class_name = class_name - self.conn = conn + self._conn = conn self.dbname = dbname self.conn.load_headings(self.dbname) @@ -42,13 +42,14 @@ def __init__(self, conn=None, dbname=None, class_name=None, definition=None): else: declare(conn, definition, class_name) + @property - def sql(self): - return self.full_table_name + def conn(self): + return self._conn @property - def heading(self): - return self.conn.headings[self.dbname][self.table_name] + def sql(self): + return self.full_table_name, self.conn.headings[self.dbname][self.table_name] @property def full_table_name(self): diff --git a/demos/demo1.py b/demos/demo1.py index 689905730..a601a0622 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -61,15 +61,4 @@ class Scan(dj.Base): depth : float # depth from surface wavelength : smallint # (nm) laser wavelength mwatts: numeric(4,1) # (mW) laser power to brain - """ - - -class ScanInfo(dj.Base, dj.AutoPopulate): - definition = None - pop_rel = Session - - def make_tuples(self, key): - info = (Session()*Scan() & key).pro('experiment_folder').fetch() - filename = os.path.join(info.experiment_folder, 'scan_%03', ) - - + """ \ No newline at end of file From 959f6b7c75368c86c689cb3cdcfddb7f053fc2cd Mon Sep 17 00:00:00 2001 From: Edgar Walker Date: Fri, 8 May 2015 03:16:31 -0500 Subject: [PATCH 8/8] Implement Table instantiation using definitions with direct database ref --- datajoint/autopopulate.py | 10 +++++-- datajoint/base.py | 3 +- datajoint/table.py | 17 ++++++----- tests/schemata/schema1/test1.py | 11 +++++-- tests/test_base.py | 11 ++++++- tests/test_table.py | 52 ++++++++++++++++++++++++++++++++- 6 files changed, 88 insertions(+), 16 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 8e32ef43d..fa0e9300b 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -2,9 +2,11 @@ from . import DataJointError import pprint import abc +import logging #noinspection PyExceptionInherit,PyCallingNonCallable +logger = logging.getLogger(__name__) class AutoPopulate(metaclass=abc.ABCMeta): """ @@ -29,6 +31,10 @@ def make_tuples(self, key): """ pass + @property + def target(self): + return self + def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): """ rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel @@ -38,9 +44,9 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): raise DataJointError('') self.conn.cancel_transaction() - unpopulated = self.pop_rel - self + unpopulated = self.pop_rel - self.target if not unpopulated.count: - print('Nothing to populate', flush=True) # TODO: use logging? + logger.info('Nothing to populate', flush=True) if catch_errors: error_keys, errors = [], [] for key in unpopulated.fetch(): diff --git a/datajoint/base.py b/datajoint/base.py index 45370c182..0fd7db7dd 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -49,7 +49,7 @@ def full_class_name(self): return '{}.{}'.format(self.__module__, self.class_name) @property - def access_name(self): + def ref_name(self): """ :return: name by which this class should be accessible as """ @@ -60,7 +60,6 @@ def access_name(self): return parent + '.' + self.class_name - def __init__(self): #TODO: support taking in conn obj self.class_name = self.__class__.__name__ module = self.__module__ diff --git a/datajoint/table.py b/datajoint/table.py index b9a63f383..bc7c73273 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -350,7 +350,12 @@ def _parse_attr_def(self, line, in_key=False): return Heading.AttrTuple(**attr_info) def get_base(self, module_name, class_name): - return None + m = re.match(r'`(\w+)`', module_name) + if m: + dbname = m.group(1) + return Table(self.conn, dbname, class_name) + else: + return None @property def ref_name(self): @@ -367,12 +372,8 @@ def _declare(self): raise DataJointError('Table definition is missing!') table_info, parents, referenced, field_defs, index_defs = self._parse_declaration() defined_name = table_info['module'] + '.' + table_info['className'] - if self._use_package: - parent = self.__module__.split('.')[-2] - else: - parent = self.__module__.split('.')[-1] - expected_name = parent + '.' + self.class_name - if not defined_name == expected_name: + + if not defined_name == self.ref_name: raise DataJointError('Table name {} does not match the declared' 'name {}'.format(expected_name, defined_name)) @@ -466,7 +467,7 @@ def _parse_declaration(self): # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] ptrn = """ - ^(?P\w+)\.(?P\w+)\s* # module.className + ^(?P[\w\`]+)\.(?P\w+)\s* # module.className \(\s*(?P\w+)\s*\)\s* # (tier) \#\s*(?P.*)$ # comment """ diff --git a/tests/schemata/schema1/test1.py b/tests/schemata/schema1/test1.py index d0d276e6d..de36c821a 100644 --- a/tests/schemata/schema1/test1.py +++ b/tests/schemata/schema1/test1.py @@ -27,9 +27,9 @@ class Experiments(dj.Base): """ # refers to a table in dj_test2 (bound to test2) but without a class -class Session(dj.Base): +class Sessions(dj.Base): definition = """ - test1.Session (manual) # Experiment sessions + test1.Sessions (manual) # Experiment sessions -> test1.Subjects -> test2.Experimenter session_id : int # unique session id @@ -45,6 +45,13 @@ class Match(dj.Base): dob : date # date of birth """ +# this tries to reference a table in database directly without ORM +class TrainingSession(dj.Base): + definition = """ + test1.TrainingSession (manual) # training sessions + -> `dj_test2`.Experimenter + session_id : int # training session id + """ class Empty(dj.Base): pass diff --git a/tests/test_base.py b/tests/test_base.py index 49a5b6810..d4addbbc0 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -178,7 +178,7 @@ def test_detection_of_existing_table(self): assert_true(s.is_declared) def test_definition_referring_to_existing_table_without_class(self): - s1 = test1.Session() + s1 = test1.Sessions() assert_true('experimenter_id' in s1.primary_key) s2 = test2.Session() @@ -189,6 +189,15 @@ def test_reference_to_package_level_table(self): s.declare() assert_true('pop_id' in s.primary_key) + def test_direct_reference_to_existing_table_should_fail(self): + """ + When deriving from Base, definition should not contain direct reference + to a database name + """ + s = test1.TrainingSession() + with assert_raises(DataJointError): + s.declare() + @raises(TypeError) def test_instantiation_of_base_derivative_without_definition_should_fail(): test1.Empty() diff --git a/tests/test_table.py b/tests/test_table.py index a8cf35133..bfa03a4c6 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -8,6 +8,7 @@ from datajoint import DataJointError import numpy as np from numpy.testing import assert_array_equal +from datajoint.table import Table def setup(): """ @@ -121,4 +122,53 @@ def test_blob_insert(self): t = (0, x, 'this is a random image') self.relvar_blob.insert(t) x2 = self.relvar_blob.fetch()[0][1] - assert_array_equal(x,x2, 'inserted blob does not match') \ No newline at end of file + assert_array_equal(x,x2, 'inserted blob does not match') + +class TestUnboundTables(object): + """ + Test usages of Table objects not connected to a module. + """ + def setup(self): + cleanup() + self.conn = Connection(**CONN_INFO) + + def test_creation_from_definition(self): + definition = """ + `dj_free`.Animals (manual) # my animal table + animal_id : int # unique id for the animal + --- + animal_name : varchar(128) # name of the animal + """ + table = Table(self.conn, 'dj_free', 'Animals', definition) + table.declare() + assert_true('animal_id' in table.primary_key) + + def test_reference_to_non_existant_table_should_fail(self): + definition = """ + `dj_free`.Recordings (manual) # recordings + -> `dj_free`.Animals + rec_session_id : int # recording session identifier + """ + table = Table(self.conn, 'dj_free', 'Recordings', definition) + assert_raises(DataJointError, table.declare) + + def test_reference_to_existing_table(self): + definition1 = """ + `dj_free`.Animals (manual) # my animal table + animal_id : int # unique id for the animal + --- + animal_name : varchar(128) # name of the animal + """ + table1 = Table(self.conn, 'dj_free', 'Animals', definition1) + table1.declare() + + definition2 = """ + `dj_free`.Recordings (manual) # recordings + -> `dj_free`.Animals + rec_session_id : int # recording session identifier + """ + table2 = Table(self.conn, 'dj_free', 'Recordings', definition2) + table2.declare() + assert_true('animal_id' in table2.primary_key) + +