In [None]:
# Imports

import os

from gerrychain import Graph, GeographicPartition, Partition, Election, accept
from gerrychain.updaters import Tally, cut_edges
import geopandas as gpd
import numpy as np
from gerrychain.random import random
import copy
import seaborn as sns

from gerrychain import MarkovChain
from gerrychain.constraints import single_flip_contiguous
from gerrychain.proposals import recom, propose_random_flip
from gerrychain.accept import always_accept
from gerrychain.metrics import polsby_popper
from gerrychain import constraints
from gerrychain.constraints import no_vanishing_districts

from collections import defaultdict, Counter

import matplotlib.pyplot as plt

import networkx as nx

import pandas

import math

from itertools import combinations_with_replacement

#from IPython.display import clear_output

from functools import partial


In [None]:
# setup -- SLOW

shapefile = "https://github.com/mggg-states/PA-shapefiles/raw/master/PA/PA_VTD.zip"

df = gpd.read_file(shapefile)

county_col = "COUNTYFP10"
pop_col = "TOT_POP"
uid = "GEOID10"


graph = Graph.from_geodataframe(df,ignore_errors=True)
graph.add_data(df,list(df))
graph = nx.relabel_nodes(graph, df[uid])
counties = (set(list(df[county_col])))
countydict = dict(graph.nodes(data=county_col))


#print(counties)
#print(countydict)

In [None]:
totpop = 0
num_districts = 18
for n in graph.nodes():
    graph.node[n]["TOT_POP"] = int(graph.node[n]["TOT_POP"])
    totpop += graph.node[n]["TOT_POP"]

In [None]:
updaters1={
        "polsby_popper" : polsby_popper,
        "cut_edges": cut_edges,
        "population": Tally(pop_col, alias="population"),

    }

In [None]:
for n in graph.nodes():
    graph.nodes[n]['538CPCT__1'] = int(graph.nodes[n]['538CPCT__1'])
    graph.nodes[n]['538DEM_PL'] = int(graph.nodes[n]['538DEM_PL'])
    graph.nodes[n]['538GOP_PL'] = int(graph.nodes[n]['538GOP_PL'])
    graph.nodes[n]['8THGRADE_1'] = int(graph.nodes[n]['8THGRADE_1'])
    
partition_2011 = Partition(graph, "2011_PLA_1", updaters1)
partition_GOV = Partition(graph, "GOV", updaters1)
partition_TS = Partition(graph, "TS", updaters1)
partition_REMEDIAL = Partition(graph, "REMEDIAL_P", updaters1)
partition_CPCT = Partition(graph, "538CPCT__1", updaters1)
partition_DEM = Partition(graph, "538DEM_PL", updaters1)
partition_GOP = Partition(graph, "538GOP_PL", updaters1)
partition_8th = Partition(graph, "8THGRADE_1", updaters1)

partitions = [partition_2011, partition_GOV, partition_TS,
                  partition_REMEDIAL, partition_CPCT, partition_DEM,
                  partition_GOP, partition_8th]

In [None]:
print(graph.nodes())
starting_partition = GeographicPartition(
    graph,
    assignment="GOV",
    updaters1
)

In [None]:
def county_splits_dict(partition):
    ''' returns a dictionary with keys as district numbers and values Counter() dictionaries
        these counter dictionaries have pairs COUNTY_ID : NUM which counts the number of VTDS
        in the county in the district
        
    '''
    
    county_splits = {k:[] for k in counties}
    county_splits = {  k:[countydict[v] for v in d] for k,d in partition.assignment.parts.items()   }
    county_splits = {k: Counter(v) for k,v in county_splits.items()}
    return county_splits

In [None]:
def district_splits_dict(county_splits):
    district_splits = {k:[] for k in counties}
    
    for county in counties:
        districts = {}
        for district in county_splits.keys():
            if county in county_splits[district].keys():
                district_splits[county].append(district)
    return district_splits            
            

In [None]:
# various functions to measure splits according to the proposed PA rule. Feel free to ignore

def pieces_allowed():
    district_splits ={}
    
    for county in counties:
        sg=graph.subgraph(n for n, v in graph.nodes(data=True) if v[county_col]==county)
        pop = 0;
        
        for n in sg.nodes():
            pop += sg.node[n]["TOT_POP"]
        
        district_splits[county] = math.ceil(pop/(totpop/num_districts)) + 1
    return district_splits

def other_pieces_allowed():
    district_splits ={}
    
    for county in counties:
        sg=graph.subgraph(n for n, v in graph.nodes(data=True) if v[county_col]==county)
        pop = 0;
        
        for n in sg.nodes():
            pop += sg.node[n]["TOT_POP"]
        
        district_splits[county] = math.ceil(pop/(totpop/num_districts))
    return district_splits

def too_many_pieces(partition):
    district_splits = district_splits_dict(county_splits_dict(partition))
    pieces = pieces_allowed()
    too_many = 0
    
    for county in counties:
        if len(district_splits[county]) > pieces[county]:
            too_many += 1
    
    return too_many

def other_too_many_pieces(partition):
    district_splits = district_splits_dict(county_splits_dict(partition))
    pieces = other_pieces_allowed()
    too_many = 0
    
    for county in counties:
        if len(district_splits[county]) > pieces[county]:
            too_many += 1
    
    return too_many

def how_many_more(partition):
    district_splits = district_splits_dict(county_splits_dict(partition))
    pieces = pieces_allowed()
    too_many = 0
    
    for county in counties:
        if len(district_splits[county]) > pieces[county]:
            too_many += len(district_splits[county]) - pieces[county]
    return too_many

In [None]:
def cut_in_county(part,sg):
    num_ce_in_count = 0
    for edge in part["cut_edges"]:
         if edge in sg.edges():
            num_ce_in_count += 1
    return num_ce_in_count

In [None]:
def our_split_score_1(part):
    sum = 0
    
    for county in counties:
        sg=graph.subgraph(n for n, v in graph.nodes(data=True) if v[county_col]==county)
        sum += cut_in_county(part,sg) / len(sg.edges())
      
    return sum

In [None]:
def our_split_score_2(part):
    ce_btn_counties = 0
    
    for ce in part["cut_edges"]:
        if int(countydict[str(ce[0])]) != int(countydict[str(ce[1])]):
            ce_btn_counties += 1
    
    return ce_btn_counties / len(part["cut_edges"])

In [None]:
#def vtds_per_county(county_splits):
#    vtds = {}
#    
#    for counter in county_splits.values():
#        for county in counter.keys():
#            if county in vtds:
#                vtds[county] += counter[county]
#            else:
#                vtds[county] = counter[county]
#    return vtds

In [None]:
def pops_per_county(county_splits,rev):
    pops = {}
    
    for county in rev.keys():
        pop = 0
        for vtd in rev[county]:
            pop += graph.nodes[vtd]["TOT_POP"]
        pops[county] = pop
    return pops

In [None]:
#def vtds_per_district(county_splits):
#    vtds = {}
#    
#    for district in county_splits.keys():
#        sum = 0
#        counter = county_splits[district]
#        for vtd in counter.values():
#            sum += vtd
#        vtds[district] = sum
#    return vtds        

In [None]:
def pops_per_district(partition):
    dictionary = dict(partition.assignment)
    pops = {}
    
    for i in range(num_districts):
        for vtd in dictionary.keys():
            if i+1 in pops and dictionary[vtd] == i + 1:
                pops[i+1] += graph.nodes[vtd]["TOT_POP"]
            elif dictionary[vtd] == i+1:
                pops[i+1] = graph.nodes[vtd]["TOT_POP"]
    return pops

In [None]:
#def total_vtds(vtds):
#    total = 0
#    
#    for county in vtds.keys():
#        total += vtds[county]
#    return total

In [None]:
#def total_pops(pops):
#    total = 0
#    for pop in pops.values():
#        total += pop
#    return total

In [None]:
def VTDs_to_Counties(partition):
    '''
    Consumes a partition which is converted into a dictionary with keys as districts
    and values as a list of VTDs that are in that district.
    Returns a dictionary with keys as districts and values
    as dictionaries of county-population key-value pairs. This represents the population
    of each county that is in each district.
    '''
    district_dict = dict(partition.parts)
    new_district_dict = dict(partition.parts)
    for district in district_dict.keys():
        vtds = district_dict[district]
        county_pop = {k:0 for k in counties}
        for vtd in vtds:
            county_pop[countydict[vtd]] += graph.nodes[vtd][pop_col]
        new_district_dict[district] = county_pop
    return new_district_dict

In [None]:
def dictionary_to_score(dictionary):
    district_dict = dictionary
    score = 0
    for dist in district_dict.keys():
        counties_and_pops = district_dict[dist]
        total = sum(counties_and_pops.values())
        fractional_sum = 0
        for county in counties_and_pops.keys():
            fractional_sum += np.sqrt(counties_and_pops[county]/total)
        score += total*fractional_sum
    return score

def invert_dict(dictionary):
    new_dict = defaultdict(dict)
    for k,v in dictionary.items():
        for k2,v2 in v.items():
            new_dict[k2][k] = v2
    return new_dict
    
def moon_score(partition):
    dictionary = VTDs_to_Counties(partition)
    return dictionary_to_score(dictionary) + dictionary_to_score(invert_dict(dictionary))

In [None]:
#def p_i_given_j(county_splits, vtds, district_i, county_j):
#    counter = county_splits[district_i]
#    intersection = counter[str(county_j)]
#    
#    return intersection / vtds[str(county_j)]

In [None]:
county_edge_count = {}
for i in counties:
    county_graph = graph.subgraph([n for n,v in graph.nodes(data = True) if v[county_col] == i])
    total_edges = len(county_graph.edges())
    county_edge_count[i] = total_edges
countynodelist = {
    county: frozenset(
        [node for node in graph.nodes() if graph.nodes[node][county_col] == county]) for county in counties
}


county_subgraphs = {county: graph.subgraph([n for n in graph.nodes if graph.nodes[n][county_col] == county]) for county in counties}
county_edges = {county: len(county_subgraphs[county].edges()) for county in counties}
total_edges = sum(county_edges.values())

def cut_edges_in_county(partition):
   '''returns an integer score that is the sum over all the county scores. The scores are computed by taking
      number of cut egdes and dividing by the number of total edges.
   '''
   county_cut_edge_dict = {}
   cut_edge_set = partition["cut_edges"]
   for k in cut_edge_set:
       vtd_1 = k[0]
       vtd_2 = k[1]
       county_1 = countydict.get(vtd_1)
       county_2 = countydict.get(vtd_2)
       if county_1 == county_2:
           if county_1 in county_cut_edge_dict.keys():
               county_cut_edge_dict[county_1] += 1
           else:
               county_cut_edge_dict[county_1] = 1
   ratio_dict = {}
   for i in county_cut_edge_dict.keys():
       ratio = county_cut_edge_dict[i]/county_edge_count[i]
       ratio_dict[i] = ratio
   return sum(ratio_dict.values())

In [None]:
def cut_edges_in_district(partition):
    cut_edges_between = 0
    cut_edge_set = partition["cut_edges"]
    for i in cut_edge_set:
        vtd_1 = i[0]
        vtd_2 = i[1]
        county_1 = countydict.get(vtd_1)
        county_2 = countydict.get(vtd_2)
        if county_1 != county_2:
            cut_edges_between += 1
    num_cut_edges = len(cut_edge_set)
    score = cut_edges_between/num_cut_edges
    return score

In [None]:
#def q_j(vtds_d,county_j,total):
#    return vtds[county_j] / total

In [None]:
#def power_entropy(county_splits,vtds,total,alpha):
#    entropy = 0
#    for county_j in counties:
#        inner_sum = 0
#        q = q_j(vtds,county_j,total)
#        for district_i in range(num_districts):
#            p = p_i_given_j(county_splits, vtds,district_i+1,county_j)
#            inner_sum += p ** (1-alpha)
#        entropy += 1/q * (inner_sum-1)
#        #print(1/q * (inner_sum-1))
#    return entropy

In [None]:
#def Shannon_entropy(county_splits, vtds, total):
#    entropy = 0
#    for county_j in counties:
#        inner_sum = 0
#        q = q_j(vtds,county_j,total)
#        for district_i in range(num_districts):
#            p = p_i_given_j(county_splits, vtds,district_i+1,county_j)
#            if p != 0:
#                inner_sum += p * math.log(1/p)
#        entropy += q * (inner_sum)
#        #print(1/q * (inner_sum-1))
#    return entropy

In [None]:
#def p_i(vtds,district_i,total):
#    return vtds[district_i] / total

In [None]:
#def q_j_given_i(county_splits, vtds_d, district_i, county_j):
#    counter = county_splits[district_i]
#    intersection = counter[str(county_j)]
#    
#    return intersection / vtds_d[district_i]

In [None]:
#def other_power_entropy(county_splits,vtds_d,total,alpha):
#    entropy = 0
#    for district_i in range(num_districts):
#        innersum = 0
#        p = p_i(vtds_d,district_i+1,total)
#        for county_j in counties:
#            q = q_j_given_i(county_splits,vtds_d,district_i+1,county_j)
#            innersum += q ** (1-alpha)
#        entropy += 1/p * (innersum-1)
#    return entropy

In [None]:
#def symmetric_power_entropy(county_splits,vtds_c,vtds_d,total,alpha):
#    return power_entropy(county_splits,vtds_c,total,alpha) + other_power_entropy(county_splits,vtds_d,total,alpha)

In [None]:
def edge_entropy(partition):
    entropy = 0
    total_edges = len(graph.edges())
    countynodelist = {
        county: frozenset(
            [node for node in graph.nodes() if graph.nodes[node][county_col] == county]) for county in counties
    }
    districts_in_counties = {
        county: frozenset([partition.assignment[d] for d in countynodelist[county]]) for county in counties
    }
    for county in counties:
        county_subgraph = graph.subgraph([n for n in graph.nodes if graph.nodes[n][county_col] == county])
        county_edges = len(county_subgraph.edges())
        for (district1, district2) in combinations_with_replacement(districts_in_counties[county],2):
            p_ij = len([e for e in county_subgraph.edges() if set(
                [partition.assignment[e[0]], partition.assignment[e[1]]]) == set([district1, district2])])
            p_ij = p_ij/len(county_subgraph.edges())
            if (p_ij != 0):
                entropy -= p_ij*np.log(p_ij)*county_edges/total_edges
    return entropy

def num_of_splittings(partition):
    dictionary = county_splits_dict(partition)
    counter = 0
    for district in dictionary.keys():
        counter += len(dictionary[district])
    return counter

In [None]:
###################################################3

# returns a dictionary that maps a county to a list of VTDs that are in the counth

def reverse_countydict():
    rev = {k:[] for k in counties}
    for county in counties:
        for vtd in countydict.keys():
            if countydict[vtd] == county:
                rev[county].append(vtd)
    return rev

In [None]:
####################################################3

# calculates the population of a given county

def county_pop(rev, county_j):
    pop = 0
    for vtd in rev[county_j]:
        pop += graph.nodes[vtd]["TOT_POP"]
    return pop

In [None]:
###################################################33

# calculates population of given district

def district_pop(part, district_i):
    pop = 0
    for vtd in dict(part.parts)[district_i]:
        pop += graph.nodes[vtd]["TOT_POP"]
    return pop

In [None]:
#############################################3

#calculates population of intersection of given district and county

def intersection_pop(part, county_vtds, county_j, district_i):
    intersection = [vtd for vtd in county_vtds[county_j] if vtd in dict(part.parts)[district_i]]
    
    pop = 0
    for vtd in intersection:
        pop += graph.nodes[vtd]["TOT_POP"]
    return pop

In [None]:
#calculates power entropy

def power_entropy(partition, rev, alpha):
    entropy = 0
    for county_j in counties:
        inner_sum = 0
        cpop = county_pop(rev,county_j)
        q = cpop / totpop
        for district_i in range(num_districts):
            p = intersection_pop(partition,rev,county_j,district_i+1) / cpop
            inner_sum += p ** (1 - alpha)
        entropy += 1 / q * (inner_sum - 1)
    return entropy

In [None]:
###########################################33

# calculates Shannon entropy
# rev is dictionary mapping county to list of VTDs in county

def relative_entropy(partition, rev): 
    entropy = 0
    for county_j in counties:
        inner_sum = 0
        cpop = county_pop(rev,county_j)
        q = cpop / totpop
        for district_i in range(num_districts):
            p = intersection_pop(partition,rev,county_j,district_i+1) / cpop 
            if p != 0:
                inner_sum += p * math.log(1/p,2)
        entropy += q * inner_sum
    return entropy

In [None]:
#county_vtds = county_splits_dict(starting_partition)
#print(power_entropy(county_splits,vtds,total,0.5))
rev = reverse_countydict() 

#print(Shannon_entropy(starting_partition,rev))

In [None]:
d = county_splits_dict(starting_partition)
sum( [ len([ dd for dd  in [dict(v) for v in d.values()] if k in dd.keys()]) > 1 for k in counties] )
#print(our_split_score_1(starting_partition))
#print(our_split_score_2(starting_partition))
#print(too_many_pieces(starting_partition))
#print(other_too_many_pieces(starting_partition))
#print(district_splits_dict(county_splits))


In [None]:
proposal = partial(
        recom, pop_col="TOT_POP", pop_target=totpop/num_districts, epsilon=0.02, node_repeats=1
    )

compactness_bound = constraints.UpperBound(
        lambda p: len(p["cut_edges"]), 2 * len(starting_partition["cut_edges"])
    )

chain = MarkovChain(
        proposal,
        constraints=[
            constraints.within_percent_of_ideal_population(starting_partition, 0.05),compactness_bound
          #constraints.single_flip_contiguous#no_more_discontiguous
        ],
        accept=accept.always_accept,
        initial_state=starting_partition,
        total_steps=0
    )

In [None]:
entropy = []


t = 0
for part in partitions:
    entropy.append(relative_entropy(part,rev))
        
    
    t += 1
    if t % 100 == 0:
        print("finished chain " + str(t))
            
#np.savetxt("PA_cuts.txt", cuts)
#np.savetxt("PA_splittings.txt", splittings)
#np.savetxt("PA_power_entropy.txt", power)
#np.savetxt("PA_Shannon_entropy.txt",Shannon)
#np.savetxt("PA_Score1.txt",score_1)
#np.savetxt("PA_Score2.txt",score_2)
#np.savetxt("PA_moon.txt",moon)
print(entropy)

In [None]:
#colors = ['hotpink']
#labels = ['VTD']
#plt.figure()
#for i in range(1):
#    sns.distplot(score_1,kde=False, color=colors[i],label=labels[i])
#plt.legend()
#plt.xlabel("Score 1")
#plt.show()

#plt.figure()
#for i in range(1):
#    sns.distplot(score_2,kde=False, color=colors[i],label=labels[i])
#plt.legend()
#plt.xlabel("Score 2")
#plt.show()

#plt.figure()
#for i in range(1):
#    sns.distplot(pieces,kde=False, color=colors[i],label=labels[i])
#plt.legend()
#plt.xlabel("Too Many Splits")
#plt.show()

#plt.figure()
#for i in range(1):
#    sns.distplot(power,kde=False, color=colors[i],label=labels[i])
#plt.legend()
#plt.xlabel("Power Entropy (alpha = 4/5)")
#plt.show()

#plt.figure()
#for i in range(1):
#    sns.distplot(Shannon,kde=False, color=colors[i],label=labels[i])
#plt.legend()
#plt.xlabel("Shannon Entropy")
#plt.show()