diff --git a/datajoint/blob.py b/datajoint/blob.py index 7dad919cf..7988e82d5 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -6,8 +6,8 @@ mxClassID = OrderedDict(( # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html ('mxUNKNOWN_CLASS', None), - ('mxCELL_CLASS', None), # TODO: implement - ('mxSTRUCT_CLASS', None), # TODO: implement + ('mxCELL_CLASS', None), + ('mxSTRUCT_CLASS', None), ('mxLOGICAL_CLASS', np.dtype('bool')), ('mxCHAR_CLASS', np.dtype('c')), ('mxVOID_CLASS', None), @@ -23,7 +23,7 @@ ('mxUINT64_CLASS', np.dtype('uint64')), ('mxFUNCTION_CLASS', None))) -reverseClassID = {v: i for i, v in enumerate(mxClassID.values())} +reverseClassID = {dtype: i for i, dtype in enumerate(mxClassID.values())} dtypeList = list(mxClassID.values()) diff --git a/datajoint/connection.py b/datajoint/connection.py index 0c668d9fe..325d1c021 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -3,6 +3,7 @@ from . import DataJointError import logging from . import config +from .erd import ERD logger = logging.getLogger(__name__) @@ -54,6 +55,7 @@ class Connection: """ def __init__(self, host, user, passwd, init_fun=None): + self.erd = ERD() if ':' in host: host, port = host.split(':') port = int(port) diff --git a/datajoint/declare.py b/datajoint/declare.py index c0c8e44a9..18679150a 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -2,7 +2,7 @@ import pyparsing as pp import logging -from . import DataJointError +from . import DataJointError logger = logging.getLogger(__name__) @@ -69,8 +69,6 @@ def declare(full_table_name, definition, context): return sql - - def compile_attribute(line, in_key=False): """ Convert attribute definition from DataJoint format to SQL diff --git a/datajoint/erd.py b/datajoint/erd.py index be0d302ba..73af00c7c 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -1,20 +1,101 @@ -import re import logging - +import pyparsing as pp +import re import networkx as nx from networkx import DiGraph from networkx import pygraphviz_layout import numpy as np import matplotlib.pyplot as plt from matplotlib import transforms +from collections import defaultdict -from .utils import to_camel_case from . import DataJointError logger = logging.getLogger(__name__) +class ERD: + _parents = defaultdict(set) + _children = defaultdict(set) + _references = defaultdict(set) + _referenced = defaultdict(set) + + def load_dependencies(self, connection, full_table_name, primary_key): + # fetch the CREATE TABLE statement + cur = connection.query('SHOW CREATE TABLE %s' % full_table_name) + create_statement = cur.fetchone() + if not create_statement: + raise DataJointError('Could not load the definition table %s' % full_table_name) + create_statement = create_statement[1].split('\n') + + # build foreign key parser + database = full_table_name.split('.')[0].strip('`') + add_database = lambda string, loc, toc: ['`{database}`.`{table}`'.format(database=database, table=toc[0])] + + parser = pp.CaselessLiteral('CONSTRAINT').suppress() + parser += pp.QuotedString('`').suppress() + parser += pp.CaselessLiteral('FOREIGN KEY').suppress() + parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') + parser += pp.CaselessLiteral('REFERENCES') + parser += pp.Or([ + pp.QuotedString('`').setParseAction(add_database), + pp.Combine(pp.QuotedString('`', unquoteResults=False) + + '.' + pp.QuotedString('`', unquoteResults=False)) + ]).setResultsName('referenced_table') + parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') + + # parse foreign keys + for line in create_statement: + try: + result = parser.parseString(line) + except pp.ParseException: + pass + else: + if result.referenced_attributes != result.attributes: + raise DataJointError( + "%s's foreign key refers to differently named attributes in %s" + % (self.__class__.__name__, result.referenced_table)) + if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): + self._parents[full_table_name].add(result.referenced_table) + self._children[result.referenced_table].add(full_table_name) + else: + self._referenced[full_table_name].add(result.referenced_table) + self._references[result.referenced_table].add(full_table_name) + + @property + def parents(self): + return self._parents + + @property + def children(self): + return self._children + + @property + def references(self): + return self._references + + @property + def referenced(self): + return self._referenced + + + + +def to_camel_case(s): + """ + Convert names with under score (_) separation + into camel case names. + Example: + >>>to_camel_case("table_name") + "TableName" + """ + def to_upper(match): + return match.group(0)[-1].upper() + return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s) + + + class RelGraph(DiGraph): """ A directed graph representing relations between tables within and across diff --git a/datajoint/relation.py b/datajoint/relation.py index 3c9ce03dc..3cdf8e919 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -53,12 +53,12 @@ def decorator(cls): full_table_name=instance.full_table_name, definition=instance.definition, context=context)) + connection.erd.load_dependencies(connection, instance.full_table_name, instance.primary_key) return cls return decorator - class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ Relation is an abstract class that represents a base relation, i.e. a table in the database. @@ -111,13 +111,31 @@ def iter_insert(self, rows, **kwargs): for row in rows: self.insert(row, **kwargs) + # ------------- dependencies ---------- # + @property + def parents(self): + return self.connection.erd.parents[self.full_table_name] + + @property + def children(self): + return self.connection.erd.children[self.full_table_name] + + @property + def references(self): + return self.connection.erd.references[self.full_table_name] + + @property + def referenced(self): + return self.connection.erd.referenced[self.full_table_name] + + # --------- SQL functionality --------- # @property def is_declared(self): cur = self.connection.query( 'SHOW TABLES in `{database}`LIKE "{table_name}"'.format( database=self.database, table_name=self.table_name)) - return cur.rowcount>0 + return cur.rowcount > 0 def batch_insert(self, data, **kwargs): """