In [70]:
from pathlib import Path
from datetime import timedelta

import numpy as np
import torch
from torch_geometric.data import HeteroData
import analysis_utils
import importlib
importlib.reload(analysis_utils)

LOOKBACK_DAYS = 30 
TARGET_COLUMN = "Rate_Change"  


In [71]:

def build_global_indices(speeches, topic_scores, rates_df):
    
    # 1) Authors
    author_names = sorted({v["author"] for v in speeches.values()})
    author2idx = {name: i for i, name in enumerate(author_names)}

    # 2) Topics
    topic_names = set()
    for sid, topics in topic_scores.items():
        for tname in topics.keys():
            topic_names.add(tname)
    topic_names = sorted(topic_names)
    topic2idx = {name: i for i, name in enumerate(topic_names)}

    # 3) Speech ids
    speech_ids = sorted(speeches.keys())
    speech2idx = {sid: i for i, sid in enumerate(speech_ids)}

    # 4) Dates
    all_dates = sorted(set(rates_df.index))  # dates where we have rates
    date2idx = {d: i for i, d in enumerate(all_dates)}

    return {
        "author2idx": author2idx,
        "topic2idx": topic2idx,
        "speech2idx": speech2idx,
        "date2idx": date2idx,
        "dates": all_dates,
    }

In [72]:

def get_speeches_in_window(target_date, lookback_days, speeches_by_date):
    """
    Returns a list of speech IDs whose date is in [target_date - lookback_days + 1, target_date].
    """
    start_date = target_date - timedelta(days=lookback_days - 1)
    cur_date = start_date
    selected = []

    while cur_date <= target_date:
        if cur_date in speeches_by_date:
            selected.extend(speeches_by_date[cur_date])
        cur_date += timedelta(days=1)

    return selected



In [73]:

def build_all_graphs(
    speeches,
    speeches_with_embeddings,
    topic_scores,
    rates_df,
    out_dir="graphs",
    lookback_days=LOOKBACK_DAYS,
    target_column=TARGET_COLUMN
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    global_idx = build_global_indices(speeches, topic_scores, rates_df)
    speeches_by_date = analysis_utils.group_speeches_by_date(speeches)
    print(len(speeches), len(speeches_by_date))
    graphs = []
    dates = global_idx["dates"]
    print(len(dates))
    for d in dates:
        g = build_graph_for_date(
            d,
            speeches,
            speeches_with_embeddings,
            topic_scores,
            rates_df,
            speeches_by_date,
            global_idx,
            lookback_days=lookback_days,
            target_column=target_column,
        )
        if g is None:
            continue
        graphs.append(g)

    for i, g in enumerate(graphs):
        torch.save(g, out_dir / f"graph_{i:04d}.pt")

    print(f"Built {len(graphs)} graphs and saved to {out_dir}")
    return graphs


In [86]:
# build_graphs.py (part 5)

def build_graph_for_date(
    d,
    speeches,
    speeches_with_embeddings,
    topic_scores,
    rates_df,
    speeches_by_date,
    global_idx,
    lookback_days=LOOKBACK_DAYS,
    target_column=TARGET_COLUMN
):
    """
    Build HeteroData graph snapshot for date d.
    Speech nodes store embeddings.
    Topic-speech edges store topic score features.
    """

    date2idx = global_idx["date2idx"]
    author2idx = global_idx["author2idx"]
    topic2idx = global_idx["topic2idx"]

    speech_ids_window = get_speeches_in_window(d, lookback_days, speeches_by_date)
    
    
    print(len(speech_ids_window))
    
    # --- Build local indices ---
    local_speech_ids = sorted(set(speech_ids_window))
    speech_i2sid= {i: sid for i, sid in enumerate(local_speech_ids)}
    speech_sid2i = {sid: i for i, sid in enumerate(local_speech_ids)}

    # --- Local authors ---
    author_names_window = sorted({speeches[sid]["author"] for sid in local_speech_ids})
    author_i2name = {i: name for i, name in enumerate(author_names_window)}
    author_name2i = {name: i for i, name in author_i2name.items()}
    
    topic_names_window = set()
    for sid in local_speech_ids:
        topic_names_window |= set(topic_scores[sid].keys())

    topic_names_window = sorted(topic_names_window)
    topic_i2name = {i: name for i, name in enumerate(topic_names_window)}
    topic_name2i = {name: i for i, name in topic_i2name.items()}
    
    num_authors = len(author_i2name)
    num_speeches = len(speech_i2sid)
    num_topics = len(topic_i2name)

    data = HeteroData()
    
    # --- AUTHOR NODES ---
    author_feats = []
    for i in range(num_authors):
        name = author_i2name[i]
        author_feats.append([author2idx.get(name, -1)])
    data["author"].x = torch.tensor(author_feats, dtype=torch.float32)
   
    # --- SPEECH NODES: ACTUAL EMBEDDINGS ---
    speech_features = []
    for i in range(num_speeches):
        
        sid = speech_i2sid[i]

        if "embedding" not in speeches_with_embeddings[sid]:
            raise ValueError(f"Speech {sid} has no embedding: add speeches[sid]['embedding']")

        emb = speeches_with_embeddings[sid]["embedding"]
        speech_features.append(emb)

    data["speech"].x = torch.tensor(speech_features, dtype=torch.float32)
   
    # ================================
    # 3) TOPIC NODES
    # ================================
    topic_feats = []
    for i in range(num_topics):
        name = topic_i2name[i]
        topic_feats.append([topic2idx.get(name, -1)])
    data["topic"].x = torch.tensor(topic_feats, dtype=torch.float32)

    # ================================
    # 4) DAY NODE
    # ================================
    # today_rate = rates_df.loc[d, target_column]
    today_idx = date2idx[d]
    data["day"].x = torch.tensor([[today_idx]], dtype=torch.float32)
   
     # ---- AUTHOR -> SPEECH ----
    author_src, speech_dst = [], []
    for sid in local_speech_ids:
        a_name = speeches[sid]["author"]
        a_i = author_name2i[a_name]
        s_i = speech_sid2i[sid]

        author_src.append(a_i)
        speech_dst.append(s_i)

    data["author", "gives", "speech"].edge_index = torch.tensor(
        [author_src, speech_dst], dtype=torch.long
    )
  
    # ---- SPEECH -> TOPIC (edge_attr = topic score) ----
    
    st_src, st_dst, st_attr = [], [], []
    
    for sid in local_speech_ids:
        
        s_i = speech_sid2i[sid]

        for tname, score in topic_scores[sid].items():
            t_i = topic_name2i[tname]

            st_src.append(s_i)
            st_dst.append(t_i)
            st_attr.append([float(score)])  # score â†’ edge feature
    
    data["speech", "mentions", "topic"].edge_index = torch.tensor(
            [st_src, st_dst], dtype=torch.long
        )
    data["speech", "mentions", "topic"].edge_attr = torch.tensor(
            st_attr, dtype=torch.float32
        )
  
    # ---- DAY -> SPEECH (recency edges) ---- 
    day_src, day_dst, day_attr = [], [], []
    for sid in local_speech_ids:
        
        s_i = speech_sid2i[sid]
        sdate = speeches[sid]["date"]
        lag = (d - sdate).days
        decay = np.exp(-lag / 10.0)

        day_src.append(0)  # only one day node
        day_dst.append(s_i)
        day_attr.append([lag, decay])

    # reverse edge 
    # ????
    data["day", "references", "speech"].edge_index = torch.tensor(
        [day_src, day_dst], dtype=torch.long
    )
    data["day", "references", "speech"].edge_attr = torch.tensor(day_attr, dtype=torch.float32)
    
    data["speech", "rev_gives", "author"].edge_index = torch.tensor(
        [speech_dst, author_src], dtype=torch.long
    )

    data["topic", "rev_mentions", "speech"].edge_index = torch.tensor(
        [st_dst, st_src], dtype=torch.long
    )

    data["speech", "rev_references", "day"].edge_index = torch.tensor(
        [day_dst, day_src], dtype=torch.long
    )

    # author - topic edge is missing now 
    print(num_authors, num_speeches, num_topics)
    
    # ================================
    # 7) TARGET for prediction
    # ================================
    y = float(rates_df.loc[d, target_column])
    data.y = torch.tensor([y], dtype=torch.float32)

    data.date = torch.tensor([today_idx], dtype=torch.long)
    return data

In [87]:

speeches = analysis_utils.load_speeches()
topic_scores = analysis_utils.load_topic_scores_by_sid()
rates_df = analysis_utils.load_rates()
speeches_with_embeddings = analysis_utils.load_speeches_with_embeddings()

graphs = build_all_graphs(
        speeches,
        speeches_with_embeddings,
        topic_scores,
        rates_df,
        out_dir="graphs_ffr_delta",   # change as you like
        lookback_days=LOOKBACK_DAYS,
        target_column=TARGET_COLUMN,
    )


977 647
647
1
1 1 6
2
2 2 6
3
3 3 6
5
5 5 6
5
4 5 6
5
3 5 6
3
2 3 6
2
2 2 6
3
3 3 6
3
3 3 6
4
3 4 6
5
4 5 6
6
4 6 6
5
3 5 6
6
4 6 6
8
4 8 6
9
4 9 6
12
6 12 6
13
7 13 6
13
7 13 6
12
7 12 6
13
7 13 6
14
7 14 6
15
8 15 6
17
9 17 6
18
9 18 6
10
8 10 6
11
9 11 6
12
9 12 6
11
9 11 6
7
6 7 6
8
7 8 6
9
8 9 6
10
8 10 6
12
8 12 6
13
8 13 6
14
8 14 6
15
8 15 6
2
2 2 6
3
3 3 6
5
5 5 6
6
6 6 6
7
7 7 6
9
8 9 6
11
9 11 6
9
7 9 6
10
8 10 6
12
8 12 6
14
8 14 6
15
8 15 6
14
7 14 6
18
9 18 6
19
9 19 6
23
9 23 6
25
10 25 6
25
10 25 6
24
10 24 6
23
10 23 6
24
10 24 6
24
10 24 6
14
9 14 6
16
10 16 6
19
11 19 6
20
11 20 6
18
10 18 6
17
10 17 6
15
9 15 6
15
8 15 6
17
8 17 6
18
8 18 6
20
9 20 6
13
9 13 6
12
9 12 6
14
10 14 6
15
11 15 6
15
12 15 6
16
11 16 6
14
10 14 6
15
10 15 6
16
10 16 6
15
8 15 6
18
8 18 6
19
9 19 6
20
9 20 6
22
10 22 6
23
10 23 6
22
9 22 6
24
9 24 6
24
9 24 6
23
8 23 6
17
7 17 6
12
8 12 6
13
9 13 6
16
10 16 6
17
10 17 6
15
10 15 6
16
11 16 6
12
10 12 6
13
10 13 6
13
9 13 6
16
10 16 6
17
10

In [94]:
import torch.nn as nn
from torch_geometric.nn import HGTConv

SPEECH_EMB_DIM = 768
class SpeechHeteroGNN(nn.Module):
    def __init__(self, num_authors, num_topics, num_days, hidden_dim=64, num_heads=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # --- MLP for speech embeddings (this part is fine) ---
        self.speech_mlp = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, hidden_dim),
            nn.ReLU()
        )

        # --- CORRECT: Embedding layers for categorical nodes ---
        # Note: We use embedding dimension = hidden_dim to match HGT input requirements
        self.author_emb = nn.Embedding(num_authors + 1, hidden_dim)
        self.topic_emb = nn.Embedding(num_topics + 1, hidden_dim)
        self.day_emb = nn.Embedding(num_days + 1, hidden_dim)
        
        # --- HGTConv ---
        self.hgt = HGTConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            metadata=(
                ["author", "speech", "topic", "day"],
                [
                    ("author", "gives", "speech"),
                    ("speech", "rev_gives", "author"),
                    ("speech", "mentions", "topic"),
                    ("topic", "rev_mentions", "speech"),
                    ("day", "references", "speech"),
                    ("speech", "rev_references", "day"),
                ]
            ),
            heads=num_heads
        )

    def forward(self, x_dict, edge_index_dict):
        out_dict = {}

        # Process each node type appropriately
        for ntype, x in x_dict.items():
            if ntype == "speech":
                # Speech features are already floats (embeddings), use MLP
                out_dict[ntype] = self.speech_mlp(x.float())
            elif ntype == "author":
                # x is [num_nodes, 1], we need [num_nodes] for Embedding
                out_dict[ntype] = self.author_emb(x.long().squeeze())
            elif ntype == "topic":
                out_dict[ntype] = self.topic_emb(x.long().squeeze())
            elif ntype == "day":
                out_dict[ntype] = self.day_emb(x.long().squeeze())

        out_dict = self.hgt(out_dict, edge_index_dict)
        return out_dict


In [89]:
class TemporalPredictor(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=64):
        super().__init__()
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
    
    def forward(self, seq_embeddings):
        out, _ = self.gru(seq_embeddings)
        return self.fc(out[:, -1])  # last timestep


In [90]:
class FedSpeechModel(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.gnn = SpeechHeteroGNN(hidden_dim)
        self.temporal = TemporalPredictor(hidden_dim, hidden_dim)
    
    def forward(self, graph_seq):
        day_embs = []
        
        for g in graph_seq:
            x_dict = self.gnn(g.x_dict, g.edge_index_dict)
            z = x_dict["day"]  # shape [1, hidden_dim]
            day_embs.append(z)
        
        day_embs = torch.stack(day_embs, dim=1)  # [1, seq_len, hidden_dim]
        return self.temporal(day_embs)


In [95]:
import torch
from torch.optim import Adam
import torch.nn as nn

model = FedSpeechModel(hidden_dim=64)
optimizer = Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

L = 30        # sequence length
EPOCHS = 40   # number of epochs

for epoch in range(1, EPOCHS+1):
    
    total_loss = 0.0
    count = 0

    for t in range(L, len(graphs)):
        seq = graphs[t-L:t]
        target = graphs[t].y.float()

        pred = model(seq)
        loss = loss_fn(pred.squeeze(), target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
    
    print(f"Epoch {epoch}/{EPOCHS}  MSE={total_loss/count:.6f}")


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/40  MSE=0.004394
Epoch 2/40  MSE=0.003549
Epoch 3/40  MSE=0.003473
Epoch 4/40  MSE=0.003502
Epoch 5/40  MSE=0.003511
Epoch 6/40  MSE=0.003675
Epoch 7/40  MSE=0.003464
Epoch 8/40  MSE=0.003463
Epoch 9/40  MSE=0.003508
Epoch 10/40  MSE=0.003467


KeyboardInterrupt: 

In [96]:
print(0.003464**0.5)

0.05885575587824864
