From 759cba9312eb5fb0c9865d9eba04aaee8b48f9ab Mon Sep 17 00:00:00 2001 From: eywalker Date: Mon, 18 May 2015 15:56:25 -0500 Subject: [PATCH 1/2] Fix bug in load_dependency and change pop_rel to populate_relations --- datajoint/autopopulate.py | 6 +- datajoint/connection.py | 2 +- datajoint/erd.py | 147 ++++++++++++++++++++++++++++++-------- 3 files changed, 122 insertions(+), 33 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 49f21cc50..20276aa7e 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -16,7 +16,7 @@ class AutoPopulate(metaclass=abc.ABCMeta): """ @abc.abstractproperty - def pop_rel(self): + def populate_relation(self): """ Derived classes must implement the read-only property pop_rel (populate relation) which is the relational expression (a Relation object) that defines how keys are generated for the populate call. @@ -42,10 +42,10 @@ def populate(self, restriction=None, suppress_errors=False, reserve_jobs=False): """ assert not reserve_jobs, NotImplemented # issue #5 error_list = [] if suppress_errors else None - if not isinstance(self.pop_rel, RelationalOperand): + if not isinstance(self.populate_relation, RelationalOperand): raise DataJointError('Invalid pop_rel value') self.conn._cancel_transaction() # rollback previous transaction, if any - unpopulated = (self.pop_rel - self.target) & restriction + unpopulated = (self.populate_relation - self.target) & restriction for key in unpopulated.project(): self.conn._start_transaction() if key in self.target: # already populated diff --git a/datajoint/connection.py b/datajoint/connection.py index 64607de5b..64c062f1d 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -248,7 +248,7 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv """ foreign_key_regexp = re.compile(r""" - FOREIGN KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table + FOREIGN\ KEY\s+\((?P[`\w ,]+)\)\s+ # list of keys in this table REFERENCES\s+(?P[^\s]+)\s+ # table referenced \((?P[`\w ,]+)\) # list of keys in the referenced table """, re.X) diff --git a/datajoint/erd.py b/datajoint/erd.py index af6107390..2fb82f9e2 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -17,36 +17,60 @@ class RelGraph(DiGraph): """ - Represents relations between tables and databases + A directed graph representing relations between tables within and across + multiple databases found. """ @property def node_labels(self): + """ + :return: dictionary of key : label pairs for plotting + """ return {k: attr['label'] for k, attr in self.node.items()} @property def pk_edges(self): - return [edge for edge in self.edges() \ + """ + :return: list of edges representing primary key foreign relations + """ + return [edge for edge in self.edges() if self[edge[0]][edge[1]].get('rel')=='parent'] @property - def nonpk_edges(self): - return [edge for edge in self.edges() \ + def non_pk_edges(self): + """ + :return: list of edges representing non primary key foreign relations + """ + return [edge for edge in self.edges() if self[edge[0]][edge[1]].get('rel')=='referenced'] - def highlight(nodes): + def highlight(self, nodes): + """ + Highlights specified nodes when plotting + :param nodes: list of nodes to be highlighted + """ for node in nodes: self.node[node]['highlight'] = True - def remove_highlight(nodes=None): + def remove_highlight(self, nodes=None): + """ + Remove highlights from specified nodes when plotting. If specified node is not + highlighted to begin with, nothing happens. + :param nodes: list of nodes to remove highlights from + """ if not nodes: nodes = self.nodes_iter() for node in nodes: self.node[node]['highlight'] = False + # TODO: make this taken in various config parameters for plotting def plot(self): + """ + Plots an entity relation diagram (ERD) among all nodes that is part + of the current graph. + """ if not self.nodes(): # There is nothing to plot - logger.warning('No table to plot in ERD') + logger.warning('Nothing to plot') return pos = pygraphviz_layout(self, prog='dot') fig = plt.figure(figsize=[10,7]) @@ -61,20 +85,33 @@ def plot(self): # draw primary key relations nx.draw_networkx_edges(self, pos, self.pk_edges, arrows=False) # draw non-primary key relations - nx.draw_networkx_edges(self, pos, self.nonpk_edges, style='dashed', arrows=False) + nx.draw_networkx_edges(self, pos, self.non_pk_edges, style='dashed', arrows=False) apos = np.array(list(pos.values())) - xmax = apos[:,0].max() + 200 #TODO: use something more sensible then hard fixed number - xmin = apos[:,0].min() - 100 + xmax = apos[:, 0].max() + 200 #TODO: use something more sensible then hard fixed number + xmin = apos[:, 0].min() - 100 ax.set_xlim(xmin, xmax) - ax.axis('off') # hide axis + ax.axis('off') # hide axis def restrict_by_modules(self, modules, fill=False): + """ + Creates a subgraph containing only tables in the specified modules. + :param modules: list of module names + :param fill: set True to automatically include nodes connecting two nodes in the specified modules + :return: a subgraph with specified nodes + """ nodes = [n for n in self.nodes() if self.node[n].get('mod') in modules] if fill: - nodes = self.fill_connection_nodes(nodes) + nodes = self.fill_connection_nodes(nodes) return self.subgraph(nodes) def restrict_by_tables(self, tables, fill=False): + """ + Creates a subgraph containing only specified tables. + :param tables: list of tables to keep in the subgraph + :param fill: set True to automatically include nodes connecting two nodes in the specified list + of tables + :return: a subgraph with specified nodes + """ nodes = [n for n in self.nodes() if self.node[n].get('label') in tables] if fill: nodes = self.fill_connection_nodes(nodes) @@ -91,14 +128,17 @@ def fill_connection_nodes(self, nodes): """ For given set of nodes, find and add nodes that serves as connection points for two nodes in the set. + :param nodes: list of nodes for which connection nodes are to be filled in """ - H = self.subgraph(self.ancestors_of_all(nodes)) - return H.descendants_of_all(nodes) + graph = self.subgraph(self.ancestors_of_all(nodes)) + return graph.descendants_of_all(nodes) def ancestors_of_all(self, nodes): """ - Find and return a set including all ancestors of the given - nodes. The set will also contain the given nodes as well. + Find and return a set of all ancestors of the given + nodes. The set will also contain the specified nodes. + :param nodes: list of nodes for which ancestors are to be found + :return: a set containing passed in nodes and all of their ancestors """ s = set() for n in nodes: @@ -107,8 +147,10 @@ def ancestors_of_all(self, nodes): def descendants_of_all(self, nodes): """ - Find and return a set including all descendents of the given + Find and return a set including all descendants of the given nodes. The set will also contain the given nodes as well. + :param nodes: list of nodes for which descendants are to be found + :return: a set containing passed in nodes and all of their descendants """ s = set() for n in nodes: @@ -120,6 +162,8 @@ def ancestors(self, node): Find and return a set containing all ancestors of the specified node. For convenience in plotting, this set will also include the specified node as well (may change in future). + :param node: node for which all ancestors are to be discovered + :return: a set containing the node and all of its ancestors """ s = {node} for p in self.predecessors_iter(node): @@ -128,28 +172,50 @@ def ancestors(self, node): def descendants(self, node): """ - Find and return a set containing all descendents of the specified + Find and return a set containing all descendants of the specified node. For convenience in plotting, this set will also include the specified node as well (may change in future). + :param node: node for which all descendants are to be discovered + :return: a set containing the node and all of its descendants """ s = {node} for c in self.successors_iter(node): s.update(self.descendants(c)) return s + def up_down_neighbors(self, node, ups, downs, prev=None): + """ + Returns a set of all nodes that can be reached from the specified node by + moving up and down the ancestry tree with specific number of ups and downs. + + Example: + up_down_neighbors(node, ups=2, downs=1) will return all nodes that can be reached by + any combinations of two up tracing and 1 down tracing of the ancestry tree. This includes + all children of a grand-parent (two ups and one down), all grand parents of all children (one down + and then two ups), and all siblings parents (one up, one down, and one up). - def updown_neighbors(self, node, ups, downs, prev=None): + It must be noted that except for some special cases, there is no generalized interpretations for + the relationship among nodes captured by this method. However, it does tend to produce a fairy + good concise view of the relationships surrounding the specified node. + + + :param node: node to base all discovery on + :param ups: number of times to go up the ancestry tree (go up to parent) + :param downs: number of times to go down the ancestry tree (go down to children) + :param prev: previously visited node. This will be excluded from up down search in this recursion + :return: a set of all nodes that can be reached within specified numbers of ups and downs from the source node + """ s = {node} if ups > 0: for x in self.predecessors_iter(node): if x == prev: continue - s.update(self.updown_neighbors(x, ups-1, downs, node)) + s.update(self.up_down_neighbors(x, ups-1, downs, node)) if downs > 0: for x in self.successors_iter(node): if x == prev: continue - s.update(self.updown_neighbors(x, ups, downs-1, node)) + s.update(self.up_down_neighbors(x, ups, downs-1, node)) return s def n_neighbors(self, node, n, prev=None): @@ -180,14 +246,18 @@ def n_neighbors(self, node, n, prev=None): s.update(self.n_neighbors(x, n-1, prev)) return s -full_table_ptrn = re.compile(r'`(.*)`\.`(.*)`') + class DBConnGraph(RelGraph): """ - Represents relational structure of the - connected databases + Represents relational structure of the databases and tables associated with a connection object """ def __init__(self, conn, *args, **kwargs): + """ + Initializes graph associated with a connection object + :param conn: connection object for which the relational graph is to be constructed + """ + # this is calling the networkx.DiGraph initializer super().__init__(*args, **kwargs) if conn.is_connected: self._conn = conn @@ -196,24 +266,43 @@ def __init__(self, conn, *args, **kwargs): self.update_graph() def full_table_to_class(self, full_table_name): + """ + Converts full table reference of form `database`.`table` into the corresponding + module_name.class_name format. For the module name, only the actual name of the module is used, with + all of its package reference removed. + :param full_table_name: full name of the table in the form `database`.`table` + :return: name in the form of module_name.class_name if corresponding module and class exists + """ + full_table_ptrn = re.compile(r'`(.*)`\.`(.*)`') m = full_table_ptrn.match(full_table_name) dbname = m.group(1) table_name = m.group(2) - mod_name = self._conn.modules[dbname] + mod_name = self._conn.db_to_mod[dbname] + mod_name = mod_name.split('.')[-1] class_name = to_camel_case(table_name) return '{}.{}'.format(mod_name, class_name) - def update_graph(self): + def update_graph(self, reload=False): + """ + Update the connection graph. Set reload=True to cause the connection object's + table heading information to be reloaded as well + """ + if reload: + self._conn.load_headings(force=True) + self.clear() + + # create primary key foreign connections for table, parents in self._conn.parents.items(): label = self.full_table_to_class(table) mod, cls = label.split('.') - self.add_node(table, label=label, \ + self.add_node(table, label=label, mod=mod, cls=cls) for parent in parents: self.add_edge(parent, table, rel='parent') + # create non primary key foreign connections for table, referenced in self._conn.referenced.items(): for ref in referenced: self.add_edge(ref, table, rel='referenced') @@ -222,10 +311,10 @@ def copy_graph(self): """ Return copy of the graph represented by this object at the time of call. Note that the returned graph is no longer - bound to connection. + bound to a connection. """ return RelGraph(self) def subgraph(self, *args, **kwargs): - return RelGraph(self).subgraph(*args, **kwargs) + return RelGraph(self).subgraph(*args, **kwargs) From 290a2d9ccc4f52a799e854b3857306390246129b Mon Sep 17 00:00:00 2001 From: eywalker Date: Tue, 19 May 2015 12:41:42 -0500 Subject: [PATCH 2/2] Fix non-pk foreign reference and change return of erd --- datajoint/connection.py | 30 +++++++++++++++++------------- datajoint/erd.py | 5 ++++- datajoint/free_relation.py | 38 ++++---------------------------------- demos/demo1.py | 15 +++++++++++++++ tests/test_connection.py | 3 ++- 5 files changed, 42 insertions(+), 49 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 64c062f1d..51ed2d1bd 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -298,37 +298,44 @@ def clear_dependencies(self, dbname=None): if key in self.referenced: self.referenced.pop(key) - def parents_of(self, child_table): #TODO: this function is not clear to me after reading the docu + def parents_of(self, child_table): """ - Returns a list of tables that are parents for the childTable based on - primary foreign keys. + Returns a list of tables that are parents of the specified child_table. Parent-child relationship is defined + based on the presence of primary-key foreign reference: table that holds a foreign key relation to another table + is the child table. :param child_table: the child table + :return: list of parent tables """ return self.parents.get(child_table, []).copy() - def children_of(self, parent_table):#TODO: this function is not clear to me after reading the docu + def children_of(self, parent_table): """ - Returns a list of tables for which parent_table is a parent (primary foreign key) + Returns a list of tables for which parent_table is a parent (primary foreign key). Parent-child relationship + is defined based on the presence of primary-key foreign reference: table that holds a foreign key relation to + another table is the child table. :param parent_table: parent table + :return: list of child tables """ return [child_table for child_table, parents in self.parents.items() if parent_table in parents] def referenced_by(self, referencing_table): """ - Returns a list of tables that are referenced by non-primary foreign key + Returns a list of tables that are referenced by non-primary foreign key relation by the referencing_table. :param referencing_table: referencing table + :return: list of tables that are referenced by the target table """ return self.referenced.get(referencing_table, []).copy() def referencing(self, referenced_table): """ - Returns a list of tables that references referencedTable as non-primary foreign key + Returns a list of tables that references referenced_table as non-primary foreign key :param referenced_table: referenced table + :return: list of tables that refers to the target table """ return [referencing for referencing, referenced in self.referenced.items() if referenced_table in referenced] @@ -346,7 +353,7 @@ def __del__(self): logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info)) self._conn.close() - def erd(self, databases=None, tables=None, fill=True, reload=True): + def erd(self, databases=None, tables=None, fill=True, reload=False): """ Creates Entity Relation Diagram for the database or specified subset of tables. @@ -354,10 +361,7 @@ def erd(self, databases=None, tables=None, fill=True, reload=True): Set `fill` to False to only display specified tables. (By default connection tables are automatically included) """ - if reload: - self.load_headings() # load all tables and relations for bound databases - - self._graph.update_graph() # update the graph + self._graph.update_graph(reload=reload) # update the graph graph = self._graph.copy_graph() if databases: @@ -366,7 +370,7 @@ def erd(self, databases=None, tables=None, fill=True, reload=True): if tables: graph = graph.restrict_by_tables(tables, fill) - graph.plot() + return graph def query(self, query, args=(), as_dict=False): """ diff --git a/datajoint/erd.py b/datajoint/erd.py index 2fb82f9e2..b36376c7c 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -63,7 +63,7 @@ def remove_highlight(self, nodes=None): for node in nodes: self.node[node]['highlight'] = False - # TODO: make this taken in various config parameters for plotting + # TODO: make this take in various config parameters for plotting def plot(self): """ Plots an entity relation diagram (ERD) among all nodes that is part @@ -92,6 +92,9 @@ def plot(self): ax.set_xlim(xmin, xmax) ax.axis('off') # hide axis + def __repr__(self): + pass + def restrict_by_modules(self, modules, fill=False): """ Creates a subgraph containing only tables in the specified modules. diff --git a/datajoint/free_relation.py b/datajoint/free_relation.py index fd87bee6a..241ada906 100644 --- a/datajoint/free_relation.py +++ b/datajoint/free_relation.py @@ -249,42 +249,12 @@ def alter_attribute(self, attr_name, new_definition): sql = self.field_to_sql(parse_attribute_definition(new_definition)) self._alter('CHANGE COLUMN `%s` %s' % (attr_name, sql[:-2])) - def erd(self, subset=None, prog='dot'): + def erd(self, subset=None): """ Plot the schema's entity relationship diagram (ERD). - The layout programs can be 'dot' (default), 'neato', 'fdp', 'sfdp', 'circo', 'twopi' - """ - if not subset: - g = self.graph - else: - g = self.graph.copy() - # todo: make erd work (github issue #7) - """ - g = self.graph - else: - g = self.graph.copy() - for i in g.nodes(): - if i not in subset: - g.remove_node(i) - def tablelist(tier): - return [i for i in g if self.tables[i].tier==tier] - - pos=nx.graphviz_layout(g,prog=prog,args='') - plt.figure(figsize=(8,8)) - nx.draw_networkx_edges(g, pos, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('manual'), - node_color='g', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('computed'), - node_color='r', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('imported'), - node_color='b', node_size=200, alpha=0.3) - nx.draw_networkx_nodes(g, pos, nodelist=tablelist('lookup'), - node_color='gray', node_size=120, alpha=0.3) - nx.draw_networkx_labels(g, pos, nodelist = subset, font_weight='bold', font_size=9) - nx.draw(g,pos,alpha=0,with_labels=false) - plt.show() """ + def _alter(self, alter_statement): """ Execute ALTER TABLE statement for this table. The schema @@ -377,8 +347,8 @@ def _declare(self): # add secondary foreign key attributes for r in referenced: - keys = (x for x in r.heading.attrs.values() if x.in_key) - for field in keys: + for key in r.primary_key: + field = r.heading[key] if field.name not in primary_key_fields | non_key_fields: non_key_fields.add(field.name) sql += self._field_to_sql(field) diff --git a/demos/demo1.py b/demos/demo1.py index d1b793d3f..e85d6ead3 100644 --- a/demos/demo1.py +++ b/demos/demo1.py @@ -55,9 +55,24 @@ class Scan(dj.Relation): definition = """ demo1.Scan (manual) # a two-photon imaging session -> demo1.Session + -> Config scan_id : tinyint # two-photon session within this experiment ---- depth : float # depth from surface wavelength : smallint # (nm) laser wavelength mwatts: numeric(4,1) # (mW) laser power to brain + """ + +class Config(dj.Relation): + definition = """ + demo1.Config (manual) # configuration for scanner + config_id : tinyint # unique id for config setup + --- + ->ConfigParam + """ + +class ConfigParam(dj.Relation): + definition = """ + demo1.ConfigParam (lookup) # params for configurations + param_set_id : tinyint # id for params """ \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 429e06304..266b9f3fb 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -121,7 +121,8 @@ def test_bind_to_non_existing_database(self): cur = BASE_CONN.cursor() # Ensure target database doesn't exist - cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) + if cur.execute("SHOW DATABASES LIKE '{}'".format(db_name)): + cur.execute("DROP DATABASE IF EXISTS `{}`".format(db_name)) # Bind module to non-existing database self.conn.bind(module, db_name) # Check that target database was created