From c247cbbed624c5706a043c8a4f8b6ff3e2355845 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:44:48 -0500 Subject: [PATCH 01/41] update erm to load all tables in a database --- datajoint/erd.py | 56 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index aae2f4055..00ad9774a 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -7,7 +7,14 @@ import pyparsing as pp import networkx as nx from networkx import DiGraph -from networkx import pygraphviz_layout +import re + +# use pygraphviz if available +try: + from networkx import pygraphviz_layout +except: + pygraphviz_layout = None + import matplotlib.pyplot as plt from . import DataJointError from .utils import to_camel_case @@ -17,10 +24,16 @@ class RelGraph(DiGraph): """ - A directed graph representing relations between tables within and across - multiple databases found. + A directed graph representing dependencies between Relations within and across + multiple databases. """ + def __init__(self, name_map=None): + if name_map is None: + name_map = {} + + super().__init__() + @property def node_labels(self): """ @@ -47,7 +60,7 @@ def non_pk_edges(self): def highlight(self, nodes): """ Highlights specified nodes when plotting - :param nodes: list of nodes to be highlighted + :param nodes: list of nodes, specified by full table names, to be highlighted """ for node in nodes: self.node[node]['highlight'] = True @@ -56,7 +69,7 @@ def remove_highlight(self, nodes=None): """ Remove highlights from specified nodes when plotting. If specified node is not highlighted to begin with, nothing happens. - :param nodes: list of nodes to remove highlights from + :param nodes: list of nodes, specified by full table names, to remove highlights from """ if not nodes: nodes = self.nodes_iter() @@ -111,7 +124,7 @@ def restrict_by_modules(self, modules, fill=False): def restrict_by_tables(self, tables, fill=False): """ Creates a subgraph containing only specified tables. - :param tables: list of tables to keep in the subgraph + :param tables: list of tables to keep in the subgraph. Tables are specified using full table names :param fill: set True to automatically include nodes connecting two nodes in the specified list of tables :return: a subgraph with specified nodes @@ -256,6 +269,7 @@ class ERM(RelGraph): def __init__(self, conn, *args, **kwargs): super().__init__(*args, **kwargs) + self._databases = set() self._conn = conn self._parents = dict() self._referenced = dict() @@ -293,7 +307,29 @@ def copy_graph(self, *args, **kwargs): def subgraph(self, *args, **kwargs): return RelGraph(self).subgraph(*args, **kwargs) - def load_dependencies(self, full_table_name): + def register_database(self, database): + """ + Register the database to be monitored + :param database: name of database to be monitored + """ + self._databases.add(database) + + + def load_dependencies(self): + for database in self._databases: + self.load_dependencies_for_database(database) + + def load_dependencies_for_database(self, database): + sql_table_name_regexp = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') + + cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) + + for info in cur: + table_name = info[0] + full_table_name = '`{database}`.`{table_name}`'.format(database=database, table_name=table_name) + self.load_dependencies_for_table(full_table_name) + + def load_dependencies_for_table(self, full_table_name): # check if already loaded. Use clear_dependencies before reloading if full_table_name in self._parents: return @@ -366,20 +402,25 @@ def clear_dependencies(self, full_table_name): if full_table_name in self._references[ref]: self._references[ref].remove(full_table_name) + @property def parents(self): + self.load_dependencies() return self._parents @property def children(self): + self.load_dependencies() return self._children @property def references(self): + self.load_dependencies() return self._references @property def referenced(self): + self.load_dependencies() return self._referenced def get_descendants(self, full_table_name): @@ -388,6 +429,7 @@ def get_descendants(self, full_table_name): :return: list of all children and references, in order of dependence. This is helpful for cascading delete or drop operations. """ + self.load_dependencies() ret = defaultdict(lambda: 0) def recurse(full_table_name, level): From 917f4c79e4dfe8d48a72e0f5ab2efc9cabc11092 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:46:30 -0500 Subject: [PATCH 02/41] register database for monitoring --- datajoint/schema.py | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/datajoint/schema.py b/datajoint/schema.py index 002c90f21..36e6384b3 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -11,11 +11,14 @@ class schema: """ A schema object can be used as a decorator that associates a Relation class to a database as - well as a namespace for looking up foreign key references. + well as a namespace for looking up foreign key references in table declaration. """ def __init__(self, database, context, connection=None): """ + Associates the specified database with this schema object. If the target database does not exist + already, will attempt on creating the database. + :param database: name of the database to associate the decorated class with :param context: dictionary for looking up foreign keys references, usually set to locals() :param connection: Connection object. Defaults to datajoint.conn() @@ -27,8 +30,8 @@ def __init__(self, database, context, connection=None): self.context = context # if the database does not exist, create it - cur = connection.query("SHOW DATABASES LIKE '{database}'".format(database=database)) - if cur.rowcount == 0: + + if not self.exists: logger.info("Database `{database}` could not be found. " "Attempting to create the database.".format(database=database)) try: @@ -39,9 +42,37 @@ def __init__(self, database, context, connection=None): " an attempt to create has failed. Check" " permissions.".format(database=database)) + # TODO: replace with a call on connection object + connection.erm.register_database(database) + + def drop(self): + """ + Drop the associated database if it exists + """ + if self.exists: + logger.info("Dropping `{database}`.".format(database=self.database)) + try: + self.connection.query("DROP DATABASE `{database}`".format(database=self.database)) + logger.info("Database `{database}` was dropped successfully.".format(database=self.database)) + except pymysql.OperationalError: + raise DataJointError("An attempt to drop database named `{database}` " + "has failed. Check permissions.".format(database=self.database)) + else: + logger.info("Database named `{database}` does not exist. Doing nothing.".format(database=self.database)) + + @property + def exists(self): + """ + :return: true if the associated database exists on the server + """ + cur = self.connection.query("SHOW DATABASES LIKE '{database}'".format(database=self.database)) + return cur.rowcount > 0 + + + def __call__(self, cls): """ - The decorator binds its argument class object to a database + Binds the passed in class object to a database. This is intended to be used as a decorator. :param cls: class to be decorated """ @@ -56,7 +87,7 @@ def process_relation_class(relation_class, context): relation_class().declare() if issubclass(cls, Part): - raise DataJointError('The schema decorator should not apply to part relations') + raise DataJointError('The schema decorator should not be applied to Part relations') process_relation_class(cls, context=self.context) @@ -72,9 +103,10 @@ def process_relation_class(relation_class, context): if is_sub: parts.append(part) part._master = cls + # TODO: look into local namespace for the subclasses process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) elif issubclass(part, Relation): - raise DataJointError('Part relations must subclass from datajoint.Part') + raise DataJointError('Part relations must be a subclass of datajoint.Part') # invoke Relation._prepare() on class and its part relations. cls()._prepare() From cab3c5c958d0ac0e60ce73ea88774ae95d74cb03 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:47:32 -0500 Subject: [PATCH 03/41] minor tweaks and comment updates --- datajoint/jobs.py | 12 +++++------- datajoint/kill.py | 3 ++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/datajoint/jobs.py b/datajoint/jobs.py index db41e0eac..ab3d8f691 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -56,14 +56,12 @@ def reserve(self, table_name, key): :param key: the dict of the job's primary key :return: True if reserved job successfully. False = the jobs is already taken """ - job_key = dict(table_name=table_name, key_hash=key_hash(key)) + try: - self.insert1( - dict(job_key, - status="reserved", - host=os.uname().nodename, - pid=os.getpid())) - except pymysql.err.IntegrityError: + job_key = dict(table_name=table_name, key_hash=key_hash(key), + status='reserved', host=os.uname().nodename, pid=os.getpid()) + self.insert1(job_key) + except pymysql.err.IntegrityError: #TODO check for other exceptions! return False else: return True diff --git a/datajoint/kill.py b/datajoint/kill.py index 48b9cbd54..8496f99ab 100644 --- a/datajoint/kill.py +++ b/datajoint/kill.py @@ -26,7 +26,7 @@ def kill(restriction=None, connection=None): while True: print(' ID USER STATE TIME INFO') print('+--+ +----------+ +-----------+ +--+') - for process in connection.query(query, as_dict=True).fetchall(): + for process in connection.query(query, as_dict=True).fetchall(): try: print('{ID:>4d} {USER:<12s} {STATE:<12s} {TIME:>5d} {INFO}'.format(**process)) except TypeError as err: @@ -40,6 +40,7 @@ def kill(restriction=None, connection=None): pid = int(response) except ValueError: pass # ignore non-numeric input + #TODO: check behavior when invalid input given else: try: connection.query('kill %d' % pid) From ff8085368b698417d8c3a14e002ebb10ee989950 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:48:17 -0500 Subject: [PATCH 04/41] remove pygraphviz from requirements due to instability --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e8a5dcbf0..af4d48f13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,3 @@ networkx matplotlib sphinx_rtd_theme mock -pygraphviz==1.3rc1 From a38c7816084f62f2cec8b44e8fb507375b2dc240 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:50:08 -0500 Subject: [PATCH 05/41] rename schema to Schema and add alias schema --- datajoint/__init__.py | 2 +- datajoint/schema.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 47ce89b30..7bb7e28e6 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -56,5 +56,5 @@ class DataJointError(Exception): from .user_relations import Manual, Lookup, Imported, Computed, Part from .relational_operand import Not from .heading import Heading -from .schema import schema +from .schema import Schema as schema from .kill import kill diff --git a/datajoint/schema.py b/datajoint/schema.py index 36e6384b3..889fb76d9 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -class schema: +class Schema: """ A schema object can be used as a decorator that associates a Relation class to a database as well as a namespace for looking up foreign key references in table declaration. From 61adad4d8ab0d954cbcf332f6e50861d87a162c4 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:50:50 -0500 Subject: [PATCH 06/41] update docstring for AutoPopulate --- datajoint/autopopulate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 1cef4445e..c5fa0d1b0 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -60,8 +60,7 @@ def populate(self, restriction=None, suppress_errors=False, :param restriction: restriction on rel.populated_from - target :param suppress_errors: suppresses error if true - :param reserve_jobs: currently not implemented - :param batch: batch size of a single job + :param reserve_jobs: if true, reserves job to populate in asynchronous fashion :param order: "original"|"reverse"|"random" - the order of execution """ if not isinstance(self.populated_from, RelationalOperand): From 3b5997165668e735ae343843f146e4994c15126e Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:51:28 -0500 Subject: [PATCH 07/41] update heading loading logic --- datajoint/relation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index ce02351dd..b6ab9e90f 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -25,6 +25,7 @@ class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ _heading = None _context = None + database = None # ---------- abstract properties ------------ # @property @@ -58,8 +59,9 @@ def heading(self): """ if self._heading is None: self._heading = Heading() # instance-level heading - if not self._heading: + if not self._heading: # heading is not initialized self.declare() + self._heading.init_from_database(self.connection, self.database, self.table_name) return self._heading def declare(self): @@ -69,9 +71,6 @@ def declare(self): if not self.is_declared: self.connection.query( declare(self.full_table_name, self.definition, self._context)) - if self.is_declared: - self.connection.erm.load_dependencies(self.full_table_name) - self._heading.init_from_database(self.connection, self.database, self.table_name) @property def from_clause(self): @@ -124,6 +123,7 @@ def referenced(self): """ return self.connection.erm.referenced[self.full_table_name] + #TODO: implement this inside the relation object in connection @property def descendants(self): """ @@ -137,7 +137,11 @@ def descendants(self): for table in self.connection.erm.get_descendants(self.full_table_name)) return [relation for relation in relations if relation.is_declared] + def _repr_helper(self): + """ + :return: String representation of this object + """ return "%s.%s()" % (self.__module__, self.__class__.__name__) # --------- SQL functionality --------- # From 9e8e076f14d395ce9c9a5a34650e3e8b3b2924f6 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:51:58 -0500 Subject: [PATCH 08/41] load dependencies prior to fetching erd --- datajoint/connection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datajoint/connection.py b/datajoint/connection.py index 58eb65dd6..6d4948c9b 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -82,6 +82,8 @@ def __repr__(self): connected=connected, **self.conn_info) def erd(self, *args, **kwargs): + # load all dependencies + self.erm.load_dependencies() return self.erm.copy_graph(*args, **kwargs) @property From 6e6531c24313a8f44ae509a7e0804c69f4c0787a Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 14 Sep 2015 01:52:44 -0500 Subject: [PATCH 09/41] update description --- tests/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schema.py b/tests/schema.py index 3b6b0a479..e1444a189 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -1,5 +1,5 @@ """ -Test schema definition +Sample scheme with realistic tables for testing """ import random From 9ca448b3de04d292407552348170c069eeae6f84 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 15 Sep 2015 18:43:36 -0500 Subject: [PATCH 10/41] add command line erd --- datajoint/erd.py | 158 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 145 insertions(+), 13 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 00ad9774a..69db7bf84 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -6,8 +6,13 @@ from collections import defaultdict import pyparsing as pp import networkx as nx -from networkx import DiGraph import re +from networkx import DiGraph +from functools import cmp_to_key +import operator +from . import Manual, Lookup, Imported, Computed, Part + +from collections import OrderedDict # use pygraphviz if available try: @@ -17,6 +22,7 @@ import matplotlib.pyplot as plt from . import DataJointError +from functools import wraps from .utils import to_camel_case logger = logging.getLogger(__name__) @@ -28,11 +34,11 @@ class RelGraph(DiGraph): multiple databases. """ - def __init__(self, name_map=None): + def __init__(self, *args, name_map=None, **kwargs): if name_map is None: name_map = {} - super().__init__() + super().__init__(*args, **kwargs) @property def node_labels(self): @@ -106,8 +112,8 @@ def plot(self): ax.axis('off') # hide axis def __repr__(self): - #TODO: provide string version of ERD - return "RelGraph: to be implemented" + return self.repr_path() + def restrict_by_modules(self, modules, fill=False): """ @@ -174,6 +180,7 @@ def descendants_of_all(self, nodes): s.update(self.descendants(n)) return s + def ancestors(self, node): """ Find and return a set containing all ancestors of the specified @@ -258,6 +265,115 @@ def n_neighbors(self, node, n, directed=False, prev=None): s.update(self.n_neighbors(x, n-1, prev)) return s + @property + def root_nodes(self): + return {node for node in self.nodes() if len(self.predecessors(node)) == 0} + + @property + def leaf_nodes(self): + return {node for node in self.nodes() if len(self.successors(node)) == 0} + + def nodes_by_depth(self): + """ + Return all nodes, ordered by their depth in the hierarchy + :returns: list of nodes, ordered by depth from shallowest to deepest + """ + ret = defaultdict(lambda: 0) + roots = self.root_nodes + + def recurse(node, depth): + if depth > ret[node]: + ret[node] = depth + for child in self.successors_iter(node): + recurse(child, depth+1) + + for root in roots: + recurse(root, 0) + + return sorted(ret.items(), key=operator.itemgetter(1)) + + def get_longest_path(self): + if not self.edges(): + return [] + + node_depth_list = self.nodes_by_depth() + node_depth_lookup = dict(node_depth_list) + path = [] + + leaf = node_depth_list[-1][0] + predecessors = [leaf] + while predecessors: + leaf = sorted(predecessors, key=node_depth_lookup.get)[-1] + path.insert(0, leaf) + predecessors = self.predecessors(leaf) + + return path + + def remove_edges_in_path(self, path): + if len(path) <= 1: # no path exists! + return + for a, b in zip(path[:-1], path[1:]): + self.remove_edge(a, b) + + def longest_paths(self): + g = self.copy_graph() + paths = [] + path = g.get_longest_path() + while path: + paths.append(path) + g.remove_edges_in_path(path) + path = g.get_longest_path() + return paths + + def repr_path(self): + paths = self.longest_paths() + k = cmp_to_key(self.compare_path) + sorted_paths = sorted(paths, key=k) + n = max([len(x) for x in self.nodes()]) + repr = '' + for path in sorted_paths: + repr += self.repr_path_with_depth(path, n) + return repr + + def compare_path(self, path1, path2): + node_depth_lookup = dict(self.nodes_by_depth()) + for node1, node2 in zip(path1, path2): + if node_depth_lookup[node1] != node_depth_lookup[node2]: + return -1 if node_depth_lookup[node1] < node_depth_lookup[node2] else 1 + if node1 != node2: + return -1 if node1 < node2 else 1 + return 0 + + def repr_path_with_depth(self, path, n=20, m=2): + node_depth_lookup = dict(self.nodes_by_depth()) + arrow = '-' * (m-1) + '>' + space = '-' * n + pattern = '{:%ds}' % n + repr = '' + prev_depth = 0 + first = True + for (i, node) in enumerate(path): + depth = node_depth_lookup[node] + if first: + repr += (' '*(n+m))*(depth-prev_depth) + else: + repr += space.join(['-'*m]*(depth-prev_depth))[:-1] + '>' + first = False + prev_depth = depth + if i == len(path)-1: + repr += node + else: + repr += node.ljust(n, '-') + repr += '\n' + return repr + + +def require_dep_loading(f): + @wraps(f) + def wrapper(self, *args, **kwargs): + self.load_dependencies() + return f(self, *args, **kwargs) + return wrapper class ERM(RelGraph): """ @@ -314,12 +430,18 @@ def register_database(self, database): """ self._databases.add(database) - def load_dependencies(self): + """ + Load dependencies for all monitored databases + """ for database in self._databases: self.load_dependencies_for_database(database) def load_dependencies_for_database(self, database): + """ + Load dependencies for all tables found in the specified database + :param database: database for which dependencies will be loaded + """ sql_table_name_regexp = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) @@ -330,6 +452,10 @@ def load_dependencies_for_database(self, database): self.load_dependencies_for_table(full_table_name) def load_dependencies_for_table(self, full_table_name): + """ + Load dependencies for the specified table + :param full_table_name: table for which dependencies will be loaded, specified in full table name + """ # check if already loaded. Use clear_dependencies before reloading if full_table_name in self._parents: return @@ -391,10 +517,15 @@ def load_dependencies_for_table(self, full_table_name): else: self._referenced[full_table_name].append(result.referenced_table) self._references[result.referenced_table].append(full_table_name) - self.update_graph() - def clear_dependencies(self, full_table_name): + def clear_dependencies(self): + pass + + def clear_dependencies_for_database(self, database): + pass + + def clear_dependencies_for_table(self, full_table_name): for ref in self._parents.pop(full_table_name, []): if full_table_name in self._children[ref]: self._children[ref].remove(full_table_name) @@ -402,34 +533,34 @@ def clear_dependencies(self, full_table_name): if full_table_name in self._references[ref]: self._references[ref].remove(full_table_name) - @property + @require_dep_loading def parents(self): - self.load_dependencies() return self._parents @property + @require_dep_loading def children(self): self.load_dependencies() return self._children @property + @require_dep_loading def references(self): - self.load_dependencies() return self._references @property + @require_dep_loading def referenced(self): - self.load_dependencies() return self._referenced + @require_dep_loading def get_descendants(self, full_table_name): """ :param full_table_name: a table name in the format `database`.`table_name` :return: list of all children and references, in order of dependence. This is helpful for cascading delete or drop operations. """ - self.load_dependencies() ret = defaultdict(lambda: 0) def recurse(full_table_name, level): @@ -441,3 +572,4 @@ def recurse(full_table_name, level): recurse(full_table_name, 0) return sorted(ret.keys(), key=ret.__getitem__) + From 588f3046dbd13e737f5e9c665b981dfaa401e6ab Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 15 Sep 2015 18:44:16 -0500 Subject: [PATCH 11/41] move logging verbosity to top --- tests/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 149175c82..e1123a920 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -9,11 +9,15 @@ import logging from os import environ + +# turn on verbose logging +logging.basicConfig(level=logging.DEBUG) + import datajoint as dj __all__ = ['__author__', 'PREFIX', 'CONN_INFO'] -logging.basicConfig(level=logging.DEBUG) + # Connection for testing CONN_INFO = dict( @@ -27,11 +31,10 @@ def setup_package(): """ Package-level unit test setup - :return: + Turns off safemode """ dj.config['safemode'] = False - def teardown_package(): """ Package-level unit test teardown. From c42c586d839e326a9bdf04913e44a030d5257f35 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 15 Sep 2015 18:44:38 -0500 Subject: [PATCH 12/41] minor touch up on Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 77a432a28..137068e47 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ all: @echo 'MakeFile for DataJoint packaging ' @echo ' ' @echo 'make sdist Creates source distribution ' - @echo 'make wheel Creates Wheel dstribution ' + @echo 'make wheel Creates Wheel distribution ' @echo 'make pypi Package and upload to PyPI ' @echo 'make pypitest Package and upload to PyPI test server' @echo 'make purge Remove all build related directories ' From 01a857af9a8c69beda7ee8bc6dfdd5da9282d429 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 15 Sep 2015 18:45:41 -0500 Subject: [PATCH 13/41] follow pep8 --- datajoint/autopopulate.py | 1 + datajoint/schema.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index c5fa0d1b0..8b9300710 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -111,6 +111,7 @@ def populate(self, restriction=None, suppress_errors=False, jobs.complete(table_name, key) return error_list + def progress(self, restriction=None, display=True): """ report progress of populating this table diff --git a/datajoint/schema.py b/datajoint/schema.py index 889fb76d9..28d2e70ab 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -68,8 +68,6 @@ def exists(self): cur = self.connection.query("SHOW DATABASES LIKE '{database}'".format(database=self.database)) return cur.rowcount > 0 - - def __call__(self, cls): """ Binds the passed in class object to a database. This is intended to be used as a decorator. From 349f7f7ef1bc20771c64784e0730d4f7ef044b1d Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 15 Sep 2015 18:46:07 -0500 Subject: [PATCH 14/41] Load erm dependencies when new table declared --- datajoint/relation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datajoint/relation.py b/datajoint/relation.py index b6ab9e90f..42ddd80ca 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) + class Relation(RelationalOperand, metaclass=abc.ABCMeta): """ Relation is an abstract class that represents a base relation, i.e. a table in the database. @@ -62,6 +63,7 @@ def heading(self): if not self._heading: # heading is not initialized self.declare() self._heading.init_from_database(self.connection, self.database, self.table_name) + return self._heading def declare(self): @@ -71,6 +73,8 @@ def declare(self): if not self.is_declared: self.connection.query( declare(self.full_table_name, self.definition, self._context)) + #TODO: reconsider loading time + self.connection.erm.load_dependencies() @property def from_clause(self): @@ -306,7 +310,7 @@ def drop_quick(self): """ if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) - self.connection.erm.clear_dependencies(self.full_table_name) + self.connection.erm.clear_dependencies_for_table(self.full_table_name) if self._heading: self._heading.reset() logger.info("Dropped table %s" % self.full_table_name) From 1f7488192ebf963f8bb37323819300d3c29c0ebf Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 15 Sep 2015 18:50:45 -0500 Subject: [PATCH 15/41] fix import error --- datajoint/erd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 69db7bf84..ba341ca3b 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -10,7 +10,6 @@ from networkx import DiGraph from functools import cmp_to_key import operator -from . import Manual, Lookup, Imported, Computed, Part from collections import OrderedDict @@ -114,7 +113,6 @@ def plot(self): def __repr__(self): return self.repr_path() - def restrict_by_modules(self, modules, fill=False): """ Creates a subgraph containing only tables in the specified modules. From 6289bd4083e4498f5c7d7298df2d3127e88002b2 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Thu, 1 Oct 2015 17:59:17 -0500 Subject: [PATCH 16/41] Fix dependency loading timings and repr path ordering --- datajoint/erd.py | 61 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index ba341ca3b..e976d7d49 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -178,6 +178,8 @@ def descendants_of_all(self, nodes): s.update(self.descendants(n)) return s + def copy_graph(self, *args, **kwargs): + return self.__class__(self, *args, **kwargs) def ancestors(self, node): """ @@ -291,6 +293,10 @@ def recurse(node, depth): return sorted(ret.items(), key=operator.itemgetter(1)) def get_longest_path(self): + """ + :returns: a list of graph nodes defining th longest path in the graph + """ + # no path exists if there is not an edge! if not self.edges(): return [] @@ -308,12 +314,19 @@ def get_longest_path(self): return path def remove_edges_in_path(self, path): + """ + Removes all shared edges between this graph and the path + :param path: a list of nodes defining a path. All edges in this path will be removed from the graph if found + """ if len(path) <= 1: # no path exists! return for a, b in zip(path[:-1], path[1:]): self.remove_edge(a, b) def longest_paths(self): + """ + :return: list of paths from longest to shortest. A path is a list of nodes. + """ g = self.copy_graph() paths = [] path = g.get_longest_path() @@ -324,22 +337,45 @@ def longest_paths(self): return paths def repr_path(self): + """ + Construct string representation of the erm, summarizing dependencies between + tables + :return: string representation of the erm + """ paths = self.longest_paths() + + # turn comparator into Key object for use in sort k = cmp_to_key(self.compare_path) sorted_paths = sorted(paths, key=k) + + # table name will be padded to match the longest table name n = max([len(x) for x in self.nodes()]) - repr = '' + rep = '' for path in sorted_paths: - repr += self.repr_path_with_depth(path, n) - return repr + rep += self.repr_path_with_depth(path, n) + return rep def compare_path(self, path1, path2): + """ + Comparator between two paths: path1 and path2 based on a combination of rules. + Path 1 is greater than path2 if: + 1) i^th node in path1 is at greater depth than the i^th node in path2 OR + 2) if i^th nodes are at the same depth, i^th node in path 1 is alphabetically less than i^th node + in path 2 + 3) if neither of the above statement is true even if path1 and path2 are switched, proceed to i+1^th node + If path2 is a subpath start at node 1, then path1 is greater than path2 + :param path1: path 1 of 2 to be compared + :param path2: path 2 of 2 to be compared + :return: return 1 if path1 is greater than path2, -1 if path1 is less than path2, and 0 if they are identical + """ node_depth_lookup = dict(self.nodes_by_depth()) for node1, node2 in zip(path1, path2): if node_depth_lookup[node1] != node_depth_lookup[node2]: return -1 if node_depth_lookup[node1] < node_depth_lookup[node2] else 1 if node1 != node2: return -1 if node1 < node2 else 1 + if len(node1) != len(node2): + return -1 if len(node1) < len(node2) else 1 return 0 def repr_path_with_depth(self, path, n=20, m=2): @@ -393,11 +429,9 @@ def __init__(self, conn, *args, **kwargs): self._conn = conn else: raise DataJointError('The connection is broken') #TODO: make better exception message - self.update_graph() def update_graph(self, reload=False): self.clear() - # create primary key foreign connections for table, parents in self._parents.items(): mod, cls = (x.strip('`') for x in table.split('.')) @@ -410,6 +444,7 @@ def update_graph(self, reload=False): for ref in referenced: self.add_edge(ref, table, rel='referenced') + @require_dep_loading def copy_graph(self, *args, **kwargs): """ Return copy of the graph represented by this object at the @@ -418,6 +453,7 @@ def copy_graph(self, *args, **kwargs): """ return RelGraph(self, *args, **kwargs) + @require_dep_loading def subgraph(self, *args, **kwargs): return RelGraph(self).subgraph(*args, **kwargs) @@ -440,14 +476,17 @@ def load_dependencies_for_database(self, database): Load dependencies for all tables found in the specified database :param database: database for which dependencies will be loaded """ - sql_table_name_regexp = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') + #sql_table_name_regexp = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) for info in cur: table_name = info[0] - full_table_name = '`{database}`.`{table_name}`'.format(database=database, table_name=table_name) - self.load_dependencies_for_table(full_table_name) + # TODO: fix this criteria! It will exclude ANY tables ending with 'jobs' + # exclude tables ending with 'jobs' from erd + if not table_name.lower().endswith('jobs'): + full_table_name = '`{database}`.`{table_name}`'.format(database=database, table_name=table_name) + self.load_dependencies_for_table(full_table_name) def load_dependencies_for_table(self, full_table_name): """ @@ -517,6 +556,12 @@ def load_dependencies_for_table(self, full_table_name): self._references[result.referenced_table].append(full_table_name) self.update_graph() + + def __repr__(self): + # Make sure that all dependencies are loaded before printing repr + self.load_dependencies() + return super().__repr__() + def clear_dependencies(self): pass From 457080782f8472d7ca5a11b6bed43c2beeeb0f58 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 16:35:28 -0500 Subject: [PATCH 17/41] Rename Relation to BaseRelation --- datajoint/__init__.py | 4 +-- datajoint/autopopulate.py | 2 +- datajoint/{relation.py => base_relation.py} | 4 +-- datajoint/erd.py | 37 ++++++++++++++++++--- datajoint/jobs.py | 4 +-- datajoint/schema.py | 4 +-- datajoint/user_relations.py | 12 +++---- 7 files changed, 48 insertions(+), 19 deletions(-) rename datajoint/{relation.py => base_relation.py} (99%) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 7bb7e28e6..f9ad15b43 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -15,7 +15,7 @@ __version__ = "0.2" __all__ = ['__author__', '__version__', 'config', 'conn', 'kill', - 'Connection', 'Heading', 'Relation', 'FreeRelation', 'Not', 'schema', + 'Connection', 'Heading', 'BaseRelation', 'FreeRelation', 'Not', 'schema', 'Manual', 'Lookup', 'Imported', 'Computed', 'Part'] @@ -52,7 +52,7 @@ class DataJointError(Exception): # ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .relation import Relation +from .base_relation import BaseRelation from .user_relations import Manual, Lookup, Imported, Computed, Part from .relational_operand import Not from .heading import Heading diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 8b9300710..5b3befcc1 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -5,7 +5,7 @@ import random from .relational_operand import RelationalOperand from . import DataJointError -from .relation import FreeRelation +from .base_relation import FreeRelation # noinspection PyExceptionInherit,PyCallingNonCallable diff --git a/datajoint/relation.py b/datajoint/base_relation.py similarity index 99% rename from datajoint/relation.py rename to datajoint/base_relation.py index f25bb10d5..64b82a1c1 100644 --- a/datajoint/relation.py +++ b/datajoint/base_relation.py @@ -17,7 +17,7 @@ -class Relation(RelationalOperand, metaclass=abc.ABCMeta): +class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): """ Relation is an abstract class that represents a base relation, i.e. a table in the database. To make it a concrete class, override the abstract properties specifying the connection, @@ -359,7 +359,7 @@ def _prepare(self): pass -class FreeRelation(Relation): +class FreeRelation(BaseRelation): """ A base relation without a dedicated class. The table name is explicitly set. """ diff --git a/datajoint/erd.py b/datajoint/erd.py index e976d7d49..e82ae37f1 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -6,7 +6,6 @@ from collections import defaultdict import pyparsing as pp import networkx as nx -import re from networkx import DiGraph from functools import cmp_to_key import operator @@ -23,9 +22,37 @@ from . import DataJointError from functools import wraps from .utils import to_camel_case +from .base_relation import BaseRelation logger = logging.getLogger(__name__) +from inspect import isabstract + + +def get_concrete_descendants(cls): + desc = [] + child= cls.__subclasses__() + for c in child: + if not isabstract(c): + desc.append(c) + desc.extend(get_concrete_descendants(c)) + return desc + + +def parse_base_relations(rels): + name_map = {} + for r in rels: + try: + name_map[r().full_table_name] = r.__name__ + except: + pass + return name_map + + +def get_table_relation_name_map(): + rels = get_concrete_descendants(BaseRelation) + return parse_base_relations(rels) + class RelGraph(DiGraph): """ @@ -35,7 +62,7 @@ class RelGraph(DiGraph): def __init__(self, *args, name_map=None, **kwargs): if name_map is None: - name_map = {} + self.name_map = {} super().__init__(*args, **kwargs) @@ -90,6 +117,9 @@ def plot(self): if not self.nodes(): # There is nothing to plot logger.warning('Nothing to plot') return + if pygraphviz_layout is None: + logger.warning('Failed to load Pygraphviz - plotting not supported at this time') + return pos = pygraphviz_layout(self, prog='dot') fig = plt.figure(figsize=[10, 7]) ax = fig.add_subplot(111) @@ -484,7 +514,7 @@ def load_dependencies_for_database(self, database): table_name = info[0] # TODO: fix this criteria! It will exclude ANY tables ending with 'jobs' # exclude tables ending with 'jobs' from erd - if not table_name.lower().endswith('jobs'): + if not table_name == '~jobs': full_table_name = '`{database}`.`{table_name}`'.format(database=database, table_name=table_name) self.load_dependencies_for_table(full_table_name) @@ -556,7 +586,6 @@ def load_dependencies_for_table(self, full_table_name): self._references[result.referenced_table].append(full_table_name) self.update_graph() - def __repr__(self): # Make sure that all dependencies are loaded before printing repr self.load_dependencies() diff --git a/datajoint/jobs.py b/datajoint/jobs.py index ab3d8f691..eaa631d77 100644 --- a/datajoint/jobs.py +++ b/datajoint/jobs.py @@ -1,7 +1,7 @@ import hashlib import os import pymysql -from .relation import Relation +from .base_relation import BaseRelation def key_hash(key): @@ -14,7 +14,7 @@ def key_hash(key): return hashed.hexdigest() -class JobRelation(Relation): +class JobRelation(BaseRelation): """ A base relation with no definition. Allows reserving jobs """ diff --git a/datajoint/schema.py b/datajoint/schema.py index 28d2e70ab..57cd27528 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -3,7 +3,7 @@ from . import conn, DataJointError from .heading import Heading -from .relation import Relation +from .base_relation import BaseRelation from .user_relations import Part logger = logging.getLogger(__name__) @@ -103,7 +103,7 @@ def process_relation_class(relation_class, context): part._master = cls # TODO: look into local namespace for the subclasses process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) - elif issubclass(part, Relation): + elif issubclass(part, BaseRelation): raise DataJointError('Part relations must be a subclass of datajoint.Part') # invoke Relation._prepare() on class and its part relations. diff --git a/datajoint/user_relations.py b/datajoint/user_relations.py index 8ffcd2e2f..4f5875bac 100644 --- a/datajoint/user_relations.py +++ b/datajoint/user_relations.py @@ -3,13 +3,13 @@ """ import abc -from .relation import Relation +from .base_relation import BaseRelation from .autopopulate import AutoPopulate from .utils import from_camel_case from . import DataJointError -class Part(Relation, metaclass=abc.ABCMeta): +class Part(BaseRelation, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are details of an entry in another relation and if this table is populated by this relation. For example, the entries inheriting from @@ -29,7 +29,7 @@ def table_name(self): return self.master().table_name + '__' + from_camel_case(self.__class__.__name__) -class Manual(Relation, metaclass=abc.ABCMeta): +class Manual(BaseRelation, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are entered manually. """ @@ -42,7 +42,7 @@ def table_name(self): return from_camel_case(self.__class__.__name__) -class Lookup(Relation, metaclass=abc.ABCMeta): +class Lookup(BaseRelation, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are for lookup. This is currently equivalent to defining the table as Manual and serves semantic @@ -64,7 +64,7 @@ def _prepare(self): self.insert(self.contents, skip_duplicates=True) -class Imported(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Imported(BaseRelation, AutoPopulate, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are imported from external data sources. The inherited class must at least provide the function `_make_tuples`. @@ -78,7 +78,7 @@ def table_name(self): return "_" + from_camel_case(self.__class__.__name__) -class Computed(Relation, AutoPopulate, metaclass=abc.ABCMeta): +class Computed(BaseRelation, AutoPopulate, metaclass=abc.ABCMeta): """ Inherit from this class if the table's values are computed from other relations in the schema. The inherited class must at least provide the function `_make_tuples`. From fafa4f839f4c4e5dfb0b074c99222e5409f4bc77 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 17:13:16 -0500 Subject: [PATCH 18/41] Do name lookup based on the BaseRelation class in erm repr --- datajoint/erd.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index e82ae37f1..2fe6661fb 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -60,10 +60,7 @@ class RelGraph(DiGraph): multiple databases. """ - def __init__(self, *args, name_map=None, **kwargs): - if name_map is None: - self.name_map = {} - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property @@ -71,7 +68,20 @@ def node_labels(self): """ :return: dictionary of key : label pairs for plotting """ - return {k: attr['label'] for k, attr in self.node.items()} + name_map = get_table_relation_name_map() + return {k: self.get_label(k, name_map) for k in self.nodes()} + + def get_label(self, node, name_map=None): + label = self.node[node].get('label', '') + if label.strip(): + return label + + # it's not efficient to recreate name-map on every call! + if name_map is not None and node in name_map: + return name_map[node] + + return '.'.join(x.strip('`') for x in node.split('.')) + @property def pk_edges(self): @@ -145,6 +155,7 @@ def __repr__(self): def restrict_by_modules(self, modules, fill=False): """ + DEPRECATED - to be removed Creates a subgraph containing only tables in the specified modules. :param modules: list of module names :param fill: set True to automatically include nodes connecting two nodes in the specified modules @@ -379,7 +390,7 @@ def repr_path(self): sorted_paths = sorted(paths, key=k) # table name will be padded to match the longest table name - n = max([len(x) for x in self.nodes()]) + n = max([len(x) for x in self.node_labels.values()]) rep = '' for path in sorted_paths: rep += self.repr_path_with_depth(path, n) @@ -410,14 +421,14 @@ def compare_path(self, path1, path2): def repr_path_with_depth(self, path, n=20, m=2): node_depth_lookup = dict(self.nodes_by_depth()) - arrow = '-' * (m-1) + '>' + node_labels = self.node_labels space = '-' * n - pattern = '{:%ds}' % n repr = '' prev_depth = 0 first = True for (i, node) in enumerate(path): depth = node_depth_lookup[node] + label = node_labels[node] if first: repr += (' '*(n+m))*(depth-prev_depth) else: @@ -425,9 +436,9 @@ def repr_path_with_depth(self, path, n=20, m=2): first = False prev_depth = depth if i == len(path)-1: - repr += node + repr += label else: - repr += node.ljust(n, '-') + repr += label.ljust(n, '-') repr += '\n' return repr @@ -465,7 +476,7 @@ def update_graph(self, reload=False): # create primary key foreign connections for table, parents in self._parents.items(): mod, cls = (x.strip('`') for x in table.split('.')) - self.add_node(table, label=table, mod=mod, cls=cls) + self.add_node(table) for parent in parents: self.add_edge(parent, table, rel='parent') From d87661761cb6bb86496db6e08839183b4bdc2c44 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 17:57:08 -0500 Subject: [PATCH 19/41] Include immeidate module name in erm repr --- datajoint/erd.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 2fe6661fb..ec2987034 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -43,7 +43,12 @@ def parse_base_relations(rels): name_map = {} for r in rels: try: - name_map[r().full_table_name] = r.__name__ + module = r.__module__ + parts = [] + if module != '__main__': + parts.append(module.split('.')[-1]) + parts.append(r.__name__) + name_map[r().full_table_name] = '.'.join(parts) except: pass return name_map @@ -82,6 +87,12 @@ def get_label(self, node, name_map=None): return '.'.join(x.strip('`') for x in node.split('.')) + @property + def lone_nodes(self): + """ + :return: list of nodes that are not connected to any other node + """ + return list(x for x in self.root_nodes if len(self.out_edges(x)) == 0) @property def pk_edges(self): @@ -390,12 +401,18 @@ def repr_path(self): sorted_paths = sorted(paths, key=k) # table name will be padded to match the longest table name - n = max([len(x) for x in self.node_labels.values()]) + node_labels = self.node_labels + n = max([len(x) for x in node_labels.values()]) + 1 rep = '' for path in sorted_paths: rep += self.repr_path_with_depth(path, n) + + for node in self.lone_nodes: + rep += node_labels[node] + '\n' + return rep + def compare_path(self, path1, path2): """ Comparator between two paths: path1 and path2 based on a combination of rules. @@ -423,24 +440,24 @@ def repr_path_with_depth(self, path, n=20, m=2): node_depth_lookup = dict(self.nodes_by_depth()) node_labels = self.node_labels space = '-' * n - repr = '' + rep = '' prev_depth = 0 first = True for (i, node) in enumerate(path): depth = node_depth_lookup[node] label = node_labels[node] if first: - repr += (' '*(n+m))*(depth-prev_depth) + rep += (' '*(n+m))*(depth-prev_depth) else: - repr += space.join(['-'*m]*(depth-prev_depth))[:-1] + '>' + rep += space.join(['-'*m]*(depth-prev_depth))[:-1] + '>' first = False prev_depth = depth if i == len(path)-1: - repr += label + rep += label else: - repr += label.ljust(n, '-') - repr += '\n' - return repr + rep += label.ljust(n, '-') + rep += '\n' + return rep def require_dep_loading(f): From 1fa99b8b53723c342787622fa870c040e388cc18 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 18:15:40 -0500 Subject: [PATCH 20/41] Fix bug in the erd restrict by tables --- datajoint/erd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index ec2987034..10f7296c9 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -185,7 +185,7 @@ def restrict_by_tables(self, tables, fill=False): of tables :return: a subgraph with specified nodes """ - nodes = [n for n in self.nodes() if self.node[n].get('label') in tables] + nodes = [n for n in self.nodes() if n in tables] if fill: nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) From 2f349f034fa16a33a4be02c7845ab1e6ad280eb1 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 18:16:21 -0500 Subject: [PATCH 21/41] Clean up erd method handling --- datajoint/base_relation.py | 2 +- datajoint/connection.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index 64b82a1c1..236c55a13 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -94,7 +94,7 @@ def erd(self, *args, **kwargs): """ :return: the entity relationship diagram object of this relation """ - erd = self.connection.erd() + erd = self.connection.erd(*args, **kwargs) nodes = erd.up_down_neighbors(self.full_table_name) return erd.restrict_by_tables(nodes) diff --git a/datajoint/connection.py b/datajoint/connection.py index 6d4948c9b..58eb65dd6 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -82,8 +82,6 @@ def __repr__(self): connected=connected, **self.conn_info) def erd(self, *args, **kwargs): - # load all dependencies - self.erm.load_dependencies() return self.erm.copy_graph(*args, **kwargs) @property From 1919aeefb56ca29db1afa2c82c245402c1619cfa Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 18:19:12 -0500 Subject: [PATCH 22/41] Use full table name in back quotes when no class exists in erd --- datajoint/erd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 10f7296c9..8431ab249 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -84,8 +84,8 @@ def get_label(self, node, name_map=None): # it's not efficient to recreate name-map on every call! if name_map is not None and node in name_map: return name_map[node] - - return '.'.join(x.strip('`') for x in node.split('.')) + # no other name exists, so just use full table now + return node @property def lone_nodes(self): From 223ba904c09cc37027986ec6298c6c0a7d2374f6 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 2 Oct 2015 19:05:16 -0500 Subject: [PATCH 23/41] Addition documentation on base_relation --- datajoint/base_relation.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index 236c55a13..9b9f65347 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -19,7 +19,7 @@ class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): """ - Relation is an abstract class that represents a base relation, i.e. a table in the database. + BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. To make it a concrete class, override the abstract properties specifying the connection, table name, database, context, and definition. A Relation implements insert and delete methods in addition to inherited relational operators. @@ -170,7 +170,7 @@ def insert(self, rows, **kwargs): """ Insert a collection of rows. Additional keyword arguments are passed to insert1. - :param iter: Must be an iterator that generates a sequence of valid arguments for insert. + :param rows: An iterable where an element is a valid arguments for insert1. """ for row in rows: self.insert1(row, **kwargs) @@ -180,9 +180,9 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False Insert one data record or one Mapping (like a dict). :param tup: Data record, a Mapping (like a dict), or a list or tuple with ordered values. - :param replace=False: Replaces data tuple if True. + :param replace=False: If True, replaces the matching data tuple in the table if it exists. :param ignore_errors=False: If True, ignore errors: e.g. constraint violations. - :param skip_dublicates=False: If True, ignore duplicate inserts. + :param skip_dublicates=False: If True, silently skip duplicate inserts. Example:: @@ -193,14 +193,18 @@ def insert1(self, tup, replace=False, ignore_errors=False, skip_duplicates=False heading = self.heading def check_fields(fields): + """ + Validates that all items in `fields` are valid attributes in the heading + """ for field in fields: if field not in heading: raise KeyError(u'{0:s} is not in the attribute list'.format(field)) def make_attribute(name, value): """ - For a given attribute, return its value or value placeholder as a string to be included - in the query and the value, if any to be submitted for processing by mysql API. + For a given attribute `name` with `value, return its processed value or value placeholder + as a string to be included in the query and the value, if any to be submitted for + processing by mysql API. """ if heading[name].is_blob: value = pack(value) @@ -257,8 +261,10 @@ def make_attribute(name, value): def delete_quick(self): """ - Deletes the table without cascading and without user prompt. + Deletes the table without cascading and without user prompt. If this table has any dependent + table(s), this will fail. """ + #TODO: give a better exception message self.connection.query('DELETE FROM ' + self.from_clause + self.where_clause) def delete(self): @@ -316,13 +322,17 @@ def delete(self): def drop_quick(self): """ Drops the table associated with this relation without cascading and without user prompt. + If the table has any dependent table(s), this call will fail with an error. """ + #TODO: give a better exception message if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) self.connection.erm.clear_dependencies_for_table(self.full_table_name) if self._heading: self._heading.reset() logger.info("Dropped table %s" % self.full_table_name) + else: + logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) def drop(self): """ @@ -361,7 +371,8 @@ def _prepare(self): class FreeRelation(BaseRelation): """ - A base relation without a dedicated class. The table name is explicitly set. + A base relation without a dedicated class. Each instance is associated with a table + specified by full_table_name. """ def __init__(self, connection, full_table_name, definition=None, context=None): self.database, self._table_name = (s.strip('`') for s in full_table_name.split('.')) From eb147210b36615f55bec3a3ac59862da63003d66 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 2 Oct 2015 21:38:03 -0500 Subject: [PATCH 24/41] Created class AndList, work in progress --- datajoint/relational_operand.py | 172 +++++++++++++++++++------------- tests/test_cascading_delete.py | 2 +- 2 files changed, 106 insertions(+), 68 deletions(-) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index c7bc94ec0..933f552f7 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence import numpy as np import abc import re @@ -11,6 +11,82 @@ logger = logging.getLogger(__name__) +class AndList(Sequence): + """ + A list of restrictions to by applied to a relation. The restrictions are ANDed. + Each restriction can be a list or set or a relation whose elements are ORed. + But the elements that are lists can contain + """ + def __init__(self, heading): + self.heading = heading + self._list = [] + + def __len__(self): + return len(self._list) + + def __getitem__(self, i): + return self._list[i] + + def add(self, *args): + # remove Nones and duplicates + args = [r for r in args if r is not None and r not in self] + if args: + if any(is_empty_set(r) for r in args): + # if any condition is an empty list, return FALSE + self._list = ['FALSE'] + else: + self._list.extend(args) + + def where_clause(self): + """ + convert to a WHERE clause string + """ + def make_condition(arg, _negate=False): + if isinstance(arg, (str, AndList)): + return str(arg), _negate + + # semijoin or antijoin + if isinstance(arg, RelationalOperand): + common_attributes = [q for q in self.heading.names if q in arg.heading.names] + if not common_attributes: + condition = 'FALSE' if negate else 'TRUE' + else: + common_attributes = '`'+'`,`'.join(common_attributes)+'`' + condition = ['({fields}) {not_}in ({subquery})'.format( + fields=common_attributes, + not_="not " if negate else "", + subquery=arg.make_select(common_attributes))] + return condition, False # negate is cleared + + # mappings are turned into ANDed equality conditions + if isinstance(arg, Mapping): + condition = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] + elif isinstance(arg, np.void): + # element of a record array + condition = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields if k in self.heading] + else: + raise DataJointError('invalid restriction type') + return ' AND '.join(condition) if condition else 'TRUE', _negate + + if not self: + return '' + + conditions = [] + for item in self: + negate = isinstance(item, Not) + if negate: + item = item.restriction + if isinstance(item, (list, tuple, set, np.ndarray)): + # sets of conditions are ORed + item = '(' + ') OR ('.join([make_condition(q)[0] for q in item]) + ')' + else: + item, negate = make_condition(item, negate) + if not item: + raise DataJointError('Empty condition') + conditions.append(('NOT %s' if negate else '%s') % item) + return ' WHERE ' + ' AND '.join(conditions) + + class RelationalOperand(metaclass=abc.ABCMeta): """ RelationalOperand implements relational algebra and fetch methods. @@ -25,12 +101,18 @@ class RelationalOperand(metaclass=abc.ABCMeta): @property def restrictions(self): - return [] if self._restrictions is None else self._restrictions + if self._restrictions is None: + self._restrictions = AndList(self.heading) + return self._restrictions @property def primary_key(self): return self.heading.primary_key + @property + def where_clause(self): + return self.restrictions.where_clause() + # --------- abstract properties ----------- @property @@ -107,13 +189,20 @@ def aggregate(self, group, *attributes, **renamed_attributes): Join(self, group, left=True), *attributes, **renamed_attributes) + def __iand__(self, restriction): + """ + in-place restriction by a single condition + """ + self.restrict(restriction) + def __and__(self, restriction): """ relational restriction or semijoin :return: a restricted copy of the argument """ ret = copy(self) - ret.restrict(restriction, *ret.restrictions) + ret._restrictions = None + ret.restrict(restriction, *list(self.restrictions)) return ret def restrict(self, *restrictions): @@ -124,25 +213,15 @@ def restrict(self, *restrictions): However, each member of restrictions can be a list of conditions, which are combined with OR. :param restrictions: list of restrictions. """ - # remove Nones and duplicates - restrictions = [r for r in restrictions if r is not None and r not in self.restrictions] - if restrictions: - if any(is_empty_set(r) for r in restrictions): - # if any condition is an empty list, return empty - self._restrictions = ['FALSE'] - else: - if self._restrictions is None: - self._restrictions = restrictions - else: - self._restrictions.extend(restrictions) + self.restrictions.add(*restrictions) def attributes_in_restrictions(self): """ :return: list of attributes that are probably used in the restrictions. This is used internally for optimizing SQL statements """ - where_clause = self.where_clause - return set(name for name in self.heading.names if name in where_clause) + s = self.restrictions.where_clause() # avoid calling multiple times + return set(name for name in self.heading.names if name in s) def __sub__(self, restriction): """ @@ -166,9 +245,8 @@ def make_select(self, select_fields=None): return 'SELECT {fields} FROM {from_}{where}{group}'.format( fields=select_fields if select_fields else self.select_fields, from_=self.from_clause, - where=self.where_clause, - group=' GROUP BY `%s`' % '`,`'.join(self.primary_key) if self._grouped else '' - ) + where=self.restrictions.where_clause(), + group=' GROUP BY `%s`' % '`,`'.join(self.primary_key) if self._grouped else '') def __len__(self): """ @@ -212,54 +290,10 @@ def fetch1(self): def fetch(self): return Fetch(self) - @property - def where_clause(self): - """ - convert the restriction into an SQL WHERE - """ - if not self.restrictions: - return '' - - def make_condition(arg, _negate=False): - if isinstance(arg, str): - condition = [arg] - elif isinstance(arg, RelationalOperand): - common_attributes = [q for q in self.heading.names if q in arg.heading.names] - if not common_attributes: - condition = ['FALSE' if negate else 'TRUE'] - else: - common_attributes = '`'+'`,`'.join(common_attributes)+'`' - condition = ['({fields}) {not_}in ({subquery})'.format( - fields=common_attributes, - not_="not " if negate else "", - subquery=arg.make_select(common_attributes))] - _negate = False - elif isinstance(arg, Mapping): - condition = ['`%s`=%s' % (k, repr(v)) for k, v in arg.items() if k in self.heading] - elif isinstance(arg, np.void): - # element of a record array - condition = ['`%s`=%s' % (k, arg[k]) for k in arg.dtype.fields if k in self.heading] - else: - raise DataJointError('invalid restriction type') - return ' AND '.join(condition) if condition else 'TRUE', _negate - - conditions = [] - for r in self.restrictions: - negate = isinstance(r, Not) - if negate: - r = r.restriction - if isinstance(r, (list, tuple, set, np.ndarray)): - r = '(' + ') OR ('.join([make_condition(q)[0] for q in r]) + ')' - else: - r, negate = make_condition(r, negate) - if r: - conditions.append('%s(%s)' % ('not ' if negate else '', r)) - return ' WHERE ' + ' AND '.join(conditions) - class Not: """ - inverse restriction + invert restriction """ def __init__(self, restriction): self.restriction = restriction @@ -351,13 +385,16 @@ def _grouped(self): def from_clause(self): return self._arg.from_clause - def __and__(self, restriction): + def restrict(self, *restrictions): """ When restricting on renamed attributes, enclose in subquery """ - has_restriction = isinstance(restriction, RelationalOperand) or restriction + has_restriction = any(isinstance(r, RelationalOperand) or r for r in restrictions) do_subquery = has_restriction and self.heading.computed - return Subquery(self) & restriction if do_subquery else super().__and__(restriction) + if do_subquery: + self._arg = Subquery(self) + self._restrictions = None + super().restrict(*restrictions) class Aggregation(Projection): @@ -370,6 +407,7 @@ class Subquery(RelationalOperand): """ A Subquery encapsulates its argument in a SELECT statement, enabling its use as a subquery. The attribute list and the WHERE clause are resolved. + As such, a subquery does not have any renamed attributes. """ __counter = 0 diff --git a/tests/test_cascading_delete.py b/tests/test_cascading_delete.py index 8182ad6c1..fa499ba1e 100644 --- a/tests/test_cascading_delete.py +++ b/tests/test_cascading_delete.py @@ -63,7 +63,7 @@ def test_delete_lookup(): 'schema is not populated') L().delete() assert_false(bool(L() or D() or E() or E.F()), 'incomplete delete') - A().delete() # delete all is necessary because delete L deletes from subtables. TODO: submit this as an issue + A().delete() # delete all is necessary because delete L deletes from subtables. # @staticmethod # def test_delete_lookup_restricted(): From 3ca58be9f369265246bd8caf6dc1e4b1929555ef Mon Sep 17 00:00:00 2001 From: dimitri-yatsenko Date: Sat, 3 Oct 2015 01:57:13 -0500 Subject: [PATCH 25/41] tentatively fixed issued #164. --- datajoint/base_relation.py | 3 --- datajoint/fetch.py | 8 ++----- datajoint/relational_operand.py | 39 ++++++++++++++++++++++----------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index 9b9f65347..10c98db7f 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) - class BaseRelation(RelationalOperand, metaclass=abc.ABCMeta): """ BaseRelation is an abstract class that represents a base relation, i.e. a table in the database. @@ -328,8 +327,6 @@ def drop_quick(self): if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) self.connection.erm.clear_dependencies_for_table(self.full_table_name) - if self._heading: - self._heading.reset() logger.info("Dropped table %s" % self.full_table_name) else: logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 395885528..61a911ba5 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -188,10 +188,7 @@ def keys(self, **kwargs): """ Iterator that returns primary keys. """ - b = dict(self.behavior, **kwargs) - if 'as_dict' not in kwargs: - b['as_dict'] = True - yield from self._relation.project().fetch.set_behavior(**b) + yield from self._relation.project().fetch.set_behavior(**dict(self.behavior, as_dict=True, **kwargs)) def __getitem__(self, item): """ @@ -216,8 +213,7 @@ def __getitem__(self, item): result, 0, result.strides) if attribute is PRIMARY_KEY else result[attribute] - for attribute in item - ] + for attribute in item] return return_values[0] if single_output else return_values def __repr__(self): diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 933f552f7..183730856 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -52,10 +52,10 @@ def make_condition(arg, _negate=False): condition = 'FALSE' if negate else 'TRUE' else: common_attributes = '`'+'`,`'.join(common_attributes)+'`' - condition = ['({fields}) {not_}in ({subquery})'.format( + condition = '({fields}) {not_}in ({subquery})'.format( fields=common_attributes, not_="not " if negate else "", - subquery=arg.make_select(common_attributes))] + subquery=arg.make_select(common_attributes)) return condition, False # negate is cleared # mappings are turned into ANDed equality conditions @@ -83,7 +83,7 @@ def make_condition(arg, _negate=False): item, negate = make_condition(item, negate) if not item: raise DataJointError('Empty condition') - conditions.append(('NOT %s' if negate else '%s') % item) + conditions.append(('NOT (%s)' if negate else '(%s)') % item) return ' WHERE ' + ' AND '.join(conditions) @@ -105,6 +105,9 @@ def restrictions(self): self._restrictions = AndList(self.heading) return self._restrictions + def clear_restrictions(self): + self._restrictions = None + @property def primary_key(self): return self.heading.primary_key @@ -201,7 +204,7 @@ def __and__(self, restriction): :return: a restricted copy of the argument """ ret = copy(self) - ret._restrictions = None + ret.clear_restrictions() ret.restrict(restriction, *list(self.restrictions)) return ret @@ -311,10 +314,10 @@ def __init__(self, arg1, arg2, left=False): raise DataJointError('Cannot join relations with different database connections') self._arg1 = Subquery(arg1) if arg1.heading.computed else arg1 self._arg2 = Subquery(arg2) if arg2.heading.computed else arg2 - self.restrict(*self._arg1.restrictions) - self.restrict(*self._arg2.restrictions) - self._left = left self._heading = self._arg1.heading.join(self._arg2.heading, left=left) + self.restrict(*list(self._arg1.restrictions)) + self.restrict(*list(self._arg2.restrictions)) + self._left = left def _repr_helper(self): return "(%r) * (%r)" % (self._arg1, self._arg2) @@ -364,7 +367,7 @@ def __init__(self, arg, *attributes, **renamed_attributes): if use_subquery: self._arg = Subquery(arg) else: - self.restrict(*arg.restrictions) + self.restrict(*list(arg.restrictions)) def _repr_helper(self): return "(%r).project(%r)" % (self._arg, self._attributes) @@ -375,7 +378,9 @@ def connection(self): @property def heading(self): - return self._arg.heading.project(*self._attributes, **self._renamed_attributes) + heading = self._arg.heading + heading = heading.project(*self._attributes, **self._renamed_attributes) + return heading @property def _grouped(self): @@ -385,15 +390,21 @@ def _grouped(self): def from_clause(self): return self._arg.from_clause + def __and__(self, restriction): + has_restriction = isinstance(restriction, RelationalOperand) or bool(restriction) + do_subquery = has_restriction and self.heading.computed + ret = Subquery(self) if do_subquery else self + ret.restrict(restriction) + return ret + def restrict(self, *restrictions): """ - When restricting on renamed attributes, enclose in subquery + Override restrict: when restricting on renamed attributes, enclose in subquery """ has_restriction = any(isinstance(r, RelationalOperand) or r for r in restrictions) do_subquery = has_restriction and self.heading.computed if do_subquery: - self._arg = Subquery(self) - self._restrictions = None + raise DataJointError('In-place restriction on renamed attributes is not allowed') super().restrict(*restrictions) @@ -433,7 +444,9 @@ def select_fields(self): @property def heading(self): - return self._arg.heading.resolve() + h = self._arg.heading + h = h.resolve() + return h def _repr_helper(self): return "%r" % self._arg From ca2ea15adae44858d03c2965911cd856462ae120 Mon Sep 17 00:00:00 2001 From: dimitri-yatsenko Date: Sat, 3 Oct 2015 02:14:45 -0500 Subject: [PATCH 26/41] minor cleanup --- datajoint/erd.py | 7 +++---- datajoint/relational_operand.py | 8 ++------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 8431ab249..35b64b7de 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -534,15 +534,14 @@ def load_dependencies_for_database(self, database): Load dependencies for all tables found in the specified database :param database: database for which dependencies will be loaded """ - #sql_table_name_regexp = re.compile('^(#|_|__|~)?[a-z][a-z0-9_]*$') + #sql_table_name_regexp = re.compile('^(#|_|__)?[a-z][a-z0-9_]*$') cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) for info in cur: table_name = info[0] - # TODO: fix this criteria! It will exclude ANY tables ending with 'jobs' - # exclude tables ending with 'jobs' from erd - if not table_name == '~jobs': + # exclude service tables from ERD + if not table_name.startswith('~'): full_table_name = '`{database}`.`{table_name}`'.format(database=database, table_name=table_name) self.load_dependencies_for_table(full_table_name) diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 183730856..3ee4e3bf4 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -378,9 +378,7 @@ def connection(self): @property def heading(self): - heading = self._arg.heading - heading = heading.project(*self._attributes, **self._renamed_attributes) - return heading + return self._arg.heading.project(*self._attributes, **self._renamed_attributes) @property def _grouped(self): @@ -444,9 +442,7 @@ def select_fields(self): @property def heading(self): - h = self._arg.heading - h = h.resolve() - return h + return self._arg.heading.resolve() def _repr_helper(self): return "%r" % self._arg From ec02b7525136ba7013c72c532afccadbbc80cd94 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 10:08:51 -0500 Subject: [PATCH 27/41] Improve parameterization of existing node selection functions --- datajoint/erd.py | 45 ++++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 8431ab249..42f552661 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -164,15 +164,14 @@ def plot(self): def __repr__(self): return self.repr_path() - def restrict_by_modules(self, modules, fill=False): + def restrict_by_database(self, databases, fill=False): """ - DEPRECATED - to be removed - Creates a subgraph containing only tables in the specified modules. - :param modules: list of module names - :param fill: set True to automatically include nodes connecting two nodes in the specified modules + Creates a subgraph containing only tables in the specified database. + :param databases: list of database names + :param fill: if True, automatically include nodes connecting two nodes in the specified modules :return: a subgraph with specified nodes """ - nodes = [n for n in self.nodes() if self.node[n].get('mod') in modules] + nodes = [n for n in self.nodes() if n.split('.')[0].strip('`') in databases] if fill: nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) @@ -194,7 +193,7 @@ def restrict_by_tables_in_module(self, module, tables, fill=False): nodes = [n for n in self.nodes() if self.node[n].get('mod') in module and self.node[n].get('cls') in tables] if fill: - nodes = self.fill_connection_nodes(nodes) + nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) def fill_connection_nodes(self, nodes): @@ -206,57 +205,69 @@ def fill_connection_nodes(self, nodes): graph = self.subgraph(self.ancestors_of_all(nodes)) return graph.descendants_of_all(nodes) - def ancestors_of_all(self, nodes): + def ancestors_of_all(self, nodes, n=-1): """ Find and return a set of all ancestors of the given nodes. The set will also contain the specified nodes. :param nodes: list of nodes for which ancestors are to be found + :param n: maximum number of generations to go up for each node. + If set to a negative number, will return all ancestors. :return: a set containing passed in nodes and all of their ancestors """ s = set() - for n in nodes: - s.update(self.ancestors(n)) + for node in nodes: + s.update(self.ancestors(node, n)) return s - def descendants_of_all(self, nodes): + def descendants_of_all(self, nodes, n=-1): """ Find and return a set including all descendants of the given nodes. The set will also contain the given nodes as well. :param nodes: list of nodes for which descendants are to be found + :param n: maximum number of generations to go down for each node. + If set to a negative number, will return all descendants. :return: a set containing passed in nodes and all of their descendants """ s = set() - for n in nodes: - s.update(self.descendants(n)) + for node in nodes: + s.update(self.descendants(node, n)) return s def copy_graph(self, *args, **kwargs): return self.__class__(self, *args, **kwargs) - def ancestors(self, node): + def ancestors(self, node, n=-1): """ Find and return a set containing all ancestors of the specified node. For convenience in plotting, this set will also include the specified node as well (may change in future). :param node: node for which all ancestors are to be discovered + :param n: maximum number of generations to go up. If set to a negative number, + will return all ancestors. :return: a set containing the node and all of its ancestors """ s = {node} + if n == 0: + return s for p in self.predecessors_iter(node): - s.update(self.ancestors(p)) + s.update(self.ancestors(p, n-1)) return s - def descendants(self, node): + def descendants(self, node, n=-1): """ Find and return a set containing all descendants of the specified node. For convenience in plotting, this set will also include the specified node as well (may change in future). :param node: node for which all descendants are to be discovered + :param n: maximum number of generations to go down. If set to a negative number, + will return all descendants :return: a set containing the node and all of its descendants """ s = {node} + if n == 0: + return s for c in self.successors_iter(node): - s.update(self.descendants(c)) + s.update(self.descendants(c, n-1)) return s def up_down_neighbors(self, node, ups=2, downs=2, _prev=None): From f6e1ae5685fe079f6a78c0dda0dd9fd429d214ed Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 10:09:35 -0500 Subject: [PATCH 28/41] Support different modes for BaseRelation erd --- datajoint/base_relation.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index 9b9f65347..594eae3e9 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -90,13 +90,25 @@ def select_fields(self): """ return '*' - def erd(self, *args, **kwargs): + def erd(self, *args, fill=True, mode='updown', **kwargs): """ + :param mode: diffent methods of creating a graph pertaining to this relation. + Currently options includes the following: + * 'updown': Contains this relation and all other nodes that can be reached within specific + number of ups and downs in the graph. ups(=2) and downs(=2) are optional keyword arguments + * 'ancestors': Returs :return: the entity relationship diagram object of this relation """ - erd = self.connection.erd(*args, **kwargs) - nodes = erd.up_down_neighbors(self.full_table_name) - return erd.restrict_by_tables(nodes) + erd = self.connection.erd() + if mode == 'updown': + nodes = erd.up_down_neighbors(self.full_table_name, *args, **kwargs) + elif mode == 'ancestors': + nodes = erd.ancestors(self.full_table_name, *args, **kwargs) + elif mode == 'descendants': + nodes = erd.descendants(self.full_table_name, *args, **kwargs) + else: + raise DataJointError('Unsupported erd mode', 'Mode "%s" is currently not supported' % mode) + return erd.restrict_by_tables(nodes, fill=fill) # ------------- dependencies ---------- # @property From f2e181b957d2470083640d1418a2ec678a6a083b Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 11:17:32 -0500 Subject: [PATCH 29/41] Change restrict_by_database to restrict_by_databases --- datajoint/erd.py | 3 ++- datajoint/schema.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 42f552661..45cdba586 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -141,6 +141,7 @@ def plot(self): if pygraphviz_layout is None: logger.warning('Failed to load Pygraphviz - plotting not supported at this time') return + pos = pygraphviz_layout(self, prog='dot') fig = plt.figure(figsize=[10, 7]) ax = fig.add_subplot(111) @@ -164,7 +165,7 @@ def plot(self): def __repr__(self): return self.repr_path() - def restrict_by_database(self, databases, fill=False): + def restrict_by_databases(self, databases, fill=False): """ Creates a subgraph containing only tables in the specified database. :param databases: list of database names diff --git a/datajoint/schema.py b/datajoint/schema.py index 57cd27528..033d73346 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -82,6 +82,7 @@ def process_relation_class(relation_class, context): relation_class._connection = self.connection relation_class._heading = Heading() relation_class._context = context + # instantiate the class and declare the table in database if not already present relation_class().declare() if issubclass(cls, Part): @@ -89,7 +90,7 @@ def process_relation_class(relation_class, context): process_relation_class(cls, context=self.context) - # Process subordinate relations + # Process subordinate relations parts = list() for name in (name for name in dir(cls) if not name.startswith('_')): part = getattr(cls, name) From 1d6e63b09605c2bdec10d9d0453856ec99047aa1 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 11:47:07 -0500 Subject: [PATCH 30/41] Simplify Part discovery and processing --- datajoint/schema.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/datajoint/schema.py b/datajoint/schema.py index 033d73346..e926c998e 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -5,6 +5,7 @@ from .heading import Heading from .base_relation import BaseRelation from .user_relations import Part +import inspect logger = logging.getLogger(__name__) @@ -92,20 +93,13 @@ def process_relation_class(relation_class, context): # Process subordinate relations parts = list() - for name in (name for name in dir(cls) if not name.startswith('_')): - part = getattr(cls, name) - try: - is_sub = issubclass(part, Part) - except TypeError: - pass - else: - if is_sub: - parts.append(part) - part._master = cls - # TODO: look into local namespace for the subclasses - process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) - elif issubclass(part, BaseRelation): - raise DataJointError('Part relations must be a subclass of datajoint.Part') + is_part = lambda x: inspect.isclass(x) and issubclass(x, Part) + + for var, part in inspect.getmembers(cls, is_part): + parts.append(part) + part._master = cls + # TODO: look into local namespace for the subclasses + process_relation_class(part, context=dict(self.context, **{cls.__name__: cls})) # invoke Relation._prepare() on class and its part relations. cls()._prepare() From d6bff28ec23b9c2cd6f558cf30659a637f32d6a4 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 12:07:22 -0500 Subject: [PATCH 31/41] Add detailed error messages in declare --- datajoint/declare.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/datajoint/declare.py b/datajoint/declare.py index 4f2cd2463..e853cae40 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -38,8 +38,21 @@ def declare(full_table_name, definition, context): elif line.startswith('---') or line.startswith('___'): in_key = False # start parsing dependent attributes elif line.startswith('->'): - # foreign key - ref = eval(line[2:], context)() # TODO: surround this with try...except... to give a better error message + # foreign + # TODO: clean up import order + from .base_relation import BaseRelation + # TODO: break this step into finer steps, checking the type of reference before calling it + try: + ref = eval(line[2:], context)() + except NameError: + raise DataJointError('Foreign key reference %s could not be resolved' % line[2:]) + except TypeError: + raise DataJointError('Foreign key reference %s could not be instantiated.' + 'Make sure %s is a valid BaseRelation subclass' % line[2:]) + # TODO: consider the case where line[2:] is a function that returns an instance of BaseRelation + if not isinstance(ref, BaseRelation): + raise DataJointError('Foreign key reference %s must be a subclass of BaseRelation' % line[2:]) + foreign_key_sql.append( 'FOREIGN KEY ({primary_key})' ' REFERENCES {ref} ({primary_key})' From 5c4d6da489d4b401164095b50c6494da69eee377 Mon Sep 17 00:00:00 2001 From: dimitri-yatsenko Date: Sat, 3 Oct 2015 12:52:28 -0500 Subject: [PATCH 32/41] created the Dependencies class. The ERD class is not connected yet. --- datajoint/autopopulate.py | 8 ++- datajoint/base_relation.py | 31 ++++----- datajoint/connection.py | 9 +-- datajoint/dependencies.py | 127 +++++++++++++++++++++++++++++++++++++ datajoint/schema.py | 10 +-- 5 files changed, 154 insertions(+), 31 deletions(-) create mode 100644 datajoint/dependencies.py diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 5b3befcc1..753e42e5b 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -63,9 +63,6 @@ def populate(self, restriction=None, suppress_errors=False, :param reserve_jobs: if true, reserves job to populate in asynchronous fashion :param order: "original"|"reverse"|"random" - the order of execution """ - if not isinstance(self.populated_from, RelationalOperand): - raise DataJointError('Invalid populated_from value') - if self.connection.in_transaction: raise DataJointError('Populate cannot be called during a transaction.') @@ -73,6 +70,11 @@ def populate(self, restriction=None, suppress_errors=False, if order not in valid_order: raise DataJointError('The order argument must be one of %s' % str(valid_order)) + self.connection.dependencies.load() + + if not isinstance(self.populated_from, RelationalOperand): + raise DataJointError('Invalid populated_from value') + error_list = [] if suppress_errors else None jobs = self.connection.jobs[self.target.database] diff --git a/datajoint/base_relation.py b/datajoint/base_relation.py index 10c98db7f..c77a13127 100644 --- a/datajoint/base_relation.py +++ b/datajoint/base_relation.py @@ -72,8 +72,6 @@ def declare(self): if not self.is_declared: self.connection.query( declare(self.full_table_name, self.definition, self._context)) - #TODO: reconsider loading time - self.connection.erm.load_dependencies() @property def from_clause(self): @@ -89,13 +87,13 @@ def select_fields(self): """ return '*' - def erd(self, *args, **kwargs): - """ - :return: the entity relationship diagram object of this relation - """ - erd = self.connection.erd(*args, **kwargs) - nodes = erd.up_down_neighbors(self.full_table_name) - return erd.restrict_by_tables(nodes) + # def erd(self, *args, **kwargs): + # """ + # :return: the entity relationship diagram object of this relation + # """ + # erd = self.connection.erd(*args, **kwargs) + # nodes = erd.up_down_neighbors(self.full_table_name) + # return erd.restrict_by_tables(nodes) # ------------- dependencies ---------- # @property @@ -103,30 +101,29 @@ def parents(self): """ :return: the parent relation of this relation """ - return self.connection.erm.parents[self.full_table_name] + return self.connection.dependencies.parents[self.full_table_name] @property def children(self): """ :return: the child relations of this relation """ - return self.connection.erm.children[self.full_table_name] + return self.connection.dependencies.children[self.full_table_name] @property def references(self): """ :return: list of tables that this tables refers to """ - return self.connection.erm.references[self.full_table_name] + return self.connection.dependencies.references[self.full_table_name] @property def referenced(self): """ :return: list of tables for which this table is referenced by """ - return self.connection.erm.referenced[self.full_table_name] + return self.connection.dependencies.referenced[self.full_table_name] - #TODO: implement this inside the relation object in connection @property def descendants(self): """ @@ -137,10 +134,9 @@ def descendants(self): :return: list of descendants """ relations = (FreeRelation(self.connection, table) - for table in self.connection.erm.get_descendants(self.full_table_name)) + for table in self.connection.dependencies.get_descendants(self.full_table_name)) return [relation for relation in relations if relation.is_declared] - def _repr_helper(self): """ :return: String representation of this object @@ -271,6 +267,7 @@ def delete(self): Deletes the contents of the table and its dependent tables, recursively. User is prompted for confirmation if config['safemode'] is set to True. """ + self.connection.dependencies.load() # construct a list (OrderedDict) of relations to delete relations = OrderedDict((r.full_table_name, r) for r in self.descendants) @@ -326,7 +323,6 @@ def drop_quick(self): #TODO: give a better exception message if self.is_declared: self.connection.query('DROP TABLE %s' % self.full_table_name) - self.connection.erm.clear_dependencies_for_table(self.full_table_name) logger.info("Dropped table %s" % self.full_table_name) else: logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) @@ -336,6 +332,7 @@ def drop(self): Drop the table and all tables that reference it, recursively. User is prompted for confirmation if config['safemode'] is set to True. """ + self.connection.dependencies.load() do_drop = True relations = self.descendants if config['safemode']: diff --git a/datajoint/connection.py b/datajoint/connection.py index 58eb65dd6..e518aacbc 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -8,7 +8,7 @@ import logging from . import config from . import DataJointError -from .erd import ERM +from .dependencies import Dependencies from .jobs import JobManager logger = logging.getLogger(__name__) @@ -67,7 +67,8 @@ def __init__(self, host, user, passwd, init_fun=None): self._conn.autocommit(True) self._in_transaction = False self.jobs = JobManager(self) - self.erm = ERM(self) + self.schemas = dict() + self.dependencies = Dependencies(self) def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) @@ -81,8 +82,8 @@ def __repr__(self): return "DataJoint connection ({connected}) {user}@{host}:{port}".format( connected=connected, **self.conn_info) - def erd(self, *args, **kwargs): - return self.erm.copy_graph(*args, **kwargs) + def register(self, schema): + self.schemas[schema.database] = schema @property def is_connected(self): diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py new file mode 100644 index 000000000..6875cbd4f --- /dev/null +++ b/datajoint/dependencies.py @@ -0,0 +1,127 @@ +from collections import defaultdict +import pyparsing as pp +from . import DataJointError + + +class Dependencies: + """ + Lookup for dependencies between tables + """ + __primary_key_parser = (pp.CaselessLiteral('PRIMARY KEY') + + pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key')) + + def __init__(self, conn): + self._conn = conn + self._parents = dict() + self._referenced = dict() + self._children = defaultdict(list) + self._references = defaultdict(list) + + @property + def parents(self): + return self._parents + + @property + def children(self): + return self._children + + @property + def references(self): + return self._references + + @property + def referenced(self): + return self._referenced + + def get_descendants(self, full_table_name): + """ + :param full_table_name: a table name in the format `database`.`table_name` + :return: list of all children and references, in order of dependence. + This is helpful for cascading delete or drop operations. + """ + ret = defaultdict(lambda: 0) + + def recurse(name, level): + ret[name] = max(ret[name], level) + for child in self.children[name] + self.references[name]: + recurse(child, level+1) + + recurse(full_table_name, 0) + return sorted(ret.keys(), key=ret.__getitem__) + + @staticmethod + def __foreign_key_parser(database): + + def add_database(string, loc, toc): + return ['`{database}`.`{table}`'.format(database=database, table=toc[0])] + + return (pp.CaselessLiteral('CONSTRAINT').suppress() + + pp.QuotedString('`').suppress() + + pp.CaselessLiteral('FOREIGN KEY').suppress() + + pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') + + pp.CaselessLiteral('REFERENCES') + + pp.Or([ + pp.QuotedString('`').setParseAction(add_database), + pp.Combine(pp.QuotedString('`', unquoteResults=False) + '.' + + pp.QuotedString('`', unquoteResults=False))]).setResultsName('referenced_table') + + pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes')) + + def load(self): + """ + Load dependencies for all tables that have not yet been loaded in all registered schemas + """ + for database in self._conn.schemas: + cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) + fk_parser = self.__foreign_key_parser(database) + # list tables that have not yet been loaded, exclude service tables that starts with ~ + tables = ('`{database}`.`{table_name}`'.format(database=database, table_name=info[0]) + for info in cur if not info[0].startswith('~')) + tables = (name for name in tables if name not in self._parents) + for name in tables: + self._parents[name] = list() + self._referenced[name] = list() + # fetch the CREATE TABLE statement + create_statement = self._conn.query('SHOW CREATE TABLE %s' % name).fetchone() + create_statement = create_statement[1].split('\n') + primary_key = None + for line in create_statement: + if primary_key is None: + try: + result = self.__primary_key_parser.parseString(line) + except pp.ParseException: + pass + else: + primary_key = [s.strip(' `') for s in result.primary_key.split(',')] + try: + result = fk_parser.parseString(line) + except pp.ParseException: + pass + else: + if not primary_key: + raise DataJointError('No primary key found %s' % name) + if result.referenced_attributes != result.attributes: + raise DataJointError( + "%s's foreign key refers to differently named attributes in %s" + % (self.__class__.__name__, result.referenced_table)) + if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): + self._parents[name].append(result.referenced_table) + self._children[result.referenced_table].append(name) + else: + self._referenced[name].append(result.referenced_table) + self._references[result.referenced_table].append(name) + + def get_descendants(self, full_table_name): + """ + :param full_table_name: a table name in the format `database`.`table_name` + :return: list of all children and references, in order of dependence. + This is helpful for cascading delete or drop operations. + """ + ret = defaultdict(lambda: 0) + + def recurse(name, level): + ret[name] = max(ret[name], level) + for child in self.children[name] + self.references[name]: + recurse(child, level+1) + + recurse(full_table_name, 0) + return sorted(ret.keys(), key=ret.__getitem__) diff --git a/datajoint/schema.py b/datajoint/schema.py index 57cd27528..c4b295c2d 100644 --- a/datajoint/schema.py +++ b/datajoint/schema.py @@ -28,10 +28,8 @@ def __init__(self, database, context, connection=None): self.database = database self.connection = connection self.context = context - - # if the database does not exist, create it - if not self.exists: + # create schema logger.info("Database `{database}` could not be found. " "Attempting to create the database.".format(database=database)) try: @@ -41,9 +39,7 @@ def __init__(self, database, context, connection=None): raise DataJointError("Database named `{database}` was not defined, and" " an attempt to create has failed. Check" " permissions.".format(database=database)) - - # TODO: replace with a call on connection object - connection.erm.register_database(database) + connection.register(self) def drop(self): """ @@ -95,7 +91,7 @@ def process_relation_class(relation_class, context): part = getattr(cls, name) try: is_sub = issubclass(part, Part) - except TypeError: + except TypeError: # issubclass works for classes only pass else: if is_sub: From 48c98503bf38a76d2c198110a12191755987541c Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 14:13:10 -0500 Subject: [PATCH 33/41] Update comments --- datajoint/erd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datajoint/erd.py b/datajoint/erd.py index 45cdba586..159ecb08d 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -632,9 +632,11 @@ def __repr__(self): return super().__repr__() def clear_dependencies(self): + # TODO: complete the implementation pass def clear_dependencies_for_database(self, database): + # TODO: complete the implementation pass def clear_dependencies_for_table(self, full_table_name): From c21e5035878b11ab81f857c700366f05818f6c80 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 14:44:46 -0500 Subject: [PATCH 34/41] Remove redundant part of ERD and Dependencies, simplifying ERD --- datajoint/dependencies.py | 1 - datajoint/erd.py | 211 ++------------------------------------ 2 files changed, 10 insertions(+), 202 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 6875cbd4f..105f970ed 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -51,7 +51,6 @@ def recurse(name, level): @staticmethod def __foreign_key_parser(database): - def add_database(string, loc, toc): return ['`{database}`.`{table}`'.format(database=database, table=toc[0])] diff --git a/datajoint/erd.py b/datajoint/erd.py index 49197ec98..458ee994c 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -472,217 +472,26 @@ def repr_path_with_depth(self, path, n=20, m=2): return rep -def require_dep_loading(f): - @wraps(f) - def wrapper(self, *args, **kwargs): - self.load_dependencies() - return f(self, *args, **kwargs) - return wrapper - -class ERM(RelGraph): +class ERD(RelGraph): """ - Entity Relation Map + Entity Relation Diagram Represents known relation between tables """ # _checked_dependencies = set() - def __init__(self, conn, *args, **kwargs): + def __init__(self, parents_dict, referenced_dict, *args, **kwargs): super().__init__(*args, **kwargs) - self._databases = set() - self._conn = conn - self._parents = dict() - self._referenced = dict() - self._children = defaultdict(list) - self._references = defaultdict(list) - if conn.is_connected: - self._conn = conn - else: - raise DataJointError('The connection is broken') #TODO: make better exception message - - def update_graph(self, reload=False): + self.clear() # create primary key foreign connections - for table, parents in self._parents.items(): - mod, cls = (x.strip('`') for x in table.split('.')) - self.add_node(table) + for full_table, parents in parents_dict.items(): + database, table = (x.strip('`') for x in full_table.split('.')) + self.add_node(full_table, database=database, table=table) for parent in parents: - self.add_edge(parent, table, rel='parent') + self.add_edge(parent, full_table, rel='parent') # create non primary key foreign connections - for table, referenced in self._referenced.items(): + for full_table, referenced in referenced_dict.items(): for ref in referenced: - self.add_edge(ref, table, rel='referenced') - - @require_dep_loading - def copy_graph(self, *args, **kwargs): - """ - Return copy of the graph represented by this object at the - time of call. Note that the returned graph is no longer - bound to a connection. - """ - return RelGraph(self, *args, **kwargs) - - @require_dep_loading - def subgraph(self, *args, **kwargs): - return RelGraph(self).subgraph(*args, **kwargs) - - def register_database(self, database): - """ - Register the database to be monitored - :param database: name of database to be monitored - """ - self._databases.add(database) - - def load_dependencies(self): - """ - Load dependencies for all monitored databases - """ - for database in self._databases: - self.load_dependencies_for_database(database) - - def load_dependencies_for_database(self, database): - """ - Load dependencies for all tables found in the specified database - :param database: database for which dependencies will be loaded - """ - #sql_table_name_regexp = re.compile('^(#|_|__)?[a-z][a-z0-9_]*$') - - cur = self._conn.query('SHOW TABLES FROM `{database}`'.format(database=database)) - - for info in cur: - table_name = info[0] - # exclude service tables from ERD - if not table_name.startswith('~'): - full_table_name = '`{database}`.`{table_name}`'.format(database=database, table_name=table_name) - self.load_dependencies_for_table(full_table_name) - - def load_dependencies_for_table(self, full_table_name): - """ - Load dependencies for the specified table - :param full_table_name: table for which dependencies will be loaded, specified in full table name - """ - # check if already loaded. Use clear_dependencies before reloading - if full_table_name in self._parents: - return - self._parents[full_table_name] = list() - self._referenced[full_table_name] = list() - - # fetch the CREATE TABLE statement - cur = self._conn.query('SHOW CREATE TABLE %s' % full_table_name) - create_statement = cur.fetchone() - if not create_statement: - raise DataJointError('Could not load the definition for %s' % full_table_name) - create_statement = create_statement[1].split('\n') - - # build foreign key fk_parser - database = full_table_name.split('.')[0].strip('`') - add_database = lambda string, loc, toc: ['`{database}`.`{table}`'.format(database=database, table=toc[0])] - - # primary key parser - pk_parser = pp.CaselessLiteral('PRIMARY KEY') - pk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('primary_key') - - # foreign key parser - fk_parser = pp.CaselessLiteral('CONSTRAINT').suppress() - fk_parser += pp.QuotedString('`').suppress() - fk_parser += pp.CaselessLiteral('FOREIGN KEY').suppress() - fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('attributes') - fk_parser += pp.CaselessLiteral('REFERENCES') - fk_parser += pp.Or([ - pp.QuotedString('`').setParseAction(add_database), - pp.Combine(pp.QuotedString('`', unquoteResults=False) + - '.' + pp.QuotedString('`', unquoteResults=False)) - ]).setResultsName('referenced_table') - fk_parser += pp.QuotedString('(', endQuoteChar=')').setResultsName('referenced_attributes') - - # parse foreign keys - primary_key = None - for line in create_statement: - if primary_key is None: - try: - result = pk_parser.parseString(line) - except pp.ParseException: - pass - else: - primary_key = [s.strip(' `') for s in result.primary_key.split(',')] - try: - result = fk_parser.parseString(line) - except pp.ParseException: - pass - else: - if not primary_key: - raise DataJointError('No primary key found %s' % full_table_name) - if result.referenced_attributes != result.attributes: - raise DataJointError( - "%s's foreign key refers to differently named attributes in %s" - % (self.__class__.__name__, result.referenced_table)) - if all(q in primary_key for q in [s.strip('` ') for s in result.attributes.split(',')]): - self._parents[full_table_name].append(result.referenced_table) - self._children[result.referenced_table].append(full_table_name) - else: - self._referenced[full_table_name].append(result.referenced_table) - self._references[result.referenced_table].append(full_table_name) - self.update_graph() - - def __repr__(self): - # Make sure that all dependencies are loaded before printing repr - self.load_dependencies() - return super().__repr__() - - def clear_dependencies(self): - # TODO: complete the implementation - pass - - def clear_dependencies_for_database(self, database): - # TODO: complete the implementation - pass - - def clear_dependencies_for_table(self, full_table_name): - for ref in self._parents.pop(full_table_name, []): - if full_table_name in self._children[ref]: - self._children[ref].remove(full_table_name) - for ref in self._referenced.pop(full_table_name, []): - if full_table_name in self._references[ref]: - self._references[ref].remove(full_table_name) - - @property - @require_dep_loading - def parents(self): - return self._parents - - @property - @require_dep_loading - def children(self): - self.load_dependencies() - return self._children - - @property - @require_dep_loading - def references(self): - return self._references - - @property - @require_dep_loading - def referenced(self): - return self._referenced - - @require_dep_loading - def get_descendants(self, full_table_name): - """ - :param full_table_name: a table name in the format `database`.`table_name` - :return: list of all children and references, in order of dependence. - This is helpful for cascading delete or drop operations. - """ - ret = defaultdict(lambda: 0) - - def recurse(full_table_name, level): - if level > ret[full_table_name]: - ret[full_table_name] = level - for child in self.children[full_table_name] + self.references[full_table_name]: - recurse(child, level+1) - - recurse(full_table_name, 0) - return sorted(ret.keys(), key=ret.__getitem__) - - + self.add_edge(ref, full_table, rel='referenced') \ No newline at end of file From 61ebdd3594877e5f014a85b7a9a9eb41e4683574 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 14:56:41 -0500 Subject: [PATCH 35/41] Add factory method to create ERD from Dependencies --- datajoint/erd.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 458ee994c..4fa890996 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -59,7 +59,7 @@ def get_table_relation_name_map(): return parse_base_relations(rels) -class RelGraph(DiGraph): +class ERD(DiGraph): """ A directed graph representing dependencies between Relations within and across multiple databases. @@ -424,7 +424,6 @@ def repr_path(self): return rep - def compare_path(self, path1, path2): """ Comparator between two paths: path1 and path2 based on a combination of rules. @@ -471,27 +470,19 @@ def repr_path_with_depth(self, path, n=20, m=2): rep += '\n' return rep + @classmethod + def create_from_dependencies(cls, dependencies, *args, **kwargs): + obj = cls(*args, **kwargs) -class ERD(RelGraph): - """ - Entity Relation Diagram - - Represents known relation between tables - """ - # _checked_dependencies = set() - - def __init__(self, parents_dict, referenced_dict, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.clear() - # create primary key foreign connections - for full_table, parents in parents_dict.items(): + for full_table, parents in dependencies.parents.items(): database, table = (x.strip('`') for x in full_table.split('.')) - self.add_node(full_table, database=database, table=table) + obj.add_node(full_table, database=database, table=table) for parent in parents: - self.add_edge(parent, full_table, rel='parent') + obj.add_edge(parent, full_table, rel='parent') # create non primary key foreign connections - for full_table, referenced in referenced_dict.items(): + for full_table, referenced in dependencies.referenced.items(): for ref in referenced: - self.add_edge(ref, full_table, rel='referenced') \ No newline at end of file + obj.add_edge(ref, full_table, rel='referenced') + + return obj From f05eb5f53f0f6a22f2beb4b631e91c02e39a18b7 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 14:57:59 -0500 Subject: [PATCH 36/41] Add erd method to Dependencies for creating ERD from existing dependencies --- datajoint/dependencies.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 105f970ed..38ae6944e 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -1,5 +1,6 @@ from collections import defaultdict import pyparsing as pp +from .erd import ERD from . import DataJointError @@ -17,10 +18,15 @@ def __init__(self, conn): self._children = defaultdict(list) self._references = defaultdict(list) + def erd(self): + self.load() + return ERD.create_from_dependencies(self) + @property def parents(self): return self._parents + @property def children(self): return self._children From da049a9c5ab34a45e42a353134342eb67a3a0dc9 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sat, 3 Oct 2015 14:58:30 -0500 Subject: [PATCH 37/41] Add erd convenience method to Connection --- datajoint/connection.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datajoint/connection.py b/datajoint/connection.py index e518aacbc..17d9677ed 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -92,6 +92,9 @@ def is_connected(self): """ return self._conn.ping() + def erd(self): + return self.dependencies.erd() + def query(self, query, args=(), as_dict=False): """ Execute the specified query and return the tuple generator (cursor). From fee51e8602868bcda5cb6ed3fd3c2a838b50dcb0 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 6 Oct 2015 00:08:59 -0500 Subject: [PATCH 38/41] Correctly handle ERD repr when ERD is empty --- datajoint/erd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datajoint/erd.py b/datajoint/erd.py index 4fa890996..712afa1b3 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -406,6 +406,9 @@ def repr_path(self): tables :return: string representation of the erm """ + if len(self) == 0: + return "No relations to show" + paths = self.longest_paths() # turn comparator into Key object for use in sort From fb44713a0d8da96e71517db49b0b0386b2e32eaf Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 6 Oct 2015 00:09:38 -0500 Subject: [PATCH 39/41] Separate dependencies from ERD --- datajoint/connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 17d9677ed..02a6d9795 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -8,6 +8,7 @@ import logging from . import config from . import DataJointError +from datajoint.erd import ERD from .dependencies import Dependencies from .jobs import JobManager @@ -93,7 +94,8 @@ def is_connected(self): return self._conn.ping() def erd(self): - return self.dependencies.erd() + self.dependencies.load() + return ERD.create_from_dependencies(self.dependencies) def query(self, query, args=(), as_dict=False): """ From ba9f57858215d2c6685d9e0f5a03eaf2b0887724 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 6 Oct 2015 00:19:37 -0500 Subject: [PATCH 40/41] Remove erd method from Dependencies --- datajoint/dependencies.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index 38ae6944e..3ca943a35 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -1,6 +1,5 @@ from collections import defaultdict import pyparsing as pp -from .erd import ERD from . import DataJointError @@ -18,10 +17,6 @@ def __init__(self, conn): self._children = defaultdict(list) self._references = defaultdict(list) - def erd(self): - self.load() - return ERD.create_from_dependencies(self) - @property def parents(self): return self._parents From cc7b6c6eb9f884512e3489864883f7ee234f6cce Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 6 Oct 2015 00:20:23 -0500 Subject: [PATCH 41/41] ERD displays full module+class name in place of table name --- datajoint/erd.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/datajoint/erd.py b/datajoint/erd.py index 712afa1b3..5655f08e3 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -43,13 +43,9 @@ def parse_base_relations(rels): name_map = {} for r in rels: try: - module = r.__module__ - parts = [] - if module != '__main__': - parts.append(module.split('.')[-1]) - parts.append(r.__name__) - name_map[r().full_table_name] = '.'.join(parts) - except: + name_map[r().full_table_name] = '{module}.{cls}'.format(module=r.__module__, cls=r.__name__) + except TypeError: + # skip if failed to instantiate BaseRelation derivative pass return name_map