In [None]:
import re, torch, json, pickle, itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
import networkx as nx
from community import community_louvain
from sklearn.mixture import GaussianMixture
from sentence_transformers import SentenceTransformer, util
from collections import Counter, defaultdict
import warnings

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)

In [65]:
with open('bill_subjects.json', 'r') as f:
    bill_subjects = json.load(f)

In [66]:
subject_originals = pickle.load(open('subjects_original.pkl', 'rb'))

In [67]:
so = {k: subject_originals[v] for k, v in bill_subjects.items() if v in subject_originals}

In [68]:
from rapidfuzz import process, fuzz

def canonical(sub):
    txt = sub.split(':')[0].lower()
    txt = re.sub(r'[^a-z\s]', ' ', txt)
    txt = re.sub(r'(?:california|state|bill|law|act|amendment|proposition|measure|initiative|program|act|code|section|chapter|month|awareness|prevention)', '', txt)
    txt = re.sub(r'\s+', ' ', txt).strip()
    return txt

def fuzzy_bucket(subjects, thresh=87):
    canons = {s: canonical(s) for s in subjects}
    canon_list = list(set(canons.values()))
    canon_to_subjects = defaultdict(list)
    for s, c in canons.items():
        canon_to_subjects[c].append(s)
    buckets, assigned = [], set()
    with tqdm(total=len(canon_list), desc='Fuzzy Buckets') as pbar:
        for i, c in enumerate(canon_list):
            if c in assigned:
                pbar.update(1)
                continue
            group = []
            for j in range(i, len(canon_list)):
                d = canon_list[j]
                if d in assigned:
                    continue
                if fuzz.token_sort_ratio(c, d) >= thresh:
                    group.extend(canon_to_subjects[d])
                    assigned.add(d)
            if group:
                buckets.append(group)
            pbar.update(1)
    return buckets

def embed_subjects(subjs):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    def sub_clean(s):
        s = s.lower().strip()
        s = re.sub(r'(?:california|state|bill|law|act|amendment|proposition|measure|initiative|program|act|code|section|chapter|month|awareness|prevention)', '', s)
        s = re.sub(r'[^a-z\s]', ' ', s)
        s = re.sub(r'\s+', ' ', s).strip()
        return s
    subjs = [sub_clean(s) for s in subjs]
    subjs = [s for s in subjs if s]
    embs = model.encode(subjs, normalize_embeddings=True, batch_size=128)
    return np.asarray(embs)

def gmm_soft_labels(embs, min_k=150, max_k=325):
    best_bic, best_gmm = np.inf, None
    for k in tqdm(np.linspace(min_k, max_k, 8, dtype=int)):
        gmm = GaussianMixture(k, covariance_type='diag', random_state=0).fit(embs)
        if gmm.bic(embs) < best_bic:
            best_bic, best_gmm = gmm.bic(embs), gmm
    probs = best_gmm.predict_proba(embs)
    return probs

def build_graph(subjects, fuzzy_buckets, soft_proba, alpha=0.35):
    idx = {s:i for i,s in enumerate(subjects)}
    N = len(subjects)
    fuzzy_adj = np.zeros((N,N), dtype=float)
    for bucket in fuzzy_buckets:
        for a,b in itertools.combinations(bucket, 2):
            fuzzy_adj[idx[a], idx[b]] = 1
            fuzzy_adj[idx[b], idx[a]] = 1
    embed_adj = soft_proba @ soft_proba.T
    W = alpha * fuzzy_adj + (1-alpha) * embed_adj
    G = nx.Graph()
    for i,s in enumerate(subjects):
        G.add_node(i, label=s)
    rows, cols = np.where(W > 0.01)
    for i,j in zip(rows, cols):
        if i < j:
            G.add_edge(i, j, weight=float(W[i,j]))
    return G

In [69]:
subs = list(set(list(so.values())))
buckets = fuzzy_bucket(subs, thresh=87)

Fuzzy Buckets: 100%|██████████| 13480/13480 [01:20<00:00, 167.54it/s] 


In [70]:
embs = embed_subjects(subs)

In [71]:
soft_proba = gmm_soft_labels(embs)

  0%|          | 0/8 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 8/8 [02:06<00:00, 15.83s/it]


In [72]:
G = build_graph(subs, buckets, soft_proba, alpha=0.35)

In [73]:
def consensus_clusters(G):
    part = community_louvain.best_partition(G, weight='weight', resolution=1.0)
    clusters = defaultdict(list)
    for node, cid in part.items():
        clusters[cid].append(G.nodes[node]['label'])
    return clusters

clusters = consensus_clusters(G)

In [74]:
def community_scale(G, r):
    part = community_louvain.best_partition(G, weight='weight', resolution=r)
    clusters = defaultdict(list)
    for node, cid in part.items():
        clusters[cid].append(G.nodes[node]['label'])
    return clusters

In [75]:
import igraph as ig, leidenalg as la
from sklearn.metrics.pairwise import cosine_similarity

cluster_ids = list(clusters.keys())
model = SentenceTransformer('all-MiniLM-L6-v2')

def cluster_embedding(texts, model):
    vecs = model.encode(texts, batch_size=128,
                        normalize_embeddings=True,
                        device='mps', show_progress_bar=True)
    return vecs.mean(0)

emb = {cid: cluster_embedding(clusters[cid], model)
       for cid in cluster_ids}
X = np.vstack([emb[cid] for cid in cluster_ids])

k = 15
sim = cosine_similarity(X)
knn_ix = np.argsort(-sim, axis=1)[:, 1:k+1]

edges, wts = [], []
for i, nbrs in enumerate(knn_ix):
    for j in nbrs:
        edges.append((i, j))
        wts.append(float(sim[i, j]))

g = ig.Graph(edges=edges, directed=False)
g.vs['name'] = cluster_ids
g.es['weight'] = wts

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/12 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [76]:
def leiden_cpm(gamma):
    part = la.find_partition(
        g,
        la.CPMVertexPartition,
        weights='weight',
        resolution_parameter=gamma,
        n_iterations=-1
    )
    return {g.vs[i]['name']: cid for i, cid in enumerate(part.membership)}

coarse05 = leiden_cpm(0.7)

In [77]:
labels = pd.DataFrame({
    'bill_id': list(so.keys()),
    'subject': list(so.values())
})

In [79]:
clust_reverse = {v: k for k, val in clusters.items() for v in val}

In [80]:
labels['cluster'] = labels['subject'].map(clust_reverse)
labels['coarse05'] = labels['cluster'].map(coarse05)

In [81]:
with open('bill_ids.txt', 'r') as f:
    bill_ids = f.read().splitlines()

In [82]:
with open('missed_bills.txt', 'r') as f:
    missed_bills = f.read().splitlines()

for b in labels.loc[~labels['bill_id'].isin(missed_bills) & (labels['cluster'].isna()), 'bill_id'].unique().tolist():
    missed_bills.append(b)

In [83]:
bill_id_mapping = pickle.load(open('bill_id_mapping.pkl', 'rb'))

In [84]:
missing_variations = [k for k, v in bill_id_mapping.items() if v in missed_bills]

In [85]:
bill_labels = {}
for _, row in labels.loc[labels['cluster'].notna(), ['bill_id', 'cluster']].drop_duplicates().iterrows():
    bill_labels[row['bill_id']] = row['cluster']

In [86]:
with open('bill_labels.json', 'w') as f:
    json.dump(bill_labels, f)

In [87]:
digests = pd.read_csv('ca_leg/legislation_data/digest.csv')

In [88]:
digest_embeddings = torch.load('digests.pt', weights_only=False)

In [89]:
repairs = digests.loc[digests['bill_id'].isin(missing_variations)]
repairs['bill'] = repairs['bill_id'].map(bill_id_mapping)
repairs['version'] = repairs['bill_id'].apply(lambda x: x[-5:-3] if 'VETO' not in x else re.search(r'\d{2}(?=VETO)', x).group()).astype(int)
repairs = repairs.sort_values('version', ascending=False).groupby('bill').head(1)

In [90]:
digest_embedding_lookup = {
    key: tensor.cpu().numpy()
    for key, tensor in digest_embeddings.items()
}

In [91]:
repairs['digest_embedding'] = repairs['DigestText'].map(digest_embedding_lookup)
repairs = repairs.loc[repairs['digest_embedding'].notna()].reset_index(drop=True)

In [None]:
import numpy as np, pandas as pd, re, json, ast
from collections import Counter

def extract_version(x):
    if not isinstance(x, str):
        return -1
    m = re.search(r'(\d{2})(?=VETO)', x)
    if m:
        return int(m.group(1))
    m = re.search(r'(\d{2})(?!.*\d)', x)
    return int(m.group(1)) if m else -1

def coerce_vec(v):
    if v is None or (isinstance(v, float) and np.isnan(v)):
        return None
    if isinstance(v, np.ndarray):
        arr = v
    elif isinstance(v, (list, tuple)):
        arr = np.asarray(v)
    elif isinstance(v, str):
        try:
            arr = np.asarray(json.loads(v))
        except Exception:
            try:
                arr = np.asarray(ast.literal_eval(v))
            except Exception:
                return None
    else:
        try:
            arr = np.asarray(v)
        except Exception:
            return None
    arr = np.asarray(arr, dtype=np.float32)
    if arr.size == 0:
        return None
    return arr.reshape(-1)

def normalize_embeddings(matrix):
    matrix = np.asarray(matrix, dtype=np.float32)
    if matrix.size == 0:
        return matrix
    norms = np.linalg.norm(matrix, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    return matrix / norms

# -- rebuild labeled_digests robustly
labeled = labels[['bill_id', 'cluster']].drop_duplicates().rename(columns={'bill_id': 'bill'})
labeled_variations = [k for k, v in bill_id_mapping.items() if v in labeled['bill'].values]

labeled_digests = digests.loc[digests['bill_id'].isin(labeled_variations)].copy()
labeled_digests['bill'] = labeled_digests['bill_id'].map(bill_id_mapping)
labeled_digests['version'] = labeled_digests['bill_id'].apply(extract_version).astype(int)

labeled_digests = (
    labeled_digests.sort_values('version', ascending=False)
    .groupby('bill', as_index=False)
    .head(1)
    .reset_index(drop=True)
)

labeled_digests['digest_embedding'] = labeled_digests['DigestText'].map(digest_embedding_lookup)
labeled_digests = labeled_digests.merge(labeled, on='bill', how='inner').reset_index(drop=True)

# -- clean embeddings and record dims
def attach_clean_and_dim(df, col='digest_embedding'):
    df = df.copy()
    df['__vec'] = df[col].apply(coerce_vec)
    df = df[df['__vec'].notna()].reset_index(drop=True)
    df['__dim'] = df['__vec'].apply(lambda x: int(x.shape[0]))
    return df

labeled_digests = attach_clean_and_dim(labeled_digests, 'digest_embedding')
repairs = attach_clean_and_dim(repairs, 'digest_embedding') if 'repairs' in globals() and not repairs.empty else pd.DataFrame(columns=['__vec','__dim'])

# -- choose a single target dimension (majority across both)
dim_counts = Counter()
dim_counts.update(labeled_digests['__dim'].tolist())
dim_counts.update(repairs['__dim'].tolist())
target_dim = dim_counts.most_common(1)[0][0] if len(dim_counts) else 0

# -- filter both sets to the chosen dimension
if target_dim > 0:
    labeled_use = labeled_digests[labeled_digests['__dim'] == target_dim].reset_index(drop=True)
    repairs_use = repairs[repairs['__dim'] == target_dim].reset_index(drop=True)
else:
    labeled_use = labeled_digests.iloc[0:0].copy()
    repairs_use = repairs.iloc[0:0].copy()

# -- build matrices
if labeled_use.empty:
    train_matrix = np.empty((0, 0), dtype=np.float32)
    train_labels = np.array([], dtype=int)
else:
    train_matrix = normalize_embeddings(np.vstack(labeled_use['__vec'].to_numpy()))
    train_labels = labeled_use['cluster'].to_numpy()

if repairs_use.empty:
    query_matrix = np.empty((0, 0), dtype=np.float32)
else:
    query_matrix = normalize_embeddings(np.vstack(repairs_use['__vec'].to_numpy()))


In [None]:
from sklearn.neighbors import NearestNeighbors

if len(train_matrix) and len(query_matrix):
    n_neighbors = min(15, len(train_matrix))
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine', algorithm='brute')
    nn.fit(train_matrix)
    distances, indices = nn.kneighbors(query_matrix, return_distance=True)

    def weighted_vote(neighbor_indices, neighbor_distances):
        sims = 1 - neighbor_distances
        scores = defaultdict(float)
        for idx, sim in zip(neighbor_indices, sims):
            scores[train_labels[idx]] += float(sim)
        return max(scores.items(), key=lambda item: item[1])[0]

    predicted_clusters = [
        weighted_vote(ind, dist) for ind, dist in zip(indices, distances)
    ]
    repairs['label_pred'] = predicted_clusters
else:
    repairs['label_pred'] = np.nan

In [None]:
reps = {k: v for k, v in repairs[['bill', 'label_pred']].values if pd.notna(v)}
bbb = bill_labels.copy()
bbb.update(reps)

with open('bill_labels_updated.json', 'w') as f:
    json.dump(bbb, f)