In [1]:
import matplotlib.pyplot as plt
from gerrychain import (Partition, Graph, MarkovChain,
                        updaters, constraints, accept)
from gerrychain.proposals import recom
from gerrychain.constraints import contiguous
from functools import partial
import pandas

# Set the random seed so that the results are reproducible.
import random
random.seed(2024)

In [2]:
graph = Graph.from_json("./gerrymandria.json")

In [3]:
my_updaters = {
    "population": updaters.Tally("TOTPOP"),
    "cut_edges": updaters.cut_edges
}

initial_partition = Partition(
    graph,
    assignment="district",
    updaters=my_updaters
)

In [4]:
" Make the Proposal "

# This should be 8 since each district has 1 person in it.
# Note that the key "population" corresponds to the population updater
# that we defined above and not with the population column in the json file.
ideal_population = sum(initial_partition["population"].values()) / len(initial_partition)

proposal = partial(
    recom,
    pop_col="TOTPOP",
    pop_target=ideal_population,
    epsilon=0.01,
    node_repeats=2
)

In [5]:
" Set up the chain "
recom_chain = MarkovChain(
    proposal=proposal,
    constraints=[contiguous],
    accept=accept.always_accept,
    initial_state=initial_partition,
    total_steps=40
)

In [6]:
" Run the chain "

assignment_list = []

for i, item in enumerate(recom_chain):
    print(f"Finished step {i+1}/{len(recom_chain)}", end="\r")
    assignment_list.append(item.assignment)

Finished step 40/40

In [None]:
" collect the assignment at each step of the chain so that we can watch the chain work in a fun animation "

%matplotlib inline
import matplotlib_inline.backend_inline
import matplotlib.cm as mcm
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
import io
import ipywidgets as widgets
from IPython.display import display, clear_output

frames = []

for i in range(len(assignment_list)):
    fig, ax = plt.subplots(figsize=(8,8))
    pos = {node :(data['x'],data['y']) for node, data in graph.nodes(data=True)}
    node_colors = [mcm.tab20(int(assignment_list[i][node]) % 20) for node in graph.nodes()]
    node_labels = {node: str(assignment_list[i][node]) for node in graph.nodes()}

    nx.draw_networkx_nodes(graph, pos, node_color=node_colors)
    nx.draw_networkx_edges(graph, pos)
    nx.draw_networkx_labels(graph, pos, labels=node_labels)
    plt.axis('off')

    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    image = Image.open(buffer)
    frames.append(image)
    plt.close(fig)

def show_frame(idx):
    clear_output(wait=True)
    display(frames[idx])

slider = widgets.IntSlider(value=0, min=0, max=len(frames)-1, step=1, description='Frame:')
slider.layout.width = '500px'
widgets.interactive(show_frame, idx=slider)