In [1]:
import re, torch, json, pickle
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
from sklearn.neighbors import kneighbors_graph
import warnings
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)

In [2]:
with open('bill_subjects.json', 'r') as f:
    bill_subjects = json.load(f)

In [3]:
subject_originals = pickle.load(open('subjects_original.pkl', 'rb'))

In [4]:
so = {k: subject_originals[v] for k, v in bill_subjects.items() if v in subject_originals}

In [5]:
subs = pd.DataFrame.from_dict(so, orient='index', columns=['subject']).reset_index().rename(columns={'index':'bill_id'})
subs = subs.loc[subs['subject'].notna()]

In [6]:
digests = torch.load('digests.pt')

In [7]:
digests = {k: v.cpu().numpy() for k, v in digests.items()}

In [8]:
digest_df = pd.read_csv('ca_leg/legislation_data/digest.csv')

In [9]:
bv2b = pickle.load(open('bill_id_mapping.pkl', 'rb'))
digest_df['bill_ID'] = digest_df['bill_id'].map(bv2b)

In [10]:
full_table = digest_df.merge(subs, left_on='bill_ID', right_on='bill_id', how='right')
full_table.drop(columns=['bill_id_y'], inplace=True)
full_table.columns = ['bill_version', 'digest', 'bill_ID', 'subject']

In [11]:
def text_clean(title):
    if not isinstance(title, str):
        return ''
    title = re.sub(r'\(.*?\)', '', title)
    title = re.sub(r'[^a-zA-Z0-9\s]', ' ', title)
    title = re.sub(r'\s+', ' ', title).strip()
    return title.lower()

full_table['digest'] = full_table['digest'].apply(text_clean)
full_table['digest_emb'] = full_table['digest'].map(digests)
full_table = full_table.loc[full_table['digest_emb'].notna()]

In [12]:
ft = full_table.groupby('bill_ID').agg({'digest_emb': lambda x: np.mean(np.vstack(x), axis=0)}).reset_index()

In [26]:
ft['subject'] = ft['bill_ID'].map(so)

In [65]:
from sentence_transformers import SentenceTransformer

def canon_clean(txt):
    txt = re.sub(r'-', ' ', txt).lower()
    txt = re.sub(r'[^a-z\s]', ' ', txt)
    pattern = r'\b(?:california|state|bill|law|act|amendment|proposition|measure|initiative|program|code|section|chapter|month|awareness|prevention|day|week|month|memorial|highway|department|council|initiative)\b'
    txt = re.sub(pattern, ' ', txt)
    txt = re.sub(r'-', '', txt)
    txt = re.sub(r'[^a-z\s]', ' ', txt)
    txt = re.sub(pattern, ' ', txt)

    txt = re.sub(r'\s+', ' ', txt).strip()
    txt = ' '.join([word.strip() for word in txt.split() if word not in stop_words]).strip()
    return txt

def embed_subjects(sub):
    subjs = [canon_clean(s) for s in sub['subject'].tolist()]
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embs = model.encode(subjs, batch_size=128, normalize_embeddings=True, show_progress_bar=True, truncate_dim=64)
    return np.asarray(embs)

subj_embs = embed_subjects(ft)

Batches:   0%|          | 0/361 [00:00<?, ?it/s]

In [13]:
pca = PCA(n_components=155)
x1 = pca.fit_transform(np.vstack(ft['digest_emb']))

connectivity = kneighbors_graph(x1, n_neighbors=120, include_self=False).toarray()

In [14]:
clusterer = AgglomerativeClustering(n_clusters=None, distance_threshold=3, connectivity=connectivity, compute_full_tree=True)
clusterer.fit(x1)

In [35]:
from scipy.cluster.hierarchy import fcluster

def linkage(model):
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)
    return linkage_matrix

Z = linkage(clusterer)

(46,)

In [55]:
labels2 = fcluster(Z, t=5.8, criterion='distance')
np.unique(labels2).shape

(114,)

In [56]:
ftd = ft.copy()
ftd['label'] = clusterer.labels_
ftd['label2'] = labels2

In [184]:
groups = [((0, 4), 'Special Holidays'), (5, 'Memorial Highways'), (6, 'Joint Rules'), ((7, 8), 'Declarations'), ((9, 10), 'Income Tax'), ((11, 12), "Sales Tax"), (13, "Property Tax"), (14, 'Budget Plans'), (15, 'Validations'), (16, 'Special Legal Cases'), (17, "Use of Funds"), (18, 'Local Government Administration'), ((19, 20), 'Energy & Utilities'), ((21, 24), 'Environmental Resources'), (25, 'Waste Management'), ((26, 27), 'Environmental Resources'), ((28, 30), 'Horse Racing & Gambling'), (31, 'Agriculture & Farming'), (32, 'Animal Welfare'), (33, 'Wildlife Conservation'), (34, 'Weapons & Firearms'), ((35, 36), 'Pharmaceuticals'), (37, "Alcohol"), (38, 'Cannabis'), (39, 'Tobacco'), ((40, 41), 'Consumer Protections & Advertising'), (42, 'Public Records'), (43, 'Voting & Elections'), (44, 'Government Transparency'), (45, 'Privacy'), (46, 'Military & Veterans'), (47, 'Peace Officers'), (48, 'Civil Rights'), (49, 'Restorative Justice'), (50, "Special Legal Cases"), (51, 'Local Initiatives'), (52, 'Insurance'), (53, 'Civil Actions & Arbitration'), ((54, 55), 'Financial Institutions'), (56, 'Securities & Investments'), ((57, 60), 'Transportation'), (61, 'Fire Prevention'), (62, 'Emergency Services & Disaster Relief'), (63, 'Occupational Health & Safety'), (64, 'Tenancy'), (65, 'Housing Development & Planning'), (66, 'Housing Assistance Programs'), (67, 'Local Initiatives'),  (68, 'Local Government Administration'), (69, 'Regional Parks'), (70, 'Local Government Administration'), (71, 'Public Contracts'), (72, 'Telecommunications'), (73, 'Grant Programs'), (74, 'Special Holidays'), ((75, 76), 'State Initiatives'), ((77, 79), 'Postsecondary Education'), ((80, 82), 'Education'), (83, 'Workforce Development'), ((84, 85), 'Education'), (86, 'Youth Social Services'), (87, 'Abuse Protections'), (88, 'Marriage & Family Law'), ((89, 92), 'Criminal Procedure & Law'), ((93, 95), 'Correctional System'), (96, 'Public Employees'), (97, 'Employee Protections'), (98, 'Unemployment'), ((99, 100), 'Employee Protections'), (101, 'Public Employees'), ((102, 104), 'Public Social Services'), (105, 'Youth Social Services'), (106, 'Childcare'), (107, 'Professional Licensing'), ((108, 110), 'Health & Medicine'), (111, 'Public Social Services'), ((112, 113), 'Healthcare')]

group_dict = {}
for grp in groups:
    if isinstance(grp[0], tuple):
        for g in range(grp[0][0], grp[0][1]+1):
            group_dict[g] = grp[1]
    else:
        group_dict[grp[0]] = grp[1]

ftd['group'] = ftd['label2'].map(group_dict)

In [191]:
ftd['cluster'] = ftd['group'].astype('category').cat.codes

In [192]:
bill_labels = {}
for _, row in ftd.iterrows():
    bill_labels[row['bill_ID']] = row['cluster']

with open('bill_labels.json', 'w') as f:
    json.dump(bill_labels, f)

In [194]:
subject_key = {}
for _, row in ftd[['cluster', 'group']].drop_duplicates().iterrows():
    subject_key[row['cluster']] = row['group']

with open('subject_key.json', 'w') as f:
    json.dump(subject_key, f)