In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.serialization
from pathlib import Path

from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HGTConv


# ============================================================================
# 0. SAFE LOADING FOR PyTorch 2.6+
# ============================================================================

torch.serialization.add_safe_globals([HeteroData])


# ============================================================================
# 1. LOAD GRAPHS
# ============================================================================

def load_graphs(graph_dir):
    graph_dir = Path(graph_dir)
    graphs = []
    for f in sorted(graph_dir.glob("graph_*.pt")):
        g = torch.load(f, weights_only=False)
        graphs.append(g)
    return graphs



# ============================================================================
# 2. MODEL COMPONENT: SpeechHeteroGNN (HGTConv GNN)
# ============================================================================

SPEECH_EMB_DIM = 768   # your embedding dimension


class SpeechHeteroGNN(nn.Module):
    def __init__(self, hidden_dim=64, num_heads=2):
        super().__init__()

        # --- Speech embedding projection ---
        self.speech_mlp = nn.Sequential(
            nn.Linear(SPEECH_EMB_DIM, 256),
            nn.ReLU(),
            nn.Linear(256, hidden_dim),
            nn.ReLU()
        )

        # --- Linear projections for 1-dim node types ---
        self.node_lin = nn.ModuleDict({
            "author": nn.Linear(1, hidden_dim),
            "topic": nn.Linear(1, hidden_dim),
            "day": nn.Linear(1, hidden_dim),
        })

        # --- HGTConv layer (heterogeneous transformer GNN) ---
        self.hgt = HGTConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            metadata=(
                ["author", "speech", "topic", "day"],
                [
                    ("author", "gives", "speech"),
                    ("speech", "rev_gives", "author"),
                    ("speech", "mentions", "topic"),
                    ("topic", "rev_mentions", "speech"),
                    ("day", "references", "speech"),
                    ("speech", "rev_references", "day"),
                ]
            ),
            heads=num_heads
        )

    def forward(self, g: HeteroData):
        # Project each node type into hidden_dim
        x_dict = {}
        for ntype, x in g.x_dict.items():
            if ntype == "speech":
                x_dict[ntype] = self.speech_mlp(x.float())
            else:
                x_dict[ntype] = self.node_lin[ntype](x.float())

        # HGTConv message passing
        out_dict = self.hgt(x_dict, g.edge_index_dict)
        return out_dict



# ============================================================================
# 3. FINAL MODEL: Predict from "day" node embedding
# ============================================================================

class FedSpeechModel(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.gnn = SpeechHeteroGNN(hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)  # predict Rate_Change directly

    def forward(self, g: HeteroData):
        x_dict = self.gnn(g)
        day_emb = x_dict["day"]            # shape [1, hidden_dim]
        return self.fc(day_emb).squeeze()   # scalar prediction



# ============================================================================
# 4. TRAINING UTILITIES
# ============================================================================

def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    count = 0

    for g in loader:
        g = g.to(device)
        target = g.y.float().to(device)

        pred = model(g)
        loss = F.mse_loss(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1

    return total_loss / count


@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    total_loss = 0.0
    count = 0

    for g in loader:
        g = g.to(device)
        target = g.y.float().to(device)

        pred = model(g)
        loss = F.mse_loss(pred, target)

        total_loss += loss.item()
        count += 1

    return total_loss / count



# ============================================================================
# 5. MAIN TRAINING FUNCTION
# ============================================================================

def train_fedspeech_hgt(
    graphs,
    hidden_dim=64,
    batch_size=8,
    epochs=40,
    train_frac=0.8,
):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Remove graphs with 0 speech nodes (empty windows)
    graphs = [g for g in graphs if g["speech"].x.size(0) > 0]

    # Chronological split
    n = len(graphs)
    n_train = int(train_frac * n)
    train_graphs = graphs[:n_train]
    val_graphs = graphs[n_train:]

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader  = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

    # Build model
    example = graphs[0]
    print("Speech embedding dimension:", example["speech"].x.shape[-1])

    model = FedSpeechModel(hidden_dim=hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    print(f"Training on {device}")
    print(f"Num graphs = {n}, Train = {len(train_graphs)}, Val = {len(val_graphs)}\n")

    best_val = float("inf")
    best_state = None

    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_loss = eval_epoch(model, val_loader, device)

        print(f"Epoch {epoch:03d} | Train MSE={train_loss:.4f} | Val MSE={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            best_state = model.state_dict()

    # Load best checkpoint
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"\nBest Val MSE = {best_val:.4f}")

    return model



# ============================================================================
# 6. SCRPIT ENTRYPOINT
# ============================================================================

if __name__ == "__main__":
    graphs = load_graphs("graphs_ffr_delta")

    model = train_fedspeech_hgt(
        graphs,
        hidden_dim=64,
        batch_size=8,
        epochs=40,
        train_frac=0.8,
    )


  from .autonotebook import tqdm as notebook_tqdm


Speech embedding dimension: 768
Training on cpu
Num graphs = 647, Train = 517, Val = 130

Epoch 001 | Train MSE=14.2220 | Val MSE=0.1198
Epoch 002 | Train MSE=0.0235 | Val MSE=0.0598
Epoch 003 | Train MSE=0.0092 | Val MSE=0.0435
Epoch 004 | Train MSE=0.0079 | Val MSE=0.0056
Epoch 005 | Train MSE=0.0039 | Val MSE=0.0047
Epoch 006 | Train MSE=0.0029 | Val MSE=0.0033
Epoch 007 | Train MSE=0.0024 | Val MSE=0.0052
Epoch 008 | Train MSE=0.0026 | Val MSE=0.0032
Epoch 009 | Train MSE=0.0026 | Val MSE=0.0048
Epoch 010 | Train MSE=0.0025 | Val MSE=0.0022
Epoch 011 | Train MSE=0.0026 | Val MSE=0.0028
Epoch 012 | Train MSE=0.0029 | Val MSE=0.0025
Epoch 013 | Train MSE=0.0023 | Val MSE=0.0025
Epoch 014 | Train MSE=0.0026 | Val MSE=0.0045
Epoch 015 | Train MSE=0.0030 | Val MSE=0.0174
Epoch 016 | Train MSE=0.0026 | Val MSE=0.0069
Epoch 017 | Train MSE=0.0022 | Val MSE=0.0052
Epoch 018 | Train MSE=0.0030 | Val MSE=0.0079
Epoch 019 | Train MSE=0.0032 | Val MSE=0.0041
Epoch 020 | Train MSE=0.0048 | Val 