In [None]:
%config InlineBackend.figure_formats = ['svg']
import os
import json
import pickle
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from collections import Counter
from submission_analysis.crosswalk import Crosswalk
from tqdm import tqdm, trange
from scipy.cluster import hierarchy

In [None]:
db_path = '../../WI/data/wi_cluster_db_20210820.pkl'
block_2010_to_block_2020_crosswalk_path = '../../WI/data/tab2010_tab2020_st55_wi.txt'
block_2020_shp_path = '../../WI/data/tl_2020_55_tabblock20'
county_shp_path = '../data/tl_2020_us_county'
num_clusters = 40
state_fips_code = '55'
cluster_name_prefix = 'A'  # Moon's versioning scheme
output_dir = '../../WI/outputs'
output_prefix = 'WI_20210822_geo32'
crs = 'EPSG:32616'
output_formats = ['png'] #['shapefile', 'csv', 'html', 'png']
output_columns = ['districtr_id', 'submission_title', 'submission_text', 'area_name', 'area_text', 'labels']
excluded_submissions = {}
reassigned_submissions = {}

In [None]:
db = pickle.load(open(db_path, 'rb'))

In [None]:
clusters = db.clusters_from_number(num_clusters)

In [None]:
cw = Crosswalk(block_2010_to_block_2020_crosswalk_path)

In [None]:
blocks_2020_gdf = gpd.read_file(block_2020_shp_path).set_index('GEOID20').to_crs(crs)

In [None]:
counties_gdf = gpd.read_file(county_shp_path).to_crs(crs)
counties_gdf = counties_gdf[counties_gdf['STATEFP'] == state_fips_code]

## ✂️ Cluster surgery ✂️

In [None]:
clusters['clusters'] = clusters.apply(
  lambda row: reassigned_submissions.get(str(row.name), row['clusters']),
  axis=1
)
clusters = clusters[~clusters.index.isin(excluded_submissions)].copy()
db.coi_data = db.coi_data[~db.coi_data.index.isin(excluded_submissions)]

## Frequencies

In [None]:
cluster_counts = []
for cluster_idx in trange(1, max(clusters['clusters']) + 1):     
    cluster_df = clusters[clusters['clusters'] == cluster_idx]
    bg_2010_count = Counter()
    for bgs in cluster_df['block_groups_2010']:
      for bg in bgs:
        bg_2010_count[bg] += 1
    block_2020_count = Counter()
    for bg, count in bg_2010_count.items():
      if count > 0:
        for block_2020 in cw.map_2010_block_groups([bg]):
          block_2020_count[block_2020] += count
    cluster_counts.append(block_2020_count)

In [None]:
clusters = clusters.rename(columns={
  'area_text': 'area_name',
  'area_name': 'area_text',
  'clusters': 'cluster'
})
for text_col in ('area_name', 'area_text', 'submission_text', 'cluster'):
  clusters[text_col] = clusters[text_col].astype(str)
  
clusters.loc[clusters['area_name'] == 'nan', 'area_name'] = ''
clusters.loc[clusters['area_text'] == 'nan', 'area_text'] = ''
clusters.loc[clusters['submission_text'] == 'nan', 'submission_text'] = ''
clusters.loc[clusters['submission_text'] == '0', 'submission_text'] = ''
clusters['cluster'] = cluster_name_prefix + clusters['cluster']
clusters.index.name = 'plan_id'

## Output formats (per cluster)

* `shapefile` – Shapefile containing the subset of 2020 blocks within the cluster, with `count` and `freq` attributes.
* `csv` - List of 2020 blocks within the cluster, with `count` and `freq` attributes.
* `html` - Table containing COI submissions (including labels) supporting the cluster.
* `png` - Block-level heatmap (based on `count`) of the cluster with the state's counties as the basemap.

### Attributes
* `count` - The number of supporting COIs a block appears in.
* `freq` - `count`, but normalized (0-1).

In [None]:
full_output_dir = os.path.join(output_dir, output_prefix)

In [None]:
for ext in output_formats:
  os.makedirs(os.path.join(full_output_dir, ext), exist_ok=True)

In [None]:
for cluster_idx, counts in tqdm(enumerate(cluster_counts)):
  cluster_id = f'{cluster_name_prefix}{cluster_idx + 1}'
  cluster_label = f'{output_prefix}_cluster_{cluster_id}'

  df = pd.DataFrame.from_dict(counts, orient='index', columns=['count'])
  df.index.name = 'GEOID20'
  df['freq'] = df['count'] / df['count'].max()
  submissions = clusters[clusters['cluster'] == cluster_id][output_columns]
  
  if 'csv' in output_formats:
    df.to_csv(f'{full_output_dir}/csv/{cluster_label}.csv')
    
  if 'html' in output_formats:    
    submissions.to_html(f'{full_output_dir}/html/{cluster_label}.html', index=False)
  
  if 'shapefile' in output_formats or 'png' in output_formats:
    gdf = gpd.GeoDataFrame(df).join(blocks_2020_gdf[['geometry']])
    gdf.crs = crs
    
    if 'shapefile' in output_formats:
      gdf.to_file(f'{full_output_dir}/shapefile/{cluster_label}')
      
    if 'png' in output_formats:
      #gdf = gdf[gdf['count'] >= 3]
      blocks_2020_gdf['count'] = gdf['count']
      blocks_2020_gdf['count'] = blocks_2020_gdf['count'].fillna(0)
      fig, ax = plt.subplots(figsize=(10, 8))
      counties_gdf.plot(color='#d8f3dc', edgecolor='#e5e5e5', ax=ax)
      gdf.plot(ax=ax, column='count', cmap='viridis_r', linewidth=0,
               edgecolor='none', antialiased=False)
      ax.axis('off')  
      ax.set_title(f"Cluster {cluster_label} ({len(submissions)} submissions)")
      plt.savefig(f'{full_output_dir}/png/{cluster_label}.png',
                  dpi=300, transparent=True, bbox_inches='tight')
      plt.show()
      plt.close()