diff --git a/datajoint/base.py b/datajoint/base.py index 79515a0af..75dc7db2a 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -12,7 +12,8 @@ # table names have prefixes that designate their roles in the processing chain logger = logging.getLogger(__name__) -Role = Enum('Role', 'manual lookup imported computed job') # Todo: Shouldn't this go into the settings module? +# Todo: Shouldn't this go into the settings module? +Role = Enum('Role', 'manual lookup imported computed job') role_to_prefix = { Role.manual: '', Role.lookup: '#', @@ -26,6 +27,7 @@ class Base(_Relational): + """ Base integrates all data manipulation and data declaration functions. An instance of the class provides an interface to a single table in the database. @@ -45,14 +47,14 @@ class Base(_Relational): property, which is a string in CamelCase. The actual table name is obtained by converting className from CamelCase to underscore_separated_words and prefixing according to the table's role. - + The table declaration can be specified in the doc string of the inheriting class, in the DataJoint table declaration syntax. - + Base also implements the methods insert and delete to insert and delete tuples from the table. It can also be an argument in relational operators: restrict, join, pro, and aggr. See class :mod:`datajoint.relational`. - + Base instances return their table's heading by looking it up in the connection object. This ensures that Base instances contain the current table definition even after tables are modified after the instance is created. @@ -61,7 +63,7 @@ class Base(_Relational): instantiated directly. :param dbname=None: Name of the database. Only used when Base is instantiated directly. :param class_name=None: Class name. Only used when Base is instantiated directly. - :param declaration=None: + :param table_def=None: Declaration of the table. Only used when Base is instantiated directly. Example for a usage of Base:: @@ -80,21 +82,23 @@ class Subjects(dj.Base): """ - def __init__(self, conn=None, dbname=None, class_name=None, declaration=None): + def __init__(self, conn=None, dbname=None, class_name=None, table_def=None): self._use_package = False if self.__class__ is Base: # instantiate without subclassing if not (conn and dbname and class_name): - raise DataJointError('Missing argument: please specify conn, dbname, and class name.') + raise DataJointError( + 'Missing argument: please specify conn, dbname, and class name.') self.class_name = class_name self.conn = conn self.dbname = dbname - self.declaration = declaration # todo: why is this set as declaration and not as _table_def? - if dbname not in self.conn.modules: # register with a fake module, enclosed in back quotes + self._table_def = table_def + # register with a fake module, enclosed in back quotes + if dbname not in self.conn.modules: self.conn.bind('`{0}`'.format(dbname), dbname) else: # instantiate a derived class - if conn or dbname or class_name or declaration: + if conn or dbname or class_name or table_def: raise DataJointError( 'With derived classes, constructor arguments are ignored') # TODO: consider changing this to a warning instead self.class_name = self.__class__.__name__ @@ -119,15 +123,14 @@ def __init__(self, conn=None, dbname=None, class_name=None, declaration=None): except KeyError: raise DataJointError( 'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__)) - # take table declaration from the deriving class' _table_def string - # todo: declaration and _table_def seem to be redundant! + if hasattr(self, '_table_def'): - self.declaration = self._table_def + self._table_def = self._table_def else: - self.declaration = None - + self._table_def = None - def insert(self, tup, ignore_errors=False, replace=False): # todo: do we support records and named tuples for tup? + # todo: do we support records and named tuples for tup? + def insert(self, tup, ignore_errors=False, replace=False): """ Insert one data tuple. @@ -146,10 +149,13 @@ def insert(self, tup, ignore_errors=False, replace=False): # todo: do we suppor valueList = ','.join([repr(q) for q in tup]) fieldList = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`' elif issubclass(type(tup), dict): - valueList = ','.join([repr(tup[q]) for q in self.heading.names if q in tup]) - fieldList = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`' + valueList = ','.join([repr(tup[q]) + for q in self.heading.names if q in tup]) + fieldList = '`' + \ + '`,`'.join([q for q in self.heading.names if q in tup]) + '`' elif issubclass(type(tup), np.void): - valueList = ','.join([repr(tup[q]) for q in self.heading.names if q in tup]) + valueList = ','.join([repr(tup[q]) + for q in self.heading.names if q in tup]) fieldList = '`' + '`,`'.join(tup.dtype.fields) + '`' else: raise DataJointError('Datatype %s cannot be inserted' % type(tup)) @@ -159,7 +165,8 @@ def insert(self, tup, ignore_errors=False, replace=False): # todo: do we suppor sql = 'INSERT IGNORE' else: sql = 'INSERT' - sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, fieldList, valueList) + sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, + fieldList, valueList) logger.info(sql) self.conn.query(sql) @@ -173,7 +180,6 @@ def drop(self): self.conn.load_headings(dbname=self.dbname, force=True) logger.debug("Dropped table %s" % self.full_table_name) - @property def sql(self): return self.full_table_name + self._whereClause @@ -199,7 +205,6 @@ def table_name(self): self.declare() return self.conn.table_names[self.dbname][self.class_name] - @property def full_table_name(self): """ @@ -230,7 +235,8 @@ def declare(self): if not self.is_declared: self._declare() if not self.is_declared: - raise DataJointError('Table could not be declared for %s' % self.class_name) + raise DataJointError( + 'Table could not be declared for %s' % self.class_name) """ Data definition functionalities @@ -259,7 +265,8 @@ def add_attribute(self, definition, after=None): :param after=None: After which attribute of the table the new attribute is inserted. If None, the attribute is inserted in front. """ - position = ' FIRST' if after is None else (' AFTER %s' % after if after else '') + position = ' FIRST' if after is None else ( + ' AFTER %s' % after if after else '') sql = self._field_to_SQL(self._parse_attr_def(definition)) self._alter('ADD COLUMN %s%s' % (sql[:-2], position)) @@ -365,7 +372,6 @@ def get_base(self, module_name, class_name): class_name=class_name) return ret - # //////////////////////////////////////////////////////////// # Private Methods # //////////////////////////////////////////////////////////// @@ -394,7 +400,7 @@ def _field_to_SQL(self, field): assert not any((c in r'\"' for c in field.comment)), \ 'Illegal characters in attribute comment "%s"' % field.comment - return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( \ + return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( name=field.name, type=field.type, default=default, comment=field.comment) def _alter(self, alter_statement): @@ -413,7 +419,7 @@ def _declare(self): """ Declares the table in the data base if no table in the database matches this object. """ - if not self.declaration: + if not self._table_def: raise DataJointError('Table declaration is missing!') table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration() defined_name = table_info['module'] + '.' + table_info['className'] @@ -424,7 +430,8 @@ def _declare(self): # compile the CREATE TABLE statement # TODO: support prefix - table_name = role_to_prefix[table_info['tier']] + from_camel_case(self.class_name) + table_name = role_to_prefix[ + table_info['tier']] + from_camel_case(self.class_name) sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name) # add inherited primary key fields @@ -446,9 +453,9 @@ def _declare(self): raise DataJointError('Primary key {} cannot be nullable'.format( field.name)) if field.name in primary_key_fields: - raise DataJointError('Duplicate declaration of the primary key ' \ - '{key}. Check to make sure that the key ' \ - 'is not declared already in referenced ' \ + raise DataJointError('Duplicate declaration of the primary key ' + '{key}. Check to make sure that the key ' + 'is not declared already in referenced ' 'tables'.format(key=field.name)) primary_key_fields.add(field.name) sql += self._field_to_SQL(field) @@ -477,18 +484,18 @@ def _declare(self): sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \ (keys, ref.full_table_name, keys) - # add secondary index declarations # gather implicit indexes due to foreign keys first implicit_indices = [] for fk_source in parents + referenced: implicit_indices.append(fk_source.primary_key) - #for index in indexDefs: - #TODO: finish this up... + # for index in indexDefs: + # TODO: finish this up... # close the declaration - sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (sql[:-2], table_info['comment']) + sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % ( + sql[:-2], table_info['comment']) # make sure that the table does not alredy exist self.conn.load_headings(self.dbname, force=True) @@ -506,7 +513,7 @@ def _parse_declaration(self): referenced = [] index_defs = [] field_defs = [] - declaration = re.split(r'\s*\n\s*', self.declaration.strip()) + declaration = re.split(r'\s*\n\s*', self._table_def.strip()) # remove comment lines declaration = [x for x in declaration if not x.startswith('#')] @@ -520,7 +527,8 @@ def _parse_declaration(self): if table_info['tier'] not in Role.__members__: raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\ {module}.{cls}'.format(tier=table_info['tier'], - module=table_info['module'], + module=table_info[ + 'module'], cls=table_info['className'])) table_info['tier'] = Role[table_info['tier']] # convert into enum @@ -545,7 +553,8 @@ def _parse_declaration(self): elif fieldP.match(line): field_defs.append(self._parse_attr_def(line, in_key)) else: - raise DataJointError('Invalid table declaration line "%s"' % line) + raise DataJointError( + 'Invalid table declaration line "%s"' % line) return table_info, parents, referenced, field_defs, index_defs diff --git a/tests/test_base.py b/tests/test_base.py index 7939ddc06..4c9f2c98f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -76,7 +76,7 @@ def test_instantiation_of_base_derivatives(self): s = test1.Subjects() assert_equal(s.dbname, PREFIX + '_test1') assert_equal(s.conn, self.conn) - assert_equal(s.declaration, test1.Subjects._table_def) + assert_equal(s._table_def, test1.Subjects._table_def) def test_declaration_status(self): b = test1.Subjects()