In [2]:
import torch
from pathlib import Path
from torch_geometric.data import HeteroData

#######################################################
# 1. Load graphs
#######################################################


# Allow PyTorch to unpickle HeteroData
torch.serialization.add_safe_globals([HeteroData])

def load_graphs(graph_dir):
    graph_dir = Path(graph_dir)
    graphs = []
    for f in sorted(graph_dir.glob("graph_*.pt")):
        g = torch.load(f, weights_only=False)  # full load
        graphs.append(g)
    return graphs

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
graphs = load_graphs(Path("graphs_ffr_delta"))
print(len(graphs))

647


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.serialization
from pathlib import Path

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool
from torch_geometric.loader import DataLoader


# ----------------------------------------------------------------------
# SAFE LOADING (PyTorch 2.6+)
# ----------------------------------------------------------------------
torch.serialization.add_safe_globals([HeteroData])


# ----------------------------------------------------------------------
# LOAD GRAPHS
# ----------------------------------------------------------------------
def load_graphs(graph_dir):
    graph_dir = Path(graph_dir)
    graphs = []
    for f in sorted(graph_dir.glob("graph_*.pt")):
        g = torch.load(f, weights_only=False)
        graphs.append(g)
    return graphs


# ----------------------------------------------------------------------
# BASELINE HeteroGraphSAGE (yesterday's version)
# ----------------------------------------------------------------------
class HeteroGraphSAGERegressor(nn.Module):
    def __init__(self, metadata, example_graph, hidden_dim=64):
        super().__init__()

        node_types, edge_types = metadata

        # Node feature encoders (per node type)
        self.node_lin = nn.ModuleDict()
        for nt in node_types:
            in_dim = example_graph[nt].x.size(-1)
            self.node_lin[nt] = nn.Linear(in_dim, hidden_dim)

        # Define SAGEConv for each relation (NO EDGE ATTRS)
        convs = {}
        for et in edge_types:
            convs[et] = SAGEConv(
                in_channels=(hidden_dim, hidden_dim),
                out_channels=hidden_dim
            )

        self.convs = HeteroConv(convs, aggr="sum")

        # Regression head
        self.mlp_out = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, data):
        # Encode node features
        x_dict = {nt: self.node_lin[nt](data[nt].x)
                  for nt in data.node_types}

        # Message pass (NO edge_attr)
        x_dict = self.convs(x_dict, data.edge_index_dict)

        # Day node is graph-level representation
        day_x = x_dict["day"]
        if hasattr(data["day"], "batch"):
            batch = data["day"].batch
        else:
            batch = torch.zeros(day_x.size(0), dtype=torch.long, device=day_x.device)

        pooled = global_mean_pool(day_x, batch)
        return self.mlp_out(pooled).view(-1)


# ----------------------------------------------------------------------
# TRAINING UTILITIES
# ----------------------------------------------------------------------
def train_epoch(model, loader, optimizer, device):
    model.train()
    total = 0
    count = 0

    for batch in loader:
        batch = batch.to(device)
        y = batch.y.view(-1).to(device)

        optimizer.zero_grad()
        pred = model(batch)
        loss = F.mse_loss(pred, y)
        loss.backward()
        optimizer.step()

        total += loss.item() * batch.num_graphs
        count += batch.num_graphs

    return total / count


@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    total = 0
    count = 0

    for batch in loader:
        batch = batch.to(device)
        y = batch.y.view(-1).to(device)

        pred = model(batch)
        loss = F.mse_loss(pred, y)

        total += loss.item() * batch.num_graphs
        count += batch.num_graphs

    return total / count


# ----------------------------------------------------------------------
# MAIN TRAINER (yesterday's stable version)
# ----------------------------------------------------------------------
def train_gnn_baseline(
    graphs,
    hidden_dim=64,
    batch_size=16,
    epochs=50,
    train_frac=0.8,
):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Chronological split
    n = len(graphs)
    n_train = int(train_frac * n)
    train_graphs = graphs[:n_train]
    val_graphs = graphs[n_train:]

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader  = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

    example = graphs[0]
    metadata = (example.node_types, example.edge_types)

    model = HeteroGraphSAGERegressor(metadata, example, hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    print("Training baseline GNN on", device)
    print("Num graphs:", n)
    print("Node types:", example.node_types)
    print("Edge types:", example.edge_types)

    best_val = float("inf")
    best_state = None

    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_loss = eval_epoch(model, val_loader, device)

        print(f"Epoch {epoch:03d} | Train MSE={train_loss:.4f} | Val MSE={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            best_state = model.state_dict()

    if best_state:
        model.load_state_dict(best_state)

    print("Best Val MSE:", best_val)
    return model


# ----------------------------------------------------------------------
# ENTRYPOINT
# ----------------------------------------------------------------------
if __name__ == "__main__":
    graphs = load_graphs("graphs_ffr_delta")

    model = train_gnn_baseline(
        graphs,
        hidden_dim=64,
        batch_size=8,
        epochs=50,
        train_frac=0.8,
    )


Training baseline GNN on cpu
Num graphs: 647
Node types: ['author', 'speech', 'topic', 'day']
Edge types: [('author', 'gives', 'speech'), ('speech', 'mentions', 'topic'), ('day', 'references', 'speech'), ('speech', 'rev_gives', 'author'), ('topic', 'rev_mentions', 'speech'), ('speech', 'rev_references', 'day')]
Epoch 001 | Train MSE=17.7629 | Val MSE=0.0523
Epoch 002 | Train MSE=0.0105 | Val MSE=0.0031
Epoch 003 | Train MSE=0.0098 | Val MSE=0.0168
Epoch 004 | Train MSE=0.0164 | Val MSE=0.0184
Epoch 005 | Train MSE=0.0058 | Val MSE=0.0062
Epoch 006 | Train MSE=0.0044 | Val MSE=0.0029
Epoch 007 | Train MSE=0.0051 | Val MSE=0.0132
Epoch 008 | Train MSE=0.0059 | Val MSE=0.0030
Epoch 009 | Train MSE=0.0036 | Val MSE=0.0027
Epoch 010 | Train MSE=0.1581 | Val MSE=0.0915
Epoch 011 | Train MSE=0.0335 | Val MSE=0.0680
Epoch 012 | Train MSE=0.0120 | Val MSE=0.0048
Epoch 013 | Train MSE=0.0780 | Val MSE=0.0043
Epoch 014 | Train MSE=0.0331 | Val MSE=0.1917
Epoch 015 | Train MSE=0.0104 | Val MSE=0.0