Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default updaters bug in GeographicPartition #391

Merged
merged 2 commits into from Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions gerrychain/partition/partition.py
Expand Up @@ -31,28 +31,29 @@ class Partition:
'_cache'
)

default_updaters = {"cut_edges": cut_edges}

def __init__(
self, graph=None, assignment=None, updaters=None, parent=None, flips=None,
use_cut_edges=True
use_default_updaters=True
):
"""
:param graph: Underlying graph.
:param assignment: Dictionary assigning nodes to districts.
:param updaters: Dictionary of functions to track data about the partition.
The keys are stored as attributes on the partition class,
which the functions compute.
:param use_cut_edges: If `False`, do not include `cut_edges` updater by default
and do not calculate edge flows.
:param use_default_updaters: If `False`, do not include default updaters.
"""
if parent is None:
self._first_time(graph, assignment, updaters, use_cut_edges)
self._first_time(graph, assignment, updaters, use_default_updaters)
else:
self._from_parent(parent, flips)

self._cache = dict()
self.subgraphs = SubgraphView(self.graph, self.parts)

def _first_time(self, graph, assignment, updaters, use_cut_edges):
def _first_time(self, graph, assignment, updaters, use_default_updaters):
if isinstance(graph, Graph):
self.graph = FrozenGraph(graph)
elif isinstance(graph, networkx.Graph):
Expand All @@ -71,8 +72,8 @@ def _first_time(self, graph, assignment, updaters, use_cut_edges):
if updaters is None:
updaters = dict()

if use_cut_edges:
self.updaters = {"cut_edges": cut_edges}
if use_default_updaters:
self.updaters = self.default_updaters
else:
self.updaters = {}

Expand Down
22 changes: 22 additions & 0 deletions tests/partition/test_partition.py
Expand Up @@ -153,3 +153,25 @@ def test_partition_has_default_updaters(example_partition):

def test_partition_has_keys(example_partition):
assert "cut_edges" in set(example_partition.keys())


def test_geographic_partition_has_keys(example_geographic_partition):
keys = set(example_geographic_partition.updaters.keys())

assert "perimeter" in keys
assert "exterior_boundaries" in keys
assert "interior_boundaries" in keys
assert "boundary_nodes" in keys
assert "cut_edges" in keys
assert "area" in keys
assert "cut_edges_by_part" in keys


def test_partition_has_default_updaters(example_geographic_partition):
assert hasattr(example_geographic_partition, "perimeter")
assert hasattr(example_geographic_partition, "exterior_boundaries")
assert hasattr(example_geographic_partition, "interior_boundaries")
assert hasattr(example_geographic_partition, "boundary_nodes")
assert hasattr(example_geographic_partition, "cut_edges")
assert hasattr(example_geographic_partition, "area")
assert hasattr(example_geographic_partition, "cut_edges_by_part")