diff --git a/gerrychain/graph/graph.py b/gerrychain/graph/graph.py index 3d06dda7..5bb43650 100644 --- a/gerrychain/graph/graph.py +++ b/gerrychain/graph/graph.py @@ -4,7 +4,6 @@ import warnings import networkx -import retworkx from networkx.classes.function import frozen from networkx.readwrite import json_graph import pandas as pd @@ -20,6 +19,13 @@ def json_serialize(input_object): if pd.api.types.is_integer_dtype(input_object): # handle int64 return int(input_object) +try: + import rustworkx +except ImportError: + _has_rustworkx = False +else: + _has_rustworkx = True + class Graph(networkx.Graph): """Represents a graph to be partitioned. It is based on :class:`networkx.Graph`. @@ -366,32 +372,29 @@ class FrozenGraph: "rustworkx_networkx_mapping" ] - def __init__( - self, - graph: Graph, - pygraph: retworkx.PyGraph = None, - mappings: Tuple[dict, dict] = None - ): + def __init__(self, graph: Graph, pygraph: "rustworkx.PyGraph" = None): self.graph = networkx.classes.function.freeze(graph) self.graph.join = frozen self.graph.add_data = frozen self.size = len(self.graph) - if pygraph: - self.pygraph = pygraph - else: - self.pygraph = retworkx.networkx_converter(graph, keep_attributes=True) + if _has_rustworkx: + if pygraph is None: + self.pygraph = rustworkx.networkx_converter(graph, keep_attributes=True) + else: + self.pygraph = pygraph - if mappings: - self.retworkx_networkx_mapping, self.networkx_retworkx_mapping = mappings - else: - self.retworkx_networkx_mapping = { + self.rustworkx_networkx_mapping = { n: self.pygraph[n]["__networkx_node__"] for n in self.pygraph.node_indexes() } - self.networkx_retworkx_mapping = { + self.networkx_rustworkx_mapping = { self.pygraph[n]["__networkx_node__"]: n for n in self.pygraph.node_indexes() } + else: + self.pygraph = None + self.rustworkx_networkx_mapping = None + self.networkx_rustworkx_mapping = None def __len__(self): return self.size @@ -426,15 +429,20 @@ def lookup(self, node, field): return self.graph.nodes[node][field] def subgraph(self, nodes): + if self.pygraph is None: + return FrozenGraph(self.graph.subgraph(nodes)) + return FrozenGraph( self.graph.subgraph(nodes), self.pygraph.subgraph( - [self.networkx_retworkx_mapping[x] for x in nodes] + [self.networkx_rustworkx_mapping[x] for x in nodes] ) ) - # @functools.cache def pygraph_pop_lookup(self, field: str): + if self.pygraph is None: + raise ValueError("No rustworkx graph available.") + attrs = [0] * len(self.pygraph.node_indexes()) for node in self.pygraph.node_indexes(): attrs[node] = float(self.pygraph[node][field])