In [None]:
import os
import json
import pickle
import pandas as pd
import networkx as nx
import geopandas as gpd
import contextily as ctx
import plotly.express as px
import matplotlib.pyplot as plt
from collections import Counter
from submission_analysis.crosswalk import Crosswalk
from tqdm import tqdm, trange
from ast import literal_eval
from scipy.cluster import hierarchy

In [None]:
db_path = '../../WI/data/wi_cluster_db_20210820.pkl'
clusters_path = None
block_2010_to_block_2020_crosswalk_path = '../../WI/data/tab2010_tab2020_st55_wi.txt'
block_2020_shp_path = '../../WI/data/tl_2020_55_tabblock20'
base_shp_path = '../data/tl_2020_us_county'
num_clusters = 40
state_fips_code = '55'
cluster_name_prefix = 'A'  # Moon's versioning scheme
cluster_cores = True
cluster_core_threshold = 3  # minimum number of times a block must appear in a cluster to be considered "core"
output_dir = '../../WI/outputs'
output_prefix = 'WI_20210822_geo32'
crs = 'EPSG:32616'
output_formats = ['tex', 'png_summary']
output_columns = ['districtr_id', 'portal_url', 'submission_title', 'submission_text', 'area_name', 'area_text', 'cluster']
excluded_submissions = {}
reassigned_submissions = {}
portal_url_prefix = 'https://portal.wisconsin-mapping.org/submission/'
swap_area_columns = False

# choose the largest connected component (by population) of each COI when generating PNGs/shapefiles
force_connected = True
block_dual_graph_path = '../../WI/data/tl_2020_55_tabblock20.json'

In [None]:
geo_enabled = any(fmt in output_formats for fmt in ('shapefile', 'png', 'png_summary', 'csv_geo'))

In [None]:
if db_path:
  db = pickle.load(open(db_path, 'rb'))
  clusters = db.clusters_from_number(num_clusters)
elif clusters_path:
  clusters = pd.read_csv(clusters_path)
for col in ('block_groups_2010', 'labels'): 
  try:
    clusters[col] = clusters[col].apply(literal_eval) 
  except ValueError:
    pass  # doesn't need to be parsed

In [None]:
cw = Crosswalk(block_2010_to_block_2020_crosswalk_path)

In [None]:
# Optional: connected component filtering.
if force_connected:
  graph = nx.readwrite.json_graph.adjacency_graph(json.load(open(block_dual_graph_path)))
  graph = nx.relabel_nodes(graph, mapping=dict(graph.nodes('GEOID20')))

In [None]:
if geo_enabled:
  blocks_2020_gdf = gpd.read_file(block_2020_shp_path).set_index('GEOID20').to_crs(crs)

In [None]:
base_gdf = gpd.read_file(base_shp_path).to_crs(crs)
if 'STATEFP' in base_gdf.columns:
  base_gdf = base_gdf[base_gdf['STATEFP'] == state_fips_code]

## ✂️ Cluster surgery (filtering) ✂️

In [None]:
clusters['clusters'] = clusters.apply(
  lambda row: reassigned_submissions.get(str(row.name), row['clusters']),
  axis=1
)
clusters = clusters[~clusters.index.isin(excluded_submissions)].copy()

## 2020 block frequencies

In [None]:
cluster_counts = {}
for cluster_id in tqdm(clusters['clusters'].unique()):
  cluster_df = clusters[clusters['clusters'] == cluster_id]
  
  bg_2010_count = Counter()
  for bgs in cluster_df['block_groups_2010']:
    for bg in bgs:
      bg_2010_count[bg] += 1
  
  block_2020_count = Counter()
  for bg, count in bg_2010_count.items():
    if count > 0:
      for block_2020 in cw.map_2010_block_groups([bg]):
        block_2020_count[block_2020] += count
  cluster_counts[cluster_id] = block_2020_count

## ✂️ More cluster surgery (typing) ✂️

In [None]:
clusters = clusters.rename(columns={'clusters': 'cluster'})
if swap_area_columns:
  # Correct for an error in the 8/9 databases.
  clusters = clusters.rename(columns={
    'area_text': 'area_name',
    'area_name': 'area_text'
  })

for text_col in ('area_name', 'area_text', 'submission_text', 'cluster'):
  clusters[text_col] = clusters[text_col].astype(str)
  
clusters.loc[clusters['area_name'] == 'nan', 'area_name'] = ''
clusters.loc[clusters['area_text'] == 'nan', 'area_text'] = ''
clusters.loc[clusters['submission_text'] == 'nan', 'submission_text'] = ''
clusters.loc[clusters['submission_text'] == '0', 'submission_text'] = ''
clusters['cluster'] = cluster_name_prefix + clusters['cluster']
clusters['labels'] = clusters['labels'].apply(lambda s: ', '.join(s))
clusters.index.name = 'plan_id'

## Output formats (per cluster)

* `shapefile` – Shapefile containing the subset of 2020 blocks within the cluster, with `count` and `freq` attributes.
* `csv_geo` - List of 2020 blocks within the cluster, with `count` and `freq` attributes.
* `csv_comment` – Table containing COI submissions (including labels) supporting the cluster.
* `html` - Table containing COI submissions (including labels) supporting the cluster.
* `tex` - Table containing COI submissions (without labels) supporting the cluster.
* `png` - Block-level heatmap (based on `count`) of the cluster with the state's counties as the basemap.

### Attributes
* `count` - The number of supporting COIs a block appears in.
* `freq` - `count`, but normalized (0-1).

In [None]:
full_output_dir = os.path.join(output_dir, output_prefix)

In [None]:
for ext in output_formats:
  os.makedirs(os.path.join(full_output_dir, ext), exist_ok=True)

In [None]:
filtered_output_columns = [
  col for col in output_columns
  if col in clusters.columns or col == 'portal_url'
]

### $\LaTeX$ formatting
Before generating a $\LaTeX$ table from a submissions DataFrame, we fuse columns:
* The portal ID (index) and Districtr ID (`districtr_id`) are fused into the `Portal Link (Districtr)` column; the raw portal ID is replaced with link to the appropriate mapping portal if a portal URL prefix is available.
* `submission_title` and `submission_text` are fused into the `Overall Submission Information` column. The title is **bolded** with `\textbf{}`.
* `area_name` and `area_text` are formatted similarly and fused into the `Individual Area Information` column. `
* Labels, which are for internal use only, are removed.

We use Pandas' `.to_latex()` to generate an initial $\LaTeX$ table from this fused DataFrame; all columns are truncated at 2000 characters (approximately 7.143 tweets). We then apply some styling modifications:
* We use `supertabular` instead of `tabular` to enable stretching table entries across pages.
* We [use `arraystretch` to increase the table's vertical padding](https://tex.stackexchange.com/a/31704).
* We fix column widths:
  * `Portal Link Districtr` – .48in
  * `Overall Submission Information` - 3.5in
  *`Individual Area Information` - 2in

In [None]:
def format_tex_joint_columns(row, columns=('area_name', 'area_text')):
  """Fuses two columns (name + text) into a single column with conditional formatting."""
  name = row[columns[0]].strip()
  text = row[columns[1]].strip()
  if name and text:
    return '\\textbf{' + name + ':} ' + text
  elif name and not text:
    return '\\textbf{' + name + '.}'
  return text


def format_tex(submissions_df, portal_url_prefix=None, max_colwidth_chars=2000):
  """Generates LaTeX submission tables according to MGGG report specs."""
  submissions_tex = submissions_df.copy()
  submissions_tex['portal_id'] = submissions_tex.index.str.split('-').str[0]
  submissions_tex['part_id'] = submissions_tex.index.str.split('-').str[1]
  submissions_tex['plan_link'] = submissions_tex['portal_id']
  submissions_tex = submissions_tex.set_index(['portal_id', 'part_id']).sort_index(level=1).reset_index()
  
  if portal_url_prefix:
    submissions_tex['plan_link'] = submissions_tex['plan_link'].apply(
      lambda portal_id: '\href{' + portal_url_prefix + portal_id.split('-')[0] + '}{' + portal_id + '}'
    )
  submissions_tex['districtr_id'] = submissions_tex['districtr_id'].str.split('-').str[0].str.strip()
  submissions_tex['Portal Link (Districtr)'] = submissions_tex['plan_link'] + ' (' + submissions_tex['districtr_id'] + ')'
  submissions_tex['Individual Area Information'] = submissions_tex.apply(format_tex_joint_columns, axis=1)
  submissions_tex['Overall Submission Information'] = submissions_tex.apply(
    lambda row: format_tex_joint_columns(row, ('submission_title', 'submission_text')),
    axis=1
  )
  display_cols = ['Portal Link (Districtr)', 'Overall Submission Information', 'Individual Area Information']
  # oof. (see https://stackoverflow.com/a/46974532)
  submissions_tex = submissions_tex[display_cols].set_index(display_cols)
  with pd.option_context('max_colwidth', max_colwidth_chars):
    tex = submissions_tex.to_latex(index=True, multirow=True)
  subs = {
    '\\textbackslash href\{': '\href{',
    '\}\{': '}{',
    '\} (': '} (',
    '\\textbackslash n': ' ',  # suppress newlines
    '\\textbackslash textbf\{': '\\textbf{',
    ':\}': ':}',
    '.\}': '.}'
  }
  for start, end in subs.items():
    tex = tex.replace(start, end)
  tex_lines = tex.split('\n')    
  return '\n'.join(tex_lines[5:-3])

In [None]:
def largest_connected_block_component(gdf):
  """Finds the largest connected component of a block-level cluster."""
  # Fetching block-level 2020 populations is still annoying as of
  # 2021-08-25---the Census API still only has 2000/2010 data---
  # so we use block count as a rough proxy for population-weighted
  # component size.
  if gdf.empty:
    return gdf
  subgraph = nx.subgraph(graph, list(gdf.index))
  components = nx.connected_components(subgraph)
  largest_component = sorted(components, key=len)[-1]
  component_ids = set(graph.nodes[b]['GEOID20'] for b in largest_component)
  return gdf[gdf.index.to_series().isin(component_ids)]

### Final outputs

In [None]:
clusters_dissolved = []
cluster_cores_dissolved = []

for cluster_short_id, counts in tqdm(cluster_counts.items()):
  cluster_id = cluster_name_prefix + str(cluster_short_id)
  cluster_label = f'{output_prefix}_cluster_{cluster_id}'

  df = pd.DataFrame.from_dict(counts, orient='index', columns=['count'])
  df.index.name = 'GEOID20'
  df['freq'] = df['count'] / df['count'].max()
  submissions = clusters[clusters['cluster'] == cluster_id].copy()
  submissions['portal_url'] = portal_url_prefix + submissions.index.to_series().str.split('-').str[0]
  
  submissions = submissions[filtered_output_columns]
  if 'csv_geo' in output_formats:
    df.to_csv(f'{full_output_dir}/csv_geo/{cluster_label}.csv')
    
  if 'csv_comments' in output_formats:
    submissions.to_csv(f'{full_output_dir}/csv_comments/{cluster_label}.csv')
    
  if 'html' in output_formats:
    html = submissions.to_html()
    html = html.replace('\\n', '<br>').replace('\\t', ' ')
    with open(f'{full_output_dir}/html/{cluster_label}.html', 'w') as f:
      if portal_url_prefix:
        for portal_id in submissions.index:
          url = portal_url_prefix + portal_id.split('-')[0]
          html = html.replace(
            portal_id,
            f'<a href="{url}" target="_blank">{portal_id}</a>'
          )
      f.write(html)
  
  if 'tex' in output_formats:    
      with open(f'{full_output_dir}/tex/{cluster_label}.tex', 'w') as f:
        f.write(format_tex(submissions, portal_url_prefix))
  
  if geo_enabled:
    gdf = gpd.GeoDataFrame(df).join(blocks_2020_gdf[['geometry']])
    gdf.crs = crs
    if force_connected:
      gdf = largest_connected_block_component(gdf)
    
    if cluster_cores:
      core_gdf = gdf[gdf['count'] >= cluster_core_threshold]
      if force_connected:
        core_gdf = largest_connected_block_component(core_gdf)
    
    if 'shapefile' in output_formats:
      gdf.to_file(f'{full_output_dir}/shapefile/{cluster_label}')
      
    if 'png' in output_formats:
      fig, ax = plt.subplots(figsize=(10, 8), dpi=100)
      base_gdf.plot(ax=ax, edgecolor='black', linewidth=2)
      base_gdf.plot(color='#fffff5', edgecolor='#e5e5e5', ax=ax)
      gdf.plot(ax=ax, column='count', cmap='YlOrRd', #'viridis_r',
                           linewidth=0, edgecolor='none', antialiased=False,
                           vmin=0, vmax=10)
      ax.axis('off')  
      plt.savefig(f'{full_output_dir}/png/{cluster_id}.png',
                  dpi=300, transparent=True, bbox_inches='tight')
      plt.close()
      
    if 'png_summary' in output_formats:
      dissolved = gdf.dissolve()
      if not dissolved.empty:
        clusters_dissolved.append({'cluster': cluster_id, 'geometry': dissolved.iloc[0].geometry})
      if cluster_cores:
        core_dissolved = core_gdf.dissolve()
        if not core_dissolved.empty:
          cluster_cores_dissolved.append({
            'cluster': cluster_id,
            'geometry': core_dissolved.iloc[0].geometry
          })

In [None]:
def plot_summary(dissolved_gdf):
  """Plots dissolved clusters (or cluster cores)."""
  # Contextily expects the Web Mercator projection.
  dissolved_gdf = dissolved_gdf.to_crs(epsg=3857)
  
  # A (futile?) attempt to improve the z-order such that smaller clusters
  # are more visible.
  dissolved_gdf['area'] = dissolved_gdf.geometry.apply(lambda geom: geom.area)
  dissolved_gdf = dissolved_gdf.sort_values(by=['area'], ascending=False)
  
  # (see https://jcutrer.com/python/learn-geopandas-plotting-usmaps)
  ax = dissolved_gdf.plot(figsize=(20, 16), alpha=0.6, edgecolor='black', column='cluster')
  dissolved_gdf.apply(
    lambda x: ax.annotate(
      text=x.cluster,
      xy=x.geometry.centroid.coords[0],
      ha='center',
      fontsize=16,
      fontname='Helvetica',
      fontweight='bold'),
    axis=1)
  ctx.add_basemap(ax, source=ctx.providers.CartoDB.Voyager)
  ax.axis('off')
  return ax

In [None]:
if 'png_summary' in output_formats:
  clusters_gdf = gpd.GeoDataFrame(clusters_dissolved)
  clusters_gdf.crs = crs
  plot_summary(clusters_gdf)
  plt.savefig(f'{full_output_dir}/png_summary/{output_prefix}_{cluster_name_prefix}_summary.png', dpi=300)
  plt.close()
  
  if cluster_cores:
    cluster_cores_gdf = gpd.GeoDataFrame(cluster_cores_dissolved)
    cluster_cores_gdf.crs = crs
    plot_summary(cluster_cores_gdf)
    plt.savefig(f'{full_output_dir}/png_summary/{output_prefix}_{cluster_name_prefix}_cores_summary.png', dpi=300)
    plt.close()