In [None]:
%config InlineBackend.figure_formats = ['retina']
import os
import random
import pickle
import numpy as np
import pandas as pd
import geopandas as gpd
from ast import literal_eval
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
from artemis.general.pareto_efficiency import is_pareto_efficient
from submission_analysis.mc_clustering import geo_semantic_chain
from tqdm import tqdm
from collections import Counter

In [None]:
db_path = 'mi_cluster_db_20210809_new_format.pkl'
bg_2010_shp_path = 'tl_2010_26_bg10.shp'
num_clusters = 36
cluster_splits = {22: 2, 32: 2}
chain_beta = 2.
chain_length = 25000
rng_seed = 42
exclude_labels = [
    'theory',
    'personal-unusable-incoherent',
    'ideology',
    'named neighborhood'
]

In [None]:
random.seed(rng_seed)

In [None]:
db = pickle.load(open(db_path, 'rb'))

In [None]:
db.coi_data['idx'] = range(len(db.coi_data))

In [None]:
clusters_df = db.clusters_from_number(num_clusters)

In [None]:
clusters_df['labels'] = clusters_df['labels'].apply(literal_eval)

In [None]:
bg_gdf = gpd.read_file(bg_2010_shp_path).set_index('GEOID10')

In [None]:
def jaccard_similarity_matrix(cluster_df):
  """Generates a Jaccard similarity matrix over unique labels."""
  unique_labels = {
    label: idx
    for idx, label in enumerate(
        set.union(*(set(labels) for labels in cluster_df['labels'])) -
        exclude_labels)
  }
  
  n = len(cluster_df)
  label_vectors = np.zeros((n, len(unique_labels)), dtype=int)
  for idx, labels in enumerate(cluster_df['labels']):
    for label in labels:
      if label in unique_labels:
        label_vectors[idx, unique_labels[label]] = 1
        
  semantic_similarities = np.zeros((n, n))
  for ii, outer_vec in enumerate(label_vectors):
    for jj, inner_vec in enumerate(label_vectors):
      inter = np.bitwise_and(inner_vec, outer_vec)
      union = np.bitwise_or(inner_vec, outer_vec)
      semantic_similarities[ii, jj] = inter.sum() / max(union.sum(), 1)
  return semantic_similarities

In [None]:
cluster_states = {}
cluster_scores = {}
cluster_pareto_fronts = {}

for cluster_id, num_clusters in cluster_splits.items():
  cluster_df = clusters_df[clusters_df['clusters'] == int(cluster_id)]
  semantic_sims = jaccard_similarity_matrix(cluster_df)
  indices = cluster_df['idx']
  geo_distances = db.coi_total_dissimilarities[indices][:, indices]
  chain = geo_semantic_chain(geo_distances, semantic_sims, chain_beta, num_clusters, chain_length)
  
  states = [s for s in tqdm(chain)]
  scores = np.array([(state.scores['semantic'], state.scores['geo']) for state in states]).T.copy()
  cluster_states[cluster_id] = states
  cluster_scores[cluster_id] = scores
  # normalize semantic similarities: 1 is random baseline, higher is better
  scores[0] = scores[0] / scores[0, 0]
  # normalize geographic distances: 1 is random baseline, lower is better
  scores[1] = scores[1] / scores[1, 0]
  pareto_front = np.where(is_pareto_efficient(np.array([-scores[0], scores[1]]).T))[0]
  cluster_pareto_fronts[cluster_id] = pareto_front

  fig, ax = plt.subplots(figsize=(10, 6))
  ax.scatter(scores[0], scores[1])
  ax.set_xlabel('Semantic similarity (higher better)')
  ax.set_ylabel('Geographic distance (lower better)')
  ax.set_title(f'All steps (cluster {cluster_id})')
  plt.show()
  plt.close()

  fig, ax = plt.subplots(figsize=(10, 6))
  ax.scatter(scores[0, pareto_front], scores[1, pareto_front])
  ax.set_xlabel('Semantic similarity (higher better)')
  ax.set_ylabel('Geographic distance (lower better)')
  ax.set_title(f'Pareto front (cluster {cluster_id})')
  for idx, (x, y) in enumerate(scores[:, pareto_front].T):
    ax.annotate(str(idx), (x, y), textcoords='offset points', ha='center', xytext=(-10, 0))
  plt.show()
  plt.close()

In [None]:
def plot_split(cluster_id, pareto_id):
  cluster_df = clusters_df[clusters_df['clusters'] == int(cluster_id)].copy()
  state = cluster_states[cluster_id][pareto_id]

  fig, ax1 = plt.subplots(dpi=300)
  fig, ax2 = plt.subplots(dpi=300)
  colors = ['red', 'blue', 'green', 'yellow']
  for label, indices in state.partitions.items():
    partition_bgs = Counter()
    for bgs in cluster_df.iloc[sorted(indices)]['block_groups_2010']:
        for bg in bgs:
            partition_bgs[bg] += 1
    bg_gdf.loc[list(partition_bgs)].plot(ax=ax1, color=colors[label], alpha=0.5)
    bg_gdf.loc[[k for k, v in partition_bgs.items() if v >= 3]].plot(ax=ax2, color=colors[label], alpha=0.5)
    ax1.scatter([], [], label=str(label), color=colors[label])
    ax2.scatter([], [], label=str(label), color=colors[label])
  ax1.legend()
  ax2.legend()
  ax1.set_title('Subclusters (≥ 1 occurrence)')
  ax2.set_title('Subclusters (≥ 3 occurrences)')
  plt.show()
  for label, indices in state.partitions.items():
    display(Markdown(f'# Subcluster {label}'))
    text_label_freq = Counter()
    for text_labels in cluster_df.iloc[sorted(indices)]['labels']:
        for text_label in text_labels:
            if text_label not in exclude_labels:
                text_label_freq[text_label] += 1
    counts_df = pd.DataFrame(text_label_freq.most_common(), columns=('label', 'count')).set_index('label')
    fig, ax = plt.subplots(dpi=100)
    counts_df.plot.barh(ax=ax)
    ax.set_title(f'Label frequency (subcluster {label})')
    plt.show()

    with pd.option_context('display.max_rows', None, 'display.max_colwidth', None):
        display(cluster_df.iloc[sorted(indices)][['submission_text', 'area_text', 'area_name', 'labels']])

In [None]:
plot_split(32, 7)