In [24]:
import re, torch, json, pickle, itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
import networkx as nx
from community import community_louvain
from sklearn.mixture import GaussianMixture
from sentence_transformers import SentenceTransformer, util
from collections import Counter, defaultdict
import warnings

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)

In [3]:
with open('bill_subjects.json', 'r') as f:
    bill_subjects = json.load(f)

In [4]:
subject_originals = pickle.load(open('subjects_original.pkl', 'rb'))

In [5]:
so = {k: subject_originals[v] for k, v in bill_subjects.items() if v in subject_originals}

In [5]:
from rapidfuzz import process, fuzz

def canonical(sub):
    txt = sub.split(':')[0].lower()
    txt = re.sub(r'[^a-z\s]', ' ', txt)
    txt = re.sub(r'(?:california|state|bill|law|act|amendment|proposition|measure|initiative|program|act|code|section|chapter|month|awareness|prevention)', '', txt)
    txt = re.sub(r'\s+', ' ', txt).strip()
    return txt

def fuzzy_bucket(subjects, thresh=87):
    canons = {s: canonical(s) for s in subjects}
    canon_list = list(set(canons.values()))
    canon_to_subjects = defaultdict(list)
    for s, c in canons.items():
        canon_to_subjects[c].append(s)
    buckets, assigned = [], set()
    with tqdm(total=len(canon_list), desc='Fuzzy Buckets') as pbar:
        for i, c in enumerate(canon_list):
            if c in assigned:
                pbar.update(1)
                continue
            group = []
            for j in range(i, len(canon_list)):
                d = canon_list[j]
                if d in assigned:
                    continue
                if fuzz.token_sort_ratio(c, d) >= thresh:
                    group.extend(canon_to_subjects[d])
                    assigned.add(d)
            if group:
                buckets.append(group)
            pbar.update(1)
    return buckets

def embed_subjects(subjs):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    def sub_clean(s):
        s = s.lower().strip()
        s = re.sub(r'(?:california|state|bill|law|act|amendment|proposition|measure|initiative|program|act|code|section|chapter|month|awareness|prevention)', '', s)
        s = re.sub(r'[^a-z\s]', ' ', s)
        s = re.sub(r'\s+', ' ', s).strip()
        return s
    subjs = [sub_clean(s) for s in subjs]
    subjs = [s for s in subjs if s]
    embs = model.encode(subjs, normalize_embeddings=True, batch_size=128)
    return np.asarray(embs)

def gmm_soft_labels(embs, min_k=150, max_k=325):
    best_bic, best_gmm = np.inf, None
    for k in tqdm(np.linspace(min_k, max_k, 8, dtype=int)):
        gmm = GaussianMixture(k, covariance_type='diag', random_state=0).fit(embs)
        if gmm.bic(embs) < best_bic:
            best_bic, best_gmm = gmm.bic(embs), gmm
    probs = best_gmm.predict_proba(embs)
    return probs

def build_graph(subjects, fuzzy_buckets, soft_proba, alpha=0.35):
    idx = {s:i for i,s in enumerate(subjects)}
    N = len(subjects)
    fuzzy_adj = np.zeros((N,N), dtype=float)
    for bucket in fuzzy_buckets:
        for a,b in itertools.combinations(bucket, 2):
            fuzzy_adj[idx[a], idx[b]] = 1
            fuzzy_adj[idx[b], idx[a]] = 1
    embed_adj = soft_proba @ soft_proba.T
    W = alpha * fuzzy_adj + (1-alpha) * embed_adj
    G = nx.Graph()
    for i,s in enumerate(subjects):
        G.add_node(i, label=s)
    rows, cols = np.where(W > 0.01)
    for i,j in zip(rows, cols):
        if i < j:
            G.add_edge(i, j, weight=float(W[i,j]))
    return G

In [6]:
subs = list(set(list(so.values())))
buckets = fuzzy_bucket(subs, thresh=87)

Fuzzy Buckets: 100%|██████████| 13433/13433 [02:16<00:00, 98.68it/s]  


In [7]:
embs = embed_subjects(subs)

In [8]:
soft_proba = gmm_soft_labels(embs)

  0%|          | 0/8 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 8/8 [03:15<00:00, 24.39s/it]


In [9]:
G = build_graph(subs, buckets, soft_proba, alpha=0.35)

In [10]:
def consensus_clusters(G):
    part = community_louvain.best_partition(G, weight='weight', resolution=1.0)
    clusters = defaultdict(list)
    for node, cid in part.items():
        clusters[cid].append(G.nodes[node]['label'])
    return clusters

In [11]:
clusters = consensus_clusters(G)

In [12]:
def community_scale(G, r):
    part = community_louvain.best_partition(G, weight='weight', resolution=r)
    clusters = defaultdict(list)
    for node, cid in part.items():
        clusters[cid].append(G.nodes[node]['label'])
    return clusters

In [13]:
import igraph as ig, leidenalg as la
from sklearn.metrics.pairwise import cosine_similarity

cluster_ids = list(clusters.keys())
model = SentenceTransformer('all-MiniLM-L6-v2')

def cluster_embedding(texts, model):
    vecs = model.encode(texts, batch_size=128,
                        normalize_embeddings=True,
                        device='mps', show_progress_bar=True)
    return vecs.mean(0)

emb = {cid: cluster_embedding(clusters[cid], model)
       for cid in cluster_ids}
X = np.vstack([emb[cid] for cid in cluster_ids])

k = 15
sim = cosine_similarity(X)
knn_ix = np.argsort(-sim, axis=1)[:, 1:k+1]

edges, wts = [], []
for i, nbrs in enumerate(knn_ix):
    for j in nbrs:
        edges.append((i, j))
        wts.append(float(sim[i, j]))

g = ig.Graph(edges=edges, directed=False)
g.vs['name'] = cluster_ids
g.es['weight'] = wts

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [14]:
def leiden_cpm(gamma):
    part = la.find_partition(
        g,
        la.CPMVertexPartition,
        weights='weight',
        resolution_parameter=gamma,
        n_iterations=-1
    )
    return {g.vs[i]['name']: cid for i, cid in enumerate(part.membership)}

coarse05 = leiden_cpm(0.7)

In [15]:
labels = pd.DataFrame({
    'bill_id': list(so.keys()),
    'subject': list(so.values())
})

In [16]:
clust_reverse = {v: k for k, val in clusters.items() for v in val}

In [17]:
labels['cluster'] = labels['subject'].map(clust_reverse)
labels['coarse05'] = labels['cluster'].map(coarse05)

In [18]:
with open('bill_ids.txt', 'r') as f:
    bill_ids = f.read().splitlines()

In [19]:
with open('missed_bills.txt', 'r') as f:
    missed_bills = f.read().splitlines()

In [20]:
bill_id_mapping = pickle.load(open('bill_id_mapping.pkl', 'rb'))

In [21]:
missing_variations = [k for k, v in bill_id_mapping.items() if v in missed_bills]

In [22]:
bill_labels = {}
for _, row in labels[['bill_id', 'cluster']].drop_duplicates().iterrows():
    bill_labels[row['bill_id']] = row['cluster']

In [23]:
with open('bill_labels.json', 'w') as f:
    json.dump(bill_labels, f)

In [24]:
digests = pd.read_csv('ca_leg/legislation_data/digest.csv')

In [25]:
digest_embeddings = torch.load('digests.pt', weights_only=False)

In [26]:
repairs = digests.loc[digests['bill_id'].isin(missing_variations)]
repairs['bill'] = repairs['bill_id'].map(bill_id_mapping)
repairs['version'] = repairs['bill_id'].apply(lambda x: x[-5:-3]).astype(int)
repairs = repairs.sort_values('version', ascending=False).groupby('bill').head(1)

In [27]:
de = {k: v.cpu().numpy() for k, v in digest_embeddings.items() if k in repairs['DigestText'].values}

In [28]:
repairs['digest_embedding'] = repairs['DigestText'].map(de)

In [29]:
sample_weights = labels['cluster'].value_counts().to_dict()

def sample_weighted(labs, sample_weights, n):
    lab = labs.copy().sample(frac=1).reset_index(drop=True)
    weights = labs['cluster'].map(sample_weights)
    return lab.sample(n, weights=weights, replace=True)

training_sample = sample_weighted(labels.loc[~labels['bill_id'].isin(missed_bills)], sample_weights, 2000)

In [31]:
t_vars = [k for k, v in bill_id_mapping.items() if v in training_sample['bill_id'].values]
t = digests.loc[digests['bill_id'].isin(t_vars)]
t['bill'] = t['bill_id'].map(bill_id_mapping)
t['version'] = t['bill_id'].apply(lambda x: re.sub(r'\D+', '', x)[-2:]).astype(int)
t = t.sort_values('version', ascending=False).groupby('bill').head(1)

dee = {k: v.cpu().numpy() for k, v in digest_embeddings.items() if k in t['DigestText'].values}
t['digest_embedding'] = t['DigestText'].map(dee)
t = t.loc[t['digest_embedding'].notna()]
t = t.merge(training_sample[['bill_id', 'cluster']], right_on='bill_id', left_on='bill', how='inner')

In [33]:
X = np.stack(t['digest_embedding'].values)
y = t['cluster'].values

In [34]:
from sklearn.model_selection import train_test_split, ParameterGrid
from sklearn.ensemble import RandomForestClassifier

param_grid = {
    'criterion': ['gini'],
    'max_depth': [10, 15, 20],
    'min_samples_split': [2],
    'min_samples_leaf': [2],
    'max_features': ['sqrt', 'log2']
}
rf = RandomForestClassifier(random_state=42, n_jobs=1)
grid = ParameterGrid(param_grid)
best_score, best_params = 0, None
for params in tqdm(grid):
    rf.set_params(**params)
    rf.fit(X, y)
    score = rf.score(X, y)
    if score > best_score:
        best_score = score
        best_params = params
print(f"Best Score: {best_score:.4f}")

100%|██████████| 6/6 [00:47<00:00,  7.88s/it]

Best Score: 0.9865





In [35]:
rf = rf.set_params(**best_params)
rf.fit(X, y)

In [36]:
repairs['label_pred'] = rf.predict(np.vstack(repairs['digest_embedding']))
reps = {k: v for k, v in repairs[['bill', 'label_pred']].values if v is not None}
bbb = bill_labels.copy()
for k, v in reps.items():
    bbb[k] = v

with open('bill_labels_updated.json', 'w') as f:
    json.dump(bbb, f)

In [6]:
with open('bill_labels_updated.json', 'r') as f:
    updated_labels = json.load(f)

In [25]:
labels = pd.read_csv('sampled_labels - sampled_labels.csv')

In [27]:
so = {row['cluster']: row['Label'] for _, row in labels.iterrows()}

In [38]:
corrections = {}

for l in labels.groupby('Label')['cluster'].count().sort_values(ascending=False).loc[lambda x: x > 1].index:
    correction = [k for k, v in so.items() if v == l]
    m = min(correction)
    for c in correction:
        if c != m:
            corrections[c] = m

In [41]:
updated_labels2 = {}

for k, v in updated_labels.items():
    if v in corrections:
        updated_labels2[k] = corrections[v]
    else:
        updated_labels2[k] = v

In [42]:
with open('bill_labels_updated.json', 'w') as f:
    json.dump(updated_labels2, f)

In [12]:
labels = pd.DataFrame.from_dict(updated_labels, orient='index', columns=['cluster']).reset_index(names='bill_id')

In [17]:
text = labels['bill_id'].map(bill_subjects)
labels['subject'] = text

Unnamed: 0,bill_id,cluster,subject
0,200320040SB73,82,national guard
1,200520060AB1484,5,sexually violent predators definition
2,200920100AB2531,32,redevelopment economic development
3,200520060SB1576,102,foster care transitional housing
4,200920100AB383,34,criminal procedure dna evidence
...,...,...,...
46095,201720180ACR25,115,
46096,200920100SJR35,115,
46097,201720180AB291,34,
46098,201720180ACR6,56,


In [21]:
sample = labels.groupby('cluster').sample(50).reset_index()
sample['count'] = sample.groupby('cluster').cumcount()

In [23]:
sample.pivot(index='cluster', columns='count', values='subject').to_csv('sampled_labels.csv', index=True)

In [20]:
from collections import defaultdict, Counter
from sklearn.feature_extraction.text import CountVectorizer

In [18]:
labels.groupby('cluster').sample(50, replace=True).sort_values('cluster').to_csv('sampled_labels.csv', index=False)

In [133]:
def _clean(text):
    if not isinstance(text, str):
        return ""
    text = text.lower().strip()
    text = re.sub(r"\b(state|bill|act|law|code|section|chapter|california|month)\b", " ", text, flags=re.I)
    text = re.sub(r"\s+", " ", text).lower()
    return text

def text_cluster(label, ngram_range=(2, 3), max_features=125):
    section = labels.loc[labels['cluster'] == label, 'subject'].values
    cleaned_texts = [_clean(text) for text in section if text and isinstance(text, str)]

    if not cleaned_texts:
        return []

    vectorizer = CountVectorizer(
        ngram_range=ngram_range,
        max_features=max_features,
        stop_words='english',
        min_df=5,
        lowercase=True
    )

    try:
        count_matrix = vectorizer.fit_transform(cleaned_texts)
        feature_names = vectorizer.get_feature_names_out()

        phrase_counts = count_matrix.sum(axis=0).A1

        phrase_count_pairs = list(zip(feature_names, phrase_counts))
        phrase_count_pairs.sort(key=lambda x: x[1], reverse=True)

        return phrase_count_pairs[:3]

    except ValueError:
        return []

In [134]:
cluster_phrases = {}
for label in labels['cluster'].unique():
    cluster_phrases[label] = text_cluster(label)

In [135]:
cp = pd.DataFrame.from_dict(cluster_phrases, orient='index', columns=['phrase1', 'phrase2', 'phrase3']).reset_index(names='cluster').sort_values('cluster')

In [None]:
sl = pd.read_csv('sampled_labels - sampled_labels.csv')
big_labels = {row['cluster']: row['Unnamed: 1'] for _, row in sl.iterrows()}

In [23]:
labels['label'] = labels['cluster'].map(big_labels)