In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv
from torch_geometric.data import HeteroData
import numpy as np
from pathlib import Path

# ==========================================
# 1. CORRECTED GNN MODEL
# ==========================================
class SpeechHeteroGNN(nn.Module):
    def __init__(self, global_idx, hidden_dim=64, num_heads=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # --- A. Embeddings for ID-based nodes ---
        # We add +1 to max index to be safe, or len(idx) if indices are 0..N-1
        num_authors = len(global_idx["author2idx"]) + 1
        num_topics = len(global_idx["topic2idx"]) + 1
        num_days = len(global_idx["date2idx"]) + 1

        self.emb_dict = nn.ModuleDict({
            "author": nn.Embedding(num_authors, hidden_dim),
            "topic":  nn.Embedding(num_topics, hidden_dim),
            "day":    nn.Embedding(num_days, hidden_dim),
        })

        # --- B. Projection for Speech Nodes ---
        # Input: 768 (BERT) + 1 (Topic Score) + 2 (Lag, Decay) = 771
        self.speech_lin = nn.Linear(768 + 3, hidden_dim)

        # --- C. HGT Convolution ---
        # Metadata must match the edge types in the graph exactly
        self.hgt = HGTConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            metadata=(
                ["author", "speech", "topic", "day"],
                [
                    ("author", "gives", "speech"),
                    ("speech", "rev_gives", "author"),
                    ("speech", "mentions", "topic"),
                    ("topic", "rev_mentions", "speech"),
                    ("day", "references", "speech"),
                    ("speech", "rev_references", "day"),
                ]
            ),
            heads=num_heads
        )

    def compute_edge_aggregates(self, data, device):
        """
        Manually aggregates edge attributes into speech nodes because 
        standard HGTConv does not use edge_attr.
        """
        num_speech = data["speech"].x.size(0)
        # [mean_topic_score, mean_lag, mean_decay]
        extras = torch.zeros(num_speech, 3, device=device)
        counts = torch.zeros(num_speech, 3, device=device)

        # 1. Topic Scores (speech -> mentions -> topic)
        if ("speech", "mentions", "topic") in data.edge_types:
            store = data["speech", "mentions", "topic"]
            if hasattr(store, "edge_attr") and store.edge_attr is not None:
                # edge_attr is [score]
                src_idx = store.edge_index[0]
                scores = store.edge_attr.view(-1).to(device)
                
                # Accumulate
                extras[:, 0].index_add_(0, src_idx, scores)
                counts[:, 0].index_add_(0, src_idx, torch.ones_like(scores))

        # 2. Time Lags (day -> references -> speech)
        if ("day", "references", "speech") in data.edge_types:
            store = data["day", "references", "speech"]
            if hasattr(store, "edge_attr") and store.edge_attr is not None:
                # edge_attr is [lag, decay]
                dst_idx = store.edge_index[1] # speech is destination here
                attrs = store.edge_attr.to(device)
                
                # Accumulate
                extras[:, 1:].index_add_(0, dst_idx, attrs)
                counts[:, 1:].index_add_(0, dst_idx, torch.ones_like(attrs))

        # Average
        counts = torch.clamp(counts, min=1.0)
        return extras / counts

    def forward(self, data):
        x_dict = data.x_dict
        edge_index_dict = data.edge_index_dict
        device = data["speech"].x.device
        
        # --- Step 1: Embeddings for ID nodes ---
        out_dict = {}
        for ntype in ["author", "topic", "day"]:
            # Flatten/squeeze input indices: [N, 1] -> [N]
            idx_tensor = x_dict[ntype].long().view(-1)
            out_dict[ntype] = self.emb_dict[ntype](idx_tensor)

        # --- Step 2: Speech Features (Emb + Edges) ---
        speech_base = x_dict["speech"] # [N, 768]
        edge_stats = self.compute_edge_aggregates(data, device) # [N, 3]
        
        # Concatenate and Project
        speech_full = torch.cat([speech_base, edge_stats], dim=-1)
        out_dict["speech"] = self.speech_lin(speech_full)

        # --- Step 3: Graph Convolution ---
        out_dict = self.hgt(out_dict, edge_index_dict)

        return out_dict

# ==========================================
# 2. TEMPORAL PREDICTOR WRAPPER
# ==========================================
class TemporalPredictor(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=64):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
    
    def forward(self, seq_embeddings):
        # seq_embeddings: [Batch=1, Seq_Len, Hidden_Dim]
        out, _ = self.gru(seq_embeddings)
        return self.fc(out[:, -1])  # Predict from last step

class FedSpeechModel(nn.Module):
    def __init__(self, global_idx, hidden_dim=64):
        super().__init__()
        self.gnn = SpeechHeteroGNN(global_idx, hidden_dim)
        self.temporal = TemporalPredictor(hidden_dim, hidden_dim)
    
    def forward(self, graph_seq):
        day_embs = []
        for g in graph_seq:
            # Pass the whole graph object 'g' to access edge_attrs
            x_dict = self.gnn(g)
            
            # We use the 'day' node embedding as the graph representation
            z = x_dict["day"] # [1, hidden_dim]
            day_embs.append(z)
        
        # Stack time steps: [1, Seq_Len, Hidden_Dim]
        day_embs = torch.stack(day_embs, dim=1)
        return self.temporal(day_embs)

# ==========================================
# 3. CORRECTED GRAPH BUILDER FUNCTION
# ==========================================
def build_graph_for_date_corrected(
    d, speeches, speeches_with_embeddings, topic_scores, rates_df,
    speeches_by_date, global_idx, lookback_days, target_column
):
    """
    Re-implementation of your notebook function, but ensures
    ID-based nodes (Author, Topic, Day) use torch.long for Embeddings.
    """
    from datetime import timedelta
    
    date2idx = global_idx["date2idx"]
    author2idx = global_idx["author2idx"]
    topic2idx = global_idx["topic2idx"]

    # --- Helper to get window ---
    start_date = d - timedelta(days=lookback_days - 1)
    speech_ids_window = []
    cur = start_date
    while cur <= d:
        if cur in speeches_by_date:
            speech_ids_window.extend(speeches_by_date[cur])
        cur += timedelta(days=1)
    
    # Local mappings
    local_sids = sorted(set(speech_ids_window))
    sid2i = {sid: i for i, sid in enumerate(local_sids)}
    
    # Identify active Authors/Topics
    active_authors = sorted({speeches[sid]["author"] for sid in local_sids})
    active_topics = set()
    for sid in local_sids:
        active_topics |= set(topic_scores[sid].keys())
    active_topics = sorted(active_topics)
    
    topic_name2i = {t: i for i, t in enumerate(active_topics)}
    author_name2i = {a: i for i, a in enumerate(active_authors)}

    data = HeteroData()

    # --- NODE FEATURES (CORRECTED DTYPES) ---
    
    # 1. Author (Long)
    author_ids = [author2idx.get(a, 0) for a in active_authors]
    data["author"].x = torch.tensor(author_ids, dtype=torch.long).view(-1, 1)

    # 2. Topic (Long)
    topic_ids = [topic2idx.get(t, 0) for t in active_topics]
    data["topic"].x = torch.tensor(topic_ids, dtype=torch.long).view(-1, 1)

    # 3. Day (Long)
    day_id = date2idx.get(d, 0)
    data["day"].x = torch.tensor([[day_id]], dtype=torch.long)

    # 4. Speech (Float - Embeddings)
    speech_feats = []
    for sid in local_sids:
        speech_feats.append(speeches_with_embeddings[sid]["embedding"])
    
    if not speech_feats: # Handle empty graph case if necessary
        return None
        
    data["speech"].x = torch.tensor(speech_feats, dtype=torch.float32)

    # --- EDGE INDICES & ATTRIBUTES (Preserved from your notebook) ---
    
    # Author -> Speech
    a_src, s_dst = [], []
    for sid in local_sids:
        a_name = speeches[sid]["author"]
        a_src.append(author_name2i[a_name])
        s_dst.append(sid2i[sid])
    
    data["author", "gives", "speech"].edge_index = torch.tensor([a_src, s_dst], dtype=torch.long)
    data["speech", "rev_gives", "author"].edge_index = torch.tensor([s_dst, a_src], dtype=torch.long)

    # Speech -> Topic (with Scores)
    s_src, t_dst, scores = [], [], []
    for sid in local_sids:
        for tname, sc in topic_scores[sid].items():
            s_src.append(sid2i[sid])
            t_dst.append(topic_name2i[tname])
            scores.append([float(sc)])
            
    data["speech", "mentions", "topic"].edge_index = torch.tensor([s_src, t_dst], dtype=torch.long)
    data["speech", "mentions", "topic"].edge_attr = torch.tensor(scores, dtype=torch.float32)
    # Reverse edge (no attr needed for HGT usually, but good for connectivity)
    data["topic", "rev_mentions", "speech"].edge_index = torch.tensor([t_dst, s_src], dtype=torch.long)

    # Day -> Speech (with Lag/Decay)
    d_src, s_dst_day, lags = [], [], []
    for sid in local_sids:
        lag = (d - speeches[sid]["date"]).days
        decay = np.exp(-lag / 10.0)
        d_src.append(0)
        s_dst_day.append(sid2i[sid])
        lags.append([lag, decay])

    data["day", "references", "speech"].edge_index = torch.tensor([d_src, s_dst_day], dtype=torch.long)
    data["day", "references", "speech"].edge_attr = torch.tensor(lags, dtype=torch.float32)
    data["speech", "rev_references", "day"].edge_index = torch.tensor([s_dst_day, d_src], dtype=torch.long)

    # Target
    y = float(rates_df.loc[d, target_column])
    data.y = torch.tensor([y], dtype=torch.float32)
    
    return data

# ============================================================================
# 0. SAFE LOADING FOR PyTorch 2.6+
# ============================================================================

torch.serialization.add_safe_globals([HeteroData])


# ============================================================================
# 1. LOAD GRAPHS
# ============================================================================

def load_graphs(graph_dir):
    graph_dir = Path(graph_dir)
    graphs = []
    for f in sorted(graph_dir.glob("graph_*.pt")):
        g = torch.load(f, weights_only=False)
        graphs.append(g)
    return graphs

# ==========================================
# 4. MAIN EXECUTION (Mock)
# ==========================================
if __name__ == "__main__":
    # ASSUMPTION: You have loaded 'speeches', 'rates_df', etc. using analysis_utils
    # and generated 'global_idx' as per your notebook.
    
    # 1. Build Graphs (using corrected function)
    # graphs = [build_graph_for_date_corrected(d, ...) for d in dates]
    
    # 2. Initialize Model
    # We need global_idx from your notebook
    # mock_global_idx provided for instantiation example:
    mock_global_idx = {
        "author2idx": {"A":0, "B":1}, 
        "topic2idx": {"Eco":0}, 
        "date2idx": {0:0, 1:1}
    }
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = FedSpeechModel(global_idx=mock_global_idx, hidden_dim=64).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    print("Model initialized on:", device)
    
    # 3. Training Loop
    # Assuming 'graphs' is a list of HeteroData objects
    # L = Sequence Length for LSTM/GRU
    L = 30 
    EPOCHS = 50
    
    # Placeholder for graphs list to prevent NameError in this script
    #raphs = [] 
    graphs = load_graphs("graphs_ffr_delta")


    if len(graphs) > L:
        model.train()
        for epoch in range(1, EPOCHS + 1):
            total_loss = 0.0
            count = 0
            
            # Simple sliding window training
            for t in range(L, len(graphs)):
                # Get sequence of L graphs
                seq_graphs = graphs[t-L : t]
                
                # Move to device
                seq_graphs = [g.to(device) for g in seq_graphs]
                target = graphs[t].y.to(device) # .view(1) if needed

                # Forward
                pred = model(seq_graphs) # Output shape [1]
                
                loss = loss_fn(pred.view(-1), target.view(-1))
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                count += 1
            
            avg_loss = total_loss / count if count > 0 else 0
            print(f"Epoch {epoch:03d} | MSE Loss: {avg_loss:.6f}")

Model initialized on: cpu


IndexError: index out of range in self