In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LEConv, GATv2Conv, GCNConv, HGTConv, LEConv, GENConv, SAGPooling, HeteroConv

from sklearn.preprocessing import StandardScaler
import time, pickle
from tqdm import tqdm

In [2]:
data = torch.load('gnn_clean.pt', weights_only=False)
metadata = (data.node_types, data.edge_types)
device = torch.device('mps')

In [12]:
data

HeteroData(
  bill_version={
    x=[142952, 389],
    num_nodes=142952,
  },
  bill={
    x=[43937, 769],
    num_nodes=43937,
  },
  legislator={
    x=[508, 385],
    num_nodes=508,
  },
  legislator_term={
    x=[1448, 3],
    num_nodes=1448,
  },
  committee={
    x=[1707, 385],
    num_nodes=1707,
  },
  lobby_firm={
    x=[1206, 384],
    num_nodes=1206,
  },
  donor={
    x=[429, 384],
    num_nodes=429,
  },
  (bill_version, Version, bill)={
    edge_index=[2, 146100],
    edge_attr={ order=[146100, 1] },
  },
  (bill_version, nextVersion, bill_version)={
    edge_index=[2, 102093],
    edge_attr={},
  },
  (legislator_term, samePerson, legislator)={
    edge_index=[2, 1448],
    edge_attr={},
  },
  (committee, member, legislator_term)={
    edge_index=[2, 17633],
    edge_attr={ position=[17633, 1] },
  },
  (legislator_term, lobbying, lobby_firm)={
    edge_index=[2, 65280],
    edge_attr={
      amount=[65280, 1],
      date=[65280, 1],
      expn_dscr=[65127, 384],
    },


In [3]:
with open('input_dims.pkl', 'rb') as f:
    input_dims = pickle.load(f)

In [4]:
def preprocess(data):
    def date_to_float(t):
        if t.dtype == torch.float32:
            return t
        return (t.float() / 86_400)

    for rel in data.edge_types:
        attrs = data[rel].edge_attr
        if attrs is None:
            continue
        for k,v in attrs.items():
            if k == "date":
                attrs[k] = date_to_float(v)

    scalers = {}
    for rel in data.edge_types:
        attrs = data[rel].edge_attr
        if attrs is None:
            continue
        num_cols = {k:v for k,v in attrs.items() if v.dim()==2 and v.size(1)==1}
        if num_cols:
            M = torch.cat(list(num_cols.values()), dim=0).cpu().numpy()
            scaler = StandardScaler().fit(M)
            scalers[rel] = scaler
            for k,v in num_cols.items():
                attrs[k] = torch.as_tensor(
                    scaler.transform(v.cpu().numpy()), dtype=torch.float32)
    return scalers
scalers = preprocess(data)

In [5]:
data = data.to(device)

In [6]:
class EdgeAttrEncoder(nn.Module):
    def __init__(self, attr_dict_sample: dict[str, torch.Tensor],
                 bottleneck=128, out_dim=256):
        super().__init__()
        mods = []
        for k, v in attr_dict_sample.items():
            if v.dim() == 2 and v.size(1) == 1:
                mods.append((k, nn.Linear(1, bottleneck, bias=False)))
            elif v.size(1) == 384:
                mods.append((k, nn.Linear(384, bottleneck, bias=False)))
            else:
                continue
        self.feat_proj = nn.ModuleDict(mods)
        self.mlp       = nn.Sequential(
            nn.ReLU(),
            nn.Linear(len(mods)*bottleneck, out_dim)
        )

    def forward(self, attr_dict: dict[str, torch.Tensor]) -> torch.Tensor:
        if not attr_dict:
            raise ValueError("EdgeAttrEncoder got empty attr_dict")

        max_len = max(v.size(0) for v in attr_dict.values())
        device  = next(self.parameters()).device

        feats = []
        for k, proj in self.feat_proj.items():
            col = attr_dict[k].to(device)
            if col.size(0) < max_len:
                pad_rows = max_len - col.size(0)
                col = F.pad(col, (0, 0, 0, pad_rows))
            feats.append(proj(col))

        return self.mlp(torch.cat(feats, dim=1))

In [None]:
_conv_dict = {
    "LEConv": lambda h, md=None: LEConv(h, h),
    "GATv2Conv": lambda h, md=None: GATv2Conv(
        (-1, -1), h // 4, heads=4, add_self_loops=False),
    "GCNConv": lambda h, md=None: GCNConv((-1, -1), h, add_self_loops=False),
    "HGTConv": lambda h, md: HGTConv(
        in_channels=h, out_channels=h, metadata=md),
    "GENConv": lambda h, md=None: GENConv(h, h, aggr='mean',
                                              t=1.0, learn_t=True, num_layers=2),
}

class StackedEncoder(nn.Module):
    def __init__(self, metadata, in_dims,
                hidden=256,
                conv_names=None,
                pool_ratio=0.3,
                add_pool=False):
        super().__init__()
        self.lin_in = nn.ModuleDict({
            n: nn.Sequential(
                nn.Linear(in_dims[n], hidden),
                nn.ReLU(),
                nn.HeteroLayerNorm(hidden)
            ) for n in metadata[0]
        })

        conv_names = conv_names or ["GATv2Conv", "HGTConv", "GENConv"]
        self.convs = nn.ModuleList()

        for name in conv_names:
            ctor = _conv_dict[name]
            self.convs.append(
                HeteroConv({rel: ctor(hidden, metadata) if name == "HGTConv"
                          else ctor(hidden)
                          for rel in metadata[1]},
                         aggr="mean")
            )

        self.use_pool = add_pool
        if add_pool:
            self.pool = SAGPooling(hidden, ratio=pool_ratio)

        self.bn_pool = nn.BatchNorm1d(hidden)
        self.hidden = hidden

        self.layer_norms = nn.ModuleDict({
            nt: nn.HeteroLayerNorm(hidden) for nt in metadata[0]
        })

    def _merge(self, prev, new):
        out = {}
        for nt in prev.keys() | new.keys():
            if nt in prev and nt in new:
                out[nt] = F.relu(prev[nt] + new[nt])
                out[nt] = self.layer_norms[nt](out[nt])
            else:
                out[nt] = F.relu(new.get(nt, prev[nt]))
        return out

    def forward(self, x_dict, edge_index_dict):
        h = {nt: self.lin_in[nt](x) for nt, x in x_dict.items()}
        for conv in self.convs:
            h = self._merge(h, conv(h, edge_index_dict))

        if self.use_pool:
            bv = h['bill_version']
            bill_edges = None
            edge_key = ('bill_version', 'nextVersion', 'bill_version')
            if edge_key in edge_index_dict:
                bill_edges = edge_index_dict[edge_key]

            if bill_edges is None:
                idx = torch.arange(bv.size(0), device=bv.device)
                bill_edges = torch.stack([idx, idx])

            bv, *rest = self.pool(bv, bill_edges)
            h['bill_version'] = self.bn_pool(bv)

        for nt in h:
            h[nt] = F.normalize(h[nt], p=2, dim=1)

        return h


In [None]:
class LinkRecon(nn.Module):
    def __init__(self, edge_attr_encoders: nn.ModuleDict, weight=1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.mse = nn.MSELoss()
        self.encoders = edge_attr_encoders
        self.weight = weight

    def forward(self, z, data, num_neg=1):
        loss = torch.tensor(0., device=z[next(iter(z))].device)
        for rel, ei in data.edge_index_dict.items():
            s_t, _, d_t = rel
            src, dst = ei
            pos_log = (z[s_t][src] * z[d_t][dst]).sum(-1)
            pos_loss = self.bce(pos_log, torch.ones_like(pos_log))
            attr_dict = data[rel].edge_attr
            if str(rel) in self.encoders and attr_dict and len(attr_dict) > 0:
                enc = self.encoders[str(rel)]
                tgt = enc(attr_dict)
                proj = nn.Linear(z[s_t].size(1), tgt.size(1), bias=False, device=z[s_t].device)
                pred = proj(z[s_t][src] * z[d_t][dst])
                pos_loss = pos_loss + self.mse(pred, tgt)

            neg_dst = torch.randint(0, data[d_t].num_nodes, (src.size(0)*num_neg,), device=src.device)
            neg_src = src.repeat(num_neg)
            neg_log = (z[s_t][neg_src] * z[d_t][neg_dst]).sum(-1)
            neg_loss = self.bce(neg_log, torch.zeros_like(neg_log))
            loss = loss + pos_loss + neg_loss
        return loss * self.weight

In [None]:
class NextVersionContrastive(nn.Module):
    def __init__(self, margin=0.25, weight=0.25):
        super().__init__()
        self.margin = margin
        self.weight = weight

    def forward(self, z, data):
        rel = ('bill_version', 'nextVersion', 'bill_version')
        if rel not in data.edge_index_dict:
            return torch.tensor(0., device=z[next(iter(z))].device)
        src, dst = data.edge_index_dict[rel]
        pos_dist = F.pairwise_distance(z['bill_version'][src], z['bill_version'][dst])

        neg_dst  = dst[torch.randperm(dst.size(0))]
        neg_dist = F.pairwise_distance(z['bill_version'][src], z['bill_version'][neg_dst])

        zeros = torch.zeros_like(pos_dist)
        loss  = torch.mean(torch.maximum(
            zeros, self.margin + pos_dist - neg_dist))
        return loss * self.weight

In [None]:
class StakeholderContrastive(nn.Module):
    def __init__(self, temperature=0.2, weight=0.5):
        super().__init__()
        self.temp = temperature
        self.weight = weight
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, z, data, sample=1000):
        rel = ('lobby_firm', 'Lobbying', 'legislator_term')
        if rel not in data.edge_index_dict:
            return torch.tensor(0., device=z[next(iter(z))].device)
        src, dst = data.edge_index_dict[rel]
        if src.size(0) > sample:
            idx = torch.randperm(src.size(0))[:sample]
            src, dst = src[idx], dst[idx]

        h_dst = z['legislator_term'][dst]
        h_src = z['lobby_firm'][src]
        logits = (h_src @ h_dst.T) / self.temp
        labels = torch.arange(src.size(0), device=device)
        loss = F.cross_entropy(logits, labels)
        return loss * self.weight

In [None]:
class GraphStructureLearning(nn.Module):
    def __init__(self, weight=0.4):
        super().__init__()
        self.weight = weight
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, z, data):
        loss = torch.tensor(0., device=z[next(iter(z))].device)

        key_relations = [
            ('bill_version', 'nextVersion', 'bill_version'),
            ('bill', 'Version',  'bill_version'),
            ('lobby_firm', 'Lobbying', 'legislator_term'),
            ('donor', 'CampaignContribution', 'legislator_term'),
            ('legislator_term', 'Member', 'committee'),
            ('legislator_term', 'Author', 'bill_version')
        ]

        for rel in key_relations:
            if rel not in data.edge_index_dict:
                continue

            src_type, edge_type, dst_type = rel
            src, dst = data.edge_index_dict[rel]
            pos_score = (z[src_type][src] * z[dst_type][dst]).sum(dim=1)
            neg_dst = dst[torch.randperm(dst.size(0))]
            neg_score = (z[src_type][src] * z[dst_type][neg_dst]).sum(dim=1)

            loss += self.criterion(pos_score, torch.ones_like(pos_score))
            loss += self.criterion(neg_score, torch.zeros_like(neg_score))

        return loss * self.weight

In [None]:
def run_epoch(encoder, tasks, data, opt, device):
    encoder.train()
    z = encoder(data.x_dict, data.edge_index_dict)

    loss_total = torch.tensor(0., device=device)

    for task in tasks:
        loss = task(z, data)
        loss_total += loss

    opt.zero_grad(set_to_none=True)
    loss_total.backward()
    opt.step()
    return float(loss_total)

encoder = StackedEncoder(
    data.metadata(), input_dims,
    hidden=256,
    conv_names=["GATv2Conv", "HGTConv", "GENConv"],
    pool_ratio=0.3,
    add_pool=True
).to(device)

edge_attr_enc = nn.ModuleDict()
for rel in data.edge_types:
    attrs = data[rel].edge_attr
    if attrs is None:
        continue
    usable = {k: v for k, v in attrs.items()
              if (v.dim() == 2 and (v.size(1) == 1 or v.size(1) == 384)) or (v.dim() == 1 and v.size(0) != 0)}
    if usable:
        edge_attr_enc[str(rel)] = EdgeAttrEncoder(usable)
edge_attr_enc = edge_attr_enc.to(device)


tasks = [
    LinkRecon(edge_attr_enc, weight=1.0),
    StakeholderContrastive(weight=0.25),
    NextVersionContrastive(weight=0.5),
    GraphStructureLearning(weight=0.4)
]



In [None]:
optimizer = torch.optim.AdamW(
    encoder.parameters(),
    lr=1e-5
)

In [14]:
import numpy as np

EPOCHS = 100
for epoch in tqdm(range(1, EPOCHS+1), desc="Epochs"):
    t0   = time.time()
    loss = run_epoch(encoder, tasks, data, optimizer, device)
    dt   = time.time() - t0
    print(f"[{epoch:02d}/{EPOCHS}] "
            f"loss={loss:.4f}  time={dt:.1f}s")
    if loss == np.nan or loss == None:
        break

    encoder.eval()
    with torch.inference_mode():
        z = encoder(data.x_dict, data.edge_index_dict)
    torch.save(z, "node_embeddings_2.pt")

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

[01/100] loss=10009.3574  time=16.4s


Epochs:   1%|          | 1/100 [00:28<46:41, 28.30s/it]

[02/100] loss=nan  time=22.6s


Epochs:   2%|▏         | 2/100 [01:05<54:17, 33.24s/it]

[03/100] loss=nan  time=19.1s


Epochs:   2%|▏         | 2/100 [01:27<1:11:32, 43.80s/it]


KeyboardInterrupt: 

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def visualize_embeddings(z, labels=None, title="Embedding Visualization"):
    tsne = TSNE(n_components=2, perplexity=50, random_state=42)
    z_2d = tsne.fit_transform(z.cpu().detach().numpy())

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(z_2d[:, 0], z_2d[:, 1], c=labels, cmap='tab10', alpha=0.7)
    if labels is not None:
        plt.legend(*scatter.legend_elements(), title="Labels")
    plt.title(title)
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.grid(True)
    plt.show()

In [None]:
d = torch.load('gnn_data.pt', weights_only=False)
visualize_embeddings(z['legislator'], labels=d['legislator'].x['party'].tolist(), title="Legislators by Party")

In [None]:
california_topics = [
    "Public health",
    "Mental health services",
    "Medi-Cal and health insurance",
    "Substance abuse and harm reduction",
    "Child welfare and foster care",
    "Developmental and disability services",
    "Elder care and long-term support",

    "K-12 education funding",
    "Curriculum and instruction",
    "Higher education and UC/CSU systems",
    "Community colleges",
    "School construction and facilities",
    "Special education",
    "Charter schools and school choice",

    "Climate change and carbon reduction",
    "Air quality and pollution control",
    "Water supply and drought",
    "Coastal protection",
    "Wildfire prevention and forestry",
    "Environmental justice",
    "Recycling and waste management",

    "Criminal justice reform",
    "Police oversight and accountability",
    "Firearms and gun control",
    "Corrections and parole",
    "Emergency services and disaster response",
    "Human trafficking prevention",

    "State budget and fiscal policy",
    "Personal and corporate income taxes",
    "Sales and use taxes",
    "Proposition 13 and property tax",
    "State bonds and financing",
    "Local government finance",

    "Workplace safety and Cal/OSHA",
    "Paid family leave",
    "Minimum wage and wage theft",
    "Public employee unions",
    "Employment discrimination and DEI",
    "Workforce development",

    "Redistricting and electoral reform",
    "Voter access and registration",
    "Campaign finance and lobbying",
    "Open meetings and transparency (Brown Act)",
    "Public records and data access",
    "Government agency operations",

    "Roads and highways (Caltrans)",
    "Public transit and rail",
    "High-speed rail",
    "Ports and logistics",
    "Vehicle emissions and EV policy",
    "Infrastructure resilience",

    "Affordable housing development",
    "Zoning and local control",
    "Tenant protections and rent control",
    "Homelessness and supportive housing",
    "CEQA and environmental permitting",
    "Redevelopment and gentrification",

    "Electric grid and reliability",
    "Renewable energy incentives",
    "Utility regulation (CPUC)",
    "Natural gas and oil regulation",
    "Wildfire liability (PG&E)",
    "Broadband and digital equity",

    "Immigration and sanctuary laws",
    "LGBTQ+ rights",
    "Gender equity and reproductive health",
    "Racial equity and anti-discrimination",
    "Food insecurity and public benefits",
    "Language access and cultural inclusion",

    "Water rights and agriculture",
    "Pesticide regulation",
    "Farmworker labor conditions",
    "Fisheries and marine policy",
    "Wildlife conservation",
    "Timber and land management",

    "Small business support",
    "Technology and innovation",
    "Cannabis regulation",
    "Insurance industry oversight",
    "Economic stimulus and recovery",
    "Licensing and regulation (e.g., BAR, ABC)",

    "Financial services and predatory lending",
    "Data privacy and cybersecurity",
    "Product safety and recalls",
    "Housing scams and fraud",
    "Telemarketing and spam regulation",

    "Freedom of speech and assembly",
    "Facial recognition and surveillance",
    "Disability rights",
    "Due process protections",
    "First Amendment in schools/public spaces",

    "State-local relations",
    "Tribal affairs",
    "Military and veterans issues",
    "COVID-19 response and recovery",
    "Technology in government"
]

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L6-v2', truncate_dim=256)
topic_embs = model.encode(california_topics, normalize_embeddings=True, convert_to_tensor=True, show_progress_bar=True)

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm

z_text_all = torch.cat([
    z[0]['bill'],
    z[0]['bill_version'],
    z[0]['committee'],
    z[0]['legislator'],
    z[0]['donor'],
    z[0]['lobby_firm'],
], dim=0)

k = 25
kmeans = KMeans(n_clusters=k, random_state=40)
cluster_ids = kmeans.fit_predict(z_text_all.cpu())
score = silhouette_score(z_text_all.cpu(), kmeans.labels_)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

centroids = kmeans.cluster_centers_
sims = cosine_similarity(centroids, topic_embs.cpu().numpy())
top_topic_ids = sims.argmax(axis=1)
cluster_topic_labels = [california_topics[i] for i in top_topic_ids]

In [None]:
import pandas as pd

df = pd.DataFrame({
    'cluster': list(range(k)),
    'predicted_topic': cluster_topic_labels,
    'most_similar_score': sims.max(axis=1)
})