In [None]:
%config InlineBackend.figure_formats = ['svg']
import maup
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import graph_tool.all as gt
from ast import literal_eval
from collections import defaultdict
from gerrychain import Graph

In [None]:
graph = Graph.from_json('michigan_dualgraph.json')

In [None]:
submissions_df = pd.read_csv('../mi_all_subs_pseudo_cois.csv')
submissions_df = submissions_df[submissions_df['type'] == 'coi']

In [None]:
bg_gdf = gpd.read_file('2010_Block_Groups_(v17a)')
bg_gdf['id'] = '26' + bg_gdf['LINK']
bg_gdf = bg_gdf.set_index('id')

In [None]:
juris_gdf = gpd.read_file('Minor_Civil_Divisions_(Cities_%26_Townships)_').set_index('FIPSCODE')

In [None]:
import warnings; warnings.filterwarnings('ignore', 'GeoSeries.isna', UserWarning)
bg_to_juris = dict(maup.assign(bg_gdf, juris_gdf))

In [None]:
bg_to_county = {bg: bg[:5] for bg in bg_to_juris}

In [None]:
vtd_to_juris = {
  graph.nodes[node]['VTD']: graph.nodes[node]['Jurisdicti']
  for node in graph.nodes
}

In [None]:
vtd_to_county = {
  graph.nodes[node]['VTD']: str(graph.nodes[node]['county_fip'])
  for node in graph.nodes
}

In [None]:
districtr_data = []
for row in submissions_df['districtr_data']:
  try:
    parsed = literal_eval(row)
  except ValueError:
    pass
  if 'assignment' in parsed['plan'] and parsed['plan']['assignment']:
    districtr_data.append(parsed)

In [None]:
def whole(small_to_dist, small_to_large):
  """Given an assignment that maps small units (e.g. VTDs) to district/COI labels
  and a map between from small units to large units (e.g. jurisdictions/MCDs),
  finds all large units wholly contained in a single district."""
  smalls_in_large = defaultdict(set)
  for small, large in small_to_large.items():
    smalls_in_large[large].add(small)
  
  large_districts = defaultdict(set)
  district_smalls = defaultdict(set)

  for small, dist in small_to_dist.items():
    large = small_to_large[small]
    if isinstance(dist, list):
      all_dists = dist
    else:
      all_dists = [dist]
    for d in all_dists:
      large_districts[large].add(d)
      district_smalls[d].add(small)
  
  whole_larges = set()
  for large, districts in large_districts.items():
    for district in districts:
      if smalls_in_large[large].issubset(district_smalls[district]):
        whole_larges.add(large)
  return whole_larges

In [None]:
whole_juris = []
for submission in districtr_data:
  if 'assignment' not in submission['plan']:
    continue
  assignment = submission['plan']['assignment']
  if not assignment:
    continue
  first_key = next(iter(assignment))
  if first_key in vtd_to_juris:
    whole_juris.append(whole(assignment, vtd_to_juris))
  else:
    whole_juris.append(whole(assignment, bg_to_juris))

In [None]:
whole_county = []
for submission in districtr_data:
  if 'assignment' not in submission['plan']:
    continue
  assignment = submission['plan']['assignment']
  if not assignment:
    continue
  first_key = next(iter(assignment))
  if first_key in vtd_to_county:
    whole_county.append(whole(assignment, vtd_to_county))
  else:
    whole_county.append(whole(assignment, bg_to_county))

In [None]:
def submission_whole_unit_bipartite_graph(whole_units):
  """Constructs a bipartite graph between submissions and whole units."""
  graph = gt.Graph(directed=False)
  submission_vertices = [graph.add_vertex() for _ in whole_units]
  unique_units = list(set.union(*whole_units))
  unit_vertices = {unit: graph.add_vertex() for unit in unique_units}
  label_prop = graph.vp['label'] = graph.new_vertex_property('string')
  for vertex, label in zip(unit_vertices, unique_units):
    label_prop[vertex] = label
  for submission_vertex, units in zip(submission_vertices, whole_units):
    for unit in units:
      graph.add_edge(submission_vertex, unit_vertices[unit])
  return graph, unit_vertices

In [None]:
juris_submission_graph, juris_vertices = submission_whole_unit_bipartite_graph(whole_juris)
county_submission_graph, county_vertices = submission_whole_unit_bipartite_graph(whole_county)

In [None]:
state = gt.minimize_nested_blockmodel_dl(juris_submission_graph)
for _ in range(100):
  state.multiflip_mcmc_sweep(niter=10, beta=np.inf)

In [None]:
state.draw(layout='bipartite') #, output='mi_juris_coi_bipartite.png')

In [None]:
levels = state.get_levels()
for l, s in enumerate(levels[:5]):
  labels = s.get_blocks()
  offset = len(whole_juris)
  unit_to_label = {
    unit: labels[vertex]
    for unit, vertex in juris_vertices.items()
  }
  fig, ax = plt.subplots(figsize=(8, 10))
  ax.axis('off')
  juris_gdf[f'level{l}'] = juris_gdf.index.map(unit_to_label)
  juris_gdf.plot(column=f'level{l}', ax=ax, cmap='tab20')
  if s.get_N() == 1:
      break

In [None]:
levels = state.get_levels()
for l, s in enumerate(levels[:5]):
  labels = s.get_blocks()
  offset = len(whole_juris)
  unit_to_label = {
    unit: labels[vertex]
    for unit, vertex in juris_vertices.items()
  }
  fig, ax = plt.subplots(figsize=(8, 10))
  ax.axis('off')
  juris_gdf[f'level{l}'] = juris_gdf.index.map(unit_to_label)
  juris_gdf.plot(column=f'level{l}', ax=ax, cmap='tab20')
  if s.get_N() == 1:
      break