In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.serialization
from pathlib import Path

from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool

# ----------------------------------------------------------------------
# 0. SAFE LOADING FOR PyTorch 2.6+
# ----------------------------------------------------------------------
torch.serialization.add_safe_globals([HeteroData])


# ----------------------------------------------------------------------
# 1. LOAD GRAPHS
# ----------------------------------------------------------------------
def load_graphs(graph_dir):
    graph_dir = Path(graph_dir)
    graphs = []
    for f in sorted(graph_dir.glob("graph_*.pt")):
        g = torch.load(f, weights_only=False)
        graphs.append(g)
    return graphs


# ----------------------------------------------------------------------
# 2. AGGREGATE EDGE ATTRIBUTES â†’ NODE FEATURES (STABLE)
# ----------------------------------------------------------------------
def compute_speech_edge_features(data: HeteroData, device=None):
    """
    Returns a tensor [num_speech, 3] with:
       [ mean_topic_score, mean_lag, mean_decay ]
    computed from:
       - ('speech','mentions','topic').edge_attr = [[score]]
       - ('day','references','speech').edge_attr = [[lag, decay]]
    """
    if device is None:
        device = data["speech"].x.device

    num_speech = data["speech"].x.size(0)
    if num_speech == 0:
        return torch.zeros((0, 3), device=device)

    # --- 1) mean topic score per speech (from speech -> topic edges) ---
    score_mean = torch.zeros(num_speech, device=device)
    score_count = torch.zeros(num_speech, device=device)

    if ("speech", "mentions", "topic") in data.edge_types:
        store = data["speech", "mentions", "topic"]
        if hasattr(store, "edge_attr") and store.edge_attr is not None:
            edge_index = store.edge_index  # [2, E]
            scores = store.edge_attr.view(-1).to(device)  # [E]
            src = edge_index[0]  # speech indices

            score_mean.index_add_(0, src, scores)
            score_count.index_add_(0, src, torch.ones_like(scores))
            score_count = torch.clamp(score_count, min=1.0)
            score_mean = score_mean / score_count

    score_mean = score_mean.unsqueeze(-1)  # [num_speech, 1]

    # --- 2) mean lag & decay per speech (from day -> speech edges) ---
    lag_decay_mean = torch.zeros(num_speech, 2, device=device)
    lag_decay_count = torch.zeros(num_speech, 1, device=device)

    if ("day", "references", "speech") in data.edge_types:
        store = data["day", "references", "speech"]
        if hasattr(store, "edge_attr") and store.edge_attr is not None:
            edge_index = store.edge_index  # [2, E]
            attrs = store.edge_attr.to(device)  # [E, 2]
            dst = edge_index[1]  # speech indices

            lag_decay_mean.index_add_(0, dst, attrs)
            lag_decay_count.index_add_(0, dst, torch.ones(attrs.size(0), 1, device=device))
            lag_decay_count = torch.clamp(lag_decay_count, min=1.0)
            lag_decay_mean = lag_decay_mean / lag_decay_count

    # Concatenate: [score_mean, lag, decay]
    speech_extra = torch.cat([score_mean, lag_decay_mean], dim=-1)  # [num_speech, 3]
    return speech_extra


# ----------------------------------------------------------------------
# 3. GraphSAGE model with edge-aggregate features
# ----------------------------------------------------------------------
class GraphSAGEEdgeAggRegressor(nn.Module):
    def __init__(self, example_graph: HeteroData, hidden_dim=64):
        super().__init__()

        self.node_types = example_graph.node_types
        self.edge_types = example_graph.edge_types

        # Extra dims per node type (only speech gets extra features)
        self.extra_dims = {
            "speech": 3,    # mean_topic_score, mean_lag, mean_decay
            "author": 0,
            "topic": 0,
            "day": 0,
        }

        # 1) Node input projections: [x || extra] -> hidden_dim
        self.node_lin = nn.ModuleDict()
        for ntype in self.node_types:
            base_dim = example_graph[ntype].x.size(-1)
            extra_dim = self.extra_dims.get(ntype, 0)
            in_dim = base_dim + extra_dim
            self.node_lin[ntype] = nn.Linear(in_dim, hidden_dim)

        # 2) HeteroConv with SAGEConv (no edge_attr inside conv)
        convs = {}
        for etype in self.edge_types:
            convs[etype] = SAGEConv(
                in_channels=(hidden_dim, hidden_dim),
                out_channels=hidden_dim,
            )
        self.convs = HeteroConv(convs, aggr="sum")

        # 3) Output MLP (graph-level prediction from "day" node)
        self.mlp_out = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data: HeteroData):
        device = data["day"].x.device

        # --- compute extra node features from edges ---
        speech_extra = compute_speech_edge_features(data, device=device)  # [Ns, 3] or [0,3]

        # --- prepare per-type features ---
        x_dict = {}
        for ntype in self.node_types:
            base_x = data[ntype].x.to(device)
            if ntype == "speech":
                if base_x.size(0) != speech_extra.size(0):
                    # safety check; should not happen
                    extra = torch.zeros(base_x.size(0), self.extra_dims["speech"], device=device)
                else:
                    extra = speech_extra
                x_in = torch.cat([base_x, extra], dim=-1)
            else:
                x_in = base_x  # no extra features for other node types

            x_dict[ntype] = self.node_lin[ntype](x_in)

        # --- hetero GraphSAGE ---
        x_dict = self.convs(x_dict, data.edge_index_dict)

        # --- use "day" node embedding as graph representation ---
        day_x = x_dict["day"]
        if hasattr(data["day"], "batch"):
            day_batch = data["day"].batch
        else:
            day_batch = torch.zeros(day_x.size(0), dtype=torch.long, device=device)

        pooled = global_mean_pool(day_x, day_batch)  # [num_graphs, hidden_dim]
        out = self.mlp_out(pooled).view(-1)          # [num_graphs]
        return out


# ----------------------------------------------------------------------
# 4. TRAIN / EVAL LOOPS
# ----------------------------------------------------------------------
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    n = 0

    for batch in loader:
        batch = batch.to(device)
        target = batch.y.view(-1).to(device)

        optimizer.zero_grad()
        pred = model(batch)
        loss = F.mse_loss(pred, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs
        n += batch.num_graphs

    return total_loss / n


@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    total_loss = 0.0
    n = 0

    for batch in loader:
        batch = batch.to(device)
        target = batch.y.view(-1).to(device)

        pred = model(batch)
        loss = F.mse_loss(pred, target)

        total_loss += loss.item() * batch.num_graphs
        n += batch.num_graphs

    return total_loss / n


# ----------------------------------------------------------------------
# 5. MAIN TRAINING FUNCTION
# ----------------------------------------------------------------------
def train_graphsage_edgeagg(
    graphs,
    hidden_dim=64,
    batch_size=8,
    epochs=50,
    train_frac=0.8,
):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Filter out graphs with no speech nodes (no information)
    graphs = [g for g in graphs if g["speech"].x.size(0) > 0]

    # Chronological split (graphs already sorted by date)
    n = len(graphs)
    n_train = int(train_frac * n)
    train_graphs = graphs[:n_train]
    val_graphs = graphs[n_train:]

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

    example = graphs[0]
    model = GraphSAGEEdgeAggRegressor(example, hidden_dim=hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print("Training on:", device)
    print("Num graphs:", n, "| Train:", len(train_graphs), "| Val:", len(val_graphs))
    print("Node types:", example.node_types)
    print("Edge types:", example.edge_types)

    best_val = float("inf")
    best_state = None

    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_loss = eval_epoch(model, val_loader, device)
        print(f"Epoch {epoch:03d} | Train MSE={train_loss:.4f} | Val MSE={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            best_state = model.state_dict()

    if best_state is not None:
        model.load_state_dict(best_state)
        print("Best Val MSE:", best_val)

    return model


# ----------------------------------------------------------------------
# 6. ENTRYPOINT
# ----------------------------------------------------------------------
if __name__ == "__main__":
    graphs = load_graphs("graphs_ffr_delta")

    model = train_graphsage_edgeagg(
        graphs,
        hidden_dim=64,
        batch_size=8,
        epochs=50,
        train_frac=0.8,
    )


Training on: cpu
Num graphs: 647 | Train: 517 | Val: 130
Node types: ['author', 'speech', 'topic', 'day']
Edge types: [('author', 'gives', 'speech'), ('speech', 'mentions', 'topic'), ('day', 'references', 'speech'), ('speech', 'rev_gives', 'author'), ('topic', 'rev_mentions', 'speech'), ('speech', 'rev_references', 'day')]
Epoch 001 | Train MSE=15.4588 | Val MSE=0.3891
Epoch 002 | Train MSE=0.0334 | Val MSE=0.0024
Epoch 003 | Train MSE=0.0039 | Val MSE=0.0167
Epoch 004 | Train MSE=0.0033 | Val MSE=0.0033
Epoch 005 | Train MSE=0.0031 | Val MSE=0.0140
Epoch 006 | Train MSE=0.0031 | Val MSE=0.0039
Epoch 007 | Train MSE=0.0031 | Val MSE=0.0073
Epoch 008 | Train MSE=0.0040 | Val MSE=0.0062
Epoch 009 | Train MSE=0.0033 | Val MSE=0.0125
Epoch 010 | Train MSE=0.0042 | Val MSE=0.0055
Epoch 011 | Train MSE=0.0031 | Val MSE=0.0140
Epoch 012 | Train MSE=0.0033 | Val MSE=0.0062
Epoch 013 | Train MSE=0.0028 | Val MSE=0.0024
Epoch 014 | Train MSE=0.0029 | Val MSE=0.0062
Epoch 015 | Train MSE=0.0031 |