In [14]:
import json
from pathlib import Path
from datetime import datetime, timedelta
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from torch_geometric.data import HeteroData

# ============================================================
# Fed Speech → Temporal Graph → GNN → Fed Funds Prediction
# ============================================================

import json
from pathlib import Path
from datetime import datetime, date, timedelta
from collections import defaultdict

import torch
from torch_geometric.data import HeteroData, Dataset
from torch_geometric.nn import HeteroConv, SAGEConv

import torch.nn as nn
from torch.optim import Adam


DATA_DIR = Path("data")
SPEECH_FOLDER = DATA_DIR / "text_data/"
TOPIC_SCORE_FOLDER = DATA_DIR / "topic_scores/"
RATES_FILE = DATA_DIR / "price_data/2025-10-26 Fed Funds 12M 6M Historical Swap Rates.xlsx"

LOOKBACK_DAYS = 30   # rolling window length
TARGET_COLUMN = "fed_funds"   # or "asw"
PREDICT_DELTA = True         # True = predict Δy, False = level
START_DATE = datetime(2018, 1, 1)


In [15]:
# date_utils.py

from datetime import datetime, date

def parse_date(dstr: str) -> date:
    """
    Try several common date formats and return a datetime.date.
    Adjust/add formats if your data differs.
    """
    dstr = dstr.strip()
    formats = [
        "%Y-%m-%d",       # 2023-08-25
        "%Y/%m/%d",       # 2023/08/25
        "%Y-%m-%dT%H:%M:%S",  # 2023-08-25T00:00:00
        "%B %d, %Y",      # August 25, 2023
        "%b %d, %Y",      # Aug 25, 2023
    ]
    for fmt in formats:
        try:
            return datetime.strptime(dstr, fmt)
        except ValueError:
            continue
    raise ValueError(f"Unrecognized date format: {dstr}")


In [3]:
import glob

def load_speeches(path=SPEECH_FOLDER):
    
    
    json_files = glob.glob(str(path) + "/*.json")

    speeches = {}
    
    for json_file in json_files:
        with open(json_file, "r", encoding="utf-8") as f:
            raw = json.load(f)

        for row in raw:
        
            sid = row["id"]
            date = parse_date(row["date"])
            if date < START_DATE:
                continue 
            
            speeches[sid] = {
                "author": json_file.split("/")[-1].split(".")[0],
                "text": row["text"],
                "date": parse_date(row["date"]),
            }
    return speeches

def load_topic_scores(path=TOPIC_SCORE_FOLDER):
    
    json_files = glob.glob(str(path) + "/*.json")
    scores = {}
    
    for json_file in json_files: 
        with open(json_file, "r", encoding="utf-8") as f:
            raw = json.load(f)
            
        for row in raw:
            sid = row["id"]
            scores[sid] = row["gpt-5"]
    return scores

def load_rates(path=RATES_FILE):
    df = pd.read_excel(path)
    df["Date"] = df["Date"].apply(lambda x: str(x).split(" ")[0])
    df["Date"] = df["Date"].apply(parse_date)
    df = df.set_index("Date").sort_index()
    df = df[["Rate"]]
    return df

In [4]:

def build_global_indices(speeches, topic_scores, rates_df):
    
    # 1) Authors
    author_names = sorted({v["author"] for v in speeches.values()})
    author2idx = {name: i for i, name in enumerate(author_names)}

    # 2) Topics
    topic_names = set()
    for sid, topics in topic_scores.items():
        for tname in topics.keys():
            topic_names.add(tname)
    topic_names = sorted(topic_names)
    topic2idx = {name: i for i, name in enumerate(topic_names)}

    # 3) Speech ids
    speech_ids = sorted(speeches.keys())
    speech2idx = {sid: i for i, sid in enumerate(speech_ids)}

    # 4) Dates
    all_dates = sorted(set(rates_df.index))  # dates where we have rates
    date2idx = {d: i for i, d in enumerate(all_dates)}

    return {
        "author2idx": author2idx,
        "topic2idx": topic2idx,
        "speech2idx": speech2idx,
        "date2idx": date2idx,
        "dates": all_dates,
    }

In [5]:
# build_graphs.py (part 3)

def group_speeches_by_date(speeches):
    speeches_by_date = defaultdict(list)
    for sid, info in speeches.items():
        d = info["date"]
        speeches_by_date[d].append(sid)
    return speeches_by_date


In [6]:

def get_speeches_in_window(target_date, lookback_days, speeches_by_date):
    """
    Returns a list of speech IDs whose date is in [target_date - lookback_days + 1, target_date].
    """
    start_date = target_date - timedelta(days=lookback_days - 1)
    cur_date = start_date
    selected = []

    while cur_date <= target_date:
        if cur_date in speeches_by_date:
            selected.extend(speeches_by_date[cur_date])
        cur_date += timedelta(days=1)

    return selected



In [7]:

def build_all_graphs(
    speeches,
    topic_scores,
    rates_df,
    out_dir="graphs",
    lookback_days=LOOKBACK_DAYS,
    target_column=TARGET_COLUMN,
    predict_delta=PREDICT_DELTA,
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    global_idx = build_global_indices(speeches, topic_scores, rates_df)
    speeches_by_date = group_speeches_by_date(speeches)

    graphs = []
    dates = global_idx["dates"]

    for d in dates:
        g = build_graph_for_date(
            d,
            speeches,
            topic_scores,
            rates_df,
            speeches_by_date,
            global_idx,
            lookback_days=lookback_days,
            target_column=target_column,
            predict_delta=predict_delta,
        )
        if g is None:
            continue
        graphs.append(g)

    for i, g in enumerate(graphs):
        torch.save(g, out_dir / f"graph_{i:04d}.pt")

    print(f"Built {len(graphs)} graphs and saved to {out_dir}")
    return graphs


In [11]:
# build_graphs.py (part 5)

def build_graph_for_date(
    d,
    speeches,
    topic_scores,
    rates_df,
    speeches_by_date,
    global_idx,
    lookback_days=LOOKBACK_DAYS,
    target_column=TARGET_COLUMN,
    predict_delta=PREDICT_DELTA,
):
    """
    Build HeteroData graph snapshot for date d.
    The target is the next day's rate (or change).
    """
    print(d)
    dates = global_idx["dates"]
    date2idx = global_idx["date2idx"]
    author2idx = global_idx["author2idx"]
    topic2idx = global_idx["topic2idx"]

    # --- Check that we can define a target (need d and next_date) ---
    if d not in date2idx:
        return None  # no market data for this date
    t_idx = date2idx[d]
    if t_idx + 1 >= len(dates):
        return None  # no next day to predict

    next_date = dates[t_idx + 1]
    if next_date not in rates_df.index:
        return None

    # --- Gather speeches in window ---
    speech_ids_window = get_speeches_in_window(d, lookback_days, speeches_by_date)
    if len(speech_ids_window) == 0:
        return None  # no information in window

    # --- Build local indices for this snapshot ---
    # Local speech indices
    local_speech_ids = sorted(set(speech_ids_window))
    speech_local2global = {i: sid for i, sid in enumerate(local_speech_ids)}
    speech_global2local = {sid: i for i, sid in enumerate(local_speech_ids)}

    # Local authors (only those that appear)
    author_names_window = sorted(
        {speeches[sid]["author"] for sid in local_speech_ids}
    )
    author_local2name = {i: name for i, name in enumerate(author_names_window)}
    author_name2local = {name: i for i, name in author_local2name.items()}

    # Local topics (only those that appear)
    topic_names_window = set()
    for sid in local_speech_ids:
        if sid not in topic_scores:
            continue
        for tname in topic_scores[sid].keys():
            topic_names_window.add(tname)
    topic_names_window = sorted(topic_names_window)
    topic_local2name = {i: name for i, name in enumerate(topic_names_window)}
    topic_name2local = {name: i for i, name in topic_local2name.items()}

    num_authors = len(author_local2name)
    num_speeches = len(speech_local2global)
    num_topics = len(topic_local2name)
    num_days = 1  # single day node

    data = HeteroData()

    # ============================
    # 1) NODE FEATURES
    # ============================

    # --- author nodes ---
    # simple feature: global author index (can be embedded later)
    author_global_idx = []
    for i in range(num_authors):
        name = author_local2name[i]
        author_global_idx.append(author2idx.get(name, -1))
    author_global_idx = torch.tensor(author_global_idx, dtype=torch.long).unsqueeze(-1)
    data["author"].x = author_global_idx  # shape [num_authors, 1]

    # --- speech nodes ---
    # feature: days since start + maybe current rate
    speech_features = []
    start_date = dates[0]
    for i in range(num_speeches):
        sid = speech_local2global[i]
        sdate = speeches[sid]["date"]
        days_since_start = (sdate - start_date).days
        speech_features.append([days_since_start])
    data["speech"].x = torch.tensor(speech_features, dtype=torch.float32)  # [S, 1]

    # --- topic nodes ---
    # simple zero or global topic index
    topic_global_idx = []
    for i in range(num_topics):
        name = topic_local2name[i]
        topic_global_idx.append(topic2idx.get(name, -1))
    topic_global_idx = torch.tensor(topic_global_idx, dtype=torch.long).unsqueeze(-1)
    data["topic"].x = topic_global_idx  # [T, 1]

    # --- day node ---
    # You can encode today's rate and maybe date index
    today_rate = rates_df.loc[d, target_column]
    today_idx = date2idx[d]
    day_x = torch.tensor([[today_idx, today_rate]], dtype=torch.float32)  # [1, 2]
    data["day"].x = day_x

    # ============================
    # 2) EDGES
    # ============================

    # --- author -> speech edges (author "gives" speech) ---
    author_src = []
    speech_dst = []
    for sid in local_speech_ids:
        author_name = speeches[sid]["author"]
        if author_name not in author_name2local:
            continue
        a_local = author_name2local[author_name]
        s_local = speech_global2local[sid]
        author_src.append(a_local)
        speech_dst.append(s_local)

    if len(author_src) == 0:
        return None

    data["author", "gives", "speech"].edge_index = torch.tensor(
        [author_src, speech_dst], dtype=torch.long
    )

    # --- speech -> topic edges (with hawk/dove features) ---
    st_src = []
    st_dst = []
    st_attr = []
    for sid in local_speech_ids:
        if sid not in topic_scores:
            continue
        s_local = speech_global2local[sid]
        topics = topic_scores[sid]
        for tname, vals in topics.items():
            if tname not in topic_name2local:
                continue
            t_local = topic_name2local[tname]
            
            net = vals 
            st_src.append(s_local)
            st_dst.append(t_local)
            st_attr.append([net])

    if len(st_src) > 0:
        data["speech", "mentions", "topic"].edge_index = torch.tensor(
            [st_src, st_dst], dtype=torch.long
        )
        data["speech", "mentions", "topic"].edge_attr = torch.tensor(
            st_attr, dtype=torch.float32
        )
    else:
        # still set empty tensors to avoid errors
        data["speech", "mentions", "topic"].edge_index = torch.empty(
            (2, 0), dtype=torch.long
        )
        data["speech", "mentions", "topic"].edge_attr = torch.empty(
            (0, 3), dtype=torch.float32
        )

    # --- day -> speech edges (connect day node to all speeches in window) ---
    day_src = []
    day_dst = []
    day_attr = []

    for sid in local_speech_ids:
        s_local = speech_global2local[sid]
        sdate = speeches[sid]["date"]
        lag_days = (d - sdate).days  # how many days ago
        # simple recency feature: linear lag + exponential decay
        decay = np.exp(-lag_days / 10.0)
        day_src.append(0)        # only one day node, index 0
        day_dst.append(s_local)
        day_attr.append([lag_days, decay])

    data["day", "references", "speech"].edge_index = torch.tensor(
        [day_src, day_dst], dtype=torch.long
    )
    data["day", "references", "speech"].edge_attr = torch.tensor(
        day_attr, dtype=torch.float32
    )
    
    
# Reverse of: author -> speech
    data["speech", "rev_gives", "author"].edge_index = torch.tensor(
    [speech_dst, author_src], dtype=torch.long
)

# Reverse of: speech -> topic
    data["topic", "rev_mentions", "speech"].edge_index = torch.tensor(
    [st_dst, st_src], dtype=torch.long
)

# Reverse of: day -> speech
    data["speech", "rev_references", "day"].edge_index = torch.tensor(
    [day_dst, day_src], dtype=torch.long
) 
    # ============================
    # 3) TARGET y (next-day rate or Δrate)
    # ============================
    today_val = float(rates_df.loc[d, target_column])
    next_val = float(rates_df.loc[next_date, target_column])

    if predict_delta:
        y = next_val - today_val
    else:
        y = next_val

    
    data.y = torch.tensor([y], dtype=torch.float32)
    data.date = torch.tensor([date2idx[d]], dtype=torch.long)  # optional

    data["author"].x = torch.tensor(author_global_idx, dtype=torch.float32)
    data["topic"].x = torch.tensor(topic_global_idx, dtype=torch.float32)
    data["speech"].x = torch.tensor(speech_features, dtype=torch.float32)
    data["day"].x = torch.tensor([[today_idx, today_rate]], dtype=torch.float32)

    return data

In [12]:

speeches = load_speeches()
topic_scores = load_topic_scores()
rates_df = load_rates()

graphs = build_all_graphs(
        speeches,
        topic_scores,
        rates_df,
        out_dir="graphs_ffr_delta",   # change as you like
        lookback_days=30,
        target_column="Rate",
        predict_delta=True,
    )


2018-06-04 00:00:00
2018-06-05 00:00:00
2018-06-06 00:00:00
2018-06-07 00:00:00
2018-06-08 00:00:00
2018-06-11 00:00:00
2018-06-12 00:00:00
2018-06-13 00:00:00
2018-06-14 00:00:00
2018-06-15 00:00:00
2018-06-18 00:00:00
2018-06-19 00:00:00
2018-06-20 00:00:00
2018-06-21 00:00:00
2018-06-22 00:00:00
2018-06-25 00:00:00
2018-06-26 00:00:00
2018-06-27 00:00:00
2018-06-28 00:00:00
2018-06-29 00:00:00
2018-07-02 00:00:00
2018-07-03 00:00:00
2018-07-05 00:00:00
2018-07-06 00:00:00
2018-07-09 00:00:00
2018-07-10 00:00:00
2018-07-11 00:00:00
2018-07-12 00:00:00
2018-07-13 00:00:00
2018-07-16 00:00:00
2018-07-17 00:00:00
2018-07-18 00:00:00
2018-07-19 00:00:00
2018-07-20 00:00:00
2018-07-23 00:00:00
2018-07-24 00:00:00
2018-07-25 00:00:00
2018-07-26 00:00:00
2018-07-27 00:00:00
2018-07-30 00:00:00
2018-07-31 00:00:00
2018-08-01 00:00:00
2018-08-02 00:00:00
2018-08-03 00:00:00
2018-08-06 00:00:00
2018-08-07 00:00:00
2018-08-08 00:00:00
2018-08-09 00:00:00
2018-08-10 00:00:00
2018-08-13 00:00:00


  data["author"].x = torch.tensor(author_global_idx, dtype=torch.float32)
  data["topic"].x = torch.tensor(topic_global_idx, dtype=torch.float32)


2022-02-16 00:00:00
2022-02-17 00:00:00
2022-02-18 00:00:00
2022-02-22 00:00:00
2022-02-23 00:00:00
2022-02-24 00:00:00
2022-02-25 00:00:00
2022-02-28 00:00:00
2022-03-01 00:00:00
2022-03-02 00:00:00
2022-03-03 00:00:00
2022-03-04 00:00:00
2022-03-07 00:00:00
2022-03-08 00:00:00
2022-03-09 00:00:00
2022-03-10 00:00:00
2022-03-11 00:00:00
2022-03-14 00:00:00
2022-03-15 00:00:00
2022-03-16 00:00:00
2022-03-17 00:00:00
2022-03-18 00:00:00
2022-03-21 00:00:00
2022-03-22 00:00:00
2022-03-23 00:00:00
2022-03-24 00:00:00
2022-03-25 00:00:00
2022-03-28 00:00:00
2022-03-29 00:00:00
2022-03-30 00:00:00
2022-03-31 00:00:00
2022-04-01 00:00:00
2022-04-04 00:00:00
2022-04-05 00:00:00
2022-04-06 00:00:00
2022-04-07 00:00:00
2022-04-08 00:00:00
2022-04-11 00:00:00
2022-04-12 00:00:00
2022-04-13 00:00:00
2022-04-14 00:00:00
2022-04-15 00:00:00
2022-04-18 00:00:00
2022-04-19 00:00:00
2022-04-20 00:00:00
2022-04-21 00:00:00
2022-04-22 00:00:00
2022-04-25 00:00:00
2022-04-26 00:00:00
2022-04-27 00:00:00


In [17]:
class FedSpeechDataset(Dataset):
    def __init__(self, graphs):
        self.graphs = graphs

    def len(self):
        return len(self.graphs)

    def get(self, idx):
        return self.graphs[idx]

class SpeechHeteroGNN(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()

        self.conv =self.conv = HeteroConv({
    ('author', 'gives', 'speech'): SAGEConv((-1, -1), hidden_dim),
    ('speech', 'rev_gives', 'author'): SAGEConv((-1, -1), hidden_dim),

    ('speech', 'mentions', 'topic'): SAGEConv((-1, -1), hidden_dim),
    ('topic', 'rev_mentions', 'speech'): SAGEConv((-1, -1), hidden_dim),

    ('day', 'references', 'speech'): SAGEConv((-1, -1), hidden_dim),
    ('speech', 'rev_references', 'day'): SAGEConv((-1, -1), hidden_dim),
})


        self.act = nn.ReLU()

    def forward(self, x_dict, edge_index_dict):
        # hetero conv expects x_dict and edge_index_dict
        x_out = self.conv(x_dict, edge_index_dict)

        # apply activation to all node types
        for k in x_out:
            x_out[k] = self.act(x_out[k])

        return x_out


class TemporalPredictor(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=64):
        super().__init__()
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, seq_embeddings):
        out, _ = self.gru(seq_embeddings)
        return self.fc(out[:, -1])

class FedSpeechModel(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.gnn = SpeechHeteroGNN(hidden_dim)
        self.temporal = TemporalPredictor(hidden_dim, hidden_dim)

    def forward(self, graph_sequence):
        """
        graph_sequence: list of HeteroData for dates [t-L, ..., t]
        """
        day_embeddings = []
        
        for g in graph_sequence:
            x_dict = self.gnn(g.x_dict, g.edge_index_dict)
            z_t = x_dict["day"]  # shape [1, hidden_dim]
            day_embeddings.append(z_t)

        # Stack into tensor: [batch=1, seq_len, embed_dim]
        day_embeddings = torch.stack(day_embeddings, dim=1)

        pred = self.temporal(day_embeddings)  # shape [1, 1]
        return pred


num_epochs = 50
L = 30  # window size

curve = {}
model = FedSpeechModel(hidden_dim=64)
optimizer = Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

for epoch in range(num_epochs):
    total_loss = 0.0
    count = 0

    for t in range(L, len(graphs) - 1):
        seq = graphs[t-L:t]           # sequence of 30 graphs
        target = graphs[t].y.float()  # scalar label

        optimizer.zero_grad()
        pred = model(seq)             # shape [1,1]

        loss = loss_fn(pred.squeeze(), target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1

    avg_loss = total_loss / count
    print(f"Epoch {epoch+1}/{num_epochs}   Loss = {avg_loss:.6f}")
    curve[epoch] = avg_loss


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/50   Loss = 0.004321
Epoch 2/50   Loss = 0.004173
Epoch 3/50   Loss = 0.004121
Epoch 4/50   Loss = 0.004221
Epoch 5/50   Loss = 0.004244
Epoch 6/50   Loss = 0.004201
Epoch 7/50   Loss = 0.004389
Epoch 8/50   Loss = 0.004607
Epoch 9/50   Loss = 0.004896
Epoch 10/50   Loss = 0.004692
Epoch 11/50   Loss = 0.004394
Epoch 12/50   Loss = 0.004362
Epoch 13/50   Loss = 0.004362
Epoch 14/50   Loss = 0.004321
Epoch 15/50   Loss = 0.004260
Epoch 16/50   Loss = 0.004201
Epoch 17/50   Loss = 0.004298
Epoch 18/50   Loss = 0.004339
Epoch 19/50   Loss = 0.004129
Epoch 20/50   Loss = 0.004188
Epoch 21/50   Loss = 0.004100
Epoch 22/50   Loss = 0.004105
Epoch 23/50   Loss = 0.004182
Epoch 24/50   Loss = 0.004292
Epoch 25/50   Loss = 0.004250
Epoch 26/50   Loss = 0.004288
Epoch 27/50   Loss = 0.004320
Epoch 28/50   Loss = 0.004346
Epoch 29/50   Loss = 0.004364
Epoch 30/50   Loss = 0.004265
Epoch 31/50   Loss = 0.004715
Epoch 32/50   Loss = 0.004533
Epoch 33/50   Loss = 0.004299
Epoch 34/50   Loss 

In [19]:
import plotly.graph_objects as go
figure = go.Figure()
figure.add_trace(go.Scatter(x=list(curve.keys()), y=list(curve.values())))
figure.show()