Skip to content
This repository has been archived by the owner on Jan 4, 2022. It is now read-only.

Commit

Permalink
Use directed graph (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
edxu96 committed Jan 12, 2021
1 parent 3a80f08 commit 825d0a9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
11 changes: 3 additions & 8 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@
IS_FIRST = lambda x: x == "high"
EDGES_NEW = {
("a", "g_hv"),
("g_hv", "a"),
("c", "g_hv"),
("g_hv", "c"),
("d", "g_lv"),
("g_lv", "d"),
("f", "g_lv"),
("g_lv", "f"),
("g_lv", "g_hv"),
("g_hv", "g_lv"),
}


@pt.mark.usefixtures("case_readme")
def test_split(case_readme: nx.Graph):
def test_split(case_readme: nx.DiGraph):
"""Check basic features of ``split`` method in ``GeoGraph``.
Note:
Expand All @@ -34,9 +29,9 @@ def test_split(case_readme: nx.Graph):
res.split("g", "g_hv", "g_lv", ATTR, IS_FIRST)

assert (
set(res.edges).difference(EDGES_NEW) == set()
set(res.edges) == EDGES_NEW
), "Terminals of associated edges should be renamed correctly."
assert nx.is_connected(res)
assert res.is_connected_graph

_new = res.new_
assert isinstance(_new, DataFrame)
Expand Down
50 changes: 45 additions & 5 deletions vsec/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
COLUMNS_POS = ["node", "x", "y"]


class Graph(nx.Graph):
class Graph(nx.DiGraph):
"""Graph with vertices to be split and edges to be contracted.
Note:
Expand All @@ -20,7 +20,7 @@ class Graph(nx.Graph):
- Edges don't have name in ``networkx``.
"""

def __init__(self, g: Optional[nx.Graph] = None):
def __init__(self, g: Optional[nx.DiGraph] = None):
"""Init an empty directed graph or existing directed graph.
Args:
Expand All @@ -34,6 +34,15 @@ def __init__(self, g: Optional[nx.Graph] = None):
self._new_dict = {}
self._renamed_dict = {}

idx = pd.MultiIndex.from_tuples(self.edges, names=COLUMNS)
self.edgelist = pd.DataFrame(
data={
"first": idx.get_level_values(0),
"second": idx.get_level_values(1),
},
index=idx,
)

def split(
self,
vertex: str,
Expand All @@ -44,6 +53,10 @@ def split(
):
"""Split a vertex and handle new vertices and associated edges.
Warning:
Only one edge attribute can be used to distinguish
associated edges to two clusters.
Args:
vertex: which ought to be modelled as an edge.
vertex_first: the first resulted vertex.
Expand All @@ -52,7 +65,10 @@ def split(
is_first: how to choose between resulted vertices. When None
is returned, an error will be logged.
"""
edges_asso = list(self.edges(nbunch=vertex, data=True))
# Find all the associated edges.
edges_asso = list(self.in_edges(nbunch=vertex, data=True)) + list(
self.out_edges(nbunch=vertex, data=True)
)

# Rename terminals of associated edges.
for u, v, attributes in edges_asso:
Expand All @@ -67,20 +83,33 @@ def split(
self.remove_edge(u, v)
if u == vertex:
self.add_edge(vertex_new, v, **attributes)

new_column = pd.Series(
[vertex_new], name="first", index=[(u, v)]
)
self._renamed_dict[(u, v)] = (vertex_new, v)
else:
self.add_edge(u, vertex_new, **attributes)
new_column = pd.Series(
[vertex_new], name="second", index=[(u, v)]
)
self._renamed_dict[(u, v)] = (u, vertex_new)

self.edgelist.update(new_column)
else:
logger.critical(
f"Unable to determine new terminal of edge ({u}, {v}) "
f"with attributes {attributes}."
)
break

# Validate if all associated edges have been renamed.
edges_asso = list(self.edges(nbunch=vertex, data=True))
assert len(edges_asso) == 0

# Add the resulted new edge and remove the original vertex.
self.add_edge(vertex_first, vertex_second, split_=True)
self.remove_node(vertex)
self.remove_node(vertex) # Removes the node and all adjacent edges.
self._new_dict[vertex] = (vertex_first, vertex_second)

@property
Expand All @@ -99,7 +128,7 @@ def new_(self) -> DataFrame:

@property
def renamed_(self) -> DataFrame:
"""Gather edges renamed because of split or contraction.
"""Gather renamed edges compared to initial graph.
Returns:
Edges renamed because of split or contraction.
Expand Down Expand Up @@ -174,3 +203,14 @@ def complete_edge_attr(self, attr: str) -> bool:
)
res = False
return res

@property
def is_connected_graph(self) -> bool:
"""Check if this undirected graph is connected.
Returns:
True if this undirected graph is connected.
"""
g = nx.Graph(self)
res = nx.is_connected(g)
return res

0 comments on commit 825d0a9

Please sign in to comment.