-
Notifications
You must be signed in to change notification settings - Fork 30
/
slimmer.py
98 lines (82 loc) · 2.89 KB
/
slimmer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import networkx as nx
import logging
logger = logging.getLogger(__name__)
def get_minimal_subgraph(g, nodes):
"""
given a set of nodes, extract a subgraph that excludes non-informative nodes - i.e.
those that are not MRCAs of pairs of existing nodes.
Note: no property chain reasoning is performed. As a result, edge labels are lost.
"""
logger.info("Slimming {} to {}".format(g,nodes))
# maps ancestor nodes to members of the focus node set they subsume
mm = {}
subnodes = set()
for n in nodes:
subnodes.add(n)
ancs = nx.ancestors(g, n)
ancs.add(n)
for a in ancs:
subnodes.add(a)
if a not in mm:
mm[a] = set()
mm[a].add(n)
# merge graph
egraph = nx.MultiDiGraph()
# TODO: ensure edge labels are preserved
for a, aset in mm.items():
for p in g.predecessors(a):
logger.info(" cmp {} -> {} // {} {}".format(len(aset),len(mm[p]), a, p))
if p in mm and len(aset) == len(mm[p]):
egraph.add_edge(p, a)
egraph.add_edge(a, p)
logger.info("will merge {} <-> {} (members identical)".format(p,a))
nmap = {}
leafmap = {}
disposable = set()
for cliq in nx.strongly_connected_components(egraph):
leaders = set()
leafs = set()
for n in cliq:
is_src = False
if n in nodes:
logger.info("Preserving: {} in {}".format(n,cliq))
leaders.add(n)
is_src = True
is_leaf = True
for p in g.successors(n):
if p in cliq:
is_leaf = False
if not(is_leaf or is_src):
disposable.add(n)
if is_leaf:
logger.info("Clique leaf: {} in {}".format(n,cliq))
leafs.add(n)
leader = None
if len(leaders) > 1:
logger.info("UHOH: {}".format(leaders))
if len(leaders) > 0:
leader = list(leaders)[0]
else:
leader = list(leafs)[0]
leafmap[n] = leafs
subg = g.subgraph(subnodes)
fg = remove_nodes(subg, disposable)
return fg
def remove_nodes(g, rmnodes):
logger.info("Removing {} from {}".format(rmnodes,g))
newg = nx.MultiDiGraph()
for (n,nd) in g.nodes(data=True):
if n not in rmnodes:
newg.add_node(n, **nd)
parents = _traverse(g, set([n]), set(rmnodes), set())
for p in parents:
newg.add_edge(p,n,**{'pred':'subClassOf'})
return newg
def _traverse(g, nset, rmnodes, acc):
if len(nset) == 0:
return acc
n = nset.pop()
parents = set(g.predecessors(n))
acc = acc.union(parents - rmnodes)
nset = nset.union(parents.intersection(rmnodes))
return _traverse(g, nset, rmnodes, acc)