Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def load_dependencies(self, dbname): # TODO: Perhaps consider making this "priv
"""

foreign_key_regexp = re.compile(r"""
FOREIGN KEY\s+\((?P<attr1>[`\w ,]+)\)\s+ # list of keys in this table
FOREIGN\ KEY\s+\((?P<attr1>[`\w ,]+)\)\s+ # list of keys in this table
REFERENCES\s+(?P<ref>[^\s]+)\s+ # table referenced
\((?P<attr2>[`\w ,]+)\) # list of keys in the referenced table
""", re.X)
Expand Down Expand Up @@ -298,37 +298,44 @@ def clear_dependencies(self, dbname=None):
if key in self.referenced:
self.referenced.pop(key)

def parents_of(self, child_table): #TODO: this function is not clear to me after reading the docu
def parents_of(self, child_table):
"""
Returns a list of tables that are parents for the childTable based on
primary foreign keys.
Returns a list of tables that are parents of the specified child_table. Parent-child relationship is defined
based on the presence of primary-key foreign reference: table that holds a foreign key relation to another table
is the child table.

:param child_table: the child table
:return: list of parent tables
"""
return self.parents.get(child_table, []).copy()

def children_of(self, parent_table):#TODO: this function is not clear to me after reading the docu
def children_of(self, parent_table):
"""
Returns a list of tables for which parent_table is a parent (primary foreign key)
Returns a list of tables for which parent_table is a parent (primary foreign key). Parent-child relationship
is defined based on the presence of primary-key foreign reference: table that holds a foreign key relation to
another table is the child table.

:param parent_table: parent table
:return: list of child tables
"""
return [child_table for child_table, parents in self.parents.items() if parent_table in parents]

def referenced_by(self, referencing_table):
"""
Returns a list of tables that are referenced by non-primary foreign key
Returns a list of tables that are referenced by non-primary foreign key relation
by the referencing_table.

:param referencing_table: referencing table
:return: list of tables that are referenced by the target table
"""
return self.referenced.get(referencing_table, []).copy()

def referencing(self, referenced_table):
"""
Returns a list of tables that references referencedTable as non-primary foreign key
Returns a list of tables that references referenced_table as non-primary foreign key

:param referenced_table: referenced table
:return: list of tables that refers to the target table
"""
return [referencing for referencing, referenced in self.referenced.items()
if referenced_table in referenced]
Expand All @@ -346,18 +353,15 @@ def __del__(self):
logger.info('Disconnecting {user}@{host}:{port}'.format(**self.conn_info))
self._conn.close()

def erd(self, databases=None, tables=None, fill=True, reload=True):
def erd(self, databases=None, tables=None, fill=True, reload=False):
"""
Creates Entity Relation Diagram for the database or specified subset of
tables.

Set `fill` to False to only display specified tables. (By default
connection tables are automatically included)
"""
if reload:
self.load_headings() # load all tables and relations for bound databases

self._graph.update_graph() # update the graph
self._graph.update_graph(reload=reload) # update the graph

graph = self._graph.copy_graph()
if databases:
Expand All @@ -366,7 +370,7 @@ def erd(self, databases=None, tables=None, fill=True, reload=True):
if tables:
graph = graph.restrict_by_tables(tables, fill)

graph.plot()
return graph

def query(self, query, args=(), as_dict=False):
"""
Expand Down
150 changes: 121 additions & 29 deletions datajoint/erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,60 @@

class RelGraph(DiGraph):
"""
Represents relations between tables and databases
A directed graph representing relations between tables within and across
multiple databases found.
"""

@property
def node_labels(self):
"""
:return: dictionary of key : label pairs for plotting
"""
return {k: attr['label'] for k, attr in self.node.items()}

@property
def pk_edges(self):
return [edge for edge in self.edges() \
"""
:return: list of edges representing primary key foreign relations
"""
return [edge for edge in self.edges()
if self[edge[0]][edge[1]].get('rel')=='parent']

@property
def nonpk_edges(self):
return [edge for edge in self.edges() \
def non_pk_edges(self):
"""
:return: list of edges representing non primary key foreign relations
"""
return [edge for edge in self.edges()
if self[edge[0]][edge[1]].get('rel')=='referenced']

def highlight(nodes):
def highlight(self, nodes):
"""
Highlights specified nodes when plotting
:param nodes: list of nodes to be highlighted
"""
for node in nodes:
self.node[node]['highlight'] = True

def remove_highlight(nodes=None):
def remove_highlight(self, nodes=None):
"""
Remove highlights from specified nodes when plotting. If specified node is not
highlighted to begin with, nothing happens.
:param nodes: list of nodes to remove highlights from
"""
if not nodes:
nodes = self.nodes_iter()
for node in nodes:
self.node[node]['highlight'] = False

# TODO: make this take in various config parameters for plotting
def plot(self):
"""
Plots an entity relation diagram (ERD) among all nodes that is part
of the current graph.
"""
if not self.nodes(): # There is nothing to plot
logger.warning('No table to plot in ERD')
logger.warning('Nothing to plot')
return
pos = pygraphviz_layout(self, prog='dot')
fig = plt.figure(figsize=[10,7])
Expand All @@ -61,20 +85,36 @@ def plot(self):
# draw primary key relations
nx.draw_networkx_edges(self, pos, self.pk_edges, arrows=False)
# draw non-primary key relations
nx.draw_networkx_edges(self, pos, self.nonpk_edges, style='dashed', arrows=False)
nx.draw_networkx_edges(self, pos, self.non_pk_edges, style='dashed', arrows=False)
apos = np.array(list(pos.values()))
xmax = apos[:,0].max() + 200 #TODO: use something more sensible then hard fixed number
xmin = apos[:,0].min() - 100
xmax = apos[:, 0].max() + 200 #TODO: use something more sensible then hard fixed number
xmin = apos[:, 0].min() - 100
ax.set_xlim(xmin, xmax)
ax.axis('off') # hide axis
ax.axis('off') # hide axis

def __repr__(self):
pass

def restrict_by_modules(self, modules, fill=False):
"""
Creates a subgraph containing only tables in the specified modules.
:param modules: list of module names
:param fill: set True to automatically include nodes connecting two nodes in the specified modules
:return: a subgraph with specified nodes
"""
nodes = [n for n in self.nodes() if self.node[n].get('mod') in modules]
if fill:
nodes = self.fill_connection_nodes(nodes)
nodes = self.fill_connection_nodes(nodes)
return self.subgraph(nodes)

def restrict_by_tables(self, tables, fill=False):
"""
Creates a subgraph containing only specified tables.
:param tables: list of tables to keep in the subgraph
:param fill: set True to automatically include nodes connecting two nodes in the specified list
of tables
:return: a subgraph with specified nodes
"""
nodes = [n for n in self.nodes() if self.node[n].get('label') in tables]
if fill:
nodes = self.fill_connection_nodes(nodes)
Expand All @@ -91,14 +131,17 @@ def fill_connection_nodes(self, nodes):
"""
For given set of nodes, find and add nodes that serves as
connection points for two nodes in the set.
:param nodes: list of nodes for which connection nodes are to be filled in
"""
H = self.subgraph(self.ancestors_of_all(nodes))
return H.descendants_of_all(nodes)
graph = self.subgraph(self.ancestors_of_all(nodes))
return graph.descendants_of_all(nodes)

def ancestors_of_all(self, nodes):
"""
Find and return a set including all ancestors of the given
nodes. The set will also contain the given nodes as well.
Find and return a set of all ancestors of the given
nodes. The set will also contain the specified nodes.
:param nodes: list of nodes for which ancestors are to be found
:return: a set containing passed in nodes and all of their ancestors
"""
s = set()
for n in nodes:
Expand All @@ -107,8 +150,10 @@ def ancestors_of_all(self, nodes):

def descendants_of_all(self, nodes):
"""
Find and return a set including all descendents of the given
Find and return a set including all descendants of the given
nodes. The set will also contain the given nodes as well.
:param nodes: list of nodes for which descendants are to be found
:return: a set containing passed in nodes and all of their descendants
"""
s = set()
for n in nodes:
Expand All @@ -120,6 +165,8 @@ def ancestors(self, node):
Find and return a set containing all ancestors of the specified
node. For convenience in plotting, this set will also include
the specified node as well (may change in future).
:param node: node for which all ancestors are to be discovered
:return: a set containing the node and all of its ancestors
"""
s = {node}
for p in self.predecessors_iter(node):
Expand All @@ -128,28 +175,50 @@ def ancestors(self, node):

def descendants(self, node):
"""
Find and return a set containing all descendents of the specified
Find and return a set containing all descendants of the specified
node. For convenience in plotting, this set will also include
the specified node as well (may change in future).
:param node: node for which all descendants are to be discovered
:return: a set containing the node and all of its descendants
"""
s = {node}
for c in self.successors_iter(node):
s.update(self.descendants(c))
return s

def up_down_neighbors(self, node, ups, downs, prev=None):
"""
Returns a set of all nodes that can be reached from the specified node by
moving up and down the ancestry tree with specific number of ups and downs.

Example:
up_down_neighbors(node, ups=2, downs=1) will return all nodes that can be reached by
any combinations of two up tracing and 1 down tracing of the ancestry tree. This includes
all children of a grand-parent (two ups and one down), all grand parents of all children (one down
and then two ups), and all siblings parents (one up, one down, and one up).

def updown_neighbors(self, node, ups, downs, prev=None):
It must be noted that except for some special cases, there is no generalized interpretations for
the relationship among nodes captured by this method. However, it does tend to produce a fairy
good concise view of the relationships surrounding the specified node.


:param node: node to base all discovery on
:param ups: number of times to go up the ancestry tree (go up to parent)
:param downs: number of times to go down the ancestry tree (go down to children)
:param prev: previously visited node. This will be excluded from up down search in this recursion
:return: a set of all nodes that can be reached within specified numbers of ups and downs from the source node
"""
s = {node}
if ups > 0:
for x in self.predecessors_iter(node):
if x == prev:
continue
s.update(self.updown_neighbors(x, ups-1, downs, node))
s.update(self.up_down_neighbors(x, ups-1, downs, node))
if downs > 0:
for x in self.successors_iter(node):
if x == prev:
continue
s.update(self.updown_neighbors(x, ups, downs-1, node))
s.update(self.up_down_neighbors(x, ups, downs-1, node))
return s

def n_neighbors(self, node, n, prev=None):
Expand Down Expand Up @@ -180,14 +249,18 @@ def n_neighbors(self, node, n, prev=None):
s.update(self.n_neighbors(x, n-1, prev))
return s

full_table_ptrn = re.compile(r'`(.*)`\.`(.*)`')


class DBConnGraph(RelGraph):
"""
Represents relational structure of the
connected databases
Represents relational structure of the databases and tables associated with a connection object
"""
def __init__(self, conn, *args, **kwargs):
"""
Initializes graph associated with a connection object
:param conn: connection object for which the relational graph is to be constructed
"""
# this is calling the networkx.DiGraph initializer
super().__init__(*args, **kwargs)
if conn.is_connected:
self._conn = conn
Expand All @@ -196,24 +269,43 @@ def __init__(self, conn, *args, **kwargs):
self.update_graph()

def full_table_to_class(self, full_table_name):
"""
Converts full table reference of form `database`.`table` into the corresponding
module_name.class_name format. For the module name, only the actual name of the module is used, with
all of its package reference removed.
:param full_table_name: full name of the table in the form `database`.`table`
:return: name in the form of module_name.class_name if corresponding module and class exists
"""
full_table_ptrn = re.compile(r'`(.*)`\.`(.*)`')
m = full_table_ptrn.match(full_table_name)
dbname = m.group(1)
table_name = m.group(2)
mod_name = self._conn.modules[dbname]
mod_name = self._conn.db_to_mod[dbname]
mod_name = mod_name.split('.')[-1]
class_name = to_camel_case(table_name)
return '{}.{}'.format(mod_name, class_name)

def update_graph(self):
def update_graph(self, reload=False):
"""
Update the connection graph. Set reload=True to cause the connection object's
table heading information to be reloaded as well
"""
if reload:
self._conn.load_headings(force=True)

self.clear()

# create primary key foreign connections
for table, parents in self._conn.parents.items():
label = self.full_table_to_class(table)
mod, cls = label.split('.')

self.add_node(table, label=label, \
self.add_node(table, label=label,
mod=mod, cls=cls)
for parent in parents:
self.add_edge(parent, table, rel='parent')

# create non primary key foreign connections
for table, referenced in self._conn.referenced.items():
for ref in referenced:
self.add_edge(ref, table, rel='referenced')
Expand All @@ -222,10 +314,10 @@ def copy_graph(self):
"""
Return copy of the graph represented by this object at the
time of call. Note that the returned graph is no longer
bound to connection.
bound to a connection.
"""
return RelGraph(self)

def subgraph(self, *args, **kwargs):
return RelGraph(self).subgraph(*args, **kwargs)
return RelGraph(self).subgraph(*args, **kwargs)

Loading