In [183]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score, silhouette_samples
import torch.nn.functional as F
from tqdm import tqdm
import re, collections, itertools, torch
import json

In [184]:
subjects = torch.load('subject_embeddings.pt')

In [185]:
X = np.stack([subjects[subject].cpu().numpy() for subject in subjects])
X = normalize(X, norm='l2', axis=1)
X = PCA(n_components=101, random_state=42).fit_transform(X)

In [186]:
clustering = {}
for n in tqdm(range(350, 365, 5)):
    clusterer = AgglomerativeClustering(
        n_clusters=n,
        linkage='ward'
    )
    clusters = clusterer.fit_predict(X)
    silhouette = silhouette_score(X, clusters)
    clustering[n] = {'clusters': clusters, 'silhouette': silhouette}
pd.DataFrame.from_dict(clustering, orient='index').sort_values('silhouette', ascending=False)

100%|██████████| 3/3 [01:30<00:00, 30.13s/it]


Unnamed: 0,clusters,silhouette
360,"[119, 2, 14, 30, 240, 212, 346, 36, 233, 30, 1...",0.065593
355,"[119, 13, 14, 30, 240, 212, 346, 36, 233, 30, ...",0.065075
350,"[119, 13, 14, 30, 240, 212, 346, 36, 233, 30, ...",0.064957


In [187]:
clusters = clustering[360]['clusters']
silhouettes = silhouette_samples(X, clusters)

In [188]:
subj = pd.DataFrame({'subject': subjects.keys(), 'label': clusters, 'silhouette': silhouettes})

In [189]:
labels = {k: v.values[0][0] for k, v in subj.sort_values('silhouette', ascending=False).groupby('label').head(1).reset_index(drop=True).groupby('label')[['subject']]}

In [190]:
with open('bill_subjects.json', 'r') as f:
    bill_subjects = json.load(f)

In [191]:
subj_labels = {row['subject']: row['label'] for _, row in subj.iterrows()}

bill_subjects_clean = {}
for bill, subject in bill_subjects.items():
    bill_subjects_clean[bill] = subj_labels.get(subject, -1)

In [182]:
with open('bill_labels.json', 'w') as f:
    json.dump(bill_subjects_clean, f)

In [192]:
subj

Unnamed: 0,subject,label,silhouette
0,Public Utilities Commission: reports.,119,0.224525
1,Volunteer firefighters: federal reimbursements.,2,0.029722
2,Public postsecondary education: Student Civic ...,14,-0.134431
3,Income taxes: exclusion.,30,-0.006432
4,Sale of water by local public entities: excise...,240,-0.032011
...,...,...,...
22642,State department budgets: zero-based budget pi...,44,-0.091683
22643,Fire prevention: local assistance grant progra...,168,-0.026203
22644,Child care and development: California State ...,67,0.001516
22645,Department of Motor Vehicles: records: confi...,162,-0.054022
