<a href="https://colab.research.google.com/github/drscook/MathVGerrmandering_CMAT_2022/blob/main/run_mcmc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## As of June 13, 2022, the version of Gerrychain we get by
## using simple pip install has a bug
## "module 'functools' has no attribute 'cached_property'"
## We can avoid this using the more complex conda install process below
! pip install -q condacolab
import condacolab
condacolab.install()  
! mamba install -q -y -c conda-forge gerrychain geopandas
from IPython import get_ipython
get_ipython().kernel.do_shutdown(True)

## Now can use simple pip install to get other necessary packages
! pip install pandas-bokeh 
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import warnings, pathlib, functools, dataclasses, math
import numpy as np, pandas as pd, geopandas as gpd, networkx as nx, gerrychain as gc
import matplotlib.pyplot as plt, pandas_bokeh
from gerrychain.proposals import recom
from gerrychain.accept import always_accept

warnings.filterwarnings('ignore', message='.*initial implementation of Parquet.*')
warnings.filterwarnings('ignore', message='.*Setting custom attributes on geometry objects is deprecated*.')
warnings.filterwarnings('ignore', message='.*Iteration over multi-part geometries is deprecated*.')

pandas_bokeh.output_notebook()
root_path  = pathlib.Path('/content/drive/MyDrive/gerrymandering/summer2022')
data_file  = root_path / 'data/vtd/final.parquet'
graph_file = root_path / 'graph.json'

meters_per_mile = 1609.34
try:
    # uses the graph file if you already have it
    G = gc.Graph.from_json(graph_file)
except FileNotFoundError:
    # else creates the graph file - this takes a long time so try to keep that graph file available so you don't have to rebuild it
    gdf = gpd.read_parquet(data_file, columns=['geometry'])
    G = gc.Graph.from_geodataframe(gdf, reproject=False)
    G.to_json(graph_file)

In [None]:
@dataclasses.dataclass
class MyGerryChain():
    pop_tol:float = 0.10
    plan   :str = 'plans2168'
    tally  :tuple = ('pop', 'p2:hispanic or latino',)
    
    def __post_init__(self):
        cols = ['fips', 'county', self.plan] + list(self.tally)
        # print(cols)
        self.geo = gpd.read_parquet(data_file, columns=['geometry'])
        self.data = pd.read_parquet(data_file, columns=cols)
        self.parts = np.unique(self.data[self.plan])
        self.ideal_pop = self.data['pop'].sum() / len(self.parts)
        self.data['seats'] = self.data['pop'] / self.ideal_pop
        self.graph = gc.Graph.from_json(graph_file)
        self.graph.add_data(self.data)
        self.updaters = {u: gc.updaters.Tally(u) for u in self.tally}
        self.updaters['cut_edges'] = gc.updaters.cut_edges
        self.constraints = dict()
        
    def setup(self):
        self.initial_partition = gc.GeographicPartition(
            self.graph,
            assignment=self.plan,
            updaters=self.updaters,
        )
        # intialize each updater with value from initial_partition
        for u in self.updaters.values():
            u(self.initial_partition)
        
    def run(self, steps=10):
        proposal = functools.partial(recom,
            pop_col='pop',
            pop_target=self.ideal_pop,
            epsilon=0.02,
            node_repeats=2
            )

        chain = gc.MarkovChain(
            proposal=proposal,
            constraints=self.constraints.values(),
            accept=always_accept,
            initial_state=self.initial_partition,
            total_steps=steps
        )
        self.partitions = list(chain)

class PolsbyPopper():
    def __call__(self, partition):
        d = {part: 4*math.pi*partition['area'][part] / perim**2 for part, perim in partition['perimeter'].items()}
        self.value = sum(d.values())
        if not hasattr(self, 'initial_value'):
            self.initial_value = self.value
        return d

class PopDev():
    def __call__(self, partition):
        if not hasattr(self, 'target'):
            self.target = sum(partition['pop'].values()) / len(partition)
        d = {part: pop / self.target - 1 for part, pop in partition['pop'].items()}
        self.value = max(abs(x) for x in d.values())
        if not hasattr(self, 'initial_value'):
            self.initial_value = self.value
        return d

class Defect():
    def __call__(self, partition):
        nodes = partition.graph.nodes
        D = dict()
        counties = dict()
        for n, part in partition.assignment.items():
            county = nodes[n]['county']
            D.setdefault(part, set()).add(county)
            counties.setdefault(county, dict())
            counties[county].setdefault('seats', 0.0)
            counties[county]['seats'] += nodes[n]['seats']

        for part, county_list in D.items():
            for county in county_list:
                counties[county].setdefault('contains', set())
                counties[county].setdefault('intersects', set())
                counties[county]['intersects'].add(part)
                if len(county_list) == 1:
                    counties[county]['contains'].add(part)

        for county, data in counties.items():
            data['defect'] = abs(math.ceil(data['seats']) - len(data['intersects'])) + abs(math.floor(data['seats']) - len(data['contains']))
        d = {county: data['defect'] for county, data in counties.items()}
        self.value = sum(d.values())
        if not hasattr(self, 'initial_value'):
            self.initial_value = self.value
        return d

class PushHoldConstraint():
    def __init__(self, func, target=None):
        self.func = func
        self.value = self.func.value
        if target is None:
            self.target = self.value
        else:
            self.target = target

    def __call__(self, partition):
        self.func(partition)
        v = self.func.value
        if v <= max(self.target, self.value):
            self.value = v
            return True
        else:
            return False

class HoldConstraint(PushHoldConstraint):
    def __init__(self, func):
        super().__init__(func, target=func.value)

class PushConstraint(PushHoldConstraint):
    def __init__(self, func):
        super().__init__(func, target=0)


### PLOT ###
def plot(chain, file=None):
    # image generation code
    height = 600
    colormap = "Paired"

    xlim = [-106.2, -94.0]
    ylim = [ 25.4 ,  36.6]
    width = round((xlim[1] - xlim[0]) / (ylim[1] - ylim[0]) * height)

    B = pd.concat([p.assignment.to_series() for p in chain.partitions], axis=1)
    clr = dict(enumerate(np.linspace(0, 256, B[0].nunique()).round().astype(int)))
    B = B.replace(clr)
    steps = [str(x) for x in B.columns]
    B.columns = steps

    X = chain.geo.join(B).reset_index()#.iloc[:500]
    fig = X.plot_bokeh(
        simplify_shapes=100,
        hovertool_string = f'@county<br>@vtd<br>',# district @{step}<br>pop=@total_pop',
        slider=steps,
        slider_name="step",
        fill_alpha = 0.8,
        line_alpha = 0.00,
        show_colorbar = False,
        xlim = xlim,
        ylim = ylim,
        figsize = (width, height),
        colormap = colormap,
        return_html = True,
        show_figure = True,
    )
    try:
        with open(file, 'w') as f:
            f.write(fig)
    except TypeError:
        pass

In [None]:
chain = MyGerryChain()

## add extra updaters (if needed) between init and setup
chain.updaters['pop_dev'] = PopDev()
chain.updaters['polsby_popper'] = PolsbyPopper()
chain.updaters['defect'] = Defect()

## add extra constraints (if needed) between setup and run
chain.setup()
chain.constraints['pop_dev'] = PushHoldConstraint(chain.updaters['pop_dev'], target=0.03)
chain.constraints['defect']  = HoldConstraint(chain.updaters['defect'])

chain.run(steps=10)

In [None]:
plot(chain)