In [4]:
import torch.nn as nn
from torch_geometric.nn import HGTConv

SPEECH_EMB_DIM = 768


class SpeechHeteroGNN(nn.Module):
    def __init__(self, hidden_dim=64, num_heads=2):
        super().__init__()
        self.hidden_dim = hidden_dim

        # --- MLP for speech embeddings ---
        self.speech_mlp = nn.Sequential(
            nn.Linear(SPEECH_EMB_DIM, 256),
            nn.ReLU(),
            nn.Linear(256, hidden_dim),
            nn.ReLU()
        )

        # --- Linear projections for other node types ---
        self.node_lin = nn.ModuleDict({
            "author": nn.Linear(1, hidden_dim),
            "topic": nn.Linear(1, hidden_dim),
            "day": nn.Linear(1, hidden_dim),
        })

        # --- HGTConv ---
        self.hgt = HGTConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            metadata=(
                ["author", "speech", "topic", "day"],
                [
                    ("author", "gives", "speech"),
                    ("speech", "rev_gives", "author"),
                    ("speech", "mentions", "topic"),
                    ("topic", "rev_mentions", "speech"),
                    ("day", "references", "speech"),
                    ("speech", "rev_references", "day"),
                ]
            ),
            heads=num_heads
        )

    def forward(self, x_dict, edge_index_dict):

        # ---- Apply projection per node type ----
        out_dict = {}

        for ntype, x in x_dict.items():
            if ntype == "speech":
                out_dict[ntype] = self.speech_mlp(x.float())  # <-- IMPORTANT FIX
            else:
                out_dict[ntype] = self.node_lin[ntype](x.float())

        # ---- HGT message passing ----
        out_dict = self.hgt(out_dict, edge_index_dict)

        return out_dict


class TemporalPredictor(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=64):
        super().__init__()
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, seq_embeddings):
        out, _ = self.gru(seq_embeddings)
        return self.fc(out[:, -1])  # last timestep


class FedSpeechModel(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.gnn = SpeechHeteroGNN(hidden_dim)
        self.temporal = TemporalPredictor(hidden_dim, hidden_dim)

    def forward(self, graph_seq):
        day_embs = []

        for g in graph_seq:
            x_dict = self.gnn(g.x_dict, g.edge_index_dict)
            z = x_dict["day"]  # shape [1, hidden_dim]
            day_embs.append(z)

        day_embs = torch.stack(day_embs, dim=1)  # [1, seq_len, hidden_dim]
        return self.temporal(day_embs)

In [5]:


import torch
from torch.optim import Adam
import torch.nn as nn
import analysis_utils

from pathlib import Path
from torch_geometric.data import HeteroData

model = FedSpeechModel(hidden_dim=64)
optimizer = Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

torch.serialization.add_safe_globals([HeteroData])

def load_graphs(graph_dir):
    graph_dir = Path(graph_dir)
    graphs = []
    for f in sorted(graph_dir.glob("graph_*.pt")):
        g = torch.load(f, weights_only=False)  # full load
        graphs.append(g)
    return graphs

graphs = load_graphs("graphs_ffr_delta") 
L = 30  # sequence length
EPOCHS = 40  # number of epochs

for epoch in range(1, EPOCHS + 1):
    total_loss = 0.0
    count = 0

    for t in range(L, len(graphs)):
        seq = graphs[t - L:t]
        target = graphs[t].y.float()

        pred = model(seq)
        loss = loss_fn(pred.squeeze(), target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1

    print(f"Epoch {epoch}/{EPOCHS}  MSE={total_loss / count:.6f}")


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/40  MSE=0.005178
Epoch 2/40  MSE=0.003572
Epoch 3/40  MSE=0.003483
Epoch 4/40  MSE=0.003532
Epoch 5/40  MSE=0.003548
Epoch 6/40  MSE=0.003547
Epoch 7/40  MSE=0.003507
Epoch 8/40  MSE=0.003522
Epoch 9/40  MSE=0.003526
Epoch 10/40  MSE=0.003525
Epoch 11/40  MSE=0.003518
Epoch 12/40  MSE=0.003573
Epoch 13/40  MSE=0.003585
Epoch 14/40  MSE=0.003539
Epoch 15/40  MSE=0.003519
Epoch 16/40  MSE=0.003520
Epoch 17/40  MSE=0.003512
Epoch 18/40  MSE=0.003514
Epoch 19/40  MSE=0.003522
Epoch 20/40  MSE=0.003531
Epoch 21/40  MSE=0.003510
Epoch 22/40  MSE=0.003517
Epoch 23/40  MSE=0.003519
Epoch 24/40  MSE=0.003517
Epoch 25/40  MSE=0.003504
Epoch 26/40  MSE=0.003511
Epoch 27/40  MSE=0.003506
Epoch 28/40  MSE=0.003493
Epoch 29/40  MSE=0.003509
Epoch 30/40  MSE=0.003511
Epoch 31/40  MSE=0.003525
Epoch 32/40  MSE=0.003504
Epoch 33/40  MSE=0.003511
Epoch 34/40  MSE=0.003595
Epoch 35/40  MSE=0.003550
Epoch 36/40  MSE=0.003513
Epoch 37/40  MSE=0.003716
Epoch 38/40  MSE=0.003494
Epoch 39/40  MSE=0.00