In [None]:
%config InlineBackend.figure_formats = ['svg']
import os
import maup
import json
import numpy as np
import pandas as pd
import geopandas as gpd
from tqdm import tqdm
from gerrychain import Graph
from pcompress import Replay
from collections import defaultdict
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from ast import literal_eval
from scipy import sparse

In [None]:
import warnings; warnings.filterwarnings('ignore', 'GeoSeries.isna', UserWarning)

In [None]:
output_dir = '../../WI/outputs'
output_prefix = 'wi_congress_milwaukee_coi_preservation'
figure_title = 'COI preservation in Wisconsin (Congress)'
block_shapefile_path = '../../WI/data/tl_2020_55_tabblock20'
proj = 'EPSG:32615'
plans_path = '../../WI/data/proposed_plans'
block_pops_path = '../../WI/data/tl_2020_55_block_total_pop.json'
clusters_dir = None
clusters_path = None #'../../WI/data/WI_phaseC_20210929_fixed.csv'
vtd_graph_path = None #'/Users/pjrule/Dropbox/MGGG/plan-evaluation-reporting/dual_graphs/wi_vtds_0_indexed.json'
baf_path = None #'../../WI/data/BlockAssign_ST55_WI'
chain_path = None #'/Users/pjrule/Dropbox/raw_chains_6_10_2021/wisconsin_state_house_0.05_bal_100000_steps_county_aware.chain'
excluded_clusters = None
use_clusters = True
vtd_level_assignments = True
hierarchical_paths = True
selected_clusters = ('4', '5', '6', '7-1', '7-2', '8-1', '8-2', '8-3')
state_fips_code = '55'
exclude_subclusters = True
selected_plans = [
   {'id': 'congress/PMC-Congress-VTDs', 'label': 'PMC', 'color': 'tab:green'},
   {'id': 'congress/Enacted-VTDs', 'label': 'Enacted', 'color': 'tab:blue'},
]

In [None]:
os.makedirs(os.path.join(output_dir, output_prefix), exist_ok=True)

In [None]:
if clusters_dir is not None:
  clusters_path = os.path.join(output_dir, clusters_dir, 'csv_geo')

In [None]:
selected_plan_ids = {p['id'] for p in selected_plans}

In [None]:
graph = Graph.from_json(vtd_graph_path)

In [None]:
if vtd_level_assignments:
  vtd_block_path = os.path.join(baf_path, baf_path.split('/')[-1] + '_VTD.txt')
  vtd_block_df = pd.read_csv(vtd_block_path, sep='|', dtype=str).set_index('BLOCKID')
  vtd_block_df['vtd_id'] = state_fips_code + vtd_block_df['COUNTYFP'].str.zfill(3) + vtd_block_df['DISTRICT'].str.zfill(6)
  blocks_by_vtd = defaultdict(set)
  for block, vtd in vtd_block_df['vtd_id'].items():
    blocks_by_vtd[vtd].add(block)

In [None]:
if clusters_path is not None and clusters_path.endswith('.csv'):
  clusters_df = pd.read_csv(clusters_path).set_index('id')
  clusters_df['blocks_2020'] = clusters_df['blocks_2020'].apply(literal_eval)
  if use_clusters:
    if selected_clusters is not None:
      clusters_df = clusters_df[clusters_df['clusters'].isin(selected_clusters)]
    blocks_by_coi = {coi: set(blocks) for coi, blocks in clusters_df['blocks_2020'].items()}


elif use_clusters:
  blocks_by_coi = {}
  for cluster_csv in os.listdir(clusters_path):
    if cluster_csv.endswith('.csv'):
      cluster_id = cluster_csv[:-4].split('_')[-1][1:]
      if not exclude_subclusters or (exclude_subclusters and '-' not in cluster_id):
        blocks_by_coi[cluster_id] = set(pd.read_csv(os.path.join(clusters_path, cluster_csv))['GEOID20'].astype(str))

In [None]:
block_pop_df = pd.read_json(block_pops_path)
block_pop_df['GEOID20'] = (
  block_pop_df['state'].astype(str)  +
  block_pop_df['county'].astype(str).str.zfill(3) + 
  block_pop_df['tract'].astype(str).str.zfill(6) + 
  block_pop_df['block'].astype(str).str.zfill(4)
)
block_pop_df = block_pop_df.set_index('GEOID20')

In [None]:
block_pops = dict(block_pop_df['P1_001N'])

In [None]:
if vtd_level_assignments:
  node_ordering = {k: idx for idx, k in enumerate(blocks_by_vtd.keys())}
  num_units = len(blocks_by_vtd)
else:
  node_ordering = {k: idx for idx, k in enumerate(graph.nodes)}
  node_geoid_ordering = {
    graph.nodes[k]['GEOID20']: idx
    for idx, k in enumerate(graph.nodes)
  }
  num_units = len(graph.nodes)

In [None]:
unit_coi_inter_pops = np.zeros((len(blocks_by_coi), num_units))
if vtd_level_assignments:
  for vtd_idx, (vtd, vtd_blocks) in tqdm(enumerate(blocks_by_vtd.items())):
    for coi_idx, (coi, coi_blocks) in enumerate(blocks_by_coi.items()):
      unit_coi_inter_pops[coi_idx, vtd_idx] = sum(
        block_pops[b]
        for b in vtd_blocks & coi_blocks
      )
else:
  for coi_idx, coi_blocks in enumerate(blocks_by_coi.values()):
    for block in coi_blocks:
      unit_coi_inter_pops[coi_idx, node_geoid_ordering[block]] = block_pops[block]
      
unit_coi_inter_pops = sparse.csr_matrix(unit_coi_inter_pops)

In [None]:
coi_pops = np.array([sum(block_pops[b] for b in blocks) for blocks in blocks_by_coi.values()])

In [None]:
totpop_col = 'TOTPOP20' if 'TOTPOP20' in graph.nodes[0] else 'TOTPOP'

In [None]:
if vtd_level_assignments:
  unit_pops_by_geoid = {data['GEOID20']: data[totpop_col] for _, data in graph.nodes(data=True)}
  unit_pops = np.array([unit_pops_by_geoid.get(vtd, 0.0) for vtd in blocks_by_vtd])
  unit_pops_alt = np.array([sum(block_pops[b] for b in blocks) for vtd, blocks in blocks_by_vtd.items()])
  assert np.abs(unit_pops_alt - unit_pops).sum() < 20
else:
  unit_pops = np.array([block_pops[graph.nodes[k]['GEOID20']] for k in node_ordering.values()])

In [None]:
totpop = unit_pops.sum()

In [None]:
plans = {}
for outer_path, _, filenames in os.walk(plans_path):
  for filename in filenames:
    full_path = os.path.join(outer_path, filename)
    if hierarchical_paths:
      short_name = '/'.join(full_path[:-4].split('/')[-2:])
    else:
      short_name = full_path[:-4].split('/')[-1]
    if filename.endswith('.csv') and short_name in selected_plan_ids:
      df = pd.read_csv(full_path)
      if 'GEOID20' in df.columns:
        df['GEOID20'] = df['GEOID20'].astype(str)
        assignment = dict(df.set_index('GEOID20')['assignment'])
        plans[short_name] = assignment
      elif 'BLOCKID' in df.columns:
        assert len(df.columns) == 2
        assignment_col = [col for col in df.columns if col != 'BLOCKID'][0]
        df['BLOCKID'] = df['BLOCKID'].astype(str)
        assignment = dict(df.set_index('BLOCKID')[assignment_col])
        plans[short_name] = assignment

In [None]:
def geoid_assignment_to_matrix(assignment):
  """Converts a 1-indexed assignment vector to a per-district binary encoding."""
  min_assignment = min(assignment.values())
  assert min_assignment in (0, 1)
  dist_mat = np.zeros((num_units, max(assignment.values()) + (1 - min_assignment)))
  for node, dist in assignment.items():
    if vtd_level_assignments:
      dist_mat[node_ordering[node], int(dist) - min_assignment] = 1
    else:
      dist_mat[node_geoid_ordering[node], int(dist) - min_assignment] = 1
  return dist_mat

In [None]:
def assignment_to_matrix(assignment):
  """Converts a 1-indexed assignment vector to a per-district binary encoding."""
  assert min(assignment.values()) == 1
  dist_mat = np.zeros((num_units, max(assignment.values())))
  for node, dist in assignment.items():
    dist_mat[node_ordering[graph.nodes[node]['GEOID20']], int(dist) - 1] = 1
  return dist_mat

In [None]:
def thresholded_scores(dist_mat, threshold_intervals=20):
  # First criterion: X% of a COI is contained in a single district.
  coi_dist_pops = unit_coi_inter_pops @ dist_mat
  max_district_pop_in_coi = np.max(coi_dist_pops, axis=1)  
  score_by_threshold = {}
  ideal_dist_pop = totpop / dist_mat.shape[1]
  for threshold in range(int(32 * threshold_intervals), 40 * threshold_intervals - 40):
    normed_threshold = threshold / (40 * threshold_intervals)
    score_by_threshold[normed_threshold] = np.logical_or(
        max_district_pop_in_coi >= normed_threshold * ideal_dist_pop,
        max_district_pop_in_coi >= normed_threshold * coi_pops
    ).sum()
  return score_by_threshold

In [None]:
plan_scores = {
  plan_id: thresholded_scores(geoid_assignment_to_matrix(assn))
  for plan_id, assn in plans.items()
}

In [None]:
chain_scores = []
if chain_path:
  for idx, partition in tqdm(enumerate(Replay(graph, chain_path))):
    chain_scores.append(thresholded_scores(assignment_to_matrix(partition.assignment)))
    #if idx >= 1000: break

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
for plan in selected_plans:
  ax.plot(plan_scores[plan['id']].keys(), plan_scores[plan['id']].values(), label=plan['label'], color=plan['color'])
ax.set_xlabel('Threshold')
ax.set_ylabel('Score')
ax.set_title(figure_title)
plt.legend()
plt.savefig(os.path.join(output_dir, output_prefix, f'{output_prefix}_traces.png'), dpi=300)
plt.show()

In [None]:
for threshold in plan_scores[next(iter(plan_scores))]:
  fig, ax = plt.subplots(figsize=(8, 8))
  ax.hist([c[threshold] for c in chain_scores], alpha=0.3, density=True, label='County-aware ensemble', color='k')
  for plan in selected_plans:
    ax.axvline(plan_scores[plan['id']][threshold], color=plan['color'], label=f"{plan['label']} ({plan_scores[plan['id']][threshold]})", linewidth=3)
  ax.set_xlabel(f'Score ({int(threshold * 1000) / 10}% population inclusion)')
  ax.set_ylabel('Ensemble frequency')
  ax.set_title(figure_title)
  plt.legend()
  plt.savefig(os.path.join(output_dir, output_prefix, f'{output_prefix}_hist_{threshold}.png'), dpi=300)
  plt.close()