From 75b065b857f5edf2a3e647c07e974304ee37949b Mon Sep 17 00:00:00 2001 From: peterrrock2 Date: Fri, 19 Jan 2024 15:03:35 -0700 Subject: [PATCH] Fix issue #319 --- .circleci/config.yml | 2 + gerrychain/__init__.py | 5 +- gerrychain/graph/graph.py | 12 ++++ gerrychain/partition/geographic.py | 3 +- gerrychain/proposals/tree_proposals.py | 66 +++++++++++++----- gerrychain/tree.py | 54 ++++++++++++--- tests/test_region_aware.py | 95 ++++++++++++++++++++++---- tests/test_reproducibility.py | 8 +-- 8 files changed, 200 insertions(+), 45 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b21754e6..d2788eb9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -42,6 +42,8 @@ jobs: echo "backend: Agg" > "matplotlibrc" pytest -v --runslow --cov=gerrychain --junitxml=test-reports/junit.xml tests codecov + + no_output_timeout: 20m environment: PYTHONHASHSEED: "0" - store_test_results: diff --git a/gerrychain/__init__.py b/gerrychain/__init__.py index 27626923..2b138538 100644 --- a/gerrychain/__init__.py +++ b/gerrychain/__init__.py @@ -6,7 +6,10 @@ from .graph import Graph from .partition import GeographicPartition, Partition from .updaters.election import Election - + +# Will need to change this to a logging option later +# It might be good to see how often this happens +warnings.simplefilter("once") try: import geopandas diff --git a/gerrychain/graph/graph.py b/gerrychain/graph/graph.py index df6bb585..3f3dc28b 100644 --- a/gerrychain/graph/graph.py +++ b/gerrychain/graph/graph.py @@ -145,6 +145,18 @@ def from_file( :returns: The Graph object of the geometries from `filename`. :rtype: Graph + + .. Warning:: + + This method requires the optional ``geopandas`` dependency. + So please install ``gerrychain`` with the ``geo`` extra + via the command: + + .. code-block:: console + + pip install gerrychain[geo] + + or install ``geopandas`` separately. """ import geopandas as gp diff --git a/gerrychain/partition/geographic.py b/gerrychain/partition/geographic.py index 04ab4b2a..7dc03386 100644 --- a/gerrychain/partition/geographic.py +++ b/gerrychain/partition/geographic.py @@ -11,7 +11,8 @@ class GeographicPartition(Partition): - """A :class:`Partition` with areas, perimeters, and boundary information included. + """ + A :class:`Partition` with areas, perimeters, and boundary information included. These additional data allow you to compute compactness scores like `Polsby-Popper `_. """ diff --git a/gerrychain/proposals/tree_proposals.py b/gerrychain/proposals/tree_proposals.py index 6dfe84b5..d110117b 100644 --- a/gerrychain/proposals/tree_proposals.py +++ b/gerrychain/proposals/tree_proposals.py @@ -6,11 +6,19 @@ from ..tree import ( recursive_tree_part, bipartition_tree, bipartition_tree_random, _bipartition_tree_random_all, uniform_spanning_tree, - find_balanced_edge_cuts_memoization, + find_balanced_edge_cuts_memoization, ReselectException, ) from typing import Callable, Optional, Dict, Union +class MetagraphError(Exception): + """ + Raised when the partition we are trying to split is a low degree + node in the metagraph. + """ + pass + + def recom( partition: Partition, pop_col: str, @@ -70,26 +78,52 @@ def recom( :rtype: Partition """ - edge = random.choice(tuple(partition["cut_edges"])) - parts_to_merge = (partition.assignment.mapping[edge[0]], partition.assignment.mapping[edge[1]]) - - subgraph = partition.graph.subgraph( - partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] - ) + bad_district_pairs = set() + n_parts = len(partition) + tot_pairs = n_parts * (n_parts - 1) / 2 # n choose 2 # Try to add the region aware in if the method accepts the weight dictionary if 'weight_dict' in signature(method).parameters: method = partial(method, weight_dict=weight_dict) - flips = recursive_tree_part( - subgraph.graph, - parts_to_merge, - pop_col=pop_col, - pop_target=pop_target, - epsilon=epsilon, - node_repeats=node_repeats, - method=method, - ) + while len(bad_district_pairs) < tot_pairs: + try: + while True: + edge = random.choice(tuple(partition["cut_edges"])) + # Need to sort the tuple so that the order is consistent + # in the bad_district_pairs set + parts_to_merge = [partition.assignment.mapping[edge[0]], + partition.assignment.mapping[edge[1]]] + parts_to_merge.sort() + + if tuple(parts_to_merge) not in bad_district_pairs: + break + + subgraph = partition.graph.subgraph( + partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] + ) + + flips = recursive_tree_part( + subgraph.graph, + parts_to_merge, + pop_col=pop_col, + pop_target=pop_target, + epsilon=epsilon, + node_repeats=node_repeats, + method=method, + ) + break + + except Exception as e: + if isinstance(e, ReselectException): + bad_district_pairs.add(tuple(parts_to_merge)) + continue + else: + raise + + if len(bad_district_pairs) == tot_pairs: + raise MetagraphError(f"Bipartitioning failed for all {tot_pairs} district pairs." + f"Consider rerunning the chain with a different random seed.") return partition.flip(flips) diff --git a/gerrychain/tree.py b/gerrychain/tree.py index 8aa791e4..d575b0b6 100644 --- a/gerrychain/tree.py +++ b/gerrychain/tree.py @@ -35,6 +35,7 @@ import random from collections import deque, namedtuple from typing import Any, Callable, Dict, List, Optional, Set, Union, Hashable, Sequence, Tuple +import warnings def predecessors(h: nx.Graph, root: Any) -> Dict: @@ -295,6 +296,22 @@ def part_nodes(start): return cuts +class BipartitionWarning(UserWarning): + """ + Generally raised when it is proving difficult to find a balanced cut. + """ + pass + + +class ReselectException(Exception): + """ + Raised when the algorithm is unable to find a balanced cut after some + maximum number of attempts, but the user has allowed the algorithm to + reselect the pair of nodes to try and recombine. + """ + pass + + def bipartition_tree( graph: nx.Graph, pop_col: str, @@ -306,7 +323,8 @@ def bipartition_tree( weight_dict: Optional[Dict] = None, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, choice: Callable = random.choice, - max_attempts: Optional[int] = 10000 + max_attempts: Optional[int] = 10000, + allow_pair_reselection: bool = False ) -> Set: """ This function finds a balanced 2 partition of a graph by drawing a @@ -347,10 +365,15 @@ def bipartition_tree( :param max_attempts: The maximum number of attempts that should be made to bipartition. Defaults to 1000. :type max_attempts: Optional[int], optional + :param allow_pair_reselection: Whether we would like to return an error to the calling + function to ask it to reselect the pair of nodes to try and recombine. Defaults to False. + :type allow_pair_reselection: bool, optional :returns: A subset of nodes of ``graph`` (whose induced subgraph is connected). The other part of the partition is the complement of this subset. :rtype: Set + + :raises BipartitionWarning: If a possible cut cannot be found after 50 attempts. :raises RuntimeError: If a possible cut cannot be found after the maximum number of attempts. """ # Try to add the region-aware in if the spanning_tree_fn accepts a weight dictionary @@ -378,6 +401,17 @@ def bipartition_tree( restarts += 1 attempts += 1 + if attempts == 50 and not allow_pair_reselection: + warnings.warn("Failed to find a balanced cut after 50 attempts.\n" + "Consider running with the parameter\n" + "allow_pair_reselection=True to allow the algorithm to\n" + "select a different pair of nodes to try an recombine.", + BipartitionWarning) + + if allow_pair_reselection: + raise ReselectException(f"Failed to find a balanced cut after {max_attempts} attempts.\n" + f"Selecting a new district pair") + raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.") @@ -589,13 +623,17 @@ def recursive_tree_part( min_pop = max(pop_target * (1 - epsilon), pop_target * (1 - epsilon) - debt) max_pop = min(pop_target * (1 + epsilon), pop_target * (1 + epsilon) - debt) new_pop_target = (min_pop + max_pop) / 2 - nodes = method( - graph.subgraph(remaining_nodes), - pop_col=pop_col, - pop_target=new_pop_target, - epsilon=(max_pop - min_pop) / (2 * new_pop_target), - node_repeats=node_repeats, - ) + + try: + nodes = method( + graph.subgraph(remaining_nodes), + pop_col=pop_col, + pop_target=new_pop_target, + epsilon=(max_pop - min_pop) / (2 * new_pop_target), + node_repeats=node_repeats, + ) + except Exception: + raise if nodes is None: raise BalanceError() diff --git a/tests/test_region_aware.py b/tests/test_region_aware.py index 56d97de9..f06dbe92 100644 --- a/tests/test_region_aware.py +++ b/tests/test_region_aware.py @@ -3,7 +3,10 @@ import pytest from functools import partial from concurrent.futures import ProcessPoolExecutor -from gerrychain import MarkovChain, Partition, accept, constraints, proposals, updaters, Graph, tree +from gerrychain import ( + MarkovChain, Partition, accept, + constraints, proposals, updaters, Graph, tree) +from gerrychain.tree import ReselectException, BipartitionWarning def total_reg_splits(partition, reg_attr, all_reg_names): """Returns the total number of region splits in the partition.""" @@ -18,8 +21,9 @@ def total_reg_splits(partition, reg_attr, all_reg_names): return sum(1 for value in split.values() if value > 0) -def run_chain_single(seed, category, names, steps): +def run_chain_single(seed, category, names, steps, weight, max_attempts=100000, reselect=False): from gerrychain import MarkovChain, Partition, accept, constraints, proposals, updaters, Graph, tree + from gerrychain.tree import ReselectException from functools import partial import random @@ -46,7 +50,9 @@ def run_chain_single(seed, category, names, steps): epsilon=epsilon, weight_dict=weights, node_repeats=10, - method=partial(tree.bipartition_tree, max_attempts=1000000)) + method=partial(tree.bipartition_tree, + max_attempts=max_attempts, + allow_pair_reselection=reselect)) weighted_chain = MarkovChain(proposal=weighted_proposal, constraints=[constraints.contiguous], @@ -65,24 +71,83 @@ def test_region_aware_muni(): n_samples = 30 region = "muni" region_names = [str(i) for i in range(1,17)] + + with ProcessPoolExecutor() as executor: + results = executor.map(partial(run_chain_single, + category=region, + names=region_names, + steps=500, + weight=0.5), + range(n_samples)) + + tot_splits = sum(results) + + random.seed(2018) + # Check if splits less than 5% of the time on average + assert (float(tot_splits) / (n_samples*len(region_names))) < 0.05 + + +def test_region_aware_muni_errors(): + region = "muni" + region_names = [str(i) for i in range(1,17)] + + with pytest.raises(RuntimeError) as exec_info: + # Random seed 0 should fail here + run_chain_single(seed=0, + category=region, + names=region_names, + steps=10000, + max_attempts=10, + weight=2.0) + + random.seed(2018) + assert "Could not find a possible cut after 10 attempts" in str(exec_info.value) + + +def test_region_aware_muni_warning(): + n_samples = 1 + region = "muni" + region_names = [str(i) for i in range(1,17)] + + with pytest.warns(UserWarning) as record: + # Random seed 2 should succeed, but drawing the + # tree is hard, so we should get a warning + run_chain_single(seed=2, + category=region, + names=region_names, + steps=500, + weight=1.0) + + random.seed(2018) + + assert record[0].category == BipartitionWarning + assert "Failed to find a balanced cut after 50 attempts." in str(record[0].message) + +@pytest.mark.slow +def test_region_aware_muni_reselect(): + n_samples = 30 + region = "muni" + region_names = [str(i) for i in range(1,17)] with ProcessPoolExecutor() as executor: results = executor.map(partial(run_chain_single, category=region, names=region_names, - steps=500), + steps=500, + weight=1.0, + reselect=True), range(n_samples)) tot_splits = sum(results) random.seed(2018) - # Check if splits less than 1% of the time on average - assert (float(tot_splits) / (n_samples*len(region_names))) < 0.01 + # Check if splits less than 5% of the time on average + assert (float(tot_splits) / (n_samples*len(region_names))) < 0.05 @pytest.mark.slow def test_region_aware_county(): - n_samples = 30 + n_samples = 100 region = "county2" region_names = [str(i) for i in range(1,9)] @@ -90,13 +155,15 @@ def test_region_aware_county(): results = executor.map(partial(run_chain_single, category=region, names=region_names, - steps=10000), + steps=5000, + weight=2.0), range(n_samples)) tot_splits = sum(results) random.seed(2018) # Check if splits less than 5% of the time on average + print(f"Final score: {float(tot_splits) / (n_samples*len(region_names))}") assert (float(tot_splits) / (n_samples*len(region_names))) < 0.05 @@ -144,7 +211,7 @@ def run_chain_dual(seed, steps): epsilon=epsilon, weight_dict=weights, node_repeats=10, - method=partial(tree.bipartition_tree, max_attempts=1000000)) + method=partial(tree.bipartition_tree, max_attempts=10000)) weighted_chain = MarkovChain(proposal=weighted_proposal, constraints=[constraints.contiguous], @@ -165,7 +232,7 @@ def test_region_aware_dual(): n_samples = 30 n_munis = 16 n_counties = 4 - + with ProcessPoolExecutor() as executor: results = executor.map(partial(run_chain_dual, steps=10000), @@ -176,8 +243,6 @@ def test_region_aware_dual(): random.seed(2018) - # Check if splits less than 1% of the time on average - # The condition on counties is stricter this time since the - # munis and districts can be made to fit neatly within the counties - assert (float(tot_muni_splits) / (n_samples*n_munis)) < 0.01 and \ - (float(tot_county_splits) / (n_samples*n_counties)) < 0.01 \ No newline at end of file + # Check if splits less than 5% of the time on average + assert (float(tot_muni_splits) / (n_samples*n_munis)) < 0.05 + assert (float(tot_county_splits) / (n_samples*n_counties)) < 0.05 \ No newline at end of file diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 141f8b99..765e1ff1 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -110,7 +110,7 @@ def test_pa_freeze(): result += str(len(partition.cut_edges)) result += str(count) + "\n" - # print(hashlib.sha256(result.encode()).hexdigest()) - assert hashlib.sha256(result.encode()).hexdigest() == "957e5bd59fc2730707c6549f52dc8834ac48e5f37f0e37b71a04f6734a287b14" - - # "3bef9ac8c0bfa025fb75e32aea3847757a8fba56b2b2be6f9b3b952088ae3b3c" \ No newline at end of file + # This needs to be changed every time we change the + # tests around + assert hashlib.sha256(result.encode()).hexdigest() == "0163d0bfff9e090c06ecae4c999a7e8e55e9cee65b08939f485739a72e5c30f4" + \ No newline at end of file