Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions datajoint/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
mxClassID = OrderedDict((
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
('mxUNKNOWN_CLASS', None),
('mxCELL_CLASS', None), # TODO: implement
('mxSTRUCT_CLASS', None), # TODO: implement
('mxCELL_CLASS', None),
('mxSTRUCT_CLASS', None),
('mxLOGICAL_CLASS', np.dtype('bool')),
('mxCHAR_CLASS', np.dtype('c')),
('mxVOID_CLASS', None),
Expand All @@ -23,7 +23,7 @@
('mxUINT64_CLASS', np.dtype('uint64')),
('mxFUNCTION_CLASS', None)))

reverseClassID = {v: i for i, v in enumerate(mxClassID.values())}
reverseClassID = {dtype: i for i, dtype in enumerate(mxClassID.values())}
dtypeList = list(mxClassID.values())


Expand Down
2 changes: 2 additions & 0 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import DataJointError
import logging
from . import config
from .erd import ERD

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +55,7 @@ class Connection:
"""

def __init__(self, host, user, passwd, init_fun=None):
self.erd = ERD()
if ':' in host:
host, port = host.split(':')
port = int(port)
Expand Down
4 changes: 1 addition & 3 deletions datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyparsing as pp
import logging

from . import DataJointError
from . import DataJointError


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,8 +69,6 @@ def declare(full_table_name, definition, context):
return sql




def compile_attribute(line, in_key=False):
"""
Convert attribute definition from DataJoint format to SQL
Expand Down
87 changes: 84 additions & 3 deletions datajoint/erd.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,101 @@
import re
import logging

import pyparsing as pp
import re
import networkx as nx
from networkx import DiGraph
from networkx import pygraphviz_layout
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms
from collections import defaultdict

from .utils import to_camel_case
from . import DataJointError


logger = logging.getLogger(__name__)


class ERD:
_parents = defaultdict(set)
_children = defaultdict(set)
_references = defaultdict(set)
_referenced = defaultdict(set)

def load_dependencies(self, connection, full_table_name, primary_key):
# fetch the CREATE TABLE statement
cur = connection.query('SHOW CREATE TABLE %s' % full_table_name)
create_statement = cur.fetchone()
if not create_statement:
raise DataJointError('Could not load the definition table %s' % full_table_name)
create_statement = create_statement[1].split('\n')

# build foreign key parser
database = full_table_name.split('.')[0].strip('`')
add_database = lambda string, loc, toc: ['`{database}`.`{table}`'.format(database=database, table=toc[0])]

parser = pp.CaselessLiteral('CONSTRAINT').suppress()
parser += pp.QuotedString('`').suppress()
parser += pp.CaselessLiteral('FOREIGN KEY').suppress()
parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes')
parser += pp.CaselessLiteral('REFERENCES')
parser += pp.Or([
pp.QuotedString('`').setParseAction(add_database),
pp.Combine(pp.QuotedString('`', unquoteResults=False) +
'.' + pp.QuotedString('`', unquoteResults=False))
]).setResultsName('referenced_table')
parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes')

# parse foreign keys
for line in create_statement:
try:
result = parser.parseString(line)
except pp.ParseException:
pass
else:
if result.referenced_attributes != result.attributes:
raise DataJointError(
"%s's foreign key refers to differently named attributes in %s"
% (self.__class__.__name__, result.referenced_table))
if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]):
self._parents[full_table_name].add(result.referenced_table)
self._children[result.referenced_table].add(full_table_name)
else:
self._referenced[full_table_name].add(result.referenced_table)
self._references[result.referenced_table].add(full_table_name)

@property
def parents(self):
return self._parents

@property
def children(self):
return self._children

@property
def references(self):
return self._references

@property
def referenced(self):
return self._referenced




def to_camel_case(s):
"""
Convert names with under score (_) separation
into camel case names.
Example:
>>>to_camel_case("table_name")
"TableName"
"""
def to_upper(match):
return match.group(0)[-1].upper()
return re.sub('(^|[_\W])+[a-zA-Z]', to_upper, s)



class RelGraph(DiGraph):
"""
A directed graph representing relations between tables within and across
Expand Down
22 changes: 20 additions & 2 deletions datajoint/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def decorator(cls):
full_table_name=instance.full_table_name,
definition=instance.definition,
context=context))
connection.erd.load_dependencies(connection, instance.full_table_name, instance.primary_key)
return cls

return decorator



class Relation(RelationalOperand, metaclass=abc.ABCMeta):
"""
Relation is an abstract class that represents a base relation, i.e. a table in the database.
Expand Down Expand Up @@ -111,13 +111,31 @@ def iter_insert(self, rows, **kwargs):
for row in rows:
self.insert(row, **kwargs)

# ------------- dependencies ---------- #
@property
def parents(self):
return self.connection.erd.parents[self.full_table_name]

@property
def children(self):
return self.connection.erd.children[self.full_table_name]

@property
def references(self):
return self.connection.erd.references[self.full_table_name]

@property
def referenced(self):
return self.connection.erd.referenced[self.full_table_name]


# --------- SQL functionality --------- #
@property
def is_declared(self):
cur = self.connection.query(
'SHOW TABLES in `{database}`LIKE "{table_name}"'.format(
database=self.database, table_name=self.table_name))
return cur.rowcount>0
return cur.rowcount > 0

def batch_insert(self, data, **kwargs):
"""
Expand Down