Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 31 additions & 255 deletions datajoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Subjects(dj.Base):
'''

"""

@abc.abstractproperty
def definition(self):
"""
Expand All @@ -42,7 +41,27 @@ def definition(self):
"""
pass

def __init__(self):
@property
def full_class_name(self):
"""
:return: full class name including the entire package hierarchy
"""
return '{}.{}'.format(self.__module__, self.class_name)

@property
def access_name(self):
"""
:return: name by which this class should be accessible as
"""
if self._use_package:
parent = self.__module__.split('.')[-2]
else:
parent = self.__module__.split('.')[-1]
return parent + '.' + self.class_name



def __init__(self): #TODO: support taking in conn obj
self.class_name = self.__class__.__name__
module = self.__module__
mod_obj = importlib.import_module(module)
Expand All @@ -61,6 +80,7 @@ def __init__(self):
self.conn = conn
try:
if self._use_package:
# the database is bound to the package
pkg_name = '.'.join(module.split('.')[:-1])
dbname = self.conn.mod_to_db[pkg_name]
else:
Expand All @@ -69,256 +89,7 @@ def __init__(self):
raise DataJointError(
'Module {} is not bound to a database. See datajoint.connection.bind'.format(self.__module__))
self.dbname = dbname
self.declare()
super().__init__(conn=conn, dbname=dbname, class_name=self.__class__.__name__)

@property
def is_declared(self):
self.conn.load_headings(self.dbname)
return self.class_name in self.conn.table_names[self.dbname]

def declare(self):
"""
Declare the table in database if it doesn't already exist.

:raises: DataJointError if the table cannot be declared.
"""
if not self.is_declared:
self._declare()
if not self.is_declared:
raise DataJointError(
'Table could not be declared for %s' % self.class_name)

def _field_to_sql(self, field):
"""
Converts an attribute definition tuple into SQL code.
:param field: attribute definition
:rtype : SQL code
"""
mysql_constants = ['CURRENT_TIMESTAMP']
if field.nullable:
default = 'DEFAULT NULL'
else:
default = 'NOT NULL'
# if some default specified
if field.default:
# enclose value in quotes (even numeric), except special SQL values
# or values already enclosed by the user
if field.default.upper() in mysql_constants or field.default[:1] in ["'", '"']:
default = '%s DEFAULT %s' % (default, field.default)
else:
default = '%s DEFAULT "%s"' % (default, field.default)

# TODO: escape instead! - same goes for Matlab side implementation
assert not any((c in r'\"' for c in field.comment)), \
'Illegal characters in attribute comment "%s"' % field.comment

return '`{name}` {type} {default} COMMENT "{comment}",\n'.format(
name=field.name, type=field.type, default=default, comment=field.comment)

def _declare(self):
"""
Declares the table in the data base if no table in the database matches this object.
"""
if not self.definition:
raise DataJointError('Table declaration is missing!')
table_info, parents, referenced, fieldDefs, indexDefs = self._parse_declaration()
defined_name = table_info['module'] + '.' + table_info['className']
expected_name = self.__module__.split('.')[-1] + '.' + self.class_name
if not defined_name == expected_name:
raise DataJointError('Table name {} does not match the declared'
'name {}'.format(expected_name, defined_name))

# compile the CREATE TABLE statement
# TODO: support prefix
table_name = role_to_prefix[
table_info['tier']] + from_camel_case(self.class_name)
sql = 'CREATE TABLE `%s`.`%s` (\n' % (self.dbname, table_name)

# add inherited primary key fields
primary_key_fields = set()
non_key_fields = set()
for p in parents:
for key in p.primary_key:
field = p.heading[key]
if field.name not in primary_key_fields:
primary_key_fields.add(field.name)
sql += self._field_to_sql(field)
else:
logger.debug('Field definition of {} in {} ignored'.format(
field.name, p.full_class_name))

# add newly defined primary key fields
for field in (f for f in fieldDefs if f.in_key):
if field.nullable:
raise DataJointError('Primary key {} cannot be nullable'.format(
field.name))
if field.name in primary_key_fields:
raise DataJointError('Duplicate declaration of the primary key '
'{key}. Check to make sure that the key '
'is not declared already in referenced '
'tables'.format(key=field.name))
primary_key_fields.add(field.name)
sql += self._field_to_sql(field)

# add secondary foreign key attributes
for r in referenced:
keys = (x for x in r.heading.attrs.values() if x.in_key)
for field in keys:
if field.name not in primary_key_fields | non_key_fields:
non_key_fields.add(field.name)
sql += self._field_to_sql(field)

# add dependent attributes
for field in (f for f in fieldDefs if not f.in_key):
non_key_fields.add(field.name)
sql += self._field_to_sql(field)

# add primary key declaration
assert len(primary_key_fields) > 0, 'table must have a primary key'
keys = ', '.join(primary_key_fields)
sql += 'PRIMARY KEY (%s),\n' % keys

# add foreign key declarations
for ref in parents + referenced:
keys = ', '.join(ref.primary_key)
sql += 'FOREIGN KEY (%s) REFERENCES %s (%s) ON UPDATE CASCADE ON DELETE RESTRICT,\n' % \
(keys, ref.full_table_name, keys)

# add secondary index declarations
# gather implicit indexes due to foreign keys first
implicit_indices = []
for fk_source in parents + referenced:
implicit_indices.append(fk_source.primary_key)

# for index in indexDefs:
# TODO: finish this up...

# close the declaration
sql = '%s\n) ENGINE = InnoDB, COMMENT "%s"' % (
sql[:-2], table_info['comment'])

# make sure that the table does not alredy exist
self.conn.load_headings(self.dbname, force=True)
if not self.is_declared:
# execute declaration
logger.debug('\n<SQL>\n' + sql + '</SQL>\n\n')
self.conn.query(sql)
self.conn.load_headings(self.dbname, force=True)

def _parse_declaration(self):
"""
Parse declaration and create new SQL table accordingly.
"""
parents = []
referenced = []
index_defs = []
field_defs = []
declaration = re.split(r'\s*\n\s*', self.definition.strip())

# remove comment lines
declaration = [x for x in declaration if not x.startswith('#')]
ptrn = """
^(?P<module>\w+)\.(?P<className>\w+)\s* # module.className
\(\s*(?P<tier>\w+)\s*\)\s* # (tier)
\#\s*(?P<comment>.*)$ # comment
"""
p = re.compile(ptrn, re.X)
table_info = p.match(declaration[0]).groupdict()
if table_info['tier'] not in Role.__members__:
raise DataJointError('InvalidTableTier: Invalid tier {tier} for table\
{module}.{cls}'.format(tier=table_info['tier'],
module=table_info[
'module'],
cls=table_info['className']))
table_info['tier'] = Role[table_info['tier']] # convert into enum

in_key = True # parse primary keys
field_ptrn = """
^[a-z][a-z\d_]*\s* # name
(=\s*\S+(\s+\S+)*\s*)? # optional defaults
:\s*\w.*$ # type, comment
"""
fieldP = re.compile(field_ptrn, re.I + re.X) # ignore case and verbose

for line in declaration[1:]:
if line.startswith('---'):
in_key = False # start parsing non-PK fields
elif line.startswith('->'):
# foreign key
module_name, class_name = line[2:].strip().split('.')
rel = self.get_base(module_name, class_name)
(parents if in_key else referenced).append(rel)
elif re.match(r'^(unique\s+)?index[^:]*$', line):
index_defs.append(self._parse_index_def(line))
elif fieldP.match(line):
field_defs.append(self._parse_attr_def(line, in_key))
else:
raise DataJointError(
'Invalid table declaration line "%s"' % line)

return table_info, parents, referenced, field_defs, index_defs

def _parse_attr_def(self, line, in_key=False): # todo add docu for in_key
"""
Parse attribute definition line in the declaration and returns
an attribute tuple.

:param line: attribution line
:param in_key:
:returns: attribute tuple
"""
line = line.strip()
attr_ptrn = """
^(?P<name>[a-z][a-z\d_]*)\s* # field name
(=\s*(?P<default>\S+(\s+\S+)*?)\s*)? # default value
:\s*(?P<type>\w[^\#]*[^\#\s])\s* # datatype
(\#\s*(?P<comment>\S*(\s+\S+)*)\s*)?$ # comment
"""

attrP = re.compile(attr_ptrn, re.I + re.X)
m = attrP.match(line)
assert m, 'Invalid field declaration "%s"' % line
attr_info = m.groupdict()
if not attr_info['comment']:
attr_info['comment'] = ''
if not attr_info['default']:
attr_info['default'] = ''
attr_info['nullable'] = attr_info['default'].lower() == 'null'
assert (not re.match(r'^bigint', attr_info['type'], re.I) or not attr_info['nullable']), \
'BIGINT attributes cannot be nullable in "%s"' % line

attr_info['in_key'] = in_key
attr_info['autoincrement'] = None
attr_info['numeric'] = None
attr_info['string'] = None
attr_info['is_blob'] = None
attr_info['computation'] = None
attr_info['dtype'] = None

return Heading.AttrTuple(**attr_info)

def _parse_index_def(self, line):
"""
Parses index definition.

:param line: definition line
:return: groupdict with index info
"""
line = line.strip()
index_ptrn = """
^(?P<unique>UNIQUE)?\s*INDEX\s* # [UNIQUE] INDEX
\((?P<attributes>[^\)]+)\)$ # (attr1, attr2)
"""
indexP = re.compile(index_ptrn, re.I + re.X)
m = indexP.match(line)
assert m, 'Invalid index declaration "%s"' % line
index_info = m.groupdict()
attributes = re.split(r'\s*,\s*', index_info['attributes'].strip())
index_info['attributes'] = attributes
assert len(attributes) == len(set(attributes)), \
'Duplicate attributes in index declaration "%s"' % line
return index_info

def get_base(self, module_name, class_name):
"""
Expand All @@ -330,12 +101,15 @@ def get_base(self, module_name, class_name):
:returns: the base relation
"""
mod_obj = self.get_module(module_name)
if not mod_obj:
raise DataJointError('Module named {mod_name} was not found. Please make'
' sure that it is in the path or you import the module.'.format(mod_name=module_name))
try:
ret = getattr(mod_obj, class_name)()
except KeyError:
ret = self.__class__(conn=self.conn,
dbname=self.conn.schemas[module_name],
class_name=class_name)
except AttributeError:
ret = Table(conn=self.conn,
dbname=self.conn.mod_to_db[mod_obj.__name__],
class_name=class_name)
return ret

@classmethod
Expand All @@ -358,6 +132,8 @@ def get_module(cls, module_name):
# from IPython import embed
# embed()
mod_obj = importlib.import_module(cls.__module__)
if cls.__module__.split('.')[-1] == module_name:
return mod_obj
attr = getattr(mod_obj, module_name, None)
if isinstance(attr, ModuleType):
return attr
Expand Down
1 change: 0 additions & 1 deletion datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self, host, user, passwd, init_fun=None):
print("Connected", user + '@' + host + ':' + str(port))
self._conn.autocommit(True)

self.mod_to_db2 = {} # database indexed by module names
self.db_to_mod = {} # modules indexed by dbnames
self.mod_to_db = {} # database names indexed by modules
self.table_names = {} # tables names indexed by [dbname][class_name]
Expand Down
Loading