From 90c597e06025c37087400255ff2f0714ebc1613c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 1 May 2015 22:53:43 -0500 Subject: [PATCH 1/5] minor: changed logging levels to 'info' for create schema and create table messages --- datajoint/connection.py | 8 ++++---- demos/demo1.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 2653950a3..6be43a2a8 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -140,11 +140,11 @@ def bind(self, module, dbname): self.mod_to_db[module] = dbname elif count == 0: # Database doesn't exist, attempt to create - logger.warning("Database `{dbname}` could not be found. " - "Attempting to create the database.".format(dbname=dbname)) + logger.info("Database `{dbname}` could not be found. " + "Attempting to create the database.".format(dbname=dbname)) try: - cur = self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) - logger.warning('Created database `{dbname}`.'.format(dbname=dbname)) + self.query("CREATE DATABASE `{dbname}`".format(dbname=dbname)) + logger.info('Created database `{dbname}`.'.format(dbname=dbname)) self.db_to_mod[dbname] = module self.mod_to_db[module] = dbname except pymysql.OperationalError: diff --git a/demos/demo1.py b/demos/demo1.py index 18e54a4bf..81335b65b 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -12,8 +12,6 @@ conn = dj.conn() # connect to database; conn must be defined in module namespace conn.bind(module=__name__, dbname='dj_test') # bind this module to the database - - class Subject(dj.Base): _table_def = """ demo1.Subject (manual) # Basic subject info From d0dd7b3d26ac9d522353aef4abd3d356a6970b5b Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sat, 2 May 2015 14:59:35 -0500 Subject: [PATCH 2/5] mostly PEP8 type of changes to prepare for new Relational --- datajoint/autopopulate.py | 29 +++-- datajoint/base.py | 14 +-- datajoint/fetch.py | 1 - datajoint/heading.py | 220 +++++++++++++++++++------------------- datajoint/relational.py | 112 ++++++++++--------- datajoint/settings.py | 4 +- datajoint/task.py | 53 --------- demos/demo1.py | 40 +++---- demos/rundemo1.py | 30 +++--- 9 files changed, 223 insertions(+), 280 deletions(-) delete mode 100644 datajoint/task.py diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 30996ac06..624c01f78 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -3,41 +3,40 @@ import abc #noinspection PyExceptionInherit,PyCallingNonCallable + + class AutoPopulate(metaclass=abc.ABCMeta): """ Class datajoint.AutoPopulate is a mixin that adds the method populate() to a dj.Relvar class. Auto-populated relvars must inherit from both datajoint.Relvar and datajoint.AutoPopulate, - must define the property popRel, and must define the callback method makeTuples. + must define the property pop_rel, and must define the callback method make_tuples. """ @abc.abstractproperty - def popRel(self): + def pop_rel(self): """ - Derived classes must implement the read-only property popRel (populate relation) which is the relational + Derived classes must implement the read-only property pop_rel (populate relation) which is the relational expression (a dj.Relvar object) that defines how keys are generated for the populate call. """ pass - @abc.abstractmethod - def makeTuples(self, key): + def make_tuples(self, key): """ - Derived classes must implement methods makeTuples that fetches data from parent tables, restricting by + Derived classes must implement methods make_tuples that fetches data from parent tables, restricting by the given key, computes dependent attributes, and inserts the new tuples into self. """ pass - - def populate(self, catchErrors=False, reserveJobs=False, restrict=None): + def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): """ - rel.populate() will call rel.makeTuples(key) for every primary key in self.popRel + rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel for which there is not already a tuple in rel. """ - self.conn.cancel_transaction() # enumerate unpopulated keys - unpopulated = self.popRel + unpopulated = self.pop_rel if ~isinstance(unpopulated, _Relational): unpopulated = unpopulated() # instantiate @@ -47,7 +46,7 @@ def populate(self, catchErrors=False, reserveJobs=False, restrict=None): unpopulated = unpopulated(*args, **kwargs) # - self # TODO: implement antijoin # execute - if catchErrors: + if catch_errors: errKeys, errors = [], [] for key in unpopulated.fetch(): self.conn.start_transaction() @@ -59,15 +58,15 @@ def populate(self, catchErrors=False, reserveJobs=False, restrict=None): pprint.pprint(key) try: - self.makeTuples(key) + self.make_tuples(key) except Exception as e: self.conn.cancel_transaction() - if not catchErrors: + if not catch_errors: raise print(e) errors += [e] errKeys+= [key] else: self.conn.commit_transaction() - if catchErrors: + if catch_errors: return errors, errKeys diff --git a/datajoint/base.py b/datajoint/base.py index a43c60268..5d01ec1a4 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -72,7 +72,7 @@ class Base(_Relational): class Subjects(dj.Base): - _table_def = ''' + _definition = ''' test1.Subjects (manual) # Basic subject info subject_id : int # unique subject id @@ -93,7 +93,7 @@ def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): self.class_name = class_name self.conn = conn self.dbname = dbname - self._table_def = table_def + self._definition = table_def # register with a fake module, enclosed in back quotes if dbname not in self.conn.db_to_mod: @@ -126,10 +126,10 @@ def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): raise DataJointError( 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - if hasattr(self, '_table_def'): - self._table_def = self._table_def + if hasattr(self, '_definition'): + self._definition = self._definition else: - self._table_def = None + self._definition = None # todo: do we support records and named tuples for tup? def insert(self, tup, ignore_errors=False, replace=False): @@ -421,7 +421,7 @@ def _declare(self): """ Declares the table in the data base if no table in the database matches this object. """ - if not self._table_def: + if not self._definition: raise DataJointError('Table declaration is missing!') table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() defined_name = table_info['module'] + '.' + table_info['className'] @@ -515,7 +515,7 @@ def _parse_declaration(self): referenced = [] index_defs = [] field_defs = [] - declaration = re.split(r'\s*\n\s*', self._table_def.strip()) + declaration = re.split(r'\s*\n\s*', self._definition.strip()) # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] diff --git a/datajoint/fetch.py b/datajoint/fetch.py index e658bf1e6..5a89e55d7 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -16,7 +16,6 @@ class Fetch: """ Fetch defines callable objects that fetch data from a relation """ - def __init__(self, relational): self.rel = relational self._orderBy = None diff --git a/datajoint/heading.py b/datajoint/heading.py index ca7495db0..93be6ef1c 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -13,52 +13,51 @@ class Heading: """ - local class for relationals' headings. + local class for relations' headings. """ + AttrTuple = namedtuple('AttrTuple', + ('name', 'type', 'in_key', 'nullable', 'default', 'comment', 'autoincrement', + 'numeric', 'string', 'is_blob', 'computation', 'dtype')) - AttrTuple = namedtuple('AttrTuple',('name','type','isKey','isNullable', - 'default','comment','isAutoincrement','isNumeric','isString','isBlob', - 'computation','dtype')) - - def __init__(self, attrs): - # Input: attrs -list of dicts with attribute descriptions - self.attrs = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attrs]) + def __init__(self, attributes): + # Input: attributes -list of dicts with attribute descriptions + self.attributes = OrderedDict([(q['name'], Heading.AttrTuple(**q)) for q in attributes]) @property def names(self): - return [k for k in self.attrs] + return [k for k in self.attributes] @property def primary_key(self): - return [k for k,v in self.attrs.items() if v.isKey] + return [k for k, v in self.attributes.items() if v.in_key] @property def dependent_fields(self): - return [k for k,v in self.attrs.items() if not v.isKey] + return [k for k, v in self.attributes.items() if not v.in_key] @property def blobs(self): - return [k for k,v in self.attrs.items() if v.isBlob] + return [k for k, v in self.attributes.items() if v.is_blob] @property def non_blobs(self): - return [k for k,v in self.attrs.items() if not v.isBlob] + return [k for k, v in self.attributes.items() if not v.is_blob] @property def computed(self): - return [k for k,v in self.attrs.items() if v.computation] + return [k for k, v in self.attributes.items() if v.computation] - def __getitem__(self,name): + def __getitem__(self, name): """shortcut to the attribute""" - return self.attrs[name] + return self.attributes[name] def __repr__(self): - autoIncrementString = {False:'', True:' auto_increment'} + autoincrement_string = {False: '', True: ' auto_increment'} return '\n'.join(['%-20s : %-28s # %s' % ( - k if v.default is None else '%s="%s"'%(k,v.default), - '%s%s' % (v.type, autoIncrementString[v.isAutoincrement]), + k if v.default is None else '%s="%s"' % (k, v.default), + '%s%s' % (v.type, autoincrement_string[v.autoincrement]), v.comment) - for k,v in self.attrs.items()]) + for k, v in self.attributes.items()]) @property def asdtype(self): @@ -67,26 +66,26 @@ def asdtype(self): """ return np.dtype(dict( names=self.names, - formats=[v.dtype for k,v in self.attrs.items()])) + formats=[v.dtype for k,v in self.attributes.items()])) @property - def asSQL(self): - """represent heading as SQL field list""" - attrNames = ['`%s`' % name if self.attrs[name].computation is None else '%s as `%s`' % (self.attrs[name].computation, name) - for name in self.names] - return ','.join(attrNames) - - # Use heading as a dictionary like object + def as_sql(self): # TODO: replace with __str__? + """ + represent heading as SQL field list + """ + return ','.join(['`%s`' % name + if self.attributes[name].computation is None + else '%s as `%s`' % (self.attributes[name].computation, name) + for name in self.names]) def keys(self): - return self.attrs.keys() + return self.attributes.keys() def values(self): - return self.attrs.values() + return self.attributes.values() def items(self): - return self.attrs.items() - + return self.attributes.items() @classmethod def init_from_database(cls, conn, dbname, tabname): @@ -95,137 +94,134 @@ def init_from_database(cls, conn, dbname, tabname): """ cur = conn.query( 'SHOW FULL COLUMNS FROM `{tabname}` IN `{dbname}`'.format( - tabname=tabname, dbname=dbname),asDict=True) - attrs = cur.fetchall() + tabname=tabname, dbname=dbname), asDict=True) + attributes = cur.fetchall() rename_map = { - 'Field' : 'name', - 'Type' : 'type', - 'Null' : 'isNullable', + 'Field': 'name', + 'Type': 'type', + 'Null': 'nullable', 'Default': 'default', - 'Key' : 'isKey', + 'Key': 'in_key', 'Comment': 'comment'} - dropFields = ('Privileges', 'Collation') # unncessary + fields_to_drop = ('Privileges', 'Collation') # rename and drop attributes - attrs = [{rename_map[k] if k in rename_map else k: v - for k, v in x.items() if k not in dropFields} - for x in attrs] - numTypes ={ - ('float',False):np.float32, - ('float',True):np.float32, - ('double',False):np.float32, - ('double',True):np.float64, - ('tinyint',False):np.int8, - ('tinyint',True):np.uint8, - ('smallint',False):np.int16, - ('smallint',True):np.uint16, - ('mediumint',False):np.int32, - ('mediumint',True):np.uint32, - ('int',False):np.int32, - ('int',True):np.uint32, - ('bigint',False):np.int64, - ('bigint',True):np.uint64 + attributes = [{rename_map[k] if k in rename_map else k: v + for k, v in x.items() if k not in fields_to_drop} + for x in attributes] + + numeric_types = { + ('float', False): np.float32, + ('float', True): np.float32, + ('double', False): np.float32, + ('double', True): np.float64, + ('tinyint', False): np.int8, + ('tinyint', True): np.uint8, + ('smallint', False): np.int16, + ('smallint', True): np.uint16, + ('mediumint', False): np.int32, + ('mediumint', True): np.uint32, + ('int', False): np.int32, + ('int', True): np.uint32, + ('bigint', False): np.int64, + ('bigint', True): np.uint64 # TODO: include decimal and numeric datatypes } - # additional attribute properties - for attr in attrs: - attr['isNullable'] = (attr['isNullable'] == 'YES') - attr['isKey'] = (attr['isKey'] == 'PRI') - attr['isAutoincrement'] = bool(re.search(r'auto_increment', attr['Extra'], flags=re.IGNORECASE)) - attr['isNumeric'] = bool(re.match(r'(tiny|small|medium|big)?int|decimal|double|float', attr['type'])) - attr['isString'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type'])) - attr['isBlob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type'])) + for attr in attributes: + attr['nullable'] = (attr['nullable'] == 'YES') + attr['in_key'] = (attr['in_key'] == 'PRI') + attr['autoincrement'] = bool(re.search(r'auto_increment', attr['Extra'], flags=re.IGNORECASE)) + attr['numeric'] = bool(re.match(r'(tiny|small|medium|big)?int|decimal|double|float', attr['type'])) + attr['string'] = bool(re.match(r'(var)?char|enum|date|time|timestamp', attr['type'])) + attr['is_blob'] = bool(re.match(r'(tiny|medium|long)?blob', attr['type'])) # strip field lengths off integer types attr['type'] = re.sub(r'((tiny|small|medium|big)?int)\(\d+\)', r'\1', attr['type']) attr['computation'] = None - if not (attr['isNumeric'] or attr['isString'] or attr['isBlob']): + if not (attr['numeric'] or attr['string'] or attr['is_blob']): raise DataJointError('Unsupported field type {field} in `{dbname}`.`{tabname}`'.format( field=attr['type'], dbname=dbname, tabname=tabname)) attr.pop('Extra') # fill out the dtype. All floats and non-nullable integers are turned into specific dtypes attr['dtype'] = object - if attr['isNumeric'] : - isInteger = bool(re.match(r'(tiny|small|medium|big)?int',attr['type'])) - isFloat = bool(re.match(r'(double|float)',attr['type'])) - if isInteger and not attr['isNullable'] or isFloat: + if attr['numeric']: + isInteger = bool(re.match(r'(tiny|small|medium|big)?int', attr['type'])) + isFloat = bool(re.match(r'(double|float)',attr['type'])) + if isInteger and not attr['nullable'] or isFloat: isUnsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE)) t = attr['type'] - t = re.sub(r'\(.*\)','',t) # remove parentheses - t = re.sub(r' unsigned$','',t) # remove unsigned - assert (t,isUnsigned) in numTypes, 'dtype not found for type %s' % t - attr['dtype'] = numTypes[(t,isUnsigned)] - - return cls(attrs) + t = re.sub(r'\(.*\)', '', t) # remove parentheses + t = re.sub(r' unsigned$', '', t) # remove unsigned + assert (t, isUnsigned) in numeric_types, 'dtype not found for type %s' % t + attr['dtype'] = numeric_types[(t, isUnsigned)] + return cls(attributes) - def pro(self, *attrList, **renameDict): + def pro(self, *attribute_list, **rename_dict): """ derive a new heading by selecting, renaming, or computing attributes. In relational algebra these operators are known as project, rename, and expand. The primary key is always included. """ - # include all if '*' is in attrSet, always include primary key - attrSet = set(self.names) if '*' in attrList \ - else set(attrList).union(self.primary_key) + # include all if '*' is in attribute_set, always include primary key + attribute_set = set(self.names) if '*' in attribute_list \ + else set(attribute_list).union(self.primary_key) # report missing attributes - missing = attrSet.difference(self.names) + missing = attribute_set.difference(self.names) if missing: raise DataJointError('Attributes %s are not found' % str(missing)) - # make attrList a list of dicts for initializing a Heading - attrList = [v._asdict() for k,v in self.attrs.items() - if k in attrSet and k not in renameDict.values()] + # make attribute_list a list of dicts for initializing a Heading + attribute_list = [v._asdict() for k, v in self.attributes.items() + if k in attribute_set and k not in rename_dict.values()] # add renamed and computed attributes - for newName, computation in renameDict.items(): + for new_name, computation in rename_dict.items(): if computation in self.names: # renamed attribute - newAttr = self.attrs[computation]._asdict() - newAttr['name'] = newName - newAttr['computation'] = '`' + computation + '`' + new_attr = self.attributes[computation]._asdict() + new_attr['name'] = new_name + new_attr['computation'] = '`' + computation + '`' else: # computed attribute - newAttr = dict( - name = newName, - type = 'computed', - isKey = False, - isNullable = False, - default = None, - comment = 'computed attribute', - isAutoincrement = False, - isNumeric = None, - isString = None, - isBlob = False, - computation = computation, - dtype = object) - attrList.append(newAttr) - - return Heading(attrList) - + new_attr = dict( + name=new_name, + type='computed', + in_key=False, + nullable=False, + default=None, + comment='computed attribute', + autoincrement=False, + numeric=None, + string=None, + is_blob=False, + computation=computation, + dtype=object) + attribute_list.append(new_attr) + + return Heading(attribute_list) def join(self, other): """ join two headings """ - assert isinstance(other,Heading) - attrList = [v._asdict() for v in self.attrs.values()] + assert isinstance(other, Heading) + attribute_list = [v._asdict() for v in self.attributes.values()] for name in other.names: if name not in self.names: - attrList.append(other.attrs[name]._asdict()) - return Heading(attrList) - + attribute_list.append(other.attributes[name]._asdict()) + return Heading(attribute_list) - def resolveComputations(self): + def resolve_computations(self): """ Remove computations. To be done after computations have been resolved in a subquery """ - return Heading( [dict(v._asdict(),computation=None) for v in self.attrs.values()] ) + return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) diff --git a/datajoint/relational.py b/datajoint/relational.py index f02edbc30..244f4cf3c 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -10,6 +10,7 @@ from datajoint import DataJointError from .fetch import Fetch + class _Relational(metaclass=abc.ABCMeta): """ Relational implements relational operators. @@ -26,7 +27,6 @@ class _Relational(metaclass=abc.ABCMeta): _offset = 0 _order_by = [] - #### abstract properties that subclasses must define ##### @abc.abstractproperty def sql(self): return NotImplemented @@ -35,45 +35,54 @@ def sql(self): def heading(self): return NotImplemented - ###### Relational algebra ############## def __mul__(self, other): - "relational join" - return Join(self,other) + """ + relational join + """ + return Join(self, other) def pro(self, *arg, _sub=None, **kwarg): - "relational projection abd aggregation" + """ + relational projection abd aggregation + """ return Projection(self, _sub=_sub, *arg, **kwarg) def __iand__(self, restriction): - "in-place relational restriction or semijoin" + """ + in-place relational restriction or semijoin + """ if self._restrictions is None: self._restrictions = [] self._restrictions.append(restriction) return self def __and__(self, restriction): - "relational restriction or semijoin" + """ + relational restriction or semijoin + """ if self._restrictions is None: self._restrictions = [] - ret = copy(self) # todo: why not deepcopy it? + ret = copy(self) # todo: why not deepcopy it? ret._restrictions = list(ret._restrictions) # copy restriction ret &= restriction return ret def __isub__(self, restriction): - "in-place inverted restriction aka antijoin" + """ + in-place inverted restriction aka antijoin + """ self &= Not(restriction) return self def __sub__(self, restriction): - "inverted restriction aka antijoin" + """ + inverted restriction aka antijoin + """ return self & Not(restriction) - - ###### Fetching the data ############## @property def count(self): - sql = 'SELECT count(*) FROM ' + self.sql + self._whereClause + sql = 'SELECT count(*) FROM ' + self.sql + self._where_clause cur = self.conn.query(sql) return cur.fetchone()[0] #todo: should we assert that this only returns one result? @@ -98,7 +107,6 @@ def __repr__(self): ret_val += '%d tuples\n' % self.count return ret_val - ######## iterator ############### def __iter__(self): """ iterator yields primary key tuples @@ -112,61 +120,63 @@ def __iter__(self): @property - def _whereClause(self): - "make there WHERE clause based on the current restriction" - + def _where_clause(self): + """ + make there WHERE clause based on the current restriction + """ if not self._restrictions: return '' - def makeCondition(arg): + def make_condition(arg): if isinstance(arg,dict): - conds = ['`%s`=%s'%(k,repr(v)) for k,v in arg.items()] + conds = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()] elif isinstance(arg,np.void): - conds = ['`%s`=%s'%(k, arg[k]) for k in arg.dtype.fields] + conds = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields] else: raise DataJointError('invalid restriction type') return ' AND '.join(conds) - condStr = [] + condition_string = [] for r in self._restrictions: - negate = isinstance(r,Not) + negate = isinstance(r, Not) if negate: r = r._restriction - if isinstance(r,dict) or isinstance(r,np.void): - r = makeCondition(r) - elif isinstance(r,np.ndarray) or isinstance(r,list): - r = '('+') OR ('.join([makeCondition(q) for q in r])+')' - elif isinstance(r,_Relational): - commonAttrs = ','.join([q for q in self.heading.names if r.heading.names]) - r = '(%s) in (SELECT %s FROM %s)' % (commonAttrs, commonAttrs, r.sql) + if isinstance(r, dict) or isinstance(r, np.void): + r = make_condition(r) + elif isinstance(r, np.ndarray) or isinstance(r, list): + r = '('+') OR ('.join([make_condition(q) for q in r])+')' + elif isinstance(r, _Relational): + common_attributes = ','.join([q for q in self.heading.names if r.heading.names]) + r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.sql) - assert isinstance(r,str), 'condition must be converted into a string' + assert isinstance(r, str), 'condition must be converted into a string' r = '('+r+')' if negate: - r = 'NOT '+r; - condStr.append(r) + r = 'NOT '+r + condition_string.append(r) - return ' WHERE ' + ' AND '.join(condStr) + return ' WHERE ' + ' AND '.join(condition_string) class Not: - "inverse of a restriction" - def __init__(self,restriction): + """ + inverse of a restriction + """ + def __init__(self, restriction): self._restriction = restriction class Join(_Relational): + alias_counter = 0 - aliasCounter = 0 - - def __init__(self,rel1,rel2): - if not isinstance(rel2,_Relational): + def __init__(self, rel1, rel2): + if not isinstance(rel2, _Relational): raise DataJointError('relvars can only be joined with other relvars') if rel1.conn is not rel2.conn: raise DataJointError('Cannot join relations with different database connections') self.conn = rel1.conn - self._rel1 = rel1; - self._rel2 = rel2; + self._rel1 = rel1 + self._rel2 = rel2 @property def heading(self): @@ -174,13 +184,12 @@ def heading(self): @property def sql(self): - Join.aliasCounter += 1 - return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.aliasCounter) + Join.alias_counter += 1 + return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.alias_counter) class Projection(_Relational): - - aliasCounter = 0 + alias_counter = 0 def __init__(self, rel, *arg, _sub, **kwarg): if _sub and isinstance(_sub, _Relational): @@ -201,18 +210,17 @@ def heading(self): class Subquery(_Relational): - - aliasCounter = 0; + alias_counter = 0 def __init__(self, rel): - self.conn = rel.conn; - self._rel = rel; + self.conn = rel.conn + self._rel = rel @property def sql(self): - self.aliasCounter = self.aliasCounter + 1; - return '(SELECT ' + self._rel.heading.asSQL + ' FROM ' + self._rel.sql + ') as `s%x`' % self.aliasCounter + self.alias_counter += 1 + return '(SELECT ' + self._rel.heading.as_sql + ' FROM ' + self._rel.sql + ') as `s%x`' % self.alias_counter @property def heading(self): - return self._rel.heading.resolveComputations() + return self._rel.heading.resolve_computations() \ No newline at end of file diff --git a/datajoint/settings.py b/datajoint/settings.py index 9ad4d97c6..e1c5c29e3 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -24,6 +24,7 @@ 'config.varname': 'DJ_LOCAL_CONF' } + class Config(collections.MutableMapping): """ Stores datajoint settings. Behaves like a dictionary, but applies validator functions @@ -31,7 +32,6 @@ class Config(collections.MutableMapping): The default parameters are stored in datajoint.settings.default . If a local config file exists, the settings specified in this file override the default settings. - """ def __init__(self, *args, **kwargs): @@ -90,4 +90,4 @@ def load(self, filename): ############################################################################# logger = logging.getLogger() -logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable +logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable \ No newline at end of file diff --git a/datajoint/task.py b/datajoint/task.py deleted file mode 100644 index 156c46281..000000000 --- a/datajoint/task.py +++ /dev/null @@ -1,53 +0,0 @@ -import queue -import threading - - -def _ping(): - print("The task thread is running") - - -class TaskQueue: - """ - Executes tasks in a single parallel thread in FIFO sequence. - Example: - queue = TaskQueue() - queue.submit(func1, arg1, arg2, arg3) - queue.submit(func2) - queue.quit() # wait until the last task is done and stop thread - - Datajoint applications may use a task queue for delayed inserts. - """ - def __init__(self): - self.queue = queue.Queue() - self.thread = threading.Thread(target=self._worker) - self.thread.daemon = True - self.thread.start() - - def empty(self): - return self.queue.empty() - - def submit(self, func=_ping, *args): - """Submit task for execution""" - self.queue.put((func, args)) - - def quit(self, timeout=3.0): - """Wait until all tasks finish""" - self.queue.put('quit') - self.thread.join(timeout) - if self.thread.isAlive(): - raise Exception('Task thread is still executing. Try quitting again.') - - def _worker(self): - while True: - msg = self.queue.get() - if msg=='quit': - self.queue.task_done() - break - fun, args = msg - try: - fun(*args) - except Exception as e: - print("Exception in the task thread:") - print(e) - self.queue.task_done() - diff --git a/demos/demo1.py b/demos/demo1.py index 18e54a4bf..8fedcf7fb 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -13,11 +13,9 @@ conn.bind(module=__name__, dbname='dj_test') # bind this module to the database - class Subject(dj.Base): - _table_def = """ + _definition = """ demo1.Subject (manual) # Basic subject info - subject_id : int # internal subject id --- real_id : varchar(40) # real-world name @@ -28,13 +26,10 @@ class Subject(dj.Base): animal_notes="" : varchar(4096) # strain, genetic manipulations, etc """ -class Exp2(dj.Base): - pass class Experiment(dj.Base): - _table_def = """ + _definition = """ demo1.Experiment (manual) # Basic subject info - -> demo1.Subject experiment : smallint # experiment number for this subject --- @@ -44,25 +39,24 @@ class Experiment(dj.Base): """ -class TwoPhotonSession(dj.Base): - _table_def = """ - demo1.TwoPhotonSession (manual) # a two-photon imaging session - +class Session(dj.Base): + _definition = """ + demo1.Session (manual) # a two-photon imaging session -> demo1.Experiment - tp_session : tinyint # two-photon session within this experiment - ---- + session_id : tinyint # two-photon session within this experiment + ----------- setup : tinyint # experimental setup lens : tinyint # lens e.g.: 10x, 20x. 25x, 60x """ -class EphysSetup(dj.Base): - _table_def = """ - demo1.EphysSetup (manual) # Ephys setup - setup_id : tinyint # unique seutp id + + +class Scan(dj.Base): + _definition = """ + demo1.Scan (manual) # a two-photon imaging session + -> demo1.Session + scan_id : tinyint # two-photon session within this experiment + ---- + setup : tinyint # experimental setup + lens : tinyint # lens e.g.: 10x, 20x. 25x, 60x """ -class EphysExperiment(dj.Base): - _table_def = """ - demo1.EphysExperiment (manual) # Ephys experiment - -> demo1.Subject - -> demo1.EphysSetup - """ \ No newline at end of file diff --git a/demos/rundemo1.py b/demos/rundemo1.py index ece931611..88a8f4502 100644 --- a/demos/rundemo1.py +++ b/demos/rundemo1.py @@ -8,35 +8,35 @@ import demo1 s = demo1.Subject() -# insert as dict +e = demo1.Experiment() + s.insert(dict(subject_id=1, - real_id='George', + real_id="George", species="monkey", date_of_birth="2011-01-01", sex="M", caretaker="Arthur", animal_notes="this is a test")) + s.insert(dict(subject_id=2, real_id='1373', date_of_birth="2014-08-01", caretaker="Joe")) -# insert as tuple. Attributes must be in the same order as in table declaration -s.insert((3,'Dennis','monkey','2012-09-01')) - -# TODO: insert as ndarray +s.insert((3, 'Dennis', 'monkey', '2012-09-01')) +s.insert((12430, 'C0430', 'mouse', '2012-09-01', 'M')) +s.insert((12431, 'C0431', 'mouse', '2012-09-01', 'F')) print('inserted keys into Subject:') for key in s: print(key) +e.insert(dict(subject_id=1, + experiment=1, + experiment_date="2014-08-28", + experiment_notes="my first experiment")) -# -e = demo1.Experiment() -e.insert(dict(subject_id=1,experiment=1,experiment_date="2014-08-28",experiment_notes="my first experiment")) -e.insert(dict(subject_id=1,experiment=2,experiment_date="2014-08-28",experiment_notes="my second experiment")) - - -# drop the tables -#s.drop -#e.drop +e.insert(dict(subject_id=1, + experiment=2, + experiment_date="2014-08-28", + experiment_notes="my second experiment")) From 4da119a5b7629e0f7590213f989edf65eb63af2d Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Sun, 3 May 2015 23:52:18 -0500 Subject: [PATCH 3/5] minor changes, not working yet --- datajoint/autopopulate.py | 31 ++++------ datajoint/blob.py | 92 +++++++++++++-------------- datajoint/connection.py | 6 -- datajoint/fetch.py | 61 ------------------ datajoint/relational.py | 127 +++++++++++++++++++++++++------------- demos/demo1.py | 5 +- 6 files changed, 144 insertions(+), 178 deletions(-) delete mode 100644 datajoint/fetch.py diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 624c01f78..ff8005b29 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -1,4 +1,5 @@ -from .relational import _Relational +from .relational import Relation +from . import DataJointError import pprint import abc @@ -7,8 +8,8 @@ class AutoPopulate(metaclass=abc.ABCMeta): """ - Class datajoint.AutoPopulate is a mixin that adds the method populate() to a dj.Relvar class. - Auto-populated relvars must inherit from both datajoint.Relvar and datajoint.AutoPopulate, + AutoPopulate is a mixin class that adds the method populate() to a Base class. + Auto-populated relations must inherit from both Base and AutoPopulate, must define the property pop_rel, and must define the callback method make_tuples. """ @@ -16,14 +17,14 @@ class AutoPopulate(metaclass=abc.ABCMeta): def pop_rel(self): """ Derived classes must implement the read-only property pop_rel (populate relation) which is the relational - expression (a dj.Relvar object) that defines how keys are generated for the populate call. + expression (a Relation object) that defines how keys are generated for the populate call. """ pass @abc.abstractmethod def make_tuples(self, key): """ - Derived classes must implement methods make_tuples that fetches data from parent tables, restricting by + Derived classes must implement method make_tuples that fetches data from parent tables, restricting by the given key, computes dependent attributes, and inserts the new tuples into self. """ pass @@ -33,21 +34,15 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): rel.populate() will call rel.make_tuples(key) for every primary key in self.pop_rel for which there is not already a tuple in rel. """ + if not isinstance(self.pop_rel, Relation): + raise DataJointError('') self.conn.cancel_transaction() - # enumerate unpopulated keys - unpopulated = self.pop_rel - if ~isinstance(unpopulated, _Relational): - unpopulated = unpopulated() # instantiate - + unpopulated = self.pop_rel - self if not unpopulated.count: - print('Nothing to populate') - else: - unpopulated = unpopulated(*args, **kwargs) # - self # TODO: implement antijoin - - # execute + print('Nothing to populate', flush=True) # TODO: use logging? if catch_errors: - errKeys, errors = [], [] + error_keys, errors = [], [] for key in unpopulated.fetch(): self.conn.start_transaction() n = self(key).count @@ -65,8 +60,8 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): raise print(e) errors += [e] - errKeys+= [key] + error_keys += [key] else: self.conn.commit_transaction() if catch_errors: - return errors, errKeys + return errors, error_keys diff --git a/datajoint/blob.py b/datajoint/blob.py index 82ac9e338..e66ca11ab 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -3,29 +3,27 @@ import numpy as np from . import DataJointError - mxClassID = collections.OrderedDict( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html - mxUNKNOWN_CLASS = None, - mxCELL_CLASS = None, # not yet implemented - mxSTRUCT_CLASS = None, # not yet implemented - mxLOGICAL_CLASS = np.dtype('bool'), - mxCHAR_CLASS = np.dtype('c'), - mxVOID_CLASS = None, - mxDOUBLE_CLASS = np.dtype('float64'), - mxSINGLE_CLASS = np.dtype('float32'), - mxINT8_CLASS = np.dtype('int8'), - mxUINT8_CLASS = np.dtype('uint8'), - mxINT16_CLASS = np.dtype('int16'), - mxUINT16_CLASS = np.dtype('uint16'), - mxINT32_CLASS = np.dtype('int32'), - mxUINT32_CLASS = np.dtype('uint32'), - mxINT64_CLASS = np.dtype('int64'), - mxUINT64_CLASS = np.dtype('uint64'), - mxFUNCTION_CLASS= None - ) + mxUNKNOWN_CLASS=None, + mxCELL_CLASS=None, # not implemented + mxSTRUCT_CLASS=None, # not implemented + mxLOGICAL_CLASS=np.dtype('bool'), + mxCHAR_CLASS=np.dtype('c'), + mxVOID_CLASS=None, + mxDOUBLE_CLASS=np.dtype('float64'), + mxSINGLE_CLASS=np.dtype('float32'), + mxINT8_CLASS=np.dtype('int8'), + mxUINT8_CLASS=np.dtype('uint8'), + mxINT16_CLASS=np.dtype('int16'), + mxUINT16_CLASS=np.dtype('uint16'), + mxINT32_CLASS=np.dtype('int32'), + mxUINT32_CLASS=np.dtype('uint32'), + mxINT64_CLASS=np.dtype('int64'), + mxUINT64_CLASS=np.dtype('uint64'), + mxFUNCTION_CLASS=None) -reverseClassID = {v:i for i,v in enumerate(mxClassID.values())} +reverseClassID = {v: i for i, v in enumerate(mxClassID.values())} def pack(obj): @@ -35,55 +33,53 @@ def pack(obj): if not isinstance(obj, np.ndarray): raise DataJointError("Only numpy arrays can be saved in blobs") - blob = b"mYm\0A" # TODO: extend to process other datatypes besides arrays - blob += np.asarray((len(obj.shape),)+obj.shape,dtype=np.uint64).tostring() + blob = b"mYm\0A" # TODO: extend to process other data types besides arrays + blob += np.asarray((len(obj.shape),) + obj.shape, dtype=np.uint64).tostring() - isComplex = np.iscomplexobj(obj) - if isComplex: - obj, objImag = np.real(obj), np.imag(obj) + is_complex = np.iscomplexobj(obj) + if is_complex: + obj, imaginary = np.real(obj), np.imag(obj) - typeNum = reverseClassID[obj.dtype] - blob+= np.asarray(typeNum, dtype=np.uint32).tostring() - blob+= np.int8(isComplex).tostring() + b'\0\0\0' - blob+= obj.tostring() + type_number = reverseClassID[obj.dtype] + blob += np.asarray(type_number, dtype=np.uint32).tostring() + blob += np.int8(is_complex).tostring() + b'\0\0\0' + blob += obj.tostring() - if isComplex: - blob+= objImag.tostring() + if is_complex: + blob += imaginary.tostring() - if len(blob)>1000: - compressed = b'ZL123\0'+np.asarray(len(blob),dtype=np.uint64).tostring() + zlib.compress(blob) + if len(blob) > 1000: + compressed = b'ZL123\0'+np.asarray(len(blob), dtype=np.uint64).tostring() + zlib.compress(blob) if len(compressed) < len(blob): blob = compressed return blob - def unpack(blob): """ unpack blob into a numpy array """ # decompress if necessary - if blob[0:5]==b'ZL123': - blobLen = np.fromstring(blob[6:14],dtype=np.uint64)[0] + if blob[0:5] == b'ZL123': + blob_length = np.fromstring(blob[6:14], dtype=np.uint64)[0] blob = zlib.decompress(blob[14:]) - assert(len(blob)==blobLen) + assert len(blob) == blob_length - blobType = blob[4] - if blobType!=65: # TODO: also process structure arrays, cell arrays, etc. + blob_type = blob[4] + if blob_type != 65: # TODO: also process structure arrays, cell arrays, etc. raise DataJointError('only arrays are currently allowed in blobs') p = 5 - ndims = np.fromstring(blob[p:p+8], dtype=np.uint64) + dimensions = np.fromstring(blob[p:p+8], dtype=np.uint64) p += 8 - arrDims = np.fromstring(blob[p:p+8*ndims], dtype=np.uint64) - p += 8 * ndims - mxType, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4],dtype=np.uint32)[0]] + array_shape = np.fromstring(blob[p:p+8*dimensions], dtype=np.uint64) + p += 8 * dimensions + mx_type, dtype = [q for q in mxClassID.items()][np.fromstring(blob[p:p+4], dtype=np.uint32)[0]] if dtype is None: - raise DataJointError('Unsupported matlab datatype '+mxType+' in blob') + raise DataJointError('Unsupported MATLAB data type '+mx_type+' in blob') p += 4 - complexity = np.fromstring(blob[p:p+4],dtype=np.uint32)[0] + is_complex = np.fromstring(blob[p:p+4], dtype=np.uint32)[0] p += 4 obj = np.fromstring(blob[p:], dtype=dtype) - if complexity: + if is_complex: obj = obj[:len(obj)/2] + 1j*obj[len(obj)/2:] - - return obj.reshape(arrDims) + return obj.reshape(array_shape) \ No newline at end of file diff --git a/datajoint/connection.py b/datajoint/connection.py index 6be43a2a8..d8adf2256 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -2,13 +2,9 @@ import re from .utils import to_camel_case from . import DataJointError -import os from .heading import Heading from .base import prefix_to_role import logging -import networkx as nx -from networkx import pygraphviz_layout -import matplotlib.pyplot as plt from .erd import DBConnGraph from . import config @@ -44,12 +40,10 @@ def conn(host=None, user=None, passwd=None, initFun=None, reset=False): return _connObj return conn - # The function conn is used by others to obtain the package wide persistent connection object conn = conn_container() - class Connection: """ A dj.Connection object manages a connection to a database server. diff --git a/datajoint/fetch.py b/datajoint/fetch.py deleted file mode 100644 index 5a89e55d7..000000000 --- a/datajoint/fetch.py +++ /dev/null @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Wed Aug 20 22:05:29 2014 - -@author: dimitri -""" - -from .blob import unpack -import numpy as np -import logging - -logger = logging.getLogger(__name__) - - -class Fetch: - """ - Fetch defines callable objects that fetch data from a relation - """ - def __init__(self, relational): - self.rel = relational - self._orderBy = None - self._offset = 0 - self._limit = None - - def limit(self, n, offset=0): - self._limit = n - self._offset = offset - return self - - def order_by(self, *attrs): - self._orderBy = attrs - return self - - def __call__(self, *attrs, **renames): - """ - fetch relation from database into an np.array - """ - cur = self._cursor(*attrs, **renames) - heading = self.rel.pro(*attrs, **renames).heading - ret = np.array(list(cur), dtype=heading.asdtype) - # unpack blobs - for i in range(len(ret)): - for f in heading.blobs: - ret[i][f] = unpack(ret[i][f]) - return ret - - def _cursor(self, *attrs, **renames): - rel = self.rel.pro(*attrs, **renames) - sql = 'SELECT ' + rel.heading.asSQL + ' FROM ' + rel.sql - # add ORDER BY clause - if self._orderBy: - sql += ' ORDER BY ' + ', '.join(self._orderBy) - - # add LIMIT clause - if self._limit: - sql += ' LIMIT %d' % self._limit - if self._offset: - sql += ' OFFSET %d ' % self._offset - - logger.debug(sql) - return self.rel.conn.query(sql) diff --git a/datajoint/relational.py b/datajoint/relational.py index 244f4cf3c..c7400b9d1 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -8,25 +8,22 @@ import abc from copy import copy from datajoint import DataJointError -from .fetch import Fetch +import logging +logger = logging.getLogger(__name__) -class _Relational(metaclass=abc.ABCMeta): + +class Relation(metaclass=abc.ABCMeta): """ - Relational implements relational operators. - Relational objects reference other relational objects linked by operators. - The leaves of this tree of objects are base relvars. - When fetching data from the database, this tree of objects is compiled into - and SQL expression. - It is a mixin class that provides relational operators, iteration, and - fetch capability. - Relational operators are: restrict, pro, aggr, and join. + Relation implements relational operators. + Relation objects reference other relation objects linked by operators. + The leaves of this tree of objects are base relations. + When fetching data from the database, this tree of objects is compiled into an SQL expression. + It is a mixin class that provides relational operators, iteration, and fetch capability. + Relation operators are: restrict, pro, and join. """ _restrictions = [] - _limit = None - _offset = 0 - _order_by = [] - + @abc.abstractproperty def sql(self): return NotImplemented @@ -34,6 +31,10 @@ def sql(self): @abc.abstractproperty def heading(self): return NotImplemented + + @property + def restrictions(self): + return self._restrictions def __mul__(self, other): """ @@ -41,11 +42,17 @@ def __mul__(self, other): """ return Join(self, other) - def pro(self, *arg, _sub=None, **kwarg): + def pro(self, select=None, rename=None, expand=None, aggregate=None): """ - relational projection abd aggregation + relational operators project, rename, expand, and aggregate. Primary key attributes are always included unless + renamed. + :param select: list of attributes to project; '*' stands for all attributes. + :param rename: dictionary of renamed attributes + :param expand: dictionary of computed attributes, including summary operators on the aggregated relation + :param aggregate: a relation for which summary computations can be performed in expand + :return: projected Relation object """ - return Projection(self, _sub=_sub, *arg, **kwarg) + return Projection(self, select, rename, expand, aggregate) def __iand__(self, restriction): """ @@ -84,11 +91,40 @@ def __sub__(self, restriction): def count(self): sql = 'SELECT count(*) FROM ' + self.sql + self._where_clause cur = self.conn.query(sql) - return cur.fetchone()[0] #todo: should we assert that this only returns one result? + return cur.fetchone()[0] - @property - def fetch(self): - return Fetch(self) + def __call__(self, offset=0, limit=None, order_by=None, descending=False): + """ + fetches the relation from the database table into an np.array and unpacks blob attributes. + :param offset: the number of tuples to skip in the returned result + :param limit: the maximum number of tuples to return + :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None. + :param descending: the list of attributes to order the results + :return: the contents of the relation in the form of a structured numpy.array + """ + cur = self.cursor(offset, limit, order_by, descending) + ret = np.array(list(cur), dtype=self.heading.asdtype) + for f in self.heading.blobs: + for i in range(len(ret)): + ret[i][f] = unpack(ret[i][f]) + return ret + + def cursor(self, offset=0, limit=None, order_by=None, descending=False): + """ + :param attributes: projection attributes + :param renames: rename attributes + :return: cursor to the query + """ + rel = self.rel.pro(*attributes, **renames) + sql = 'SELECT ' + rel.heading.asSQL + ' FROM ' + rel.sql + if self._orderBy: + sql += ' ORDER BY ' + ', '.join(self._orderBy) + if self._limit: + sql += ' LIMIT %d' % self._limit + if self._offset: + sql += ' OFFSET %d ' % self._offset + logger.debug(sql) + return self.rel.conn.query(sql) def __repr__(self): header = self.heading.non_blobs @@ -96,7 +132,7 @@ def __repr__(self): width = 12 template = '%%-%d.%ds' % (width, width) ret_val = ' '.join([template % column for column in header]) + '\n' - ret_val += ' '.join(['+' + '-' * (width - 2) + '+' for column in header]) + '\n' + ret_val += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n' tuples = self.fetch.limit(limit)(*header) for tup in tuples: @@ -115,10 +151,9 @@ def __iter__(self): dtype = h.asdtype q = cur.fetchone() while q: - yield np.array([q,],dtype=dtype) #todo: why convert that into an array? + yield np.array([q, ], dtype=dtype) #todo: why convert that into an array? q = cur.fetchone() - @property def _where_clause(self): """ @@ -126,26 +161,26 @@ def _where_clause(self): """ if not self._restrictions: return '' - + def make_condition(arg): - if isinstance(arg,dict): - conds = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()] - elif isinstance(arg,np.void): - conds = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields] + if isinstance(arg, dict): + conditions = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items()] + elif isinstance(arg, np.void): + conditions = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields] else: raise DataJointError('invalid restriction type') - return ' AND '.join(conds) + return ' AND '.join(conditions) condition_string = [] for r in self._restrictions: negate = isinstance(r, Not) if negate: - r = r._restriction + r = r.restrictions if isinstance(r, dict) or isinstance(r, np.void): r = make_condition(r) elif isinstance(r, np.ndarray) or isinstance(r, list): r = '('+') OR ('.join([make_condition(q) for q in r])+')' - elif isinstance(r, _Relational): + elif isinstance(r, Relation): common_attributes = ','.join([q for q in self.heading.names if r.heading.names]) r = '(%s) in (SELECT %s FROM %s)' % (common_attributes, common_attributes, r.sql) @@ -166,11 +201,11 @@ def __init__(self, restriction): self._restriction = restriction -class Join(_Relational): +class Join(Relation): alias_counter = 0 def __init__(self, rel1, rel2): - if not isinstance(rel2, _Relational): + if not isinstance(rel2, Relation): raise DataJointError('relvars can only be joined with other relvars') if rel1.conn is not rel2.conn: raise DataJointError('Cannot join relations with different database connections') @@ -188,17 +223,21 @@ def sql(self): return '%s NATURAL JOIN %s as `j%x`' % (self._rel1.sql, self._rel2.sql, Join.alias_counter) -class Projection(_Relational): +class Projection(Relation): alias_counter = 0 - def __init__(self, rel, *arg, _sub, **kwarg): - if _sub and isinstance(_sub, _Relational): - raise DataJointError('Relational join must receive two relations') - self.conn = rel.conn - self._rel = rel - self._sub = _sub - self._selection = arg - self._renames = kwarg + def __init__(self, relation, select, rename, expand, aggregate): + """ + See Relation.pro() + """ + if aggregate is not None and not isinstance(aggregate, Relation): + raise DataJointError('Relation join must receive two relations') + self.conn = relation.conn + self._relation = relation + self._select = select + self._rename = rename + self._expand = expand + self._aggregate = aggregate @property def sql(self): @@ -209,7 +248,7 @@ def heading(self): return self._rel.heading.pro(*self._selection, **self._renames) -class Subquery(_Relational): +class Subquery(Relation): alias_counter = 0 def __init__(self, rel): diff --git a/demos/demo1.py b/demos/demo1.py index 5e1e3b7b7..ea90f7ed9 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -59,4 +59,7 @@ class Scan(dj.Base): depth : float # depth from surface wavelength : smallint # (nm) laser wavelength mwatts: numeric(4,1) # (mW) laser power to brain - """ \ No newline at end of file + """ + + +class Alignment(dj.Base, ) \ No newline at end of file From d7b71a7c034fa460429b325faf6f849f28c012ca Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 4 May 2015 02:38:20 -0500 Subject: [PATCH 4/5] split class Base into class Table and class Base --- datajoint/__init__.py | 8 +- datajoint/autopopulate.py | 3 +- datajoint/base.py | 310 +++++++++++++------------------------- datajoint/heading.py | 33 ++-- datajoint/relational.py | 52 ++++--- datajoint/table.py | 91 +++++++++++ demos/demo1.py | 22 ++- 7 files changed, 263 insertions(+), 256 deletions(-) create mode 100644 datajoint/table.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index f44d86939..3c38699a7 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -7,16 +7,15 @@ 'Connection', 'Heading', 'Base', 'Not', 'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob'] -# ------------ define datajoint error before the import hierarchy is flattened ------------ + +# ----- define datajoint error before the import hierarchy is flattened -------- class DataJointError(Exception): """ - Base class for errors specific to DataJoint internal - operation. + Base class for errors specific to DataJoint internal operation. """ pass - # ----------- loads local configuration from file ---------------- from .settings import Config, logger config = Config() @@ -37,7 +36,6 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection from .base import Base -from .task import TaskQueue from .autopopulate import AutoPopulate from . import blob from .relational import Not diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index ff8005b29..8e32ef43d 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -51,7 +51,6 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): else: print('Populating:') pprint.pprint(key) - try: self.make_tuples(key) except Exception as e: @@ -64,4 +63,4 @@ def populate(self, catch_errors=False, reserve_jobs=False, restrict=None): else: self.conn.commit_transaction() if catch_errors: - return errors, error_keys + return errors, error_keys \ No newline at end of file diff --git a/datajoint/base.py b/datajoint/base.py index 5d01ec1a4..4f9ed9355 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -1,11 +1,12 @@ import importlib import re +import abc from types import ModuleType -import numpy as np from enum import Enum from .utils import from_camel_case from . import DataJointError -from .relational import _Relational +from .relational import Relation +from .table import Table from .heading import Heading import logging @@ -27,22 +28,11 @@ mysql_constants = ['CURRENT_TIMESTAMP'] -class Base(_Relational): - +class Base(Table, metaclass=abc.ABCMeta): """ - Base integrates all data manipulation and data declaration functions. - An instance of the class provides an interface to a single table in the database. - - An instance of the class can be produced in two ways: - - 1. direct instantiation (used mostly for debugging and odd jobs) - - 2. instantiation from a derived class (regular recommended use) + Base is a Table that implements data definition functions. + It is an abstract class with the abstract property 'definition'. - With direct instantiation, instance parameters must be explicitly specified. - With a derived class, all the instance parameters are taken from the module - of the deriving class. The module must declare the connection object conn. - The name of the deriving class is used as the table's className. The table associated with an instance of Base is identified by the className property, which is a string in CamelCase. The actual table name is obtained @@ -72,9 +62,8 @@ class Base(_Relational): class Subjects(dj.Base): - _definition = ''' + definition = ''' test1.Subjects (manual) # Basic subject info - subject_id : int # unique subject id --- real_id : varchar(40) # real-world name @@ -83,22 +72,8 @@ class Subjects(dj.Base): """ - def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): - self._use_package = False - if self.__class__ is Base: - # instantiate without subclassing - if not (conn and dbname and class_name): - raise DataJointError( - 'Missing argument: please specify conn, dbname, and class name.') - self.class_name = class_name - self.conn = conn - self.dbname = dbname - self._definition = table_def - # register with a fake module, enclosed in back quotes - - if dbname not in self.conn.db_to_mod: - self.conn.bind('`{0}`'.format(dbname), dbname) - else: + + def __init__(self): # instantiate a derived class if conn or dbname or class_name or table_def: raise DataJointError( @@ -117,7 +92,7 @@ def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): raise DataJointError( "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) try: - if (self._use_package): + if self._use_package: pkg_name = '.'.join(module.split('.')[:-1]) self.dbname = self.conn.mod_to_db[pkg_name] else: @@ -126,51 +101,6 @@ def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): raise DataJointError( 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - if hasattr(self, '_definition'): - self._definition = self._definition - else: - self._definition = None - - # todo: do we support records and named tuples for tup? - def insert(self, tup, ignore_errors=False, replace=False): - """ - Insert one data tuple. - - :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void. - :param ignore_errors=False: Ignores errors if True. - :param replace=False: Replaces data tuple if True. - - Example:: - - b = djtest.Subject() - b.insert( dict(subject_id = 7, species="mouse",\\ - real_id = 1007, date_of_birth = "2014-09-01") ) - """ - - if issubclass(type(tup), tuple) or issubclass(type(tup), list): - valueList = ','.join([repr(q) for q in tup]) - fieldList = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`' - elif issubclass(type(tup), dict): - valueList = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) - fieldList = '`' + \ - '`,`'.join([q for q in self.heading.names if q in tup]) + '`' - elif issubclass(type(tup), np.void): - valueList = ','.join([repr(tup[q]) - for q in self.heading.names if q in tup]) - fieldList = '`' + '`,`'.join(tup.dtype.fields) + '`' - else: - raise DataJointError('Datatype %s cannot be inserted' % type(tup)) - if replace: - sql = 'REPLACE' - elif ignore_errors: - sql = 'INSERT IGNORE' - else: - sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, - fieldList, valueList) - logger.info(sql) - self.conn.query(sql) def drop(self): """ @@ -182,14 +112,6 @@ def drop(self): self.conn.load_headings(dbname=self.dbname, force=True) logger.debug("Dropped table %s" % self.full_table_name) - @property - def sql(self): - return self.full_table_name + self._whereClause - - @property - def heading(self): - self.declare() - return self.conn.headings[self.dbname][self.table_name] @property def is_declared(self): @@ -199,21 +121,6 @@ def is_declared(self): self.conn.load_headings(self.dbname) return self.class_name in self.conn.table_names[self.dbname] - @property - def table_name(self): - """ - :return: name of the associated table - """ - self.declare() - return self.conn.table_names[self.dbname][self.class_name] - - @property - def full_table_name(self): - """ - :return: full name of the associated table - """ - return '`%s`.`%s`' % (self.dbname, self.table_name) - @property def full_class_name(self): """ @@ -241,18 +148,18 @@ def declare(self): 'Table could not be declared for %s' % self.class_name) """ - Data definition functionalities + Data definition methods """ - def set_table_comment(self, newComment): + def set_table_comment(self, comment): """ Update the table comment in the table declaration. - :param newComment: new comment as string + :param comment: new comment as string """ # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % newComment) + self.alter('COMMENT="%s"' % comment) def add_attribute(self, definition, after=None): """ @@ -269,7 +176,7 @@ def add_attribute(self, definition, after=None): """ position = ' FIRST' if after is None else ( ' AFTER %s' % after if after else '') - sql = self._field_to_SQL(self._parse_attr_def(definition)) + sql = self._field_to_SQL(_parse_attr_def(definition)) self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) def drop_attribute(self, attr_name): @@ -287,7 +194,7 @@ def alter_attribute(self, attr_name, new_definition): :param attr_name: field that is redefined :param new_definition: new definition of the field """ - sql = self._field_to_SQL(self._parse_attr_def(new_definition)) + sql = self._field_to_SQL(_parse_attr_def(new_definition)) self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) def erd(self, subset=None, prog='dot'): @@ -374,36 +281,7 @@ def get_base(self, module_name, class_name): class_name=class_name) return ret - # //////////////////////////////////////////////////////////// - # Private Methods - # //////////////////////////////////////////////////////////// - def _field_to_SQL(self, field): - """ - Converts an attribute definition tuple into SQL code. - - :param field: attribute definition - :rtype : SQL code - """ - if field.isNullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) - - # TODO: escape instead! - same goes for Matlab side implementation - assert not any((c in r'\"' for c in field.comment)), \ - 'Illegal characters in attribute comment "%s"' % field.comment - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) def _alter(self, alter_statement): """ @@ -421,7 +299,7 @@ def _declare(self): """ Declares the table in the data base if no table in the database matches this object. """ - if not self._definition: + if not self.definition: raise DataJointError('Table declaration is missing!') table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() defined_name = table_info['module'] + '.' + table_info['className'] @@ -515,7 +393,7 @@ def _parse_declaration(self): referenced = [] index_defs = [] field_defs = [] - declaration = re.split(r'\s*\n\s*', self._definition.strip()) + declaration = re.split(r'\s*\n\s*', self.definition.strip()) # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] @@ -551,72 +429,100 @@ def _parse_declaration(self): rel = self.get_base(module_name, class_name) (parents if in_key else referenced).append(rel) elif re.match(r'^(unique\s+)?index[^:]*$', line): - index_defs.append(self._parse_index_def(line)) + index_defs.append(_parse_index_def(line)) elif fieldP.match(line): - field_defs.append(self._parse_attr_def(line, in_key)) + field_defs.append(_parse_attr_def(line, in_key)) else: raise DataJointError( 'Invalid table declaration line "%s"' % line) return table_info, parents, referenced, field_defs, index_defs - def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - - :param line: attribution line - :param in_key: - :returns: attribute tuple - """ - line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['isNullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['isNullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - attr_info['isKey'] = in_key - attr_info['isAutoincrement'] = None - attr_info['isNumeric'] = None - attr_info['isString'] = None - attr_info['isBlob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) - - def _parse_index_def(self, line): - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_ptrn = """ - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info + +def _field_to_SQL(field): + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + if field.isNullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes (even numeric), except special SQL values + # or values already enclosed by the user + if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: + default = '%s DEFAULT %s' % (default, field.default) + else: + default = '%s DEFAULT "%s"' % (default, field.default) + + # TODO: escape instead! - same goes for Matlab side implementation + assert not any((c in r'\"' for c in field.comment)), \ + 'Illegal characters in attribute comment "%s"' % field.comment + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + +def _parse_attr_def(line, in_key=False): # todo add docu for in_key + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + :param line: attribution line + :param in_key: + :returns: attribute tuple + """ + line = line.strip() + attr_ptrn = """ + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """ + + attrP = re.compile(attr_ptrn, re.I + re.X) + m = attrP.match(line) + assert m, 'Invalid field declaration "%s"' % line + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['isNullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['isNullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + attr_info['isKey'] = in_key + attr_info['isAutoincrement'] = None + attr_info['isNumeric'] = None + attr_info['isString'] = None + attr_info['isBlob'] = None + attr_info['computation'] = None + attr_info['dtype'] = None + + return Heading.AttrTuple(**attr_info) + + +def _parse_index_def(line): # why is this a method of Base instead of a local function? + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_ptrn = """ + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """ + indexP = re.compile(index_ptrn, re.I + re.X) + m = indexP.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info diff --git a/datajoint/heading.py b/datajoint/heading.py index 93be6ef1c..0176a2cbe 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -60,16 +60,16 @@ def __repr__(self): for k, v in self.attributes.items()]) @property - def asdtype(self): + def as_dtype(self): """ represent the heading as a numpy dtype """ return np.dtype(dict( names=self.names, - formats=[v.dtype for k,v in self.attributes.items()])) + formats=[v.dtype for k, v in self.attributes.items()])) @property - def as_sql(self): # TODO: replace with __str__? + def as_sql(self): """ represent heading as SQL field list """ @@ -88,13 +88,13 @@ def items(self): return self.attributes.items() @classmethod - def init_from_database(cls, conn, dbname, tabname): + def init_from_database(cls, conn, dbname, table_name): """ initialize heading from a database table """ cur = conn.query( - 'SHOW FULL COLUMNS FROM `{tabname}` IN `{dbname}`'.format( - tabname=tabname, dbname=dbname), asDict=True) + 'SHOW FULL COLUMNS FROM `{table_name}` IN `{dbname}`'.format( + table_name=table_name, dbname=dbname), asDict=True) attributes = cur.fetchall() rename_map = { @@ -127,7 +127,7 @@ def init_from_database(cls, conn, dbname, tabname): ('int', True): np.uint32, ('bigint', False): np.int64, ('bigint', True): np.uint64 - # TODO: include decimal and numeric datatypes + # TODO: include types DECIMAL and NUMERIC } # additional attribute properties @@ -144,22 +144,22 @@ def init_from_database(cls, conn, dbname, tabname): attr['computation'] = None if not (attr['numeric'] or attr['string'] or attr['is_blob']): - raise DataJointError('Unsupported field type {field} in `{dbname}`.`{tabname}`'.format( - field=attr['type'], dbname=dbname, tabname=tabname)) + raise DataJointError('Unsupported field type {field} in `{dbname}`.`{table_name}`'.format( + field=attr['type'], dbname=dbname, table_name=table_name)) attr.pop('Extra') # fill out the dtype. All floats and non-nullable integers are turned into specific dtypes attr['dtype'] = object if attr['numeric']: - isInteger = bool(re.match(r'(tiny|small|medium|big)?int', attr['type'])) - isFloat = bool(re.match(r'(double|float)',attr['type'])) - if isInteger and not attr['nullable'] or isFloat: - isUnsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE)) + is_integer = bool(re.match(r'(tiny|small|medium|big)?int', attr['type'])) + is_float = bool(re.match(r'(double|float)', attr['type'])) + if is_integer and not attr['nullable'] or is_float: + is_unsigned = bool(re.match('\sunsigned', attr['type'], flags=re.IGNORECASE)) t = attr['type'] t = re.sub(r'\(.*\)', '', t) # remove parentheses t = re.sub(r' unsigned$', '', t) # remove unsigned - assert (t, isUnsigned) in numeric_types, 'dtype not found for type %s' % t - attr['dtype'] = numeric_types[(t, isUnsigned)] + assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t + attr['dtype'] = numeric_types[(t, is_unsigned)] return cls(attributes) @@ -223,5 +223,4 @@ def resolve_computations(self): """ Remove computations. To be done after computations have been resolved in a subquery """ - return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) - + return Heading([dict(v._asdict(), computation=None) for v in self.attributes.values()]) \ No newline at end of file diff --git a/datajoint/relational.py b/datajoint/relational.py index c7400b9d1..57dc843b5 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -8,6 +8,7 @@ import abc from copy import copy from datajoint import DataJointError +from .blob import unpack import logging logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ class Relation(metaclass=abc.ABCMeta): """ - Relation implements relational operators. + Relation implements relational algebra and fetch methods. Relation objects reference other relation objects linked by operators. The leaves of this tree of objects are base relations. When fetching data from the database, this tree of objects is compiled into an SQL expression. @@ -111,47 +112,48 @@ def __call__(self, offset=0, limit=None, order_by=None, descending=False): def cursor(self, offset=0, limit=None, order_by=None, descending=False): """ - :param attributes: projection attributes - :param renames: rename attributes + :param offset: the number of tuples to skip in the returned result + :param limit: the maximum number of tuples to return + :param order_by: the list of attributes to order the results. No ordering should be assumed if order_by=None. + :param descending: the list of attributes to order the results :return: cursor to the query """ - rel = self.rel.pro(*attributes, **renames) - sql = 'SELECT ' + rel.heading.asSQL + ' FROM ' + rel.sql - if self._orderBy: + if offset and limit is None: + raise DataJointError('') + sql = 'SELECT ' + self.heading.as_sql + ' FROM ' + self.sql + if order_by is not None: sql += ' ORDER BY ' + ', '.join(self._orderBy) - if self._limit: - sql += ' LIMIT %d' % self._limit - if self._offset: - sql += ' OFFSET %d ' % self._offset + if descending: + sql += ' DESC' + if limit is not None: + sql += ' LIMIT %d' % limit + if offset: + sql += ' OFFSET %d' % offset logger.debug(sql) - return self.rel.conn.query(sql) + return self.conn.query(sql) def __repr__(self): - header = self.heading.non_blobs limit = 13 width = 12 template = '%%-%d.%ds' % (width, width) - ret_val = ' '.join([template % column for column in header]) + '\n' - ret_val += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n' - - tuples = self.fetch.limit(limit)(*header) + repr_string = ' '.join([template % column for column in header]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in header]) + '\n' + tuples = self.pro(*self.heading.non_blobs).fetch(limit=limit) for tup in tuples: - ret_val += ' '.join([template % column for column in tup]) + '\n' - cnt = self.count - if cnt > limit: - ret_val += '...\n' - ret_val += '%d tuples\n' % self.count - return ret_val + repr_string += ' '.join([template % column for column in tup]) + '\n' + if self.count > limit: + repr_string += '...\n' + repr_string += '%d tuples\n' % self.count + return repr_string def __iter__(self): """ iterator yields primary key tuples """ - cur, h = self.fetch._cursor() - dtype = h.asdtype + cur, h = self.pro().cursor() q = cur.fetchone() while q: - yield np.array([q, ], dtype=dtype) #todo: why convert that into an array? + yield np.array([q, ], dtype=h.asdtype) q = cur.fetchone() @property diff --git a/datajoint/table.py b/datajoint/table.py new file mode 100644 index 000000000..5bf7a39d8 --- /dev/null +++ b/datajoint/table.py @@ -0,0 +1,91 @@ +import numpy as np +from . import DataJointError +from .relational import Relation +import logging + +logger = logging.getLogger(__name__) + + +class Table(Relation): + """ + A Table object is a relation associated with a table. + A Table object provides insert and delete methods. + Table objects are only used internally and for debugging. + The table must already exist in the schema for the table object to work. + The table is identified by its "class name", or its CamelCase version. + """ + + def __init__(self, conn=None, dbname=None, class_name=None): + self._use_package = False + self.class_name = class_name + self.conn = conn + self.dbname = dbname + if dbname not in self.conn.db_to_mod: + # register with a fake module, enclosed in back quotes + self.conn.bind('`{0}`'.format(dbname), dbname) + + @property + def sql(self): + return self.full_table_name + self._whereClause + + @property + def heading(self): + self.declare() + return self.conn.headings[self.dbname][self.table_name] + + @property + def full_table_name(self): + """ + :return: full name of the associated table + """ + return '`%s`.`%s`' % (self.dbname, self.table_name) + + @property + def table_name(self): + """ + :return: name of the associated table + """ + return self.conn.table_names[self.dbname][self.class_name] + + def insert(self, tup, ignore_errors=False, replace=False): + """ + Insert one data tuple. + + :param tup: Data tuple. Can be an iterable in matching order, a dict with named fields, or an np.void. + :param ignore_errors=False: Ignores errors if True. + :param replace=False: Replaces data tuple if True. + + Example:: + + b = djtest.Subject() + b.insert(dict(subject_id = 7, species="mouse",\\ + real_id = 1007, date_of_birth = "2014-09-01")) + """ + # todo: do we support records and named tuples for tup? + + if issubclass(type(tup), tuple) or issubclass(type(tup), list): + value_list = ','.join([repr(q) for q in tup]) + attribute_list = '`'+'`,`'.join(self.heading.names[0:len(tup)]) + '`' + elif issubclass(type(tup), dict): + value_list = ','.join([repr(tup[q]) + for q in self.heading.names if q in tup]) + attribute_list = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`' + elif issubclass(type(tup), np.void): + value_list = ','.join([repr(tup[q]) + for q in self.heading.names if q in tup]) + attribute_list = '`' + '`,`'.join(tup.dtype.fields) + '`' + else: + raise DataJointError('Datatype %s cannot be inserted' % type(tup)) + if replace: + sql = 'REPLACE' + elif ignore_errors: + sql = 'INSERT IGNORE' + else: + sql = 'INSERT' + sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, + attribute_list, value_list) + logger.info(sql) + self.conn.query(sql) + + def delete(self): # TODO + pass diff --git a/demos/demo1.py b/demos/demo1.py index ea90f7ed9..ba926f8c6 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -6,6 +6,7 @@ """ import datajoint as dj +import os print("Welcome to the database 'demo1'") @@ -14,7 +15,7 @@ class Subject(dj.Base): - _definition = """ + definition = """ demo1.Subject (manual) # Basic subject info subject_id : int # internal subject id --- @@ -28,11 +29,12 @@ class Subject(dj.Base): class Experiment(dj.Base): - _definition = """ + definition = """ demo1.Experiment (manual) # Basic subject info -> demo1.Subject experiment : smallint # experiment number for this subject --- + experiment_folder : varchar(255) # folder path experiment_date : date # experiment start date experiment_notes="" : varchar(4096) experiment_ts=CURRENT_TIMESTAMP : timestamp # automatic timestamp @@ -40,7 +42,7 @@ class Experiment(dj.Base): class Session(dj.Base): - _definition = """ + definition = """ demo1.Session (manual) # a two-photon imaging session -> demo1.Experiment session_id : tinyint # two-photon session within this experiment @@ -51,7 +53,7 @@ class Session(dj.Base): class Scan(dj.Base): - _definition = """ + definition = """ demo1.Scan (manual) # a two-photon imaging session -> demo1.Session scan_id : tinyint # two-photon session within this experiment @@ -61,5 +63,15 @@ class Scan(dj.Base): mwatts: numeric(4,1) # (mW) laser power to brain """ +class ScanInfo(dj.Base, dj.AutoPopulate): + definition = """ + + """ + + pop_rel = Session + + def make_tuples(self, key): + info = (Session()*Scan() & key).pro('experiment_folder').fetch() + filename = os.path.join(info.experiment_folder, 'scan_%03', ) + -class Alignment(dj.Base, ) \ No newline at end of file From 56b7493f64f60b609d69f122d97f226b7402c64a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 4 May 2015 17:33:27 -0500 Subject: [PATCH 5/5] intermediate massive changes, not working yet --- datajoint/__init__.py | 10 +- datajoint/base.py | 487 +++------------------------------------- datajoint/connection.py | 21 +- datajoint/declare.py | 239 ++++++++++++++++++++ datajoint/erd.py | 1 + datajoint/heading.py | 6 - datajoint/relational.py | 67 ++++-- datajoint/settings.py | 12 + datajoint/table.py | 153 ++++++++++++- datajoint/utils.py | 11 +- demos/demo1.py | 6 +- 11 files changed, 488 insertions(+), 525 deletions(-) create mode 100644 datajoint/declare.py diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 3c38699a7..9e0142e2b 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -1,14 +1,13 @@ import logging import os -__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" +__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" __version__ = "0.2" __all__ = ['__author__', '__version__', 'Connection', 'Heading', 'Base', 'Not', - 'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob'] + 'AutoPopulate', 'conn', 'DataJointError', 'blob'] -# ----- define datajoint error before the import hierarchy is flattened -------- class DataJointError(Exception): """ Base class for errors specific to DataJoint internal operation. @@ -38,7 +37,4 @@ class DataJointError(Exception): from .base import Base from .autopopulate import AutoPopulate from . import blob -from .relational import Not - - - +from .relational import Not \ No newline at end of file diff --git a/datajoint/base.py b/datajoint/base.py index 4f9ed9355..96a34859d 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -1,61 +1,19 @@ import importlib -import re import abc from types import ModuleType from enum import Enum -from .utils import from_camel_case from . import DataJointError -from .relational import Relation from .table import Table -from .heading import Heading import logging - -# table names have prefixes that designate their roles in the processing chain logger = logging.getLogger(__name__) -# Todo: Shouldn't this go into the settings module? -Role = Enum('Role', 'manual lookup imported computed job') -role_to_prefix = { - Role.manual: '', - Role.lookup: '#', - Role.imported: '_', - Role.computed: '__', - Role.job: '~' -} -prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys())) - -mysql_constants = ['CURRENT_TIMESTAMP'] - class Base(Table, metaclass=abc.ABCMeta): """ Base is a Table that implements data definition functions. It is an abstract class with the abstract property 'definition'. - - The table associated with an instance of Base is identified by the className - property, which is a string in CamelCase. The actual table name is obtained - by converting className from CamelCase to underscore_separated_words and - prefixing according to the table's role. - - The table declaration can be specified in the doc string of the inheriting - class, in the DataJoint table declaration syntax. - - Base also implements the methods insert and delete to insert and delete tuples - from the table. It can also be an argument in relational operators: restrict, - join, pro, and aggr. See class :mod:`datajoint.relational`. - - Base instances return their table's heading by looking it up in the connection - object. This ensures that Base instances contain the current table definition - even after tables are modified after the instance is created. - - :param conn=None: :mod:`datajoint.connection.Connection` object. Only used when Base is - instantiated directly. - :param dbname=None: Name of the database. Only used when Base is instantiated directly. - :param class_name=None: Class name. Only used when Base is instantiated directly. - :param table_def=None: Declaration of the table. Only used when Base is instantiated directly. - Example for a usage of Base:: import datajoint as dj @@ -72,166 +30,39 @@ class Subjects(dj.Base): """ + @abc.abstractproperty + def definition(self): + """ + :return: string containing the table declaration using the DataJoint Data Definition Language. + The DataJoint DDL is described at: TODO + """ + pass def __init__(self): - # instantiate a derived class - if conn or dbname or class_name or table_def: - raise DataJointError( - 'With derived classes, constructor arguments are ignored') # TODO: consider changing this to a warning instead - self.class_name = self.__class__.__name__ - module = self.__module__ - mod_obj = importlib.import_module(module) + self.class_name = self.__class__.__name__ + module = self.__module__ + mod_obj = importlib.import_module(module) + try: + conn = mod_obj.conn + except AttributeError: try: - self.conn = mod_obj.conn + pkg_obj = importlib.import_module(mod_obj.__package__) + conn = pkg_obj.conn + use_package = True except AttributeError: - try: - pkg_obj = importlib.import_module(mod_obj.__package__) - self.conn = pkg_obj.conn - self._use_package = True - except AttributeError: - raise DataJointError( - "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) - try: - if self._use_package: - pkg_name = '.'.join(module.split('.')[:-1]) - self.dbname = self.conn.mod_to_db[pkg_name] - else: - self.dbname = self.conn.mod_to_db[module] - except KeyError: raise DataJointError( - 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - - - def drop(self): - """ - Drops the table associated to this object. - """ - # TODO make cascading (github issue #16) - self.conn.query('DROP TABLE %s' % self.full_table_name) - self.conn.clear_dependencies(dbname=self.dbname) - self.conn.load_headings(dbname=self.dbname, force=True) - logger.debug("Dropped table %s" % self.full_table_name) - - - @property - def is_declared(self): - """ - :returns: True if table is found in the database - """ - self.conn.load_headings(self.dbname) - return self.class_name in self.conn.table_names[self.dbname] - - @property - def full_class_name(self): - """ - :return: full class name - """ - return '{}.{}'.format(self.__module__, self.class_name) - - @property - def primary_key(self): - """ - :return: primary key of the table - """ - return self.heading.primary_key - - def declare(self): - """ - Declare the table in database if it doesn't already exist. - - :raises: DataJointError if the table cannot be declared. - """ - if not self.is_declared: - self._declare() - if not self.is_declared: - raise DataJointError( - 'Table could not be declared for %s' % self.class_name) - - """ - Data definition methods - """ - - def set_table_comment(self, comment): - """ - Update the table comment in the table declaration. - - :param comment: new comment as string - - """ - # TODO: add verification procedure (github issue #24) - self.alter('COMMENT="%s"' % comment) - - def add_attribute(self, definition, after=None): - """ - Add a new attribute to the table. A full line from the table definition - is passed in as definition. - - The definition can specify where to place the new attribute. Use after=None - to add the attribute as the first attribute or after='attribute' to place it - after an existing attribute. - - :param definition: table definition - :param after=None: After which attribute of the table the new attribute is inserted. - If None, the attribute is inserted in front. - """ - position = ' FIRST' if after is None else ( - ' AFTER %s' % after if after else '') - sql = self._field_to_SQL(_parse_attr_def(definition)) - self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) - - def drop_attribute(self, attr_name): - """ - Drops the attribute attrName from this table. - - :param attr_name: Name of the attribute that is dropped. - """ - self._alter('DROP COLUMN `%s`' % attr_name) - - def alter_attribute(self, attr_name, new_definition): - """ - Alter the definition of the field attr_name in this table using the new definition. - - :param attr_name: field that is redefined - :param new_definition: new definition of the field - """ - sql = self._field_to_SQL(_parse_attr_def(new_definition)) - self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - - def erd(self, subset=None, prog='dot'): - """ - Plot the schema's entity relationship diagram (ERD). - The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi' - """ - if not subset: - g = self.graph - else: - g = self.graph.copy() - # todo: make erd work (github issue #7) - """ - g = self.graph - else: - g = self.graph.copy() - for i in g.nodes(): - if i not in subset: - g.remove_node(i) - def tablelist(tier): - return [i for i in g if self.tables[i].tier==tier] + "Please define object 'conn' in '{}' or in its containing package.".format(self.__module__)) + try: + if use_package: + pkg_name = '.'.join(module.split('.')[:-1]) + dbname = self.conn.mod_to_db[pkg_name] + else: + dbname = self.conn.mod_to_db[module] + except KeyError: + raise DataJointError( + 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) + super().__init__(self, conn=conn, dbname=dbname, class_name=self.__class__.__name__) - pos=nx.graphviz_layout(g,prog=prog,args='') - plt.figure(figsize=(8,8)) - nx.draw_networkx_edges(g, pos, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'), - node_color='g', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'), - node_color='r', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'), - node_color='b', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'), - node_color='gray', node_size=120, alpha=0.3) - nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9) - nx.draw(g,pos,alpha=0,with_labels=false) - plt.show() - """ @classmethod def get_module(cls, module_name): @@ -262,267 +93,3 @@ def get_module(cls, module_name): return importlib.import_module(module_name) except ImportError: return None - - def get_base(self, module_name, class_name): - """ - Loads the base relation from the module. If the base relation is not defined in - the module, then construct it using Base constructor. - - :param module_name: module name - :param class_name: class name - :returns: the base relation - """ - mod_obj = self.get_module(module_name) - try: - ret = getattr(mod_obj, class_name)() - except KeyError: - ret = self.__class__(conn=self.conn, - dbname=self.conn.schemas[module_name], - class_name=class_name) - return ret - - - - def _alter(self, alter_statement): - """ - Execute ALTER TABLE statement for this table. The schema - will be reloaded within the connection object. - - :param alter_statement: alter statement - """ - sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - # TODO: place table definition sync mechanism - - def _declare(self): - """ - Declares the table in the data base if no table in the database matches this object. - """ - if not self.definition: - raise DataJointError('Table declaration is missing!') - table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() - defined_name = table_info['module'] + '.' + table_info['className'] - expected_name = self.__module__.split('.')[-1] + '.' + self.class_name - if not defined_name == expected_name: - raise DataJointError('Table name {} does not match the declared' - 'name {}'.format(expected_name, defined_name)) - - # compile the CREATE TABLE statement - # TODO: support prefix - table_name = role_to_prefix[ - table_info['tier']] + from_camel_case(self.class_name) - sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) - - # add inherited primary key fields - primary_key_fields = set() - non_key_fields = set() - for p in parents: - for key in p.primary_key: - field = p.heading[key] - if field.name not in primary_key_fields: - primary_key_fields.add(field.name) - sql += self._field_to_SQL(field) - else: - logger.debug('Field definition of {} in {} ignored'.format( - field.name, p.full_class_name)) - - # add newly defined primary key fields - for field in (f for f in fieldDefs if f.isKey): - if field.isNullable: - raise DataJointError('Primary key {} cannot be nullable'.format( - field.name)) - if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary key ' - '{key}. Check to make sure that the key ' - 'is not declared already in referenced ' - 'tables'.format(key=field.name)) - primary_key_fields.add(field.name) - sql += self._field_to_SQL(field) - - # add secondary foreign key attributes - for r in referenced: - keys = (x for x in r.heading.attrs.values() if x.isKey) - for field in keys: - if field.name not in primary_key_fields | non_key_fields: - non_key_fields.add(field.name) - sql += self._field_to_SQL(field) - - # add dependent attributes - for field in (f for f in fieldDefs if not f.isKey): - non_key_fields.add(field.name) - sql += self._field_to_SQL(field) - - # add primary key declaration - assert len(primary_key_fields) > 0, 'table must have a primary key' - keys = ', '.join(primary_key_fields) - sql += 'PRIMARY KEY (%s),\n' % keys - - # add foreign key declarations - for ref in parents + referenced: - keys = ', '.join(ref.primary_key) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ - (keys, ref.full_table_name, keys) - - # add secondary index declarations - # gather implicit indexes due to foreign keys first - implicit_indices = [] - for fk_source in parents + referenced: - implicit_indices.append(fk_source.primary_key) - - # for index in indexDefs: - # TODO: finish this up... - - # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( - sql[:-2], table_info['comment']) - - # make sure that the table does not alredy exist - self.conn.load_headings(self.dbname, force=True) - if not self.is_declared: - # execute declaration - logger.debug('\n\n' + sql + '\n\n') - self.conn.query(sql) - self.conn.load_headings(self.dbname, force=True) - - def _parse_declaration(self): - """ - Parse declaration and create new SQL table accordingly. - """ - parents = [] - referenced = [] - index_defs = [] - field_defs = [] - declaration = re.split(r'\s*\n\s*', self.definition.strip()) - - # remove comment lines - declaration = [x for x in declaration if not x.startswith('#')] - ptrn = """ - ^(?P\w+)\.(?P\w+)\s* # module.className - \(\s*(?P\w+)\s*\)\s* # (tier) - \#\s*(?P.*)$ # comment - """ - p = re.compile(ptrn, re.X) - table_info = p.match(declaration[0]).groupdict() - if table_info['tier'] not in Role.__members__: - raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ - {module}.{cls}'.format(tier=table_info['tier'], - module=table_info[ - 'module'], - cls=table_info['className'])) - table_info['tier'] = Role[table_info['tier']] # convert into enum - - in_key = True # parse primary keys - field_ptrn = """ - ^[a-z][a-z\d_]*\s* # name - (=\s*\S+(\s+\S+)*\s*)? # optional defaults - :\s*\w.*$ # type, comment - """ - fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose - - for line in declaration[1:]: - if line.startswith('---'): - in_key = False # start parsing non-PK fields - elif line.startswith('->'): - # foreign key - module_name, class_name = line[2:].strip().split('.') - rel = self.get_base(module_name, class_name) - (parents if in_key else referenced).append(rel) - elif re.match(r'^(unique\s+)?index[^:]*$', line): - index_defs.append(_parse_index_def(line)) - elif fieldP.match(line): - field_defs.append(_parse_attr_def(line, in_key)) - else: - raise DataJointError( - 'Invalid table declaration line "%s"' % line) - - return table_info, parents, referenced, field_defs, index_defs - - -def _field_to_SQL(field): - """ - Converts an attribute definition tuple into SQL code. - :param field: attribute definition - :rtype : SQL code - """ - if field.isNullable: - default = 'DEFAULT NULL' - else: - default = 'NOT NULL' - # if some default specified - if field.default: - # enclose value in quotes (even numeric), except special SQL values - # or values already enclosed by the user - if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: - default = '%s DEFAULT %s' % (default, field.default) - else: - default = '%s DEFAULT "%s"' % (default, field.default) - - # TODO: escape instead! - same goes for Matlab side implementation - assert not any((c in r'\"' for c in field.comment)), \ - 'Illegal characters in attribute comment "%s"' % field.comment - - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( - name=field.name, type=field.type, default=default, comment=field.comment) - - -def _parse_attr_def(line, in_key=False): # todo add docu for in_key - """ - Parse attribute definition line in the declaration and returns - an attribute tuple. - :param line: attribution line - :param in_key: - :returns: attribute tuple - """ - line = line.strip() - attr_ptrn = """ - ^(?P[a-z][a-z\d_]*)\s* # field name - (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value - :\s*(?P\w[^\#]*[^\#\s])\s* # datatype - (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment - """ - - attrP = re.compile(attr_ptrn, re.I + re.X) - m = attrP.match(line) - assert m, 'Invalid field declaration "%s"' % line - attr_info = m.groupdict() - if not attr_info['comment']: - attr_info['comment'] = '' - if not attr_info['default']: - attr_info['default'] = '' - attr_info['isNullable'] = attr_info['default'].lower() == 'null' - assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['isNullable']), \ - 'BIGINT attributes cannot be nullable in "%s"' % line - - attr_info['isKey'] = in_key - attr_info['isAutoincrement'] = None - attr_info['isNumeric'] = None - attr_info['isString'] = None - attr_info['isBlob'] = None - attr_info['computation'] = None - attr_info['dtype'] = None - - return Heading.AttrTuple(**attr_info) - - -def _parse_index_def(line): # why is this a method of Base instead of a local function? - """ - Parses index definition. - - :param line: definition line - :return: groupdict with index info - """ - line = line.strip() - index_ptrn = """ - ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX - \((?P[^\)]+)\)$ # (attr1, attr2) - """ - indexP = re.compile(index_ptrn, re.I + re.X) - m = indexP.match(line) - assert m, 'Invalid index declaration "%s"' % line - index_info = m.groupdict() - attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) - index_info['attributes'] = attributes - assert len(attributes) == len(set(attributes)), \ - 'Duplicate attributes in index declaration "%s"' % line - return index_info diff --git a/datajoint/connection.py b/datajoint/connection.py index d8adf2256..ca6551cea 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -3,7 +3,7 @@ from .utils import to_camel_case from . import DataJointError from .heading import Heading -from .base import prefix_to_role +from .settings import prefix_to_role import logging from .erd import DBConnGraph from . import config @@ -22,7 +22,7 @@ def conn_container(): """ _connObj = None # persistent connection object used by dj.conn() - def conn(host=None, user=None, passwd=None, initFun=None, reset=False): + def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False): """ Manage a persistent connection object. This is one of several ways to configure and access a datajoint connection. @@ -35,10 +35,10 @@ def conn(host=None, user=None, passwd=None, initFun=None, reset=False): host = host if host is not None else config['database.host'] user = user if user is not None else config['database.user'] passwd = passwd if passwd is not None else config['database.password'] - initFun = initFun if initFun is not None else config['connection.init_function'] - _connObj = Connection(host, user, passwd, initFun) + init_fun = init_fun if init_fun is not None else config['connection.init_function'] + _connObj = Connection(host, user, passwd, init_fun) return _connObj - return conn + return conn_function # The function conn is used by others to obtain the package wide persistent connection object conn = conn_container() @@ -50,14 +50,14 @@ class Connection: It also catalogues modules, schemas, tables, and their dependencies (foreign keys) """ - def __init__(self, host, user, passwd, initFun=None): + def __init__(self, host, user, passwd, init_fun=None): if ':' in host: host, port = host.split(':') port = int(port) else: port = config['database.port'] self.conn_info = dict(host=host, port=port, user=user, passwd=passwd) - self._conn = pymysql.connect(init_command=initFun, **self.conn_info) + self._conn = pymysql.connect(init_command=init_fun, **self.conn_info) # TODO Do something if connection cannot be established if self.is_connected: print("Connected", user + '@' + host + ':' + str(port)) @@ -221,7 +221,8 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv self.referenced[full_table_name] = [] for m in re.finditer(ptrn, table_def["Create Table"], re.X): # iterate through foreign key statements - assert m.group('attr1') == m.group('attr2'), 'Foreign keys must link identically named attributes' + assert m.group('attr1') == m.group('attr2'), \ + 'Foreign keys must link identically named attributes' attrs = m.group('attr1') attrs = re.split(r',\s+', re.sub(r'`(.*?)`', r'\1', attrs)) # remove ` around attrs and split into list pk = self.headings[dbname][tabName].primary_key @@ -264,14 +265,14 @@ def parents_of(self, child_table): def children_of(self, parent_table): """ - Returnis a list of tables for which parentTable is a parent (primary foreign key) + Returns a list of tables for which parent_table is a parent (primary foreign key) """ return [child_table for child_table, parents in self.parents.items() if parent_table in parents] def referenced_by(self, referencing_table): """ Returns a list of tables that are referenced by non-primary foreign key - by the referencingTable. + by the referencing_table. """ return self.referenced.get(referencing_table, []).copy() diff --git a/datajoint/declare.py b/datajoint/declare.py new file mode 100644 index 000000000..af09334ae --- /dev/null +++ b/datajoint/declare.py @@ -0,0 +1,239 @@ +import re +import logging +from .heading import Heading +from . import DataJointError +from .utils import from_camel_case +from .settings import Role, role_to_prefix + +mysql_constants = ['CURRENT_TIMESTAMP'] + +logger = logging.getLogger(__name__) + + +def declare(conn, definition, class_name): + """ + Declares the table in the data base if no table in the database matches this object. + """ + table_info, parents, referenced, field_definitions, index_definitions = parse_declaration(definition) + defined_name = table_info['module'] + '.' + table_info['className'] + if not defined_name == class_name: + raise DataJointError('Table name {} does not match the declared' + 'name {}'.format(class_name, defined_name)) + + # compile the CREATE TABLE statement + table_name = role_to_prefix[table_info['tier']] + from_camel_case(class_name) + sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) + + # add inherited primary key fields + primary_key_fields = set() + non_key_fields = set() + for p in parents: + for key in p.primary_key: + field = p.heading[key] + if field.name not in primary_key_fields: + primary_key_fields.add(field.name) + sql += field_to_sql(field) + else: + logger.debug('Field definition of {} in {} ignored'.format( + field.name, p.full_class_name)) + + # add newly defined primary key fields + for field in (f for f in field_definitions if f.isKey): + if field.nullable: + raise DataJointError('Primary key {} cannot be nullable'.format( + field.name)) + if field.name in primary_key_fields: + raise DataJointError('Duplicate declaration of the primary key ' + '{key}. Check to make sure that the key ' + 'is not declared already in referenced ' + 'tables'.format(key=field.name)) + primary_key_fields.add(field.name) + sql += field_to_sql(field) + + # add secondary foreign key attributes + for r in referenced: + keys = (x for x in r.heading.attrs.values() if x.isKey) + for field in keys: + if field.name not in primary_key_fields | non_key_fields: + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add dependent attributes + for field in (f for f in field_definitions if not f.isKey): + non_key_fields.add(field.name) + sql += field_to_sql(field) + + # add primary key declaration + assert len(primary_key_fields) > 0, 'table must have a primary key' + keys = ', '.join(primary_key_fields) + sql += 'PRIMARY KEY (%s),\n' % keys + + # add foreign key declarations + for ref in parents + referenced: + keys = ', '.join(ref.primary_key) + sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ + (keys, ref.full_table_name, keys) + + # add secondary index declarations + # gather implicit indexes due to foreign keys first + implicit_indices = [] + for fk_source in parents + referenced: + implicit_indices.append(fk_source.primary_key) + + # for index in index_definitions: + # TODO: finish this up... + + # close the declaration + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) + + # make sure that the table does not alredy exist + self.conn.load_headings(self.dbname, force=True) + if not self.is_declared: + # execute declaration + logger.debug('\n\n' + sql + '\n\n') + self.conn.query(sql) + self.conn.load_headings(self.dbname, force=True) + + +def _parse_declaration(self): + """ + Parse declaration and create new SQL table accordingly. + """ + parents = [] + referenced = [] + index_defs = [] + field_defs = [] + declaration = re.split(r'\s*\n\s*', self.definition.strip()) + + # remove comment lines + declaration = [x for x in declaration if not x.startswith('#')] + ptrn = """ + ^(?P\w+)\.(?P\w+)\s* # module.className + \(\s*(?P\w+)\s*\)\s* # (tier) + \#\s*(?P.*)$ # comment + """ + p = re.compile(ptrn, re.X) + table_info = p.match(declaration[0]).groupdict() + if table_info['tier'] not in Role.__members__: + raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ + {module}.{cls}'.format(tier=table_info['tier'], + module=table_info[ + 'module'], + cls=table_info['className'])) + table_info['tier'] = Role[table_info['tier']] # convert into enum + + in_key = True # parse primary keys + field_ptrn = """ + ^[a-z][a-z\d_]*\s* # name + (=\s*\S+(\s+\S+)*\s*)? # optional defaults + :\s*\w.*$ # type, comment + """ + fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose + + for line in declaration[1:]: + if line.startswith('---'): + in_key = False # start parsing non-PK fields + elif line.startswith('->'): + # foreign key + module_name, class_name = line[2:].strip().split('.') + rel = self.get_base(module_name, class_name) + (parents if in_key else referenced).append(rel) + elif re.match(r'^(unique\s+)?index[^:]*$', line): + index_defs.append(parse_index_defnition(line)) + elif fieldP.match(line): + field_defs.append(parse_attribute_definition(line, in_key)) + else: + raise DataJointError( + 'Invalid table declaration line "%s"' % line) + + return table_info, parents, referenced, field_defs, index_defs + + +def field_to_sql(field): + """ + Converts an attribute definition tuple into SQL code. + :param field: attribute definition + :rtype : SQL code + """ + if field.nullable: + default = 'DEFAULT NULL' + else: + default = 'NOT NULL' + # if some default specified + if field.default: + # enclose value in quotes (even numeric), except special SQL values + # or values already enclosed by the user + if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']: + default = '%s DEFAULT %s' % (default, field.default) + else: + default = '%s DEFAULT "%s"' % (default, field.default) + + # TODO: escape instead! - same goes for Matlab side implementation + assert not any((c in r'\"' for c in field.comment)), \ + 'Illegal characters in attribute comment "%s"' % field.comment + + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( + name=field.name, type=field.type, default=default, comment=field.comment) + + +def parse_attribute_definition(line, in_key=False): # todo add docu for in_key + """ + Parse attribute definition line in the declaration and returns + an attribute tuple. + :param line: attribution line + :param in_key: + :returns: attribute tuple + """ + line = line.strip() + attr_ptrn = """ + ^(?P[a-z][a-z\d_]*)\s* # field name + (=\s*(?P\S+(\s+\S+)*?)\s*)? # default value + :\s*(?P\w[^\#]*[^\#\s])\s* # datatype + (\#\s*(?P\S*(\s+\S+)*)\s*)?$ # comment + """ + + attrP = re.compile(attr_ptrn, re.I + re.X) + m = attrP.match(line) + assert m, 'Invalid field declaration "%s"' % line + attr_info = m.groupdict() + if not attr_info['comment']: + attr_info['comment'] = '' + if not attr_info['default']: + attr_info['default'] = '' + attr_info['nullable'] = attr_info['default'].lower() == 'null' + assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \ + 'BIGINT attributes cannot be nullable in "%s"' % line + + attr_info['isKey'] = in_key + attr_info['isAutoincrement'] = None + attr_info['isNumeric'] = None + attr_info['isString'] = None + attr_info['isBlob'] = None + attr_info['computation'] = None + attr_info['dtype'] = None + + return Heading.AttrTuple(**attr_info) + + +def parse_index_definition(line): # why is this a method of Base instead of a local function? + """ + Parses index definition. + + :param line: definition line + :return: groupdict with index info + """ + line = line.strip() + index_ptrn = """ + ^(?PUNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX + \((?P[^\)]+)\)$ # (attr1, attr2) + """ + indexP = re.compile(index_ptrn, re.I + re.X) + m = indexP.match(line) + assert m, 'Invalid index declaration "%s"' % line + index_info = m.groupdict() + attributes = re.split(r'\s*,\s*', index_info['attributes'].strip()) + index_info['attributes'] = attributes + assert len(attributes) == len(set(attributes)), \ + 'Duplicate attributes in index declaration "%s"' % line + return index_info diff --git a/datajoint/erd.py b/datajoint/erd.py index a7e1bc024..af6107390 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) + class RelGraph(DiGraph): """ Represents relations between tables and databases diff --git a/datajoint/heading.py b/datajoint/heading.py index 0176a2cbe..18e75df62 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -""" -Created on Mon Aug 4 01:29:51 2014 - -@author: dimitri, eywalker -""" import re from collections import OrderedDict, namedtuple diff --git a/datajoint/relational.py b/datajoint/relational.py index 57dc843b5..f72bd1369 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- """ -Created on Thu Aug 7 17:00:02 2014 - -@author: dimitri, eywalker +classes for relational algebra """ + import numpy as np import abc from copy import copy @@ -43,17 +41,46 @@ def __mul__(self, other): """ return Join(self, other) - def pro(self, select=None, rename=None, expand=None, aggregate=None): + def __mod__(self, attribute_list): + """ + relational projection operator. + :param attribute_list: list of attribute specifications. + The attribute specifications are strings in following forms: + 'name' - specific attribute + 'name->new_name' - rename attribute. The old attribute is kept only if specifically included. + 'sql_expression->new_name' - extend attribute, i.e. a new computed attribute. + :return: a new relation with specified heading + """ + self.project(attribute_list) + + def project(self, *selection, **aliases): """ - relational operators project, rename, expand, and aggregate. Primary key attributes are always included unless - renamed. - :param select: list of attributes to project; '*' stands for all attributes. - :param rename: dictionary of renamed attributes - :param expand: dictionary of computed attributes, including summary operators on the aggregated relation - :param aggregate: a relation for which summary computations can be performed in expand - :return: projected Relation object + Relational projection operator. + :param attributes: a list of attribute names to be included in the result. + :param renames: a dict of attributes to be renamed + :return: a new relation with selected fields + Primary key attributes are always selected and cannot be excluded. + Therefore obj.project() produces a relation with only the primary key attributes. + If selection includes the string '*', all attributes are selected. + Each attribute can only be used once in attributes or renames. Therefore, the projected + relation cannot have more attributes than the original relation. """ - return Projection(self, select, rename, expand, aggregate) + return self.aggregate( + group=selection.pop[0] if selection and isinstance(selection[0], Relation) else None, + *selection, **aliases) + + def aggregate(self, group, *selection, **aliases): + """ + Relational aggregation operator + :param grouped_relation: + :param extensions: + :return: + """ + if group is not None and not isinstance(group, Relation): + raise DataJointError('The second argument of aggregate must be a relation') + # convert the string notation for aliases to + + return Projection(group=group, *attriutes, **aliases) def __iand__(self, restriction): """ @@ -228,19 +255,15 @@ def sql(self): class Projection(Relation): alias_counter = 0 - def __init__(self, relation, select, rename, expand, aggregate): + def __init__(self, relation, *attributes, **renames): """ - See Relation.pro() + See Relation.project() """ - if aggregate is not None and not isinstance(aggregate, Relation): - raise DataJointError('Relation join must receive two relations') self.conn = relation.conn self._relation = relation - self._select = select - self._rename = rename - self._expand = expand - self._aggregate = aggregate - + self._projection_attributes = attributes + self._renamed_attributes = renames + @property def sql(self): return self._rel.sql diff --git a/datajoint/settings.py b/datajoint/settings.py index e1c5c29e3..22afcadb6 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -8,10 +8,22 @@ __author__ = 'eywalker' import logging import collections +from enum import Enum validators = collections.defaultdict(lambda: lambda value: True) +Role = Enum('Role', 'manual lookup imported computed job') +role_to_prefix = { + Role.manual: '', + Role.lookup: '#', + Role.imported: '_', + Role.computed: '__', + Role.job: '~' +} +prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys())) + + default = { 'database.host': 'localhost', 'database.password': 'datajoint', diff --git a/datajoint/table.py b/datajoint/table.py index 5bf7a39d8..d1c6c3bb9 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -1,7 +1,8 @@ import numpy as np +import logging from . import DataJointError from .relational import Relation -import logging +from .declare import declare logger = logging.getLogger(__name__) @@ -11,26 +12,43 @@ class Table(Relation): A Table object is a relation associated with a table. A Table object provides insert and delete methods. Table objects are only used internally and for debugging. - The table must already exist in the schema for the table object to work. - The table is identified by its "class name", or its CamelCase version. + The table must already exist in the schema for its Table object to work. + + The table associated with an instance of Base is identified by its 'class name'. + property, which is a string in CamelCase. The actual table name is obtained + by converting className from CamelCase to underscore_separated_words and + prefixing according to the table's role. + + Base instances obtain their table's heading by looking it up in the connection + object. This ensures that Base instances contain the current table definition + even after tables are modified after the instance is created. """ - def __init__(self, conn=None, dbname=None, class_name=None): + def __init__(self, conn=None, dbname=None, class_name=None, definition=None): self._use_package = False self.class_name = class_name self.conn = conn self.dbname = dbname + self.conn.load_headings(self.dbname) + if dbname not in self.conn.db_to_mod: # register with a fake module, enclosed in back quotes self.conn.bind('`{0}`'.format(dbname), dbname) + #TODO: delay the loading until first use (move out of __init__) + self.conn.load_headings() + if self.class_name not in self.conn.table_names[self.dbname]: + if definition is None: + raise DataJointError('The table is not declared') + else: + declare(conn, definition, class_name) + @property def sql(self): - return self.full_table_name + self._whereClause + return self.full_table_name @property def heading(self): - self.declare() return self.conn.headings[self.dbname][self.table_name] @property @@ -47,7 +65,22 @@ def table_name(self): """ return self.conn.table_names[self.dbname][self.class_name] - def insert(self, tup, ignore_errors=False, replace=False): + + @property + def full_class_name(self): + """ + :return: full class name + """ + return '{}.{}'.format(self.__module__, self.class_name) + + @property + def primary_key(self): + """ + :return: primary key of the table + """ + return self.heading.primary_key + + def insert(self, tup, ignore_errors=False, replace=False): # TODO: in progress (issue #8) """ Insert one data tuple. @@ -87,5 +120,109 @@ def insert(self, tup, ignore_errors=False, replace=False): logger.info(sql) self.conn.query(sql) - def delete(self): # TODO + def delete(self): # TODO: (issues #14 and #15) pass + + def drop(self): + """ + Drops the table associated to this object. + """ + # TODO: make cascading (issue #16) + self.conn.query('DROP TABLE %s' % self.full_table_name) + self.conn.clear_dependencies(dbname=self.dbname) + self.conn.load_headings(dbname=self.dbname, force=True) + logger.debug("Dropped table %s" % self.full_table_name) + + def set_table_comment(self, comment): + """ + Update the table comment in the table definition. + + :param comment: new comment as string + + """ + # TODO: add verification procedure (github issue #24) + self.alter('COMMENT="%s"' % comment) + + def add_attribute(self, definition, after=None): + """ + Add a new attribute to the table. A full line from the table definition + is passed in as definition. + + The definition can specify where to place the new attribute. Use after=None + to add the attribute as the first attribute or after='attribute' to place it + after an existing attribute. + + :param definition: table definition + :param after=None: After which attribute of the table the new attribute is inserted. + If None, the attribute is inserted in front. + """ + position = ' FIRST' if after is None else ( + ' AFTER %s' % after if after else '') + sql = self.field_to_sql(parse_attribute_definition(definition)) + self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) + + def drop_attribute(self, attr_name): + """ + Drops the attribute attrName from this table. + + :param attr_name: Name of the attribute that is dropped. + """ + self._alter('DROP COLUMN `%s`' % attr_name) + + def alter_attribute(self, attr_name, new_definition): + """ + Alter the definition of the field attr_name in this table using the new definition. + + :param attr_name: field that is redefined + :param new_definition: new definition of the field + """ + sql = self.field_to_sql(parse_attribute_definition(new_definition)) + self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) + + def erd(self, subset=None, prog='dot'): + """ + Plot the schema's entity relationship diagram (ERD). + The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi' + """ + if not subset: + g = self.graph + else: + g = self.graph.copy() + # todo: make erd work (github issue #7) + """ + g = self.graph + else: + g = self.graph.copy() + for i in g.nodes(): + if i not in subset: + g.remove_node(i) + def tablelist(tier): + return [i for i in g if self.tables[i].tier==tier] + + pos=nx.graphviz_layout(g,prog=prog,args='') + plt.figure(figsize=(8,8)) + nx.draw_networkx_edges(g, pos, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'), + node_color='g', node_size=200, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'), + node_color='r', node_size=200, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'), + node_color='b', node_size=200, alpha=0.3) + nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'), + node_color='gray', node_size=120, alpha=0.3) + nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9) + nx.draw(g,pos,alpha=0,with_labels=false) + plt.show() + """ + + def _alter(self, alter_statement): + """ + Execute ALTER TABLE statement for this table. The schema + will be reloaded within the connection object. + + :param alter_statement: alter statement + """ + sql = 'ALTER TABLE %s %s' % (self.full_table_name, alter_statement) + self.conn.query(sql) + self.conn.load_headings(self.dbname, force=True) + # TODO: place table definition sync mechanism \ No newline at end of file diff --git a/datajoint/utils.py b/datajoint/utils.py index 9d1bf85de..47aacdeeb 100644 --- a/datajoint/utils.py +++ b/datajoint/utils.py @@ -1,13 +1,7 @@ import re -# package-wide settings that control execution - -# setup root logger from . import DataJointError - - - def to_camel_case(s): """ Convert names with under score (_) separation @@ -24,7 +18,7 @@ def to_upper(matchobj): def from_camel_case(s): """ - Conver names in camel case into underscore + Convert names in camel case into underscore (_) separated names Example: @@ -37,7 +31,8 @@ def from_camel_case(s): raise DataJointError('String cannot begin with a digit') if not re.match(r'^[a-zA-Z0-9]*$', s): raise DataJointError('String can only contain alphanumeric characters') + def conv(matchobj): return ('_' if matchobj.groups()[0] else '') + matchobj.group(0).lower() - return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s) + return re.sub(r'(\B[A-Z])|(\b[A-Z])', conv, s) \ No newline at end of file diff --git a/demos/demo1.py b/demos/demo1.py index ba926f8c6..689905730 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -63,11 +63,9 @@ class Scan(dj.Base): mwatts: numeric(4,1) # (mW) laser power to brain """ -class ScanInfo(dj.Base, dj.AutoPopulate): - definition = """ - - """ +class ScanInfo(dj.Base, dj.AutoPopulate): + definition = None pop_rel = Session def make_tuples(self, key):