In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import GATConv, HeteroConv, GATv2Conv, SAGEConv
import time, pickle, inspect
from tqdm import tqdm
import matplotlib.pyplot as plt

In [4]:
data = torch.load('data.pt', weights_only=False)
metadata = (data.node_types, data.edge_types)
device = torch.device('mps')

In [5]:
data

HeteroData(
  bill_version={
    x=[1, 390],
    num_nodes=193262,
  },
  bill={
    x=[1, 770],
    y=[1],
    num_nodes=45357,
  },
  legislator={
    x=[1, 385],
    num_nodes=508,
  },
  legislator_term={
    x=[1, 3],
    num_nodes=1448,
  },
  committee={
    x=[1, 385],
    num_nodes=1707,
  },
  lobby_firm={
    x=[1, 384],
    num_nodes=1325,
  },
  donor={
    x=[1, 384],
    num_nodes=1136,
  },
  (bill_version, is_version, bill)={
    edge_index=[2, 142952],
    edge_attr=[142952, 1],
    num_edges=142952,
  },
  (bill_version, priorVersion, bill_version)={
    edge_index=[2, 100658],
    num_edges=100658,
  },
  (legislator, samePerson, legislator_term)={
    edge_index=[2, 1448],
    num_edges=1448,
  },
  (legislator_term, member_of, committee)={
    edge_index=[2, 17633],
    edge_attr=[17633, 1],
    num_edges=17633,
  },
  (lobby_firm, lobbied, legislator_term)={
    edge_index=[2, 65280],
    edge_attr=[65280, 2],
    num_edges=65280,
  },
  (lobby_firm, lobbied, com

In [3]:
with open('input_dims.pkl', 'rb') as f:
    input_dims = pickle.load(f)

In [4]:
from sklearn.preprocessing import StandardScaler

def preprocess(data):
    def date_to_float(t):
        if t.dtype == torch.float32:
            return t
        return (t.float() / 86_400)

    for rel in data.edge_types:
        attrs = data[rel].edge_attr
        if attrs is None:
            continue
        for k,v in attrs.items():
            if k == "date":
                attrs[k] = date_to_float(v)

    scalers = {}
    for rel in data.edge_types:
        attrs = data[rel].edge_attr
        if attrs is None:
            continue
        num_cols = {k:v for k,v in attrs.items() if v.dim()==2 and v.size(1)==1}
        if num_cols:
            M = torch.cat(list(num_cols.values()), dim=0).cpu().numpy()
            scaler = StandardScaler().fit(M)
            scalers[rel] = scaler
            for k,v in num_cols.items():
                attrs[k] = torch.as_tensor(
                    scaler.transform(v.cpu().numpy()), dtype=torch.float32)
    return scalers
scalers = preprocess(data)
data = data.to(device)

In [5]:
class EdgeAttrEncoder(nn.Module):
    def __init__(self, attr_dict_sample: dict[str, torch.Tensor],
                 bottleneck=128, out_dim=256):
        super().__init__()
        mods = []
        for k, v in attr_dict_sample.items():
            if v.dim() == 2 and v.size(1) == 1:
                mods.append((k, nn.Linear(1, bottleneck, bias=False)))
            elif v.size(1) == 384:
                mods.append((k, nn.Linear(384, bottleneck, bias=False)))
            else:
                continue
        self.feat_proj = nn.ModuleDict(mods)
        self.mlp       = nn.Sequential(
            nn.ReLU(),
            nn.Linear(len(mods)*bottleneck, out_dim)
        )

    def forward(self, attr_dict: dict[str, torch.Tensor]) -> torch.Tensor:
        if not attr_dict:
            raise ValueError("EdgeAttrEncoder got empty attr_dict")

        max_len = max(v.size(0) for v in attr_dict.values())
        device  = next(self.parameters()).device

        feats = []
        for k, proj in self.feat_proj.items():
            col = attr_dict[k].to(device)
            if col.size(0) < max_len:
                pad_rows = max_len - col.size(0)
                col = F.pad(col, (0, 0, 0, pad_rows))
            feats.append(proj(col))

        return self.mlp(torch.cat(feats, dim=1))

class BigEncoder(nn.Module):
    def __init__(self, metadata, in_dims: dict,
                 hidden=256, layers=3):
        super().__init__()
        self.lin = nn.ModuleDict({
            ntype: nn.Linear(in_dim, hidden)
            for ntype, in_dim in in_dims.items()
        })
        self.convs = nn.ModuleList([
            HeteroConv({rel: SAGEConv((-1, -1), hidden)
                        for rel in metadata[1]}, aggr="mean")
            for _ in range(layers)
        ])

    def forward(self, x_dict, edge_index_dict):
        h_dict = {nt: F.relu(self.lin[nt](x)) for nt, x in x_dict.items()}

        for conv in self.convs:
            out_dict = conv(h_dict, edge_index_dict)

            for nt in h_dict:
                if nt not in out_dict:
                    out_dict[nt] = h_dict[nt]

            h_dict = {nt: F.relu(h) for nt, h in out_dict.items()}

        return h_dict


In [None]:
class LinkRecon(nn.Module):
    def __init__(self, edge_attr_encoders: nn.ModuleDict, weight=1.0):
        super().__init__()
        self.bce       = nn.BCEWithLogitsLoss()
        self.mse       = nn.MSELoss()
        self.encoders  = edge_attr_encoders
        self.weight    = weight

    def forward(self, z, data, num_neg=1):
        loss = torch.tensor(0., device=z[next(iter(z))].device)
        for rel, ei in data.edge_index_dict.items():
            s_t, _, d_t = rel
            src, dst = ei
            pos_log  = (z[s_t][src] * z[d_t][dst]).sum(-1)
            pos_loss = self.bce(pos_log, torch.ones_like(pos_log))
            attr_dict = data[rel].edge_attr
            if str(rel) in self.encoders and attr_dict and len(attr_dict) > 0:
                enc  = self.encoders[str(rel)]
                tgt  = enc(attr_dict)
                proj = nn.Linear(z[s_t].size(1), tgt.size(1),
                                bias=False, device=z[s_t].device)
                pred = proj(z[s_t][src] * z[d_t][dst])
                pos_loss = pos_loss + self.mse(pred, tgt)

            neg_dst = torch.randint(0, data[d_t].num_nodes, (src.size(0)*num_neg,), device=src.device)
            neg_src = src.repeat(num_neg)
            neg_log = (z[s_t][neg_src] * z[d_t][neg_dst]).sum(-1)
            neg_loss = self.bce(neg_log,
                                torch.zeros_like(neg_log))
            loss = loss + pos_loss + neg_loss
        return loss * self.weight

class MaskedFeatRecon(nn.Module):
    def __init__(self, mask_prob=0.05, weight=0.3):
        super().__init__()
        self.mask_prob = mask_prob
        self.weight    = weight
        self.mse       = nn.MSELoss()

    def forward(self, encoder, data, z_full):
        device = z_full[next(iter(z_full))].device
        loss   = torch.tensor(0., device=device)

        for nt in data.node_types:
            x = data[nt].x
            node_mask = torch.rand(x.size(0), device=device) < self.mask_prob
            if not node_mask.any():
                continue

            x_dict_masked = {k: v for k, v in data.x_dict.items()}
            x_masked = x.clone()
            x_masked[node_mask] = 0.
            x_dict_masked[nt] = x_masked

            z_masked = encoder(x_dict_masked, data.edge_index_dict)

            loss = loss + self.mse(z_masked[nt][node_mask],
                                   z_full[nt][node_mask])

        return loss * self.weight



In [7]:
def run_epoch(encoder, tasks, data, opt, device):
    encoder.train()
    z_full = encoder(data.x_dict, data.edge_index_dict)

    loss_total = torch.tensor(0., device=device)
    for task in tasks:
        if isinstance(task, MaskedFeatRecon):
            loss = task(encoder, data, z_full)   # pass z_full
        else:
            loss = task(z_full, data)
        loss_total = loss_total + loss

    opt.zero_grad(set_to_none=True)
    loss_total.backward()
    opt.step()
    return float(loss_total)


encoder  = BigEncoder(data.metadata(), input_dims).to(device)
edge_attr_enc = nn.ModuleDict()
for rel in data.edge_types:
    attrs = data[rel].edge_attr
    if attrs is None:
        continue
    usable = {k: v for k, v in attrs.items()
              if (v.dim() == 2 and (v.size(1) == 1 or v.size(1) == 384)) or (v.dim() == 1 and v.size(0) != 0)}
    if usable:
        edge_attr_enc[str(rel)] = EdgeAttrEncoder(usable)
edge_attr_enc = edge_attr_enc.to(device)


tasks = [
    LinkRecon(edge_attr_enc, weight=1.0),
    MaskedFeatRecon(mask_prob=0.05, weight=0.3),
]
opt = torch.optim.AdamW(encoder.parameters(), lr=1e-3, weight_decay=1e-4, amsgrad=True)
EPOCHS = 250
for epoch in range(1, EPOCHS+1):
    t0   = time.time()
    loss = run_epoch(encoder, tasks, data, opt, device)
    dt   = time.time() - t0
    print(f"[{epoch:02d}/{EPOCHS}] "
            f"loss={loss:.4f}  time={dt:.1f}s")

    encoder.eval()
    with torch.inference_mode():
        z = encoder(data.x_dict, data.edge_index_dict)
    torch.save(z, "node_embeddings_3.pt")



[01/250] loss=192.3003  time=65.6s
[02/250] loss=9.8913  time=101.7s
[03/250] loss=9.8773  time=76.1s
[04/250] loss=9.8579  time=81.0s
[05/250] loss=9.8555  time=75.5s
[06/250] loss=9.8325  time=79.7s
[07/250] loss=9.8259  time=54.1s
[08/250] loss=9.8148  time=84.7s
[09/250] loss=9.8001  time=83.7s
[10/250] loss=9.7933  time=88.7s
[11/250] loss=9.7775  time=69.8s
[12/250] loss=9.7711  time=86.0s
[13/250] loss=9.7524  time=64.0s
[14/250] loss=9.7403  time=69.8s
[15/250] loss=9.7153  time=50.0s
[16/250] loss=9.6961  time=67.9s
[17/250] loss=9.6614  time=60.1s
[18/250] loss=9.6251  time=56.2s
[19/250] loss=9.5705  time=78.7s
[20/250] loss=9.5147  time=63.7s
[21/250] loss=9.4460  time=70.5s
[22/250] loss=9.3892  time=90.6s
[23/250] loss=9.3593  time=88.9s
[24/250] loss=9.3467  time=50.3s
[25/250] loss=9.2982  time=67.3s
[26/250] loss=9.2427  time=70.4s
[27/250] loss=9.1961  time=67.7s
[28/250] loss=9.1283  time=69.9s
[29/250] loss=9.0844  time=72.0s
[30/250] loss=9.0211  time=62.7s
[31/250

KeyboardInterrupt: 

In [8]:
california_topics = [
    "Public health",
    "Mental health services",
    "Medi-Cal and health insurance",
    "Substance abuse and harm reduction",
    "Child welfare and foster care",
    "Developmental and disability services",
    "Elder care and long-term support",

    "K-12 education funding",
    "Curriculum and instruction",
    "Higher education and UC/CSU systems",
    "Community colleges",
    "School construction and facilities",
    "Special education",
    "Charter schools and school choice",

    "Climate change and carbon reduction",
    "Air quality and pollution control",
    "Water supply and drought",
    "Coastal protection",
    "Wildfire prevention and forestry",
    "Environmental justice",
    "Recycling and waste management",

    "Criminal justice reform",
    "Police oversight and accountability",
    "Firearms and gun control",
    "Corrections and parole",
    "Emergency services and disaster response",
    "Human trafficking prevention",

    "State budget and fiscal policy",
    "Personal and corporate income taxes",
    "Sales and use taxes",
    "Proposition 13 and property tax",
    "State bonds and financing",
    "Local government finance",

    "Workplace safety and Cal/OSHA",
    "Paid family leave",
    "Minimum wage and wage theft",
    "Public employee unions",
    "Employment discrimination and DEI",
    "Workforce development",

    "Redistricting and electoral reform",
    "Voter access and registration",
    "Campaign finance and lobbying",
    "Open meetings and transparency (Brown Act)",
    "Public records and data access",
    "Government agency operations",

    "Roads and highways (Caltrans)",
    "Public transit and rail",
    "High-speed rail",
    "Ports and logistics",
    "Vehicle emissions and EV policy",
    "Infrastructure resilience",

    "Affordable housing development",
    "Zoning and local control",
    "Tenant protections and rent control",
    "Homelessness and supportive housing",
    "CEQA and environmental permitting",
    "Redevelopment and gentrification",

    "Electric grid and reliability",
    "Renewable energy incentives",
    "Utility regulation (CPUC)",
    "Natural gas and oil regulation",
    "Wildfire liability (PG&E)",
    "Broadband and digital equity",

    "Immigration and sanctuary laws",
    "LGBTQ+ rights",
    "Gender equity and reproductive health",
    "Racial equity and anti-discrimination",
    "Food insecurity and public benefits",
    "Language access and cultural inclusion",

    "Water rights and agriculture",
    "Pesticide regulation",
    "Farmworker labor conditions",
    "Fisheries and marine policy",
    "Wildlife conservation",
    "Timber and land management",

    "Small business support",
    "Technology and innovation",
    "Cannabis regulation",
    "Insurance industry oversight",
    "Economic stimulus and recovery",
    "Licensing and regulation (e.g., BAR, ABC)",

    "Financial services and predatory lending",
    "Data privacy and cybersecurity",
    "Product safety and recalls",
    "Housing scams and fraud",
    "Telemarketing and spam regulation",

    "Freedom of speech and assembly",
    "Facial recognition and surveillance",
    "Disability rights",
    "Due process protections",
    "First Amendment in schools/public spaces",

    "State-local relations",
    "Tribal affairs",
    "Military and veterans issues",
    "COVID-19 response and recovery",
    "Technology in government"
]

In [2]:
d = torch.load('gnn_data.pt', weights_only=False)

FileNotFoundError: [Errno 2] No such file or directory: 'gnn_data.pt'

In [16]:
d

HeteroData(
  bill_version={
    x={
      digest=[142952, 384],
      VoteRequired=[142952, 1],
      LocalProgram=[142952, 1],
      FiscalCommittee=[142952, 1],
      TaxLevy=[142952, 1],
      Urgency=[142952, 1],
    },
  },
  bill={
    x={
      title=[43937, 384],
      subject=[43937, 384],
      measure_type=[43937, 1],
    },
  },
  legislator={
    x={
      party=[508, 1],
      occupation=[508, 384],
    },
  },
  legislator_term={
    x={
      chamber=[1448, 1],
      district=[1448, 1],
      term=[1448, 1],
    },
  },
  committee={
    x={
      name=[1707, 384],
      chamber=[1707, 1],
    },
  },
  lobby_firm={
    x={ name=[1206, 384] },
  },
  donor={
    x={ name=[429, 384] },
  },
  (bill_version, Version, bill)={
    edge_index=[2, 146100],
    edge_attr={ order=[146100, 1] },
  },
  (bill_version, nextVersion, bill_version)={
    edge_index=[2, 102093],
    edge_attr={},
  },
  (legislator_term, samePerson, legislator)={
    edge_index=[2, 1448],
    edge_at

In [9]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L6-v2', truncate_dim=256)
topic_embs = model.encode(california_topics, normalize_embeddings=True, convert_to_tensor=True, show_progress_bar=True)

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm

z_text_all = torch.cat([
    z['bill'],
    z['bill_version'],
    z['committee'],
    z['legislator'],
    z['donor'],
    z['lobby_firm'],
], dim=0)

k = 25
kmeans = KMeans(n_clusters=k, random_state=40)
cluster_ids = kmeans.fit_predict(z_text_all.cpu())
score = silhouette_score(z_text_all.cpu(), kmeans.labels_)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


KeyboardInterrupt: 

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

centroids = kmeans.cluster_centers_
sims = cosine_similarity(centroids, topic_embs.cpu().numpy())
top_topic_ids = sims.argmax(axis=1)
cluster_topic_labels = [california_topics[i] for i in top_topic_ids]

In [None]:
import pandas as pd

df = pd.DataFrame({
    'cluster': list(range(k)),
    'predicted_topic': cluster_topic_labels,
    'most_similar_score': sims.max(axis=1)
})