In [2]:
from pathlib import Path
from datetime import timedelta

import numpy as np
import torch
from torch_geometric.data import HeteroData
import analysis_utils
import importlib
importlib.reload(analysis_utils)
from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer("all-mpnet-base-v2")

LOOKBACK_DAYS = 30 
TARGET_COLUMN = "Rate_Change"  


In [3]:

def build_global_indices(speeches, topic_scores, rates_df):
    
    # 1) Authors
    author_names = sorted({v["author"] for v in speeches.values()})
    author2idx = {name: i for i, name in enumerate(author_names)}

    # 2) Topics
    topic_names = set()
    for sid, topics in topic_scores.items():
        for tname in topics.keys():
            topic_names.add(tname)
    topic_names = sorted(topic_names)
    topic2idx = {name: i for i, name in enumerate(topic_names)}

    # 3) Speech ids
    speech_ids = sorted(speeches.keys())
    speech2idx = {sid: i for i, sid in enumerate(speech_ids)}

    # 4) Dates
    all_dates = sorted(set(rates_df.index))  # dates where we have rates
    date2idx = {d: i for i, d in enumerate(all_dates)}

    return {
        "author2idx": author2idx,
        "topic2idx": topic2idx,
        "speech2idx": speech2idx,
        "date2idx": date2idx,
        "dates": all_dates,
    }

In [4]:

def get_speeches_in_window(target_date, lookback_days, speeches_by_date):
    """
    Returns a list of speech IDs whose date is in [target_date - lookback_days + 1, target_date].
    """
    start_date = target_date - timedelta(days=lookback_days - 1)
    cur_date = start_date
    selected = []

    while cur_date <= target_date:
        if cur_date in speeches_by_date:
            selected.extend(speeches_by_date[cur_date])
        cur_date += timedelta(days=1)

    return selected



In [5]:

def build_all_graphs(
    speeches,
    speeches_with_embeddings,
    topic_scores,
    rates_df,
    out_dir="graphs",
    lookback_days=LOOKBACK_DAYS,
    target_column=TARGET_COLUMN
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    global_idx = build_global_indices(speeches, topic_scores, rates_df)
    speeches_by_date = analysis_utils.group_speeches_by_date(speeches)
    print(len(speeches), len(speeches_by_date))
    graphs = []
    dates = global_idx["dates"]
    dates = sorted(dates)

    print(len(dates))
    for d in dates:
        print(d)
        g = build_graph_for_date(
            d,
            speeches,
            speeches_with_embeddings,
            topic_scores,
            rates_df,
            speeches_by_date,
            global_idx,
            lookback_days=lookback_days,
            target_column=target_column,
        )
        graphs.append(g)

    for i, g in enumerate(graphs):
        torch.save(g, out_dir / f"graph_{i:04d}.pt")

    print(f"Built {len(graphs)} graphs and saved to {out_dir}")
    return graphs


In [8]:
def build_graph_for_date(
    d,
    speeches,
    speeches_with_embeddings,
    topic_scores,
    rates_df,
    speeches_by_date,
    global_idx,
    lookback_days=30,
    target_column="ffr_delta"
):
    """
    Build a HeteroData graph snapshot for date d.
    Each speech node stores:
        - embedding
        - speech date index
        - lag from d
        - rate change after speech (only if known)
        - raw date (string) for visualization

    New edges:
        speech -> speech ("follows") for past speeches within 30 days
    """

    date2idx = global_idx["date2idx"]
    author2idx = global_idx["author2idx"]
    topic2idx = global_idx["topic2idx"]

    # ==========================================================
    # 1. Collect speech IDs in window
    # ==========================================================
    local_speech_ids = sorted(set(get_speeches_in_window(d, lookback_days, speeches_by_date)))
    num_speeches = len(local_speech_ids)

    speech_i2sid = {i: sid for i, sid in enumerate(local_speech_ids)}
    speech_sid2i = {sid: i for i, sid in enumerate(local_speech_ids)}

    # ==========================================================
    # 2. AUTHOR NODES
    # ==========================================================
    author_names = sorted({speeches[sid]["author"] for sid in local_speech_ids})
    author_i2name = {i: name for i, name in enumerate(author_names)}
    author_name2i = {name: i for i, name in author_i2name.items()}

    data = HeteroData()
    num_authors = len(author_names)
    author_features = np.eye(num_authors)
    data["author"].x = torch.tensor(author_features, dtype=torch.float32)

    # ==========================================================
    # 3. SPEECH NODES with metadata
    # ==========================================================
    speech_feats = []
    speech_dates = []  # for visualization

    all_dates = rates_df.index

    for i in range(num_speeches):
        sid = speech_i2sid[i]
        sdate = speeches[sid]["date"]
        sdate_idx = date2idx.get(sdate, -1)

        # embedding
        emb = np.array(speeches_with_embeddings[sid]["embedding"], dtype=np.float32)

        # lag relative to prediction date
        lag = (d - sdate).days

        # rate change after speech (only if s+1 <= d)
        idx = all_dates.get_loc(sdate)
        rate_change = 0.0

        from datetime import datetime
        if idx + 1 < len(all_dates):
            next_date = all_dates[idx + 1]
            if next_date <= d:
                rate_change = float(rates_df[TARGET_COLUMN].loc[sdate])

        full_feat = np.concatenate([
            emb,
            np.array([sdate_idx], dtype=np.float32),
            np.array([lag], dtype=np.float32),
            np.array([rate_change], dtype=np.float32)
        ])

        speech_feats.append(full_feat)
        speech_dates.append(str(sdate))

    data["speech"].x = torch.tensor(np.stack(speech_feats), dtype=torch.float32)
    data["speech"].date = speech_dates  # store raw strings for visualization

    # ==========================================================
    # 4. TOPIC NODES
    # ==========================================================
    topic_names = sorted({t for sid in local_speech_ids for t in topic_scores[sid]})
    topic_i2name = {i: t for i, t in enumerate(topic_names)}
    topic_name2i = {t: i for i, t in topic_i2name.items()}

    data["topic"].x = torch.tensor(
        [[topic2idx.get(t, -1)] for t in topic_names],
        dtype=torch.float32
    )

    topic_embeddings = []
    for tname in topic_names:
        emb = sentence_model.encode(tname)
        topic_embeddings.append(emb)
    data["topic"].x = torch.tensor(topic_embeddings, dtype=torch.float32)

    # ==========================================================
    # 5. DAY NODE
    # ==========================================================
    today_idx = date2idx[d]
    data["day"].x = torch.tensor([[today_idx]], dtype=torch.float32)

    # ==========================================================
    # 6. AUTHOR → SPEECH edges
    # ==========================================================
    author_src, speech_dst = [], []
    for sid in local_speech_ids:
        a_i = author_name2i[speeches[sid]["author"]]
        s_i = speech_sid2i[sid]
        author_src.append(a_i)
        speech_dst.append(s_i)

    data["author", "gives", "speech"].edge_index = torch.tensor(
        [author_src, speech_dst], dtype=torch.long
    )

    # ==========================================================
    # 7. SPEECH → TOPIC edges
    # ==========================================================
    st_src, st_dst, st_attr = [], [], []
    for sid in local_speech_ids:
        s_i = speech_sid2i[sid]
        for tname, score in topic_scores[sid].items():
            t_i = topic_name2i[tname]
            st_src.append(s_i)
            st_dst.append(t_i)
            st_attr.append([float(score)])

    data["speech", "mentions", "topic"].edge_index = torch.tensor(
        [st_src, st_dst], dtype=torch.long
    )
    data["speech", "mentions", "topic"].edge_attr = torch.tensor(
        st_attr, dtype=torch.float32
    )

    # ==========================================================
    # 8. DAY → SPEECH recency edges
    # ==========================================================
    day_src, day_dst, day_attr = [], [], []
    for sid in local_speech_ids:
        s_i = speech_sid2i[sid]
        sdate = speeches[sid]["date"]
        lag = (d - sdate).days
        decay = np.exp(-lag / 10)

        day_src.append(0)
        day_dst.append(s_i)
        day_attr.append([lag, decay])

    data["day", "references", "speech"].edge_index = torch.tensor(
        [day_src, day_dst], dtype=torch.long
    )
    data["day", "references", "speech"].edge_attr = torch.tensor(
        day_attr, dtype=torch.float32
    )

    # ==========================================================
    # 9. NEW: SPEECH → SPEECH temporal edges ("follows")
    # ==========================================================
    follow_src, follow_dst, follow_attr = [], [], []

    for sid_curr in local_speech_ids:
        i_curr = speech_sid2i[sid_curr]
        sdate_curr = speeches[sid_curr]["date"]

        for sid_past in local_speech_ids:
            sdate_past = speeches[sid_past]["date"]
            if sdate_past >= sdate_curr:
                continue  # only earlier speeches

            lag = (sdate_curr - sdate_past).days
            if lag <= lookback_days:
                i_past = speech_sid2i[sid_past]

                follow_src.append(i_curr)   # current speech
                follow_dst.append(i_past)   # past speech
                follow_attr.append([lag])   # temporal lag

    data["speech", "follows", "speech"].edge_index = torch.tensor(
        [follow_src, follow_dst], dtype=torch.long
    )
    data["speech", "follows", "speech"].edge_attr = torch.tensor(
        follow_attr, dtype=torch.float32
    )

    data["speech", "rev_follows", "speech"].edge_index = torch.tensor(
        [follow_dst, follow_src], dtype=torch.long
    )

    # ==========================================================
    # 10. REVERSE edges for other types
    # ==========================================================
    data["speech", "rev_gives", "author"].edge_index = torch.tensor(
        [speech_dst, author_src], dtype=torch.long
    )
    data["topic", "rev_mentions", "speech"].edge_index = torch.tensor(
        [st_dst, st_src], dtype=torch.long
    )
    data["speech", "rev_references", "day"].edge_index = torch.tensor(
        [day_dst, day_src], dtype=torch.long
    )

    # ==========================================================
    # 11. TARGET LABEL
    # ==========================================================
    y = float(rates_df.loc[d, target_column])
    data.y = torch.tensor([y], dtype=torch.float32)
    data.date = torch.tensor([today_idx], dtype=torch.long)

    return data


In [9]:

speeches = analysis_utils.load_speeches()
topic_scores = analysis_utils.load_topic_scores_by_sid()
rates_df = analysis_utils.load_rates()
speeches_with_embeddings = analysis_utils.load_speeches_with_embeddings()

graphs = build_all_graphs(
        speeches,
        speeches_with_embeddings,
        topic_scores,
        rates_df,
        out_dir="graphs_ffr_delta",   # change as you like
        lookback_days=LOOKBACK_DAYS,
        target_column=TARGET_COLUMN,
    )


977 647
647
2018-06-18 00:00:00
2018-06-19 00:00:00
2018-06-20 00:00:00
2018-06-27 00:00:00
2018-07-18 00:00:00
2018-07-19 00:00:00
2018-08-08 00:00:00
2018-08-21 00:00:00
2018-08-24 00:00:00
2018-09-07 00:00:00
2018-09-08 00:00:00
2018-09-12 00:00:00
2018-09-14 00:00:00
2018-09-27 00:00:00
2018-09-28 00:00:00
2018-10-01 00:00:00
2018-10-02 00:00:00
2018-10-03 00:00:00
2018-10-04 00:00:00
2018-10-09 00:00:00
2018-10-15 00:00:00
2018-10-17 00:00:00
2018-10-18 00:00:00
2018-10-24 00:00:00
2018-10-25 00:00:00
2018-10-26 00:00:00
2018-11-09 00:00:00
2018-11-12 00:00:00
2018-11-13 00:00:00
2018-11-16 00:00:00
2018-11-27 00:00:00
2018-11-28 00:00:00
2018-11-29 00:00:00
2018-11-30 00:00:00
2018-12-03 00:00:00
2018-12-05 00:00:00
2018-12-06 00:00:00
2018-12-07 00:00:00
2019-01-09 00:00:00
2019-01-10 00:00:00
2019-01-18 00:00:00
2019-02-01 00:00:00
2019-02-04 00:00:00
2019-02-05 00:00:00
2019-02-06 00:00:00
2019-02-10 00:00:00
2019-02-11 00:00:00
2019-02-12 00:00:00
2019-02-13 00:00:00
2019-02-