In [1]:
import math, gc, warnings, json, datetime, torch, pickle
from collections import defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
import polars as pl
from torch.nn import functional as F
from torch_geometric.transforms import ToUndirected, RemoveIsolatedNodes
from torch_geometric.utils import to_networkx
import networkx as nx

In [2]:
def herfindahl(shares):
    return float(np.square(shares).sum())

def entropy(shares, eps=1e-9):
    shares = shares.clip(min=eps)
    return float(-(shares * np.log(shares)).sum())

def jsd(p, q, eps=1e-9):
    m = 0.5 * (p + q)
    return 0.5 * (
        F.kl_div(m.log(), p, reduction='none').sum(-1) +
        F.kl_div(m.log(), q, reduction='none').sum(-1))

def slope(y):
    if y.size < 2:
        return 0.0
    x = np.arange(y.size, dtype=np.float32)
    return float(np.polyfit(x, y, 1)[0])

def safe_ratio(n, d):
    return 0.0 if d == 0 else n / d

In [3]:
def union_parquet(glob_pat):
    files = Path('../data').glob(glob_pat)
    if not files:
        return pd.DataFrame()
    df = pd.concat([pd.read_parquet(f) for f in files], ignore_index=True)
    return df

bills = union_parquet('bills_kpis_*.parquet')
leg_term_kpi = union_parquet('legislator_kpis_*.parquet')
committee_kpi = union_parquet('committee_kpis_*.parquet')
donor_kpi = union_parquet('donor_kpis_*.parquet')
lobby_kpi = union_parquet('lobby_firm_kpis_*.parquet')
topic_snapshot = union_parquet('topic_snapshot_*.parquet')

In [4]:
def explode_topic_probs(df, actor_type):
    if "topic_probs" not in df.columns:
        return pd.DataFrame()
    out = (
        df[["node_id", "topic_probs"]]
        .explode("topic_probs")
        .reset_index(drop=True)
    )
    out["topic_id"]   = out.groupby("node_id").cumcount()
    out["topic_prob"] = out["topic_probs"].astype(float)
    out = out.drop(columns="topic_probs")
    out["actor_type"] = actor_type
    return out
SRC = Path('../data')
topic_prob_dfs = []
for nt in ["legislator", "committee", "donor", "lobby_firm"]:
    df = union_parquet(f"*{nt}_topic_probs_*.parquet")
    if not df.empty:
        topic_prob_dfs.append(explode_topic_probs(df, nt))
actor_topic_long = (
    pd.concat(topic_prob_dfs, ignore_index=True) if topic_prob_dfs
    else pd.DataFrame(columns=["node_id","topic_id","topic_prob","actor_type"])
)
actor_topic_long.to_parquet(SRC / "actor_topic_relevance.parquet", index=False)


In [96]:
def compute_controversiality(data):
    edge_type = ('legislator_term', 'voted_on', 'bill_version')
    if edge_type not in data.edge_index_dict:
        raise ValueError("Missing 'voted_on' edges in data.")

    ei = data[edge_type].edge_index
    ea = data[edge_type].edge_attr

    vote_signal = ea[:, 0]

    src_nodes = ei[0]
    tgt_nodes = ei[1]

    num_bills = data['bill_version'].num_nodes
    device = tgt_nodes.device

    yes_votes = torch.zeros(num_bills, device=device)
    no_votes = torch.zeros(num_bills, device=device)

    yes_votes.index_add_(0, tgt_nodes, (vote_signal > 0).float())
    no_votes.index_add_(0, tgt_nodes, (vote_signal < 0).float())

    total_votes = yes_votes + no_votes + 1e-6

    yes_ratio = yes_votes / total_votes
    no_ratio = no_votes / total_votes

    controversy = 4 * yes_ratio * no_ratio
    controversy = controversy.clamp(0, 1)
    data['bill_version'].controversy = controversy

    return data

def safe_normalize_timestamps(timestamps: torch.Tensor) -> torch.Tensor:
    timestamps = torch.nan_to_num(timestamps, nan=0.0, posinf=1e4, neginf=-1e4)
    min_time = timestamps.min()
    max_time = timestamps.max()
    if (max_time - min_time) < 1e-4:
        return torch.zeros_like(timestamps)
    return (timestamps - min_time) / (max_time - min_time)

def safe_standardize_time_format(time_data) -> torch.Tensor:
    times = []
    for t in time_data:
        try:
            if isinstance(t, (int, float)) and 1900 <= t  and t <= 2100:
                td = datetime.datetime(int(t), 6, 15).timestamp()
            elif (isinstance(t, str) or (isinstance(t, float))) and (float(t) < 2100 and float(t) > 1900):
                td = datetime.datetime(int(float(t)), 6, 15).timestamp()
            elif float(t) > 0 and float(t) < 1990:
                td = t
            elif float(t) > 17000000.0:
                td = float(t)
            elif isinstance(t, datetime.datetime):
                td = t.timestamp()
            else:
                td = float(t) * 1e9
        except:
            td = datetime.datetime(2000, 6, 15).timestamp()
        times.append(td)
    return torch.tensor(times, dtype=torch.float32)

def pull_timestamps(data):
    timestamp_edges = [
        ('donor', 'donated_to', 'legislator_term'),
        ('legislator_term', 'rev_donated_to', 'donor'),
        ('lobby_firm', 'lobbied', 'legislator_term'),
        ('lobby_firm', 'lobbied', 'committee'),
        ('committee', 'rev_lobbied', 'lobby_firm'),
        ('legislator_term', 'rev_lobbied', 'lobby_firm'),
        ('bill_version', 'rev_voted_on', 'legislator_term'),
        ('legislator_term', 'voted_on', 'bill_version'),
    ]
    timestamp_nodes = ['legislator_term', 'bill_version', 'bill']

    for et in timestamp_edges:
        if hasattr(data[et], 'edge_attr') and data[et].edge_attr is not None and len(data[et].edge_attr.size()) > 1:
            if data[et].edge_attr.size(1) > 1:
                edge_attr = data[et].edge_attr
                ts_col = edge_attr[:, -1]
                data[et].timestamp = safe_normalize_timestamps(ts_col)
                data[et].time = ts_col
                data[et].edge_attr = edge_attr[:, :-1]

    for nt in timestamp_nodes:
        if hasattr(data[nt], 'x') and data[nt].x is not None:
            try:
                if len(data[nt].x.size()) > 1:
                    if data[nt].x.size(1) > 1:
                        x = data[nt].x
                        ts_col = x[:, -1]
                        data[nt].timestamp = safe_normalize_timestamps(ts_col)
                        data[nt].x = x[:, :-1]
                        data[nt].time = ts_col
            except:
                pass
    return data
def clean_features(data):
    data = pull_timestamps(data)
    for nt in data.node_types:
        x = data[nt].x
        if not isinstance(x, torch.Tensor) or x.numel() == 0:
            data[nt].x = torch.from_numpy(np.vstack(x)).float()
            x = data[nt].x
        x = torch.nan_to_num(x.float(), nan=0.0, posinf=1e4, neginf=-1e4)
        if x.size(0) < 2 or torch.all(x == x[0]):
            mean = x.clone()
            std = torch.ones_like(x)
            x_clean = x.clone()
        else:
            mean = x.mean(dim=0, keepdim=True)
            std = x.std(dim=0, keepdim=True).clamp(min=1e-5)
            x_clean = (x - mean) / std
            x_clean = x_clean.clamp(-10.0, 10.0)
        data[nt].x = x_clean
        data[nt].x_mean = mean
        data[nt].x_std = std
    return data

def load_and_preprocess_data(path='../../../GNN/data2.pt'):
    full_data = torch.load(path, weights_only=False)
    for nt in full_data.node_types:
        if hasattr(full_data[nt], 'x') and full_data[nt].x is not None:
            full = torch.from_numpy(full_data[nt].x)
            s = full.size()
            full = torch.flatten(full, start_dim=1, end_dim=-1)
            full_data[nt].x = full
            full_data[nt].num_nodes = full.size(0)

    # Check and fix edge indices before transformation
    for edge_type, edge_index in full_data.edge_index_dict.items():
        src_type, _, dst_type = edge_type

        # Get max node indices
        max_src_idx = edge_index[0].max().item() if edge_index.size(1) > 0 else -1
        max_dst_idx = edge_index[1].max().item() if edge_index.size(1) > 0 else -1

        # Ensure node counts are sufficient
        if max_src_idx >= full_data[src_type].num_nodes:
            print(f"Fixing {src_type} node count: {full_data[src_type].num_nodes} -> {max_src_idx + 1}")
            full_data[src_type].num_nodes = max_src_idx + 1

        if max_dst_idx >= full_data[dst_type].num_nodes:
            print(f"Fixing {dst_type} node count: {full_data[dst_type].num_nodes} -> {max_dst_idx + 1}")
            full_data[dst_type].num_nodes = max_dst_idx + 1

    data = ToUndirected(merge=False)(full_data)
    del full_data
    gc.collect()
    data = RemoveIsolatedNodes()(data)
    data = compute_controversiality(clean_features(data))
    for store in data.stores:
        for key, value in store.items():
            if isinstance(value, torch.Tensor) and value.dtype == torch.float64:
                store[key] = value.float()
    return data

data = load_and_preprocess_data()

Fixing bill node count: 13164 -> 45350
Fixing legislator node count: 478 -> 508


In [122]:
for nt in data.node_types:
    data[nt].node_id = torch.arange(data[nt].num_nodes, dtype=torch.long)

In [109]:
v2b_edge = next(et for et in data.edge_types
                if et[0] == "bill_version" and et[2] == "bill")
src, dst = data[v2b_edge].edge_index.numpy()
bv_ts  = data["bill_version"].time.numpy()
bv_df  = pd.DataFrame({"bill_version": src, "bill_id": dst, "ts": bv_ts[src]})
bill_dates = (
    bv_df.groupby("bill_id")["ts"]
         .agg(intro_date="min", last_action="max")
         .reset_index()
         .assign(
            intro_date=lambda d: pd.to_datetime(d.intro_date, unit='s'),
            last_action=lambda d: pd.to_datetime(d.last_action, unit='s'))
)
bill = bills.merge(bill_dates, on="bill_id", how="left")

In [112]:
bill["week"] = bill["intro_date"].dt.to_period("W").dt.start_time
bill["bill_velocity_days"] = (bill["last_action"] - bill["intro_date"]).dt.days

In [10]:
with open("../../../node_id_map.json", "r") as f:
    node_id_map = json.load(f)

In [11]:
with open("../../../committees.pkl", "rb") as f:
    committees = pickle.load(f)

In [123]:
com_names = {v:k for k,v in node_id_map["committee"].items()}
committee_ids = {}
for d in data['committee'].node_id.numpy():
    name = com_names.get(int(d), None)
    if name is not None:
        committee_ids[d] = name
committee_ids = pd.DataFrame(committee_ids.items(), columns=["committee_id", "committee_name"])
committee_ids['term'] = committee_ids['committee_name'].str.split('_').str[1].astype(int)
committee_ids['com_id'] = committee_ids['committee_name'].str.split('_').str[0].astype(int)
committee_ids['name'] = committee_ids['com_id'].map(committees)

In [124]:
committee = committee_ids.join(committee_kpi, on='committee_id', how='right')

In [14]:
policy_embs = torch.load("../../../GNN/policy_embeddings.pt", map_location='cpu', weights_only=False)
T = policy_embs.size(0)

In [15]:
bill_node_id_rep = pd.DataFrame.from_dict({v:k for k,v in node_id_map['bill'].items()}, orient='index').reset_index().rename(columns={0:'bill_id', 'index': 'original_node_id'}).sort_values('original_node_id').reset_index(names='node_id')

In [16]:
from torch_geometric.data import HeteroData

In [125]:
keep_nt = {"bill", "legislator_term", "committee", "donor", "lobby_firm"}
proj = HeteroData()
for nt in keep_nt: proj[nt].num_nodes = data[nt].num_nodes
proj_edges = [et for et in data.edge_types if et[0] in keep_nt and et[2] in keep_nt and hasattr(data[et], 'edge_index')]
for et in proj_edges:
    proj[et].edge_index = data[et].edge_index
G = to_networkx(proj)
pagerank = nx.pagerank(G, alpha=0.9)
degree = dict(G.degree())

def add_centrality(df, ntype):
    df = df.copy()
    df["pagerank"] = df.node_id.map(lambda i: pagerank.get((ntype,int(i)),0.0))
    df["degree"]   = df.node_id.map(lambda i: degree.get((ntype,int(i)),0))
    return df

committee = add_centrality(committee, "committee")
donor = add_centrality(donor_kpi, "donor")
lobby = add_centrality(lobby_kpi, "lobby_firm")
leg_term = add_centrality(leg_term_kpi, "legislator_term")
bill["pagerank"] = bill.bill_id.map(lambda i: pagerank.get(("bill",int(i)),0.0))
bill["degree"] = bill.bill_id.map(lambda i: degree.get(("bill",int(i)),0))

In [126]:
src, dst = data[("bill_version","rev_wrote","legislator_term")].edge_index.numpy()
sponsor_map = pd.DataFrame({"bill_id": src, "node_id": dst})
sponsor_infl = leg_term.set_index("node_id")["influence"]
valid_sponsor_map = sponsor_map[sponsor_map["node_id"].isin(sponsor_infl.index)]
sponsor_power = (valid_sponsor_map.groupby("bill_id")["node_id"]
                                 .agg(lambda ids: sponsor_infl.loc[ids].mean())
                                 .rename("sponsor_power"))
bill = bill.join(sponsor_power, on="bill_id", how="left")

In [127]:
def attach_topk_topics(df, actor_type, K=3):
    if actor_topic_long.empty: return df
    tk = (actor_topic_long[actor_topic_long.actor_type == actor_type]
            .sort_values(["node_id","topic_prob"], ascending=[True,False])
            .groupby("node_id").head(K))
    wide = (tk.set_index(["node_id", tk.groupby("node_id").cumcount() + 1])
              .unstack(level=1))
    wide.columns = [f"{col[0]}_{col[1]}" for col in wide.columns]
    return df.merge(wide.reset_index(), on="node_id", how="left")

committee = attach_topk_topics(committee, "committee")
donor = attach_topk_topics(donor, "donor")
lobby = attach_topk_topics(lobby, "lobby_firm")
leg_term = attach_topk_topics(leg_term, "legislator_term")

In [128]:
def edge_df(et):
    ei = data[et].edge_index.numpy()
    ea = data[et].edge_attr.numpy() if data[et].edge_attr is not None else None
    if ea is None or ea.shape[1] == 0:
        return pd.DataFrame(columns=["src","dst","amount"])
    amt = ea[:,0]
    return pd.DataFrame({"src": ei[0], "dst": ei[1], "amount": amt})

don_edge = edge_df(("donor","donated_to","legislator_term"))
don_edge["donor_id"] = don_edge.src
don_edge["legterm_id"] = don_edge.dst
don_edge["type"] = "donor"

lob_edge = edge_df(("lobby_firm","lobbied","legislator_term"))
lob_edge["lobby_id"] = lob_edge.src
lob_edge["legterm_id"] = lob_edge.dst
lob_edge["type"] = "lobby"

fund_edge = pd.concat([don_edge, lob_edge], ignore_index=True)
term_ts = data["legislator_term"].time.numpy()
leg_term["session_start"] = term_ts[leg_term.node_id].astype(int)

leg_map = leg_term[["node_id","top_topic","session_start"]].rename(
           columns={"node_id":"legterm_id"})

fund_edge = fund_edge.merge(leg_map, on="legterm_id", how="left")

In [138]:
bill['year'] = bill.intro_date.dt.year
bill['session'] = bill.apply(lambda x: (x.year - 1 if x.year % 2 == 0 else x.year) if pd.notnull(x.year) else np.nan, axis=1).astype('Int64')

In [142]:
policy_weekly = (
    bill.groupby(["dominant_topic","session"])
        .agg(n_bills=("bill_id","size"),
             avg_polar=("polarisation_score","mean"))
        .reset_index()
        .rename(columns={"dominant_topic":"topic"})
)

fund_weekly = (fund_edge.groupby(["top_topic","session_start"])
                        .agg(total_funding=("amount","sum"))
                        .reset_index()
                        .rename(columns={"top_topic":"topic", 'session_start':"session"}))

policy_weekly = policy_weekly.merge(fund_weekly, on=["topic","session"], how="left")
policy_weekly["total_funding"]   = policy_weekly.total_funding.fillna(0.0)
policy_weekly["fund_volatility"] = (policy_weekly.sort_values("session")
                                      .groupby("topic")["total_funding"]
                                      .transform(lambda s: s.rolling(4,min_periods=2).std()))

policy_weekly.to_parquet(SRC/"policy_session.parquet", index=False)

In [145]:
topic_base = (bill.groupby("dominant_topic")
                 .agg(n_bills=("bill_id","size"),
                      avg_success=("success_risk","mean"),
                      avg_polar=("polarisation_score","mean"),
                      avg_velocity=("bill_velocity_days","mean"),
                      avg_sponsor_power=("sponsor_power","mean"))
                 .reset_index()
                 .rename(columns={"dominant_topic":"topic"}))

In [146]:
topic_funding = (fund_edge.groupby("top_topic")
                           .agg(total_dollars=("amount","sum"))
                           .rename(columns={"top_topic":"topic"})
                           .reset_index())

In [147]:
def slope(y):
    if len(y)<2: return 0.0
    x = np.arange(len(y)); return np.polyfit(x,y,1)[0]

pol_slope = (policy_weekly.groupby("topic")["avg_polar"]
                        .apply(slope).reset_index()
                        .rename(columns={"avg_polar":"polarisation_slope"}))

In [148]:
actor_infl = pd.concat([
    donor[["top_topic","influence"]],
    lobby[["top_topic","influence"]],
    committee[["top_topic","influence"]],
    leg_term[["top_topic","influence"]]],
    ignore_index=True).dropna()

power_conc = (actor_infl.groupby("top_topic")
                        .agg(power_concentration=("influence", herfindahl))
                        .rename(columns={"top_topic":"topic"})
                        .reset_index())

In [153]:
bip_gap = (bill.assign(bipart = lambda d: d.polarisation_score<0.25)
               .groupby("dominant_topic")["bipart"]
               .agg(lambda s: s.mean()*2 -1)
               .reset_index()
               .rename(columns={"dominant_topic":"topic",
                                "bipart":"bipartisan_gap"}))
topic_funding.rename(columns={"top_topic":"topic"}, inplace=True)
power_conc.rename(columns={"top_topic":"topic"}, inplace=True)
topic_summary = (topic_base
                 .merge(topic_funding, on="topic", how="left")
                 .merge(power_conc,   on="topic", how="left")
                 .merge(bip_gap,      on="topic", how="left")
                 .merge(pol_slope,    on="topic", how="left")
                 .merge(topic_snapshot[["topic_id","recent_momentum","power_balance"]],
                        left_on="topic", right_on="topic_id", how="left")
                 .drop(columns="topic_id")
                 .fillna({"total_dollars":0.0}))

topic_summary.to_parquet(SRC/"topic_summary.parquet", index=False)

In [155]:
leg_ids = {v: k for k, v in node_id_map['legislator_term'].items()}

In [156]:
legislators = pickle.load(open('../../../legislators.pkl', 'rb'))

In [157]:
def leg_term_to_name(leg_term_id):
    if isinstance(leg_term_id, str):
        num = int(leg_term_id.split('_')[0])
        return legislators.get(num, None)
    else:
        return None

In [158]:
def legislator_node_matching(node_id):
    n = node_id_map['legislator'].get(str(node_id), None)
    if n is not None:
        name = legislators.get(n, None)
        if name is not None:
            return name
    return None

In [162]:
leg_term['name'] = leg_term['node_id'].map(leg_ids).apply(leg_term_to_name)

In [163]:
leg_term.to_parquet(SRC/"legislator_term.parquet", index=False)
donor.to_parquet(SRC/"donor.parquet", index=False)
lobby.to_parquet(SRC/"lobby_firm.parquet", index=False)
committee.to_parquet(SRC/"committee.parquet", index=False)
bill.to_parquet(SRC/"bill.parquet", index=False)

In [None]:
src, dst = data[("bill","sponsored_by","legislator_term")].edge_index.numpy()
s_tbl = (pd.DataFrame({"bill_id":src,"node_id":dst})
            .merge(leg_term[["node_id","name"]], on="node_id"))
