Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to disable cut_edges updater with use_cut_edges flag #375

Merged
merged 1 commit into from Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 14 additions & 7 deletions gerrychain/partition/partition.py
Expand Up @@ -16,7 +16,6 @@ class Partition:
:ivar dict parts: Maps district IDs to the set of nodes in that district.
:ivar dict subgraphs: Maps district IDs to the induced subgraph of that district.
"""
default_updaters = {"cut_edges": cut_edges}
__slots__ = (
'graph',
'subgraphs',
Expand All @@ -30,26 +29,28 @@ class Partition:
)

def __init__(
self, graph=None, assignment=None, updaters=None, parent=None, flips=None
self, graph=None, assignment=None, updaters=None, parent=None, flips=None,
use_cut_edges=True
):
"""
:param graph: Underlying graph.
:param assignment: Dictionary assigning nodes to districts.
:param updaters: Dictionary of functions to track data about the partition.
The keys are stored as attributes on the partition class,
which the functions compute.
:param use_cut_edges: If `False`, do not include `cut_edges` updater by default
and do not calculate edge flows.
"""
if parent is None:
self._first_time(graph, assignment, updaters)
self._first_time(graph, assignment, updaters, use_cut_edges)
else:
self._from_parent(parent, flips)

self._cache = dict()
self.subgraphs = SubgraphView(self.graph, self.parts)

def _first_time(self, graph, assignment, updaters):
def _first_time(self, graph, assignment, updaters, use_cut_edges):
pizzimathy marked this conversation as resolved.
Show resolved Hide resolved
self.graph = graph

self.assignment = get_assignment(assignment, graph)

if set(self.assignment) != set(graph):
Expand All @@ -58,7 +59,11 @@ def _first_time(self, graph, assignment, updaters):
if updaters is None:
updaters = dict()

self.updaters = self.default_updaters.copy()
if use_cut_edges:
self.updaters = {"cut_edges": cut_edges}
else:
self.updaters = {}

self.updaters.update(updaters)
pizzimathy marked this conversation as resolved.
Show resolved Hide resolved

self.parent = None
Expand All @@ -77,7 +82,9 @@ def _from_parent(self, parent, flips):
self.updaters = parent.updaters

self.flows = flows_from_changes(parent.assignment, flips)
self.edge_flows = compute_edge_flows(self)

if "cut_edges" in self.updaters:
self.edge_flows = compute_edge_flows(self)

def __repr__(self):
number_of_parts = len(self)
Expand Down
4 changes: 1 addition & 3 deletions tests/partition/test_partition.py
Expand Up @@ -56,7 +56,7 @@ def example_geographic_partition():

def test_geographic_partition_can_be_instantiated(example_geographic_partition):
partition = example_geographic_partition
assert partition.updaters == GeographicPartition.default_updaters
assert isinstance(partition, GeographicPartition)


def test_Partition_parts_is_a_dictionary_of_parts_to_nodes(example_partition):
Expand Down Expand Up @@ -144,11 +144,9 @@ def test_repr(example_partition):

def test_partition_has_default_updaters(example_partition):
partition = example_partition
default_updaters = partition.default_updaters
should_have_updaters = {"cut_edges": cut_edges}

for updater in should_have_updaters:
assert default_updaters.get(updater, None) is not None
assert should_have_updaters[updater](partition) == partition[updater]


Expand Down