In [26]:
import re, torch, json, pickle, itertools
import numpy as np
import pandas as pd

from sentence_transformers import SentenceTransformer
from sklearn.cluster import AgglomerativeClustering
import warnings
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)

In [27]:
with open('bill_subjects.json', 'r') as f:
    bill_subjects = json.load(f)

In [28]:
subject_originals = pickle.load(open('subjects_original.pkl', 'rb'))

In [29]:
so = {k: subject_originals[v] for k, v in bill_subjects.items() if v in subject_originals}

In [30]:
subs = pd.DataFrame.from_dict(so, orient='index', columns=['subject']).reset_index().rename(columns={'index':'bill_id'})
subs = subs.loc[subs['subject'].notna()]

In [31]:
def canonical(sub, round='first'):
    if round == 'first':
        val = 0
    elif round == 'second':
        if len(sub.split(':')) < 2:
            val = 0
        else:
            val = 1
    if round == 'third':
        if sub.split(':').__len__() > 2:
            txt = ' '.join([sub.split(':')[1].lower(), sub.split(':')[2].lower()])
        elif len(sub.split(':')) <= 2:
            txt = re.sub(r'\:', ' ', sub.lower()).strip()
    else:
        txt = sub.split(':')[val].lower()
    txt = re.sub(r'\-', '', txt)
    txt = re.sub(r'[^a-z\s]', ' ', txt)
    pattern = r'\b(?:california|state|bill|law|act|amendment|proposition|measure|initiative|program|code|section|chapter|month|awareness|prevention)\b'
    txt = re.sub(pattern, ' ', txt)
    if len(txt) <= 7:
        if len(sub.split(':')) <= 2 or round == 'third':
            txt = re.sub(r'\:', ' ', sub.lower()).strip()
        elif sub.split(':').__len__() > 2:
            txt = ' '.join([sub.split(':')[1].lower(), sub.split(':')[2].lower()])
        txt = re.sub(r'\-', '', txt)
        txt = re.sub(r'[^a-z\s]', ' ', txt)
        txt = re.sub(pattern, ' ', txt)
    txt = re.sub(r'\s+', ' ', txt).strip()
    return txt

def sub_clean(sub):
    s = canonical(sub)
    if len(s) < 5:
        s = canonical(sub, round='second')
    if len(s) < 5:
        s = canonical(sub, round='third')
    s = s.lower().strip()
    s = re.sub(r'[^a-z\s]', ' ', s)
    sy = ' '.join([word.strip() for word in s.split() if word not in stop_words]).strip()
    return sy

subs['top_subject'] = subs['subject'].apply(sub_clean)

In [32]:
def embed_subjects(subjs):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embs = model.encode(subjs, batch_size=128, normalize_embeddings=True, show_progress_bar=True)
    return np.asarray(embs)

In [33]:
embs = embed_subjects(subs['top_subject'].dropna().unique().tolist())

Batches:   0%|          | 0/117 [00:00<?, ?it/s]

In [230]:
from sklearn.neighbors import kneighbors_graph

X = np.asarray(embs, dtype=np.float32)
connectivity = kneighbors_graph(X, n_neighbors=124, metric='cosine')

In [248]:
for threshold in [5.4]:
    m2 = AgglomerativeClustering(n_clusters=None, distance_threshold=threshold, linkage='ward', metric='euclidean', compute_full_tree=True, connectivity=connectivity)
    l2 = m2.fit_predict(X)
    print(threshold, len(np.unique(l2)))
    if (len(np.unique(l2)) < 85) & (len(np.unique(l2)) >=65):
        break

5.4 77


In [249]:
subjs = subs['top_subject'].dropna().unique().tolist()
subj_df = pd.DataFrame({'subject': subjs, 'cluster': l2})

In [250]:
labeling_df = subj_df.groupby('cluster').agg({'subject': lambda x: list(set(x))}).reset_index()

In [None]:
sbs = subs.merge(subj_df, left_on='top_subject', right_on='subject', how='left')
bill_labels = {}
for _, row in sbs.iterrows():
    bill_labels[row['bill_id']] = row['cluster']

In [None]:
with open('bill_labels.json', 'w') as f:
    json.dump(bill_labels, f)

In [257]:
labeling_df.to_csv('subject_clusters.csv', index=False)

In [34]:
with open('bill_labels.json', 'r') as f:
    bill_labels = json.load(f)

In [35]:
subs['cluster'] = subs['bill_id'].apply(lambda x: bill_labels.get(x, -1))

In [36]:
labeling_df = pd.read_csv('subject_clusters.csv')

In [37]:
def normalize(v):
    n = np.linalg.norm(v, axis=1, keepdims=True)
    n[n == 0] = 1.0
    return v / n

def cluster_representatives(embs, df, top_k=3):
    X = normalize(embs)
    df = df.drop_duplicates(subset=['top_subject'])
    labels = df['cluster'].to_numpy()
    texts = df['top_subject'].to_list()
    reps = {}
    for c in np.unique(labels):
        hl = []
        idx = np.where(labels == c)[0]
        Xc = X[idx]
        centroid = normalize(Xc.mean(axis=0, keepdims=True))
        sims = (Xc @ centroid.T).ravel()
        order = np.argsort(-sims)[:top_k]
        hl += [texts[int(idx[i])] for i in order]
        S = Xc @ Xc.T
        avg_sim = (S.sum(axis=1) - 1) / max(len(idx) - 1, 1)
        i = int(idx[np.argmax(avg_sim)])
        hl += [texts[i]]
        reps[c] = list(set(hl))

    return reps

reps = cluster_representatives(embs, subs, top_k=3)

In [38]:
model = SentenceTransformer('all-MiniLM-L6-v2')

embedded_labels = {}
for c, texts in reps.items():
    text = " ".join(texts)
    embs = model.encode(text, batch_size=32, normalize_embeddings=True, output_dimension=64)
    embedded_labels[int(c)] = embs

with open('embedded_labels.pkl', 'wb') as f:
    pickle.dump(embedded_labels, f)