Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 47 additions & 38 deletions datajoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# table names have prefixes that designate their roles in the processing chain
logger = logging.getLogger(__name__)

Role = Enum('Role', 'manual lookup imported computed job') # Todo: Shouldn't this go into the settings module?
# Todo: Shouldn't this go into the settings module?
Role = Enum('Role', 'manual lookup imported computed job')
role_to_prefix = {
Role.manual: '',
Role.lookup: '#',
Expand All @@ -26,6 +27,7 @@


class Base(_Relational):

"""
Base integrates all data manipulation and data declaration functions.
An instance of the class provides an interface to a single table in the database.
Expand All @@ -45,14 +47,14 @@ class Base(_Relational):
property, which is a string in CamelCase. The actual table name is obtained
by converting className from CamelCase to underscore_separated_words and
prefixing according to the table's role.

The table declaration can be specified in the doc string of the inheriting
class, in the DataJoint table declaration syntax.

Base also implements the methods insert and delete to insert and delete tuples
from the table. It can also be an argument in relational operators: restrict,
join, pro, and aggr. See class :mod:`datajoint.relational`.

Base instances return their table's heading by looking it up in the connection
object. This ensures that Base instances contain the current table definition
even after tables are modified after the instance is created.
Expand All @@ -61,7 +63,7 @@ class Base(_Relational):
instantiated directly.
:param dbname=None: Name of the database. Only used when Base is instantiated directly.
:param class_name=None: Class name. Only used when Base is instantiated directly.
:param declaration=None:
:param table_def=None: Declaration of the table. Only used when Base is instantiated directly.

Example for a usage of Base::

Expand All @@ -80,21 +82,23 @@ class Subjects(dj.Base):

"""

def __init__(self, conn=None, dbname=None, class_name=None, declaration=None):
def __init__(self, conn=None, dbname=None, class_name=None, table_def=None):
self._use_package = False
if self.__class__ is Base:
# instantiate without subclassing
if not (conn and dbname and class_name):
raise DataJointError('Missing argument: please specify conn, dbname, and class name.')
raise DataJointError(
'Missing argument: please specify conn, dbname, and class name.')
self.class_name = class_name
self.conn = conn
self.dbname = dbname
self.declaration = declaration # todo: why is this set as declaration and not as _table_def?
if dbname not in self.conn.modules: # register with a fake module, enclosed in back quotes
self._table_def = table_def
# register with a fake module, enclosed in back quotes
if dbname not in self.conn.modules:
self.conn.bind('`{0}`'.format(dbname), dbname)
else:
# instantiate a derived class
if conn or dbname or class_name or declaration:
if conn or dbname or class_name or table_def:
raise DataJointError(
'With derived classes, constructor arguments are ignored') # TODO: consider changing this to a warning instead
self.class_name = self.__class__.__name__
Expand All @@ -119,15 +123,14 @@ def __init__(self, conn=None, dbname=None, class_name=None, declaration=None):
except KeyError:
raise DataJointError(
'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
# take table declaration from the deriving class' _table_def string
# todo: declaration and _table_def seem to be redundant!

if hasattr(self, '_table_def'):
self.declaration = self._table_def
self._table_def = self._table_def
else:
self.declaration = None

self._table_def = None

def insert(self, tup, ignore_errors=False, replace=False): # todo: do we support records and named tuples for tup?
# todo: do we support records and named tuples for tup?
def insert(self, tup, ignore_errors=False, replace=False):
"""
Insert one data tuple.

Expand All @@ -146,10 +149,13 @@ def insert(self, tup, ignore_errors=False, replace=False): # todo: do we suppor
valueList = ','.join([repr(q) for q in tup])
fieldList = '`' + '`,`'.join(self.heading.names[0:len(tup)]) + '`'
elif issubclass(type(tup), dict):
valueList = ','.join([repr(tup[q]) for q in self.heading.names if q in tup])
fieldList = '`' + '`,`'.join([q for q in self.heading.names if q in tup]) + '`'
valueList = ','.join([repr(tup[q])
for q in self.heading.names if q in tup])
fieldList = '`' + \
'`,`'.join([q for q in self.heading.names if q in tup]) + '`'
elif issubclass(type(tup), np.void):
valueList = ','.join([repr(tup[q]) for q in self.heading.names if q in tup])
valueList = ','.join([repr(tup[q])
for q in self.heading.names if q in tup])
fieldList = '`' + '`,`'.join(tup.dtype.fields) + '`'
else:
raise DataJointError('Datatype %s cannot be inserted' % type(tup))
Expand All @@ -159,7 +165,8 @@ def insert(self, tup, ignore_errors=False, replace=False): # todo: do we suppor
sql = 'INSERT IGNORE'
else:
sql = 'INSERT'
sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name, fieldList, valueList)
sql += " INTO %s (%s) VALUES (%s)" % (self.full_table_name,
fieldList, valueList)
logger.info(sql)
self.conn.query(sql)

Expand All @@ -173,7 +180,6 @@ def drop(self):
self.conn.load_headings(dbname=self.dbname, force=True)
logger.debug("Dropped table %s" % self.full_table_name)


@property
def sql(self):
return self.full_table_name + self._whereClause
Expand All @@ -199,7 +205,6 @@ def table_name(self):
self.declare()
return self.conn.table_names[self.dbname][self.class_name]


@property
def full_table_name(self):
"""
Expand Down Expand Up @@ -230,7 +235,8 @@ def declare(self):
if not self.is_declared:
self._declare()
if not self.is_declared:
raise DataJointError('Table could not be declared for %s' % self.class_name)
raise DataJointError(
'Table could not be declared for %s' % self.class_name)

"""
Data definition functionalities
Expand Down Expand Up @@ -259,7 +265,8 @@ def add_attribute(self, definition, after=None):
:param after=None: After which attribute of the table the new attribute is inserted.
If None, the attribute is inserted in front.
"""
position = ' FIRST' if after is None else (' AFTER %s' % after if after else '')
position = ' FIRST' if after is None else (
' AFTER %s' % after if after else '')
sql = self._field_to_SQL(self._parse_attr_def(definition))
self._alter('ADD COLUMN %s%s' % (sql[:-2], position))

Expand Down Expand Up @@ -365,7 +372,6 @@ def get_base(self, module_name, class_name):
class_name=class_name)
return ret


# ////////////////////////////////////////////////////////////
# Private Methods
# ////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -394,7 +400,7 @@ def _field_to_SQL(self, field):
assert not any((c in r'\"' for c in field.comment)), \
'Illegal characters in attribute comment "%s"' % field.comment

return '`{name}` {type} {default} COMMENT "{comment}",\n'.format( \
return '`{name}` {type} {default} COMMENT "{comment}",\n'.format(
name=field.name, type=field.type, default=default, comment=field.comment)

def _alter(self, alter_statement):
Expand All @@ -413,7 +419,7 @@ def _declare(self):
"""
Declares the table in the data base if no table in the database matches this object.
"""
if not self.declaration:
if not self._table_def:
raise DataJointError('Table declaration is missing!')
table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration()
defined_name = table_info['module'] + '.' + table_info['className']
Expand All @@ -424,7 +430,8 @@ def _declare(self):

# compile the CREATE TABLE statement
# TODO: support prefix
table_name = role_to_prefix[table_info['tier']] + from_camel_case(self.class_name)
table_name = role_to_prefix[
table_info['tier']] + from_camel_case(self.class_name)
sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)

# add inherited primary key fields
Expand All @@ -446,9 +453,9 @@ def _declare(self):
raise DataJointError('Primary key {} cannot be nullable'.format(
field.name))
if field.name in primary_key_fields:
raise DataJointError('Duplicate declaration of the primary key ' \
'{key}. Check to make sure that the key ' \
'is not declared already in referenced ' \
raise DataJointError('Duplicate declaration of the primary key '
'{key}. Check to make sure that the key '
'is not declared already in referenced '
'tables'.format(key=field.name))
primary_key_fields.add(field.name)
sql += self._field_to_SQL(field)
Expand Down Expand Up @@ -477,18 +484,18 @@ def _declare(self):
sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \
(keys, ref.full_table_name, keys)


# add secondary index declarations
# gather implicit indexes due to foreign keys first
implicit_indices = []
for fk_source in parents + referenced:
implicit_indices.append(fk_source.primary_key)

#for index in indexDefs:
#TODO: finish this up...
# for index in indexDefs:
# TODO: finish this up...

# close the declaration
sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (sql[:-2], table_info['comment'])
sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (
sql[:-2], table_info['comment'])

# make sure that the table does not alredy exist
self.conn.load_headings(self.dbname, force=True)
Expand All @@ -506,7 +513,7 @@ def _parse_declaration(self):
referenced = []
index_defs = []
field_defs = []
declaration = re.split(r'\s*\n\s*', self.declaration.strip())
declaration = re.split(r'\s*\n\s*', self._table_def.strip())

# remove comment lines
declaration = [x for x in declaration if not x.startswith('#')]
Expand All @@ -520,7 +527,8 @@ def _parse_declaration(self):
if table_info['tier'] not in Role.__members__:
raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\
{module}.{cls}'.format(tier=table_info['tier'],
module=table_info['module'],
module=table_info[
'module'],
cls=table_info['className']))
table_info['tier'] = Role[table_info['tier']] # convert into enum

Expand All @@ -545,7 +553,8 @@ def _parse_declaration(self):
elif fieldP.match(line):
field_defs.append(self._parse_attr_def(line, in_key))
else:
raise DataJointError('Invalid table declaration line "%s"' % line)
raise DataJointError(
'Invalid table declaration line "%s"' % line)

return table_info, parents, referenced, field_defs, index_defs

Expand Down
2 changes: 1 addition & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_instantiation_of_base_derivatives(self):
s = test1.Subjects()
assert_equal(s.dbname, PREFIX + '_test1')
assert_equal(s.conn, self.conn)
assert_equal(s.declaration, test1.Subjects._table_def)
assert_equal(s._table_def, test1.Subjects._table_def)

def test_declaration_status(self):
b = test1.Subjects()
Expand Down