diff --git a/datajoint/connection.py b/datajoint/connection.py index d94d404bd..01a285769 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -83,28 +83,6 @@ def __eq__(self, other): """ return self.conn_info == other.conn_info - def is_same(self, host, user): - """ - true if the connection host and user name are the same - """ - if host is None: - host = self.conn_info['host'] - port = self.conn_info['port'] - else: - try: - host, port = host.split(':') - port = int(port) - except ValueError: - port = default_port - - if user is None: - user = self.conn_info['user'] - - return self.conn_info['host'] == host and \ - self.conn_info['port'] == port and \ - self.conn_info['user'] == user - - @property def is_connected(self): return self._conn.ping() @@ -180,14 +158,14 @@ def _load_headings(self, dbname, force=False): Setting force=True will result in reloading of the heading even if one already exists. """ - if not dbname in self.headings or force: + if dbname not in self.headings or force: logger.info('Loading table definitions from `{dbname}`...'.format(dbname=dbname)) self.table_names[dbname] = {} self.headings[dbname] = {} self.tableInfo[dbname] = {} cur = self.query('SHOW TABLE STATUS FROM `{dbname}` WHERE name REGEXP "{sqlPtrn}"'.format( - dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), asDict=True) + dbname=dbname, sqlPtrn=table_name_regexp_sql.pattern), as_dict=True) for info in cur: info = {k.lower(): v for k, v in info.items()} # lowercase it @@ -202,27 +180,27 @@ def _load_headings(self, dbname, force=False): def load_dependencies(self, dbname): # TODO: Perhaps consider making this "private" by preceding with underscore? """ - load dependencies (foreign keys) between tables by examnining their + load dependencies (foreign keys) between tables by examining their respective CREATE TABLE statements. """ - ptrn = r""" - FOREIGN\ KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table + foreign_key_regexp = re.compile(r""" + FOREIGN KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table REFERENCES\s+(?P[^\s]+)\s+ # table referenced \((?P[`\w ,]+)\) # list of keys in the referenced table - """ + """, re.X) logger.info('Loading dependencies for `{dbname}`'.format(dbname=dbname)) for tabName in self.tableInfo[dbname]: cur = self.query('SHOW CREATE TABLE `{dbname}`.`{tabName}`'.format(dbname=dbname, tabName=tabName), - asDict=True) + as_dict=True) table_def = cur.fetchone() full_table_name = '`%s`.`%s`' % (dbname, tabName) self.parents[full_table_name] = [] self.referenced[full_table_name] = [] - for m in re.finditer(ptrn, table_def["Create Table"], re.X): # iterate through foreign key statements + for m in foreign_key_regexp.finditer(table_def["Create Table"]): # iterate through foreign key statements assert m.group('attr1') == m.group('attr2'), \ 'Foreign keys must link identically named attributes' attrs = m.group('attr1') @@ -234,11 +212,7 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv if not re.search(r'`\.`', ref): # if referencing other table in same schema ref = '`%s`.%s' % (dbname, ref) # convert to full-table name - if is_primary: - self.parents[full_table_name].append(ref) - else: - self.referenced[full_table_name].append(ref) - + (self.parents if is_primary else self.referenced)[full_table_name].append(ref) self.parents.setdefault(ref, []) self.referenced.setdefault(ref, []) @@ -298,7 +272,6 @@ def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) self._conn.close() - def erd(self, databases=None, tables=None, fill=True, reload=True): """ Creates Entity Relation Diagram for the database or specified subset of @@ -321,14 +294,14 @@ def erd(self, databases=None, tables=None, fill=True, reload=True): graph.plot() - def query(self, query, args=(), asDict=False): + def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator. - If asDict is set to True, the returned cursor objects returns + If as_dict is set to True, the returned cursor objects returns query results as dictionary. """ - cursor = pymysql.cursors.DictCursor if asDict else pymysql.cursors.Cursor + cursor = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor cur = self._conn.cursor(cursor=cursor) # Log the query @@ -343,4 +316,4 @@ def cancel_transaction(self): self.query('ROLLBACK') def commit_transaction(self): - self.query('COMMIT') + self.query('COMMIT') \ No newline at end of file diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index 9cf387615..d21e17580 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -28,7 +28,7 @@ class FreeRelation(RelationalOperand): even after tables are modified after the instance is created. """ - def __init__(self, conn=None, dbname=None, class_name=None, definition=None): + def __init__(self, conn, dbname, class_name=None, definition=None): self.class_name = class_name self._conn = conn self.dbname = dbname @@ -38,14 +38,21 @@ def __init__(self, conn=None, dbname=None, class_name=None, definition=None): # register with a fake module, enclosed in back quotes # necessary for loading mechanism self.conn.bind('`{0}`'.format(dbname), dbname) + super().__init__(conn) @property - def definition(self): - return self._definition + def from_clause(self): + return self.full_table_name @property - def conn(self): - return self._conn + def heading(self): + self.declare() + return self.conn.headings[self.dbname][self.table_name] + + + @property + def definition(self): + return self._definition @property def is_declared(self): @@ -91,15 +98,6 @@ def _field_to_sql(field): #TODO move this into Attribute Tuple return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( name=field.name, type=field.type, default=default, comment=field.comment) - @property - def sql(self): - return self.full_table_name, self.heading - - @property - def heading(self): - self.declare() - return self.conn.headings[self.dbname][self.table_name] - @property def full_table_name(self): """ @@ -112,9 +110,7 @@ def table_name(self): """ :return: name of the associated table """ - return self.conn.table_names[self.dbname][self.class_name] - - + return self.conn.table_names[self.dbname][self.class_name] if self.is_declared else None @property def primary_key(self): @@ -123,7 +119,6 @@ def primary_key(self): """ return self.heading.primary_key - def iter_insert(self, iter, **kwargs): """ Inserts an entire batch of entries. Additional keyword arguments are passed to insert. @@ -139,7 +134,7 @@ def batch_insert(self, data, **kwargs): :param data: must be iterable, each row must be a valid argument for insert """ - self.iter_insert(data.__iter__()) + self.iter_insert(data.__iter__(), **kwargs) def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8) """ @@ -157,24 +152,26 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress """ if isinstance(tup, tuple) or isinstance(tup, list) or isinstance(tup, np.ndarray): - value_list = ','.join([repr(val) if not name in self.heading.blobs else '%s' - for name, val in zip(self.heading.names, tup)]) + value_list = ','.join([repr(val) if name not in self.heading.blobs else '%s' + for name, val in zip(self.heading.names, tup)]) args = tuple(pack(val) for name, val in zip(self.heading.names, tup) if name in self.heading.blobs) attribute_list = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`' elif isinstance(tup, dict): - value_list = ','.join([repr(tup[name]) if not name in self.heading.blobs else '%s' - for name in self.heading.names if name in tup]) + value_list = ','.join([repr(tup[name]) if name not in self.heading.blobs else '%s' + for name in self.heading.names if name in tup]) args = tuple(pack(tup[name]) for name in self.heading.names - if (name in tup and name in self.heading.blobs) ) - attribute_list = '`' + '`,`'.join([name for name in self.heading.names if name in tup]) + '`' + if name in tup and name in self.heading.blobs) + attribute_list = '`' + '`,`'.join( + [name for name in self.heading.names if name in tup]) + '`' elif isinstance(tup, np.void): - value_list = ','.join([repr(tup[name]) if not name in self.heading.blobs else '%s' - for name in self.heading.names if name in tup.dtype.fields]) + value_list = ','.join([repr(tup[name]) if name not in self.heading.blobs else '%s' + for name in self.heading.names if name in tup.dtype.fields]) args = tuple(pack(tup[name]) for name in self.heading.names - if (name in tup.dtype.fields and name in self.heading.blobs) ) - attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup.dtype.fields]) + '`' + if name in tup.dtype.fields and name in self.heading.blobs) + attribute_list = '`' + '`,`'.join( + [q for q in self.heading.names if q in tup.dtype.fields]) + '`' else: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) if replace: @@ -188,25 +185,26 @@ def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress logger.info(sql) self.conn.query(sql, args=args) - def delete(self): # TODO: (issues #14 and #15) - pass + def delete(self): + # TODO: make cascading (issue #15) + self.conn.query('DELETE FROM ' + self.from_clause + self.where_clause) def drop(self): """ Drops the table associated to this object. """ # TODO: make cascading (issue #16) - self.conn.query('DROP TABLE %s' % self.full_table_name) - self.conn.clear_dependencies(dbname=self.dbname) - self.conn.load_headings(dbname=self.dbname, force=True) - logger.debug("Dropped table %s" % self.full_table_name) + + if self.is_declared: + self.conn.query('DROP TABLE %s' % self.full_table_name) + self.conn.clear_dependencies(dbname=self.dbname) + self.conn.load_headings(dbname=self.dbname, force=True) + logger.info("Dropped table %s" % self.full_table_name) def set_table_comment(self, comment): """ Update the table comment in the table definition. - :param comment: new comment as string - """ # TODO: add verification procedure (github issue #24) self.alter('COMMENT="%s"' % comment) @@ -303,12 +301,11 @@ def _parse_index_def(self, line): :return: groupdict with index info """ line = line.strip() - index_ptrn = """ + index_regexp = re.compile(""" ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) + """, re.I + re.X) + m = index_regexp.match(line) assert m, 'Invalid index declaration "%s"' % line index_info = m.groupdict() attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) @@ -317,52 +314,9 @@ def _parse_index_def(self, line): 'Duplicate attributes in index declaration "%s"' % line return index_info - def _parse_attr_def(self, line, in_key=False): - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: set to True if attribute is in primary key set - :returns: attribute tuple - """ - line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['nullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - attr_info['in_key'] = in_key - attr_info['autoincrement'] = None - attr_info['numeric'] = None - attr_info['string'] = None - attr_info['is_blob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) - def get_base(self, module_name, class_name): m = re.match(r'`(\w+)`', module_name) - if m: - dbname = m.group(1) - return FreeRelation(self.conn, dbname, class_name) - else: - return None + return FreeRelation(self.conn, m.group(1), class_name) if m else None @property def ref_name(self): @@ -460,7 +414,6 @@ def _declare(self): self.conn.query(sql) self.conn.load_headings(self.dbname, force=True) - def _parse_declaration(self): """ Parse declaration and create new SQL table accordingly. @@ -507,9 +460,47 @@ def _parse_declaration(self): elif re.match(r'^(unique\s+)?index[^:]*$', line): index_defs.append(self._parse_index_def(line)) elif fieldP.match(line): - field_defs.append(self._parse_attr_def(line, in_key)) + field_defs.append(parse_attribute_definition(line, in_key)) else: raise DataJointError( 'Invalid table declaration line "%s"' % line) return table_info, parents, referenced, field_defs, index_defs + + +def parse_attribute_definition(line, in_key=False): + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + + :param line: attribution line + :param in_key: set to True if attribute is in primary key set + :returns: attribute tuple + """ + line = line.strip() + attribute_regexp = re.compile(""" + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """, re.X) + m = attribute_regexp.match(line) + assert m, 'Invalid field declaration "%s"' % line + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + attr_info['in_key'] = in_key + attr_info['autoincrement'] = None + attr_info['numeric'] = None + attr_info['string'] = None + attr_info['is_blob'] = None + attr_info['computation'] = None + attr_info['dtype'] = None + + return Heading.AttrTuple(**attr_info) \ No newline at end of file diff --git a/datajoint/heading.py b/datajoint/heading.py index 5406cf271..4f108c25c 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -15,7 +15,7 @@ class Heading: ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', 'numeric', 'string', 'is_blob', 'computation', 'dtype')) - AttrTuple.as_dict = AttrTuple._asdict # rename the method into a nicer name + AttrTuple.as_dict = AttrTuple._asdict # renaming to make public def __init__(self, attributes): """ @@ -97,7 +97,7 @@ def init_from_database(cls, conn, dbname, table_name): """ cur = conn.query( 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( - table_name=table_name, dbname=dbname), asDict=True) + table_name=table_name, dbname=dbname), as_dict=True) attributes = cur.fetchall() rename_map = { @@ -210,7 +210,7 @@ def project(self, *attribute_list, **renamed_attributes): return Heading(attribute_list) - def join(self, other): + def __add__(self, other): """ join two headings """ @@ -221,8 +221,8 @@ def join(self, other): attribute_list.append(other.attributes[name].as_dict()) return Heading(attribute_list) - def resolve_computations(self): + def resolve(self): """ - Remove computations. To be done after computations have been resolved in a subquery + Remove attribute computations after they have been resolved in a subquery """ return Heading([dict(v.as_dict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relation.py b/datajoint/relation.py index 50a3b9661..40af819ef 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -47,15 +47,11 @@ def full_class_name(self): @property def ref_name(self): """ - :return: name by which this class should be accessible as + :return: name by which this class should be accessible """ - if self._use_package: - parent = self.__module__.split('.')[-2] - else: - parent = self.__module__.split('.')[-1] + parent = self.__module__.split('.')[-2 if self._use_package else -1] return parent + '.' + self.class_name - def __init__(self): #TODO: support taking in conn obj class_name = self.__class__.__name__ module_name = self.__module__ @@ -87,7 +83,6 @@ def __init__(self): #TODO: support taking in conn obj # initialize using super class's constructor super().__init__(conn, dbname, class_name) - def get_base(self, module_name, class_name): """ Loads the base relation from the module. If the base relation is not defined in @@ -105,8 +100,8 @@ def get_base(self, module_name, class_name): ret = getattr(mod_obj, class_name)() except AttributeError: ret = FreeRelation(conn=self.conn, - dbname=self.conn.mod_to_db[mod_obj.__name__], - class_name=class_name) + dbname=self.conn.mod_to_db[mod_obj.__name__], + class_name=class_name) return ret @classmethod @@ -142,4 +137,4 @@ def get_module(cls, module_name): try: return importlib.import_module(module_name) except ImportError: - return None + return None \ No newline at end of file diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index e5ea820ab..0f59180a3 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -21,30 +21,28 @@ class RelationalOperand(metaclass=abc.ABCMeta): When fetching data from the database, this tree of objects is compiled into an SQL expression. It is a mixin class that provides relational operators, iteration, and fetch capability. RelationalOperand operators are: restrict, pro, and join. - """ - _restrictions = [] + """ - @abc.abstractproperty - def sql(self): - """ - The sql property returns the tuple: (SQL command, Heading object) for its relation. - The SQL command does not include the attribute list or the WHERE clause. - :return: sql, heading - """ - pass + def __init__(self, conn, restrictions=None): + self._conn = conn + self._restrictions = [] if restrictions is None else restrictions - @abc.abstractproperty + @property def conn(self): - """ - All relations must keep track of their connection object - :return: - """ - pass + return self._conn @property def restrictions(self): return self._restrictions - + + @abc.abstractproperty + def from_clause(self): + pass + + @abc.abstractproperty + def heading(self): + pass + def __mul__(self, other): """ relational join @@ -69,9 +67,10 @@ def project(self, *attributes, **renamed_attributes): relation cannot have more attributes than the original relation. """ # if the first attribute is a relation, it will be aggregated - group = attributes.pop[0] \ - if attributes and isinstance(attributes[0], RelationalOperand) else None - return self.aggregate(group, *attributes, **renamed_attributes) + group = None + if attributes and isinstance(attributes[0], RelationalOperand): + group = attributes.pop[0] + return Projection(self, group, *attributes, **renamed_attributes) def aggregate(self, _group, *attributes, **renamed_attributes): """ @@ -82,19 +81,7 @@ def aggregate(self, _group, *attributes, **renamed_attributes): """ if _group is not None and not isinstance(_group, RelationalOperand): raise DataJointError('The second argument must be a relation or None') - alias_parser = re.compile( - '^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') - - # expand extended attributes in the form 'sql_expression -> new_attribute' - _attributes = [] - for attribute in attributes: - alias_match = alias_parser.match(attribute) - if alias_match: - d = alias_match.group_dict() - renamed_attributes.update({d['alias']: d['sql_expression']}) - else: - _attributes += attribute - return Projection(self, _group, *_attributes, **renamed_attributes) + return Projection(self, _group, *attributes, **renamed_attributes) def __iand__(self, restriction): """ @@ -129,9 +116,14 @@ def __sub__(self, restriction): """ return self & Not(restriction) + def make_select(self, attribute_spec=None): + if attribute_spec is None: + attribute_spec = self.heading.as_sql + return 'SELECT ' + attribute_spec + ' FROM ' + self.from_clause + self.where_clause + @property def count(self): - cur = self.conn.query('SELECT count(*) FROM ' + self.sql[0] + self._where) + cur = self.conn.query(self.make_select('count(*)')) return cur.fetchone()[0] def __call__(self, *args, **kwargs): @@ -147,8 +139,9 @@ def fetch(self, offset=0, limit=None, order_by=None, descending=False): :return: the contents of the relation in the form of a structured numpy.array """ cur = self.cursor(offset, limit, order_by, descending) - ret = np.array(list(cur), dtype=self.heading.as_dtype) - for f in self.heading.blobs: + heading = self.heading + ret = np.array(list(cur), dtype=heading.as_dtype) + for f in heading.blobs: for i in range(len(ret)): ret[i][f] = unpack(ret[i][f]) return ret @@ -162,11 +155,10 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): :return: cursor to the query """ if offset and limit is None: - raise DataJointError('') - sql, heading = self.sql - sql = 'SELECT ' + heading.as_sql + ' FROM ' + sql + raise DataJointError('offset cannot be set without setting a limit') + sql = self.make_select() if order_by is not None: - sql += ' ORDER BY ' + ', '.join(self._orderBy) + sql += ' ORDER BY ' + ', '.join(order_by) if descending: sql += ' DESC' if limit is not None: @@ -177,17 +169,18 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False): return self.conn.query(sql) def __repr__(self): - limit = 13 #TODO: move some of these display settings into the config - width = 12 + limit = 7 #TODO: move some of these display settings into the config + width = 14 + rel = self.project(*self.heading.non_blobs) template = '%%-%d.%ds' % (width, width) - repr_string = ' '.join([template % column for column in self.heading]) + '\n' - repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in self.heading]) + '\n' - tuples = self.project(*self.heading.non_blobs).fetch(limit=limit) - for tup in tuples: + columns = rel.heading.names + repr_string = ' '.join([template % column for column in columns]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' + for tup in rel.fetch(limit=limit): repr_string += ' '.join([template % column for column in tup]) + '\n' if self.count > limit: repr_string += '...\n' - repr_string += '%d tuples\n' % self.count + repr_string += ' (%d tuples)\n' % self.count return repr_string def __iter__(self): @@ -200,11 +193,11 @@ def __iter__(self): cur, h = self.project().cursor() # project q = cur.fetchone() while q: - yield np.array([q, ], dtype=h.asdtype) + yield np.array([q, ], dtype=h.as_dtype) q = cur.fetchone() @property - def _where(self): + def where_clause(self): """ convert the restriction into an SQL WHERE """ @@ -224,14 +217,14 @@ def make_condition(arg): for r in self._restrictions: negate = isinstance(r, Not) if negate: - r = r.restrictions + r = r.restriction if isinstance(r, dict) or isinstance(r, np.void): r = make_condition(r) elif isinstance(r, np.ndarray) or isinstance(r, list): r = '('+') OR ('.join([make_condition(q) for q in r])+')' elif isinstance(r, RelationalOperand): common_attributes = ','.join([q for q in self.heading.names if r.heading.names]) - r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.sql) + r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.from_clause) assert isinstance(r, str), 'condition must be converted into a string' r = '('+r+')' @@ -244,77 +237,83 @@ def make_condition(arg): class Not: """ - inverse of a restriction + inverse restriction """ def __init__(self, restriction): - self.__restriction = restriction - - @property - def restriction(self): - return self.__restriction + self.restriction = restriction class Join(RelationalOperand): - subquery_counter = 0 + __counter = 0 - def __init__(self, rel1, rel2): - if not isinstance(rel2, RelationalOperand): + def __init__(self, arg1, arg2): + if not isinstance(arg2, RelationalOperand): raise DataJointError('a relation can only be joined with another relation') - if rel1.conn is not rel2.conn: + if arg1.conn != arg2.conn: raise DataJointError('Cannot join relations with different database connections') - self.conn = rel1.conn - self._rel1 = Subquery(rel1) - self._rel2 = Subquery(rel2) + self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 + self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 + super().__init__(arg1.conn, self._arg1.restrictions + self._arg2.restrictions) @property - def conn(self): - return self._rel1.conn + def counter(self): + self.__counter += 1 + return self.__counter @property def heading(self): - return self._rel1.heading.join(self._rel2.heading) + return self._arg1.heading + self._arg2.heading @property - def counter(self): - self.subquery_counter += 1 - return self.subquery_counter - - @property - def sql(self): - return '%s NATURAL JOIN %s as `_j%x`' % (self._rel1.sql, self._rel2.sql, self.counter) + def from_clause(self): + return '%s NATURAL JOIN %s' % (self._arg1.from_clause, self._arg2.from_clause) class Projection(RelationalOperand): - subquery_counter = 0 - def __init__(self, relation, group=None, *attributes, **renamed_attributes): + def __init__(self, arg, group=None, *attributes, **renamed_attributes): """ See RelationalOperand.project() """ + alias_parser = re.compile( + '^\s*(?P\S(.*\S)?)\s*->\s*(?P[a-z][a-z_0-9]*)\s*$') + # expand extended attributes in the form 'sql_expression -> new_attribute' + self._attributes = [] + self._renamed_attributes = renamed_attributes + for attribute in attributes: + alias_match = alias_parser.match(attribute) + if alias_match: + d = alias_match.groupdict() + self._renamed_attributes.update({d['alias']: d['sql_expression']}) + else: + self._attributes.append(attribute) + super().__init__(arg.conn) if group: - if relation.conn is not group.conn: + if arg.conn != group.conn: raise DataJointError('Cannot join relations with different database connections') self._group = Subquery(group) - self._relation = Subquery(relation) + self._arg = Subquery(arg) else: self._group = None - self._relation = relation - self._projection_attributes = attributes - self._renamed_attributes = renamed_attributes + if arg.heading.computed: + self._arg = Subquery(arg) + else: + # project without subquery + self._arg = arg + self._restrictions = self._arg.restrictions @property - def conn(self): - return self._relation.conn + def heading(self): + return self._arg.heading.project(*self._attributes, **self._renamed_attributes) @property - def sql(self): - sql, heading = self._relation.sql - heading = heading.project(self._projection_attributes, self._renamed_attributes) - if self._group is not None: - group_sql, group_heading = self._group.sql - sql = ("(%s) NATURAL LEFT JOIN (%s) GROUP BY `%s`" % - (sql, group_sql, '`,`'.join(heading.primary_key))) - return sql, heading + def from_clause(self): + if self._group is None: + return self._arg.from_clause + else: + return "(%s) NATURAL LEFT JOIN (%s) GROUP BY `%s`" % ( + self._arg.from_clause, self._group.from_clause, + '`,`'.join(self.heading.primary_key)) class Subquery(RelationalOperand): @@ -322,22 +321,21 @@ class Subquery(RelationalOperand): A Subquery encapsulates its argument in a SELECT statement, enabling its use as a subquery. The attribute list and the WHERE clause are resolved. """ - _counter = 0 + __counter = 0 - def __init__(self, rel): - self._rel = rel + def __init__(self, arg): + self._arg = arg + super().__init__(arg.conn) @property - def conn(self): - return self._rel.conn + def counter(self): + Subquery.__counter += 1 + return Subquery.__counter @property - def counter(self): - Subquery._counter += 1 - return Subquery._counter + def from_clause(self): + return '(' + self._arg.make_select() + ') as `_s%x`' % self.counter @property - def sql(self): - return ('(SELECT ' + self._rel.heading.as_sql + - ' FROM ' + self._rel.sql + self._rel.where + ') as `_s%x`' % self.counter),\ - self._rel.heading.clear_aliases() \ No newline at end of file + def heading(self): + return self._arg.heading.resolve() \ No newline at end of file diff --git a/demos/rundemo1.py b/demos/rundemo1.py index 88a8f4502..2f9dc650a 100644 --- a/demos/rundemo1.py +++ b/demos/rundemo1.py @@ -4,39 +4,102 @@ @author: dimitri """ - +import logging import demo1 -s = demo1.Subject() -e = demo1.Experiment() - -s.insert(dict(subject_id=1, - real_id="George", - species="monkey", - date_of_birth="2011-01-01", - sex="M", - caretaker="Arthur", - animal_notes="this is a test")) - -s.insert(dict(subject_id=2, - real_id='1373', - date_of_birth="2014-08-01", - caretaker="Joe")) - -s.insert((3, 'Dennis', 'monkey', '2012-09-01')) -s.insert((12430, 'C0430', 'mouse', '2012-09-01', 'M')) -s.insert((12431, 'C0431', 'mouse', '2012-09-01', 'F')) - -print('inserted keys into Subject:') -for key in s: - print(key) - -e.insert(dict(subject_id=1, - experiment=1, - experiment_date="2014-08-28", - experiment_notes="my first experiment")) - -e.insert(dict(subject_id=1, - experiment=2, - experiment_date="2014-08-28", - experiment_notes="my second experiment")) +logging.basicConfig(level=logging.DEBUG) + +subject = demo1.Subject() +experiment = demo1.Experiment() +session = demo1.Session() +scan = demo1.Scan() + +scan.drop() +session.drop() +experiment.drop() +subject.drop() + +subject.insert(dict(subject_id=1, + real_id="George", + species="monkey", + date_of_birth="2011-01-01", + sex="M", + caretaker="Arthur", + animal_notes="this is a test")) + +subject.insert(dict(subject_id=2, + real_id='1373', + date_of_birth="2014-08-01", + caretaker="Joe")) + +subject.insert((3, 'Alice', 'monkey', '2012-09-01')) +subject.insert((4, 'Dennis', 'monkey', '2012-09-01')) +subject.insert((5, 'Warren', 'monkey', '2012-09-01')) +subject.insert((6, 'Franky', 'monkey', '2012-09-01')) +subject.insert((7, 'Simon', 'monkey', '2012-09-01', 'F')) +subject.insert((8, 'Ferocious', 'monkey', '2012-09-01', 'M')) +subject.insert((9, 'Simon', 'monkey', '2012-09-01', 'm')) +subject.insert((10, 'Ferocious', 'monkey', '2012-09-01', 'F')) +subject.insert((11, 'Simon', 'monkey', '2012-09-01', 'm')) +subject.insert((12, 'Ferocious', 'monkey', '2012-09-01', 'M')) +subject.insert((13, 'Dauntless', 'monkey', '2012-09-01', 'F')) +subject.insert((14, 'Dawn', 'monkey', '2012-09-01', 'F')) + +subject.insert((12430, 'C0430', 'mouse', '2012-09-01', 'M')) +subject.insert((12431, 'C0431', 'mouse', '2012-09-01', 'F')) + +print(subject) +print(subject.project()) +print(subject.project(name='real_id', dob='date_of_birth', sex='sex') & 'sex="M"') + +(subject & dict(subject_id=12431)).delete() +print(subject) + +experiment.insert(dict( + subject_id=1, + experiment=1, + experiment_date="2014-08-28", + experiment_notes="my first experiment")) + +experiment.insert(dict( + subject_id=1, + experiment=2, + experiment_date="2014-08-28", + experiment_notes="my second experiment")) + +experiment.insert(dict( + subject_id=2, + experiment=1, + experiment_date="2015-05-01" +)) + +print(experiment) +print(experiment * subject) +print(subject & experiment) +print(subject - experiment) + +session.insert(dict( + subject_id=1, + experiment=2, + session_id=1, + setup=0, + lens="20x" +)) + +scan.insert(dict( + subject_id=1, + experiment=2, + session_id=1, + scan_id=1, + depth=250, + wavelength=980, + mwatts=30.5 +)) + +print((scan * experiment) % ('wavelength->lambda', 'experiment_date')) + +# cleanup +scan.drop() +session.drop() +experiment.drop() +subject.drop() \ No newline at end of file