In [1]:
import math, gc, warnings, json, datetime, torch
from collections import defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
import polars as pl
from torch.nn import functional as F
from torch_geometric.transforms import ToUndirected, RemoveIsolatedNodes

In [2]:
def herfindahl(shares):
    return float(np.square(shares).sum())

def entropy(shares, eps=1e-9):
    shares = shares.clip(min=eps)
    return float(-(shares * np.log(shares)).sum())

def jsd(p, q, eps=1e-9):
    m = 0.5 * (p + q)
    return 0.5 * (
        F.kl_div(m.log(), p, reduction='none').sum(-1) +
        F.kl_div(m.log(), q, reduction='none').sum(-1))

def slope(y):
    if y.size < 2:
        return 0.0
    x = np.arange(y.size, dtype=np.float32)
    return float(np.polyfit(x, y, 1)[0])

def safe_ratio(n, d):
    return 0.0 if d == 0 else n / d

In [3]:
def compute_controversiality(data):
    edge_type = ('legislator_term', 'voted_on', 'bill_version')
    if edge_type not in data.edge_index_dict:
        raise ValueError("Missing 'voted_on' edges in data.")

    ei = data[edge_type].edge_index
    ea = data[edge_type].edge_attr

    vote_signal = ea[:, 0]

    src_nodes = ei[0]
    tgt_nodes = ei[1]

    num_bills = data['bill_version'].num_nodes
    device = tgt_nodes.device

    yes_votes = torch.zeros(num_bills, device=device)
    no_votes = torch.zeros(num_bills, device=device)

    yes_votes.index_add_(0, tgt_nodes, (vote_signal > 0).float())
    no_votes.index_add_(0, tgt_nodes, (vote_signal < 0).float())

    total_votes = yes_votes + no_votes + 1e-6

    yes_ratio = yes_votes / total_votes
    no_ratio = no_votes / total_votes

    controversy = 4 * yes_ratio * no_ratio
    controversy = controversy.clamp(0, 1)
    data['bill_version'].controversy = controversy

    return data

def safe_normalize_timestamps(timestamps: torch.Tensor) -> torch.Tensor:
    timestamps = torch.nan_to_num(timestamps, nan=0.0, posinf=1e4, neginf=-1e4)
    min_time = timestamps.min()
    max_time = timestamps.max()
    if (max_time - min_time) < 1e-4:
        return torch.zeros_like(timestamps)
    return (timestamps - min_time) / (max_time - min_time)

def safe_standardize_time_format(time_data) -> torch.Tensor:
    times = []
    for t in time_data:
        try:
            if isinstance(t, (int, float)) and 1900 <= t  and t <= 2100:
                td = datetime.datetime(int(t), 6, 15).timestamp()
            elif (isinstance(t, str) or (isinstance(t, float))) and (float(t) < 2100 and float(t) > 1900):
                td = datetime.datetime(int(float(t)), 6, 15).timestamp()
            elif float(t) > 0 and float(t) < 1990:
                td = t
            elif float(t) > 17000000.0:
                td = float(t)
                while td > 10:
                    td = td / 10
            elif float(t) < 0:
                td = -float(t)
            else:
                td = t.timestamp()
        except:
            td = datetime.datetime(2000, 6, 15).timestamp()
        times.append(td)
    return torch.tensor(times, dtype=torch.float32)

def pull_timestamps(data):
    timestamp_edges = [
        ('donor', 'donated_to', 'legislator_term'),
        ('legislator_term', 'rev_donated_to', 'donor'),
        ('lobby_firm', 'lobbied', 'legislator_term'),
        ('lobby_firm', 'lobbied', 'committee'),
        ('committee', 'rev_lobbied', 'lobby_firm'),
        ('legislator_term', 'rev_lobbied', 'lobby_firm'),
        ('bill_version', 'rev_voted_on', 'legislator_term'),
        ('legislator_term', 'voted_on', 'bill_version'),
    ]
    timestamp_nodes = ['legislator_term', 'bill_version', 'bill', 'committee']

    for et in timestamp_edges:
        if hasattr(data[et], 'edge_attr') and data[et].edge_attr is not None and len(data[et].edge_attr.size()) > 1:
            if data[et].edge_attr.size(1) > 1:
                edge_attr = data[et].edge_attr
                ts_col = edge_attr[:, -1]
                if ts_col.abs().max() > 1e8 or ts_col.min() < 0:
                    ts_col = safe_standardize_time_format(ts_col.tolist()).to(edge_attr.device)
                data[et].timestamp = safe_normalize_timestamps(ts_col)
                data[et].edge_attr = edge_attr[:, :-1]

    for nt in timestamp_nodes:
        if hasattr(data[nt], 'x') and data[nt].x is not None:
            try:
                if len(data[nt].x.size()) > 1:
                    if data[nt].x.size(1) > 1:
                        x = data[nt].x
                        ts_col = x[:, -1]
                        if ts_col.abs().max() > 1e8 or ts_col.min() < 0:
                            ts_col = safe_standardize_time_format(ts_col.tolist()).to(x.device)
                        if nt in timestamp_nodes or ts_col.abs().max() > 1e6:
                            data[nt].timestamp = safe_normalize_timestamps(ts_col)
                            data[nt].x = x[:, :-1]
            except:
                pass
    return data
def clean_features(data):
    for nt in data.node_types:
        x = data[nt].x
        if not isinstance(x, torch.Tensor) or x.numel() == 0:
            data[nt].x = torch.from_numpy(np.vstack(x)).float()
            x = data[nt].x
        x = torch.nan_to_num(x.float(), nan=0.0, posinf=1e4, neginf=-1e4)
        if x.size(0) < 2 or torch.all(x == x[0]):
            mean = x.clone()
            std = torch.ones_like(x)
            x_clean = x.clone()
        else:
            mean = x.mean(dim=0, keepdim=True)
            std = x.std(dim=0, keepdim=True).clamp(min=1e-5)
            x_clean = (x - mean) / std
            x_clean = x_clean.clamp(-10.0, 10.0)
        data[nt].x = x_clean
        data[nt].x_mean = mean
        data[nt].x_std = std
    data = pull_timestamps(data)
    return data

def load_and_preprocess_data(path='../../../GNN/data2.pt'):
    full_data = torch.load(path, weights_only=False)
    for nt in full_data.node_types:
        if hasattr(full_data[nt], 'x') and full_data[nt].x is not None:
            full = torch.from_numpy(full_data[nt].x)
            s = full.size()
            full = torch.flatten(full, start_dim=1, end_dim=-1)
            full_data[nt].x = full
            full_data[nt].num_nodes = full.size(0)

    # Check and fix edge indices before transformation
    for edge_type, edge_index in full_data.edge_index_dict.items():
        src_type, _, dst_type = edge_type

        # Get max node indices
        max_src_idx = edge_index[0].max().item() if edge_index.size(1) > 0 else -1
        max_dst_idx = edge_index[1].max().item() if edge_index.size(1) > 0 else -1

        # Ensure node counts are sufficient
        if max_src_idx >= full_data[src_type].num_nodes:
            print(f"Fixing {src_type} node count: {full_data[src_type].num_nodes} -> {max_src_idx + 1}")
            full_data[src_type].num_nodes = max_src_idx + 1

        if max_dst_idx >= full_data[dst_type].num_nodes:
            print(f"Fixing {dst_type} node count: {full_data[dst_type].num_nodes} -> {max_dst_idx + 1}")
            full_data[dst_type].num_nodes = max_dst_idx + 1

    data = ToUndirected(merge=False)(full_data)
    del full_data
    gc.collect()
    data = RemoveIsolatedNodes()(data)
    data = compute_controversiality(clean_features(data))
    for store in data.stores:
        for key, value in store.items():
            if isinstance(value, torch.Tensor) and value.dtype == torch.float64:
                store[key] = value.float()
    return data

data = load_and_preprocess_data()

Fixing bill node count: 13164 -> 45350
Fixing legislator node count: 478 -> 508


In [None]:
back_data_dir = "../data"
bills_kpis = pd.read_parquet(f"{back_data_dir}/bill_kpis.parquet")
legislator_kpis = pd.read_parquet(f"{back_data_dir}/legislator_kpis.parquet")
committee_kpis = pd.read_parquet(f"{back_data_dir}/committee_kpis.parquet")
donor_kpis = pd.read_parquet(f"{back_data_dir}/donor_kpis.parquet")
lobby_firm_kpis = pd.read_parquet(f"{back_data_dir}/lobby_firm_kpis.parquet")
topic_snapshot = pd.read_parquet(f"{back_data_dir}/topic_snapshot.parquet")
policy_embs = torch.load("../../../GNN/policy_embeddings.pt", map_location='cpu', weights_only=False)
T = policy_embs.size(0)

In [5]:
v2b = data[('bill_version','is_version','bill')].edge_index.numpy()
bv_ids, bill_ids = v2b[0], v2b[1]

cont_df = pl.DataFrame({
        'bill_id': bill_ids,
        'controversy': data['bill_version'].controversy[bv_ids].cpu().numpy()
        }).group_by('bill_id').agg(
    pl.col('controversy').mean().alias('avg_controversy')
)

In [7]:
# amendment count
pv = data[('bill_version','priorVersion','bill_version')].edge_index[0].cpu().numpy()
amend_df = pl.DataFrame({'bv': pv}).with_columns(
	pl.col('bv').map_elements(lambda x: v2b[1][list(v2b[0]).index(x)] if x in v2b[0] else None, return_dtype=int).alias('bill_id')
).drop_nulls().group_by('bill_id').agg(pl.count().alias('amendment_count'))

  ).drop_nulls().group_by('bill_id').agg(pl.count().alias('amendment_count'))


In [None]:
# vote count
vote_idx = data[('legislator_term','voted_on','bill_version')].edge_index[1].cpu().numpy()
vote_df = pl.DataFrame({'bv': vote_idx}).with_columns(
	pl.col('bv').map_elements(lambda x: v2b[1][list(v2b[0]).index(x)] if x in v2b[0] else None, return_dtype=int).alias('bill_id')
).drop_nulls().groupby('bill_id').agg(pl.count().alias('vote_count'))
# avg vote margin
vote_attr = data[('legislator_term','voted_on','bill_version')].edge_attr.cpu().numpy()
margins = vote_attr[:, 0]
vote_margin_df = pl.DataFrame({
	'bv': vote_idx,
	'margin': margins
}).with_columns(
	pl.col('bv').map_elements(lambda x: v2b[1][list(v2b[0]).index(x)] if x in v2b[0] else None, return_dtype=int).alias('bill_id')
).drop_nulls().groupby('bill_id').agg(pl.col('margin').mean().alias('avg_vote_margin'))
# join all
bill_kpis = (cont_df
	.join(amend_df, on='bill_id', how='outer')
	.join(vote_df, on='bill_id', how='outer')
	.join(vote_margin_df, on='bill_id', how='outer')
	.fill_null(0)
)

In [None]:
def _safe_ratio(n, d):
    return 0.0 if d == 0 else n / d

lt2leg = dict(zip(*data[('legislator','samePerson','legislator_term')].edge_index.cpu().numpy()))
# donation events
d_ei = data[('donor','donated_to','legislator_term')].edge_index.cpu().numpy()
d_attr = data[('donor','donated_to','legislator_term')].edge_attr.cpu().numpy()
don_df = pl.DataFrame({
    'lt': d_ei[1],
    'amt': d_attr[:,0],
    'ts': d_attr[:,1]
}).with_columns(
    pl.col('lt').apply(lambda x: lt2leg.get(int(x))).alias('leg_id'),
    pl.col('ts').cast(pl.Datetime(time_unit='s')).dt.year().alias('year')
).drop_nulls()
don_tot = don_df.groupby(['leg_id','year']).agg(pl.col('amt').sum().alias('donation_total'))
don_cnt = don_df.groupby(['leg_id','year']).agg(pl.count().alias('donation_count'))
# lobbying events
l_ei = data[('lobby_firm','lobbied','legislator_term')].edge_index.cpu().numpy()
l_attr = data[('lobby_firm','lobbied','legislator_term')].edge_attr.cpu().numpy()
lobby_df = pl.DataFrame({
    'lf': l_ei[0],
    'lt': l_ei[1],
    'amt': l_attr[:,0],
    'ts': l_attr[:,1]
}).with_columns(
    pl.col('lt').apply(lambda x: lt2leg.get(int(x))).alias('leg_id'),
    pl.col('ts').cast(pl.Datetime(time_unit='s')).dt.year().alias('year')
).drop_nulls()
lobby_tot = lobby_df.groupby(['leg_id','year']).agg(pl.col('amt').sum().alias('lobby_total'))
lobby_cnt = lobby_df.groupby(['leg_id','year']).agg(pl.count().alias('lobby_count'))
# vote events
v_ei = data[('legislator_term','voted_on','bill_version')].edge_index.cpu().numpy()
v_attr = data[('legislator_term','voted_on','bill_version')].edge_attr.cpu().numpy()
vote_df = pl.DataFrame({
    'lt': v_ei[0],
    'ts': v_attr[:,1]
}).with_columns(
    pl.col('lt').apply(lambda x: lt2leg.get(int(x))).alias('leg_id'),
    pl.col('ts').cast(pl.Datetime(time_unit='s')).dt.year().alias('year')
).drop_nulls()
vote_cnt = vote_df.groupby(['leg_id','year']).agg(pl.count().alias('votes_cast'))
# committee membership
m_ei = data[('legislator_term','member_of','committee')].edge_index.cpu().numpy()
m_attr = data[('legislator_term','member_of','committee')].edge_attr.cpu().numpy()
mem_df = pl.DataFrame({
    'lt': m_ei[0],
    'ts': m_attr[:,1]
}).with_columns(
    pl.col('lt').apply(lambda x: lt2leg.get(int(x))).alias('leg_id'),
    pl.col('ts').cast(pl.Datetime(time_unit='s')).dt.year().alias('year')
).drop_nulls()
mem_cnt = mem_df.groupby(['leg_id','year']).agg(pl.count().alias('committee_memberships'))
# authored bills
a_ei = data[('legislator_term','wrote','bill_version')].edge_index.cpu().numpy()
a_attr = data[('legislator_term','wrote','bill_version')].edge_attr.cpu().numpy()
auth_df = pl.DataFrame({
    'lt': a_ei[0],
    'ts': a_attr[:,1]
}).with_columns(
    pl.col('lt').apply(lambda x: lt2leg.get(int(x))).alias('leg_id'),
    pl.col('ts').cast(pl.Datetime(time_unit='s')).dt.year().alias('year')
).drop_nulls()
auth_cnt = auth_df.groupby(['leg_id','year']).agg(pl.count().alias('bills_authored'))
# combine all per-leg, per-year
leg_kpis = (don_tot
    .join(don_cnt, on=['leg_id','year'], how='outer')
    .join(lobby_tot, on=['leg_id','year'], how='outer')
    .join(lobby_cnt, on=['leg_id','year'], how='outer')
    .join(vote_cnt, on=['leg_id','year'], how='outer')
    .join(mem_cnt, on=['leg_id','year'], how='outer')
    .join(auth_cnt, on=['leg_id','year'], how='outer')
    .fill_null(0)
    .with_columns((pl.col('donation_total')/pl.col('donation_count')).apply(lambda x: _safe_ratio(x,1.0)).alias('avg_donation'))
    .with_columns((pl.col('lobby_total')/pl.col('lobby_count')).apply(lambda x: _safe_ratio(x,1.0)).alias('avg_lobby'))
)

In [None]:
pv_ei = data[('bill_version', 'priorVersion', 'bill_version')].edge_index.numpy()

am_cnt, am_gap = defaultdict(int), defaultdict(float)
gap_tmp  = defaultdict(list)
for src_bv, dst_bv in zip(*pv_ei):
    b = v2b.get(int(src_bv))
    if b is None:
        continue
    am_cnt[b] += 1
    ts = data['bill_version'].timestamp[int(src_bv)].item()
    gap_tmp[b].append(float(ts))

for b, times in gap_tmp.items():
    times = np.sort(times)
    am_gap[b] = np.diff(times).mean()/86400 if len(times) > 1 else np.nan

bill_kpis['amendment_avg_gap_days'] = bill_kpis['bill_id'].map(am_gap)

In [None]:
topic_series = bill_kpis.set_index('bill_id')['dominant_topic'].to_dict()
am_topic_ent = {
    b: entropy(np.bincount([topic_series.get(b, 0)], minlength=T) / 1)
    for b in bill_kpis['bill_id']
}
bill_kpis['amendment_topic_entropy'] = bill_kpis['bill_id'].map(am_topic_ent)

In [13]:
lt2comm = defaultdict(set)
for lt,c in zip(*data[('legislator_term','member_of','committee')].edge_index.numpy()):
    lt2comm[int(lt)].add(int(c))
leg_inf = dict(zip(legislator_kpis["node_id"],legislator_kpis["influence"]))
comm_lever=defaultdict(float)
for lt,comms in lt2comm.items():
    inf=leg_inf.get(lt2leg.get(lt),0)
    for c in comms: comm_lever[c]+=inf
committee_kpis["member_leverage_sum"]=committee_kpis["node_id"].map(comm_lever).fillna(0)

In [None]:
avg_pol = bill_kpis.groupby("dominant_topic")["polarisation_score"].mean().to_dict()
topic_snapshot["avg_polarisation"]=topic_snapshot["topic_id"].map(avg_pol)

In [16]:
bill_kpis.to_parquet(f"{back_data_dir}/bill_kpis.parquet",compression='zstd',index=False)
legislator_kpis.to_parquet(f"{back_data_dir}/legislator_kpis.parquet",compression='zstd',index=False)
committee_kpis.to_parquet(f"{back_data_dir}/committee_kpis.parquet",compression='zstd',index=False)
donor_kpis.to_parquet(f"{back_data_dir}/donor_kpis.parquet",compression='zstd',index=False)
lobby_firm_kpis.to_parquet(f"{back_data_dir}/lobby_firm_kpis.parquet",compression='zstd',index=False)
topic_snapshot.to_parquet(f"{back_data_dir}/topic_snapshot.parquet",compression='zstd',index=False)