In [1]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.loader import RandomNodeLoader
from LeGNN_output import model, load_and_preprocess_data

In [2]:
device = torch.device('mps')
data = load_and_preprocess_data(path="data2.pt")
loader = RandomNodeLoader(data, num_parts=5)

In [4]:
emb_buf = {nt: [] for nt in data.node_types}
topic_buf = {}
topic_scores = []
infl_buf = {}
controversy_buf = []

In [5]:
for batch in loader:
    batch = batch.to(device)
    with torch.no_grad():
        out = model(batch)
    for nt, z in out['node_embeddings'].items():
        emb_buf[nt].append(z.cpu().numpy())
    for nt, z in out['topic_dists'].items():
        if nt not in topic_buf:
            topic_buf[nt] = []
        topic_buf[nt].append(z.cpu().numpy())
    topic_scores.append(out['flat_topic_scores'].cpu().numpy())
    for nt, z in out['influence_scores'].items():
        if nt not in infl_buf:
            infl_buf[nt] = []
        infl_buf[nt].append(z.cpu().numpy())
    controversy_buf.append(out['controversy_pred'].cpu().numpy())

In [None]:
topic_scores

[array([[[-1.16113579e+00, -1.13318288e+00, -6.32049024e-01,
          -1.93492389e+00, -7.21119702e-01, -6.67694271e-01,
          -1.36913502e+00, -9.00331140e-01, -9.03632641e-01,
          -7.92881668e-01, -1.49271083e+00, -2.03171277e+00,
          -2.09852839e+00, -1.24164796e+00, -1.14825773e+00,
          -1.72784984e+00, -1.57075572e+00,  3.25068772e-01,
          -1.75707030e+00, -8.20628047e-01, -1.72449791e+00,
          -1.14499331e+00, -2.41973710e+00, -1.33374286e+00,
          -2.73111057e+00, -6.84924066e-01, -1.41038048e+00,
          -1.36757684e+00, -1.48583555e+00,  9.25570393e+00,
          -1.62852120e+00, -1.91472733e+00, -1.56874347e+00,
          -2.07205415e+00, -1.20465195e+00, -1.53928232e+00,
          -1.61521876e+00, -2.87198126e-02,  8.45560879e-02,
          -1.59583247e+00, -1.89680004e+00, -1.28264606e+00,
          -1.20861709e+00,  7.23410249e-02, -8.06509078e-01,
          -1.36090553e+00, -1.13307691e+00, -1.91737726e-01,
          -1.40019083e+0

: 

In [None]:
# Concatenate
for nt in emb_buf:
    full = np.vstack(emb_buf[nt])
    pd.DataFrame(full).to_parquet(f"GNN/{nt}_embeddings.parquet", index=True)
tops = []
for nt in topic_buf:
    full = np.vstack(topic_buf[nt])
    tops.append(full)
pd.DataFrame(np.vstack(tops)).to_parquet("GNN/topic_distributions.parquet", index=True)
pd.DataFrame(np.vstack(topic_scores)).to_parquet("GNN/bill_topic_scores.parquet", index=True)
for nt in infl_buf:
    full = np.vstack(infl_buf[nt])
    pd.DataFrame(full).to_parquet(f"GNN/{nt}_influence_scores.parquet", index=True)
pd.DataFrame(np.vstack(controversy_buf), columns=["controversy"]).to_parquet("GNN/bill_version_controversy.parquet", index=True)