In [1]:
import random
import networkx as nx
from gerrychain.random import random
import csv
import os
from functools import partial
import json
import random
import numpy as np
from datetime import datetime

import geopandas as gpd
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib

from gerrychain import (
    Election,
    Graph,
    MarkovChain,
    Partition,
    accept,
    constraints,
    updaters,
)
from gerrychain.metrics import efficiency_gap, mean_median
from gerrychain.proposals import recom
from gerrychain.updaters import cut_edges, Tally
from gerrychain.tree import PopulatedGraph, contract_leaves_until_balanced_or_none, recursive_tree_part, predecessors, bipartition_tree
from networkx.algorithms import tree
from collections import deque, namedtuple
from state_dictionaries import *

In [2]:
from functools import wraps
import errno
import os
import signal

class TimeoutError(Exception):
    pass

def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
    def decorator(func):
        def _handle_timeout(signum, frame):
            raise TimeoutError(error_message)

        def wrapper(*args, **kwargs):
            signal.signal(signal.SIGALRM, _handle_timeout)
            signal.alarm(seconds)
            try:
                result = func(*args, **kwargs)
            finally:
                signal.alarm(0)
            return result

        return wraps(func)(wrapper)

    return decorator

In [3]:
first_check_counties = True

def division_random_spanning_tree(graph, division_col="COUNTYFP10", low_weight = 1, high_weight = 10):
    for edge in graph.edges:
        if graph.nodes[edge[0]][division_col] == graph.nodes[edge[1]][division_col]:
            graph.edges[edge]["weight"] = low_weight + random.random()
        else:
            graph.edges[edge]["weight"] = high_weight + random.random()
    spanning_tree = tree.minimum_spanning_tree(
        graph, algorithm="kruskal", weight="weight"
    )
    return spanning_tree

def split_tree_at_division(h, choice=random.choice, division_col="COUNTYFP10"):
    root = choice([x for x in h if h.degree(x) > 1])
    # BFS predecessors for iteratively contracting leaves
    pred = predecessors(h.graph, root)

    leaves = deque(x for x in h if h.degree(x) == 1)
    while len(leaves) > 0:
        leaf = leaves.popleft()
        parent = pred[leaf]
        if h.graph.nodes[parent][division_col] != h.graph.nodes[leaf][division_col] and h.has_ideal_population(leaf):
            return h.subsets[leaf]
        # Contract the leaf:
        h.contract_node(leaf, parent)
        if h.degree(parent) == 1 and parent != root:
            leaves.append(parent)
    return None


def division_bipartition_tree(
    graph,
    pop_col,
    pop_target,
    epsilon,
    division_col="COUNTYFP10",
    node_repeats=1,
    spanning_tree=None,
    choice=random.choice,
    attempts_before_giveup = 100):

    populations = {node: graph.nodes[node][pop_col] for node in graph}

    balanced_subtree = None
    if spanning_tree is None:
        spanning_tree = division_random_spanning_tree(graph, division_col=division_col)
    restarts = 0
    counter = 0
    while balanced_subtree is None and counter < attempts_before_giveup:
        # print(counter)
        if restarts == node_repeats:
            spanning_tree = division_random_spanning_tree(graph, division_col=division_col)
            restarts = 0
            counter +=1
        h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon)
        if first_check_counties and restarts == 0:
            balanced_subtree = split_tree_at_division(h, choice=choice, division_col=division_col)
        if balanced_subtree is None:
            h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon)
            balanced_subtree = contract_leaves_until_balanced_or_none(h, choice=choice)
        restarts += 1

    if counter >= attempts_before_giveup:
        return set()
    return balanced_subtree

In [4]:
def num_splits(partition, unit_df, geo_id ='GEOID10', division_col = "COUNTYFP10"):
    ### since nodes aren't labeled by geo_id
    idToAss = {}
    partitionDict = dict(partition.assignment)
    for i in range(len(unit_df)):
        ID = unit_df[geo_id][i]
        idx = unit_df.index[unit_df[geo_id] == ID].tolist()[0]
        idToAss[ID] = partitionDict[idx]
    unit_df["current"] = unit_df[geo_id].map(idToAss)
    ###

    splits = sum(unit_df.groupby(division_col)["current"].nunique() > 1)
    return splits

In [5]:
def addPartitionToJSON(graph, partition, state, level, keyname):
    fileName = "./Seeding-Division-Splits/Output/" + level + "_seed/" + state + ".json"
    assignment = partition.assignment
    ordered_assignment = [assignment[i] for i in range(len(graph.nodes))]
    assign_df = pd.DataFrame({keyname: ordered_assignment})
    graph.add_data(assign_df)
    graph.to_json(fileName)
    node_data = dict(graph.nodes.data())
    for i in range(len(graph.nodes)):
        assert ordered_assignment[i] == node_data[i][keyname]
    print("Done writing to", fileName)

In [6]:
# @timeout(180)
def generate_good_seed_plan(state, level, keyname):
    graph_path = "./Seeding-Division-Splits/Output/HD_seed/" + state + ".json"
    
    graph = Graph.from_json(graph_path)
    with open(graph_path) as json_data:
        data = json.load(json_data)
    unit_df = pd.DataFrame(data['nodes'])
    pop_col = "TOTPOP"
    
    if level == "CD":
        ep = 0.01
        k = cd_dict[state]
        ass_col = "SPLITS_SEED_" + level
        threshold = splitsDict[state][0]
        burn = 10*k
    elif level == "SD":
        ep = 0.05
        k = senate_dict[state]
        ass_col = "SPLITS_SEED_" + level
        threshold = splitsDict[state][1]
        burn = 15*k
    elif level == "HD":
        ep = 0.05
        k = house_dict[state]
        ass_col = "SPLITS_SEED_" + level
        threshold = splitsDict[state][2]
        burn = 10*k
    else:
        print("error.")
        return
    
    unit_col = "GEOID10" #  change if not GEOID10
    division_col = "COUNTYFP10" #  change if not COUNTYFP10
    n_divisions = splitsDict[state][3]
    
    updaters = {
        "population": Tally(pop_col, alias="population")
    }
    
    if ass_col == -1:
        print("Running 'recursive_tree_part'...")
        try:
            cddict = recursive_tree_part(graph,
                                         range(k),
                                         unit_df[pop_col].sum()/k,
                                         pop_col,
                                         ep,
                                         node_repeats=1)
            print("Ran 'recursive_tree_part'!")
            initial_partition = Partition(graph, cddict, updaters=updaters)
            
            ideal_population = sum(initial_partition["population"].values()) / len(initial_partition)
            division_proposal = partial(recom,
                                        pop_col=pop_col,
                                        pop_target=ideal_population,
                                        epsilon=ep,  
                                        method=partial(division_bipartition_tree, 
                                                                    division_col = division_col), 
                                                                    node_repeats=2)

            chain = MarkovChain(proposal=division_proposal,
                                constraints=[constraints.within_percent_of_ideal_population(initial_partition, ep),],
                                accept=accept.always_accept,
                                initial_state=initial_partition,
                                total_steps=burn+5)
        except ValueError:
            print("Take 2: recursive_tree_part...")
            try:
                cddict = recursive_tree_part(graph,
                                         range(k),
                                         unit_df[pop_col].sum()/k,
                                         pop_col,
                                         ep,
                                         node_repeats=1)
                initial_partition = Partition(graph, cddict, updaters=updaters)

                ideal_population = sum(initial_partition["population"].values()) / len(initial_partition)
                division_proposal = partial(recom,
                                            pop_col=pop_col,
                                            pop_target=ideal_population,
                                            epsilon=ep,  
                                            method=partial(division_bipartition_tree, 
                                                                        division_col = division_col), 
                                                                        node_repeats=2)

                chain = MarkovChain(proposal=division_proposal,
                                    constraints=[constraints.within_percent_of_ideal_population(initial_partition, ep),],
                                    accept=accept.always_accept,
                                    initial_state=initial_partition,
                                    total_steps=burn+5)
            except ValueError:
                print("Value Error: the given initial_state is not valid according to is_valid")
                return
    else:
        initial_partition = Partition(graph, assignment=ass_col, updaters=updaters)
        
        ideal_population = sum(initial_partition["population"].values()) / len(initial_partition)
        division_proposal = partial(recom,
                                    pop_col=pop_col,
                                    pop_target=ideal_population,
                                    epsilon=ep,  method=partial(division_bipartition_tree, 
                                                                division_col = division_col), 
                                                                node_repeats=2)

        chain = MarkovChain(proposal=division_proposal,
                            constraints=[constraints.within_percent_of_ideal_population(initial_partition, ep),],
                            accept=accept.always_accept,
                            initial_state=initial_partition,
                            total_steps=burn+5)

    t=0
    min_s = 999999
    orig_county_splits = num_splits(initial_partition, unit_df, geo_id =unit_col, division_col = division_col)
    
    print("Starting chain, original county splits: %d / %d" % (orig_county_splits, n_divisions))
    print("Want to get below " + str(threshold) + " county splits [using enacted " + level + "s]")
    
    if orig_county_splits < threshold or orig_county_splits < 2: # might not be good to have second Bool...
        print("Success!")
        print("Seed plan already respects county splits well.")
        addPartitionToJSON(graph, initial_partition, state, level, keyname)
        return orig_county_splits, n_divisions, threshold
    print("We won't record the first " + str(burn) + " tries to get to a low # of splits...")
    begin = datetime.now()
    for part in chain:
        if t < burn:
            t += 1
            continue
        s = num_splits(part, unit_df, geo_id =unit_col, division_col = division_col)
        end = datetime.now()
        if s < threshold:
            print("Success!")
        else:
            print("Failure")
        print("In ", str(end-begin), "found a partition with county splits = " + str(s) + "/" + str(n_divisions))
        if s > orig_county_splits: # only update JSON if you improve on county splits
            part = initial_partition
        addPartitionToJSON(graph, part, state, level, keyname)
        return s, n_divisions, threshold

In [7]:
def generate_seeds_for(level, states=states_list):
    for state in states:
        if level == "CD" and state == "DE":
            continue
        infoFile_path = "./Seeding-Division-Splits/Output/" + level + "_seed/states.txt"
        begin = datetime.now()
        print("Generating seed for", state)
        try:
            n_s, n_d, thresh = generate_good_seed_plan(state, level, "SPLITS_SEED_" + level)
            infoFile = open(infoFile_path, "a")
            if n_s >= thresh:
                infoFile.writelines(state + ": " + str(n_s) + "/" + str(n_d) + " (" + str(thresh) + ") --- splits/counties (enacted splits) --- NOT BETTER THAN ENACTED\n")
            else:
                infoFile.writelines(state + ": " + str(n_s) + "/" + str(n_d) + " (" + str(thresh) + ") --- splits/counties (enacted splits)\n")
            infoFile.close()
        except:
            print(state + " raised an error")
            infoFile = open(infoFile_path, "a")
            infoFile.writelines(state + ": ERROR\n")
            infoFile.close()
        finally:
            end = datetime.now()
            print("Finished", state, "in", str(end-begin))
            print("-------")

In [8]:
%%time
generate_seeds_for("HD", states=["AZ","NV","NJ","NC","TN","TX"])

Generating seed for AZ
Starting chain, original county splits: 10 / 15
Want to get below 10 county splits [using enacted HDs]
We won't record the first 600 tries to get to a low # of splits...
AZ raised an error
Finished AZ in 0:00:18.792292
-------
Generating seed for NV
Starting chain, original county splits: 5 / 17
Want to get below 5 county splits [using enacted HDs]
We won't record the first 420 tries to get to a low # of splits...
NV raised an error
Finished NV in 0:00:00.972586
-------
Generating seed for NJ
NJ raised an error
Finished NJ in 0:00:04.356765
-------
Generating seed for NC
NC raised an error
Finished NC in 0:00:00.243951
-------
Generating seed for TN
TN raised an error
Finished TN in 0:00:00.170410
-------
Generating seed for TX
TX raised an error
Finished TX in 0:00:00.636239
-------
CPU times: user 24.2 s, sys: 581 ms, total: 24.8 s
Wall time: 25.2 s
