In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
import matplotlib.pyplot as plt

# Set a seed for reproduceability
seed = 42
import yfinance as yf

import yfinance as yf
import pandas as pd
import numpy as np
import ta  # Technical Analysis library
from ta.momentum import RSIIndicator
from ta.trend import MACD
from torch import nn, optim
from tqdm import tqdm

In [2]:
# Market

# --- Download SPY and VIX hourly data ---
spy = yf.Ticker("ES=F").history(period='730d', interval='1h')
vix = yf.Ticker("^VIX").history(period='730d', interval='1h')

# --- Rename datetime index ---
spy = spy.tz_localize(None)
vix = vix.tz_localize(None)

# --- Interpolate VIX onto SPY time index ---
vix_interp = vix['Close'].reindex(spy.index).interpolate(method='time')
vix_interp = vix_interp.interpolate(method='time').fillna(method='bfill').fillna(method='ffill')

# --- VWAP Calculation ---
def compute_vwap(prices, volumes, window):
    return (prices * volumes).rolling(window=window).sum() / volumes.rolling(window=window).sum()

spy['VWAP_24h'] = compute_vwap(spy['Close'], spy['Volume'], window=24)
spy['VWAP_120h'] = compute_vwap(spy['Close'], spy['Volume'], window=120)

# --- Compute % difference of price from VWAPs ---
spy['pct_above_vwap_24h'] = (spy['Close'] - spy['VWAP_24h']) / spy['VWAP_24h']
spy['pct_above_vwap_120h'] = (spy['Close'] - spy['VWAP_120h']) / spy['VWAP_120h']

# --- RSI ---
rsi_calc = RSIIndicator(close=spy['Close'], window=14)
spy['RSI'] = rsi_calc.rsi()

# --- MACD Histogram ---
macd_calc = MACD(close=spy['Close'], window_slow=26, window_fast=12, window_sign=9)
spy['MACD_hist'] = macd_calc.macd_diff()

# --- Normalize and transform features ---
spy['spy_close_pct'] = spy['Close'].pct_change()
spy['spy_open_pct'] = (spy['Close'] - spy['Open'])/spy['Open']
spy['spy_volume_log'] = np.log1p(spy['Volume'])

# --- Final embedding DataFrame ---
embed_df = pd.DataFrame({
    'spy_close': spy['spy_close_pct'],
    'spy_open': spy['spy_open_pct'],
    'spy_volume': spy['spy_volume_log'],
    'vix_close': vix_interp,
    'pct_above_vwap_24h': spy['pct_above_vwap_24h'],
    'pct_above_vwap_120h': spy['pct_above_vwap_120h'],
    'rsi': spy['RSI'],
    'macd_hist': spy['MACD_hist'],
})

# --- Drop rows with any NaNs caused by rolling calculations or pct_change ---
embed_df = embed_df.dropna()

  vix_interp = vix_interp.interpolate(method='time').fillna(method='bfill').fillna(method='ffill')


In [3]:
embed_df

Unnamed: 0_level_0,spy_close,spy_open,spy_volume,vix_close,pct_above_vwap_24h,pct_above_vwap_120h,rsi,macd_hist
Datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2022-11-29 18:00:00,-0.000946,-0.000946,0.000000,21.826667,-0.001298,-0.007022,37.257193,-0.399690
2022-11-29 19:00:00,0.001263,0.001263,9.016634,21.820834,-0.000036,-0.005761,42.185776,0.102505
2022-11-29 20:00:00,0.000442,0.000378,9.093694,21.815001,0.000416,-0.005321,43.848321,0.595475
2022-11-29 21:00:00,0.000063,0.000126,8.145260,21.809167,0.000498,-0.005267,44.095648,0.959736
2022-11-29 22:00:00,0.000252,0.000252,8.193124,21.803334,0.000761,-0.005045,45.136612,1.269906
...,...,...,...,...,...,...,...,...
2025-04-17 12:00:00,0.006068,0.006021,12.061636,30.389999,0.006044,-0.001809,51.705124,1.122160
2025-04-17 13:00:00,-0.000701,-0.000701,11.363938,30.160000,0.005532,-0.002422,50.626905,2.099053
2025-04-17 14:00:00,-0.001965,-0.001965,11.584762,30.030001,0.002791,-0.004386,47.631792,2.011514
2025-04-17 15:00:00,-0.003985,-0.004032,12.108305,29.650000,-0.001900,-0.008391,42.191624,0.570774


In [4]:
# Bond info

# --- Download data ---
ten_year = yf.Ticker("10Y=F").history(period='730d', interval='1h')
two_year = yf.Ticker("2YY=F").history(period='730d', interval='1h')

# Remove timezone
ten_year = ten_year.tz_localize(None)
two_year = two_year.tz_localize(None)

# Reindex to match SPY (hourly)
ten_year = ten_year.reindex(spy.index)
two_year = two_year.reindex(spy.index)

# Interpolate missing values based on time
ten_year = ten_year.interpolate(method='time')
two_year = two_year.interpolate(method='time')

ten_year = ten_year.fillna(method='bfill').fillna(method='ffill')
two_year = two_year.fillna(method='bfill').fillna(method='ffill')

# --- VWAP Calculation ---
def compute_vwap(prices, volumes, window):
    return (prices * volumes).rolling(window=window).sum() / volumes.rolling(window=window).sum()

# --- 10Y ---
ten_year['VWAP_24h'] = compute_vwap(ten_year['Close'], ten_year['Volume'], window=24)
ten_year['VWAP_120h'] = compute_vwap(ten_year['Close'], ten_year['Volume'], window=120)
ten_year['pct_close'] = ten_year['Close'].pct_change()
ten_year['log_volume'] = np.log1p(ten_year['Volume'])
ten_year['pct_above_vwap_24h'] = (ten_year['Close'] - ten_year['VWAP_24h']) / ten_year['VWAP_24h']
ten_year['pct_above_vwap_120h'] = (ten_year['Close'] - ten_year['VWAP_120h']) / ten_year['VWAP_120h']

# --- 2Y ---
two_year['VWAP_24h'] = compute_vwap(two_year['Close'], two_year['Volume'], window=24)
two_year['VWAP_120h'] = compute_vwap(two_year['Close'], two_year['Volume'], window=120)
two_year['pct_close'] = two_year['Close'].pct_change()
two_year['log_volume'] = np.log1p(two_year['Volume'])
two_year['pct_above_vwap_24h'] = (two_year['Close'] - two_year['VWAP_24h']) / two_year['VWAP_24h']
two_year['pct_above_vwap_120h'] = (two_year['Close'] - two_year['VWAP_120h']) / two_year['VWAP_120h']

# --- Yield Curve Spread ---
yield_spread = (ten_year['Close'] - two_year['Close'])  # raw difference, not pct

# --- Drop NaNs caused by pct_change, rolling VWAPs ---
valid_index = spy.index.intersection(
    ten_year.dropna().index
).intersection(
    two_year.dropna().index
)

# Align everything
embed_df = embed_df.loc[valid_index]
ten_year = ten_year.loc[valid_index]
two_year = two_year.loc[valid_index]
yield_spread = yield_spread.loc[valid_index]

# --- Add new columns to embed_df ---
embed_df['ten_yr_close'] = ten_year['pct_close']
embed_df['ten_yr_volume'] = ten_year['log_volume']
embed_df['ten_yr_vwap_24h'] = ten_year['pct_above_vwap_24h']
embed_df['ten_yr_vwap_120h'] = ten_year['pct_above_vwap_120h']

embed_df['two_yr_close'] = two_year['pct_close']
embed_df['two_yr_volume'] = two_year['log_volume']
embed_df['two_yr_vwap_24h'] = two_year['pct_above_vwap_24h']
embed_df['yield_spread_10y_2y'] = yield_spread  # not transformed — already a relative signal

# Final clean-up
embed_df = embed_df.dropna()

  ten_year = ten_year.fillna(method='bfill').fillna(method='ffill')
  two_year = two_year.fillna(method='bfill').fillna(method='ffill')


In [5]:
embed_df

Unnamed: 0_level_0,spy_close,spy_open,spy_volume,vix_close,pct_above_vwap_24h,pct_above_vwap_120h,rsi,macd_hist,ten_yr_close,ten_yr_volume,ten_yr_vwap_24h,ten_yr_vwap_120h,two_yr_close,two_yr_volume,two_yr_vwap_24h,yield_spread_10y_2y
Datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
2022-11-29 18:00:00,-0.000946,-0.000946,0.000000,21.826667,-0.001298,-0.007022,37.257193,-0.399690,-0.006926,0.000000,0.000338,-0.002921,0.000030,1.280934,0.003309,-0.761133
2022-11-29 19:00:00,0.001263,0.001263,9.016634,21.820834,-0.000036,-0.005761,42.185776,0.102505,0.005097,2.397895,0.005416,0.002159,0.000015,1.223775,0.003268,-0.742200
2022-11-29 20:00:00,0.000442,0.000378,9.093694,21.815001,0.000416,-0.005321,43.848321,0.595475,0.000534,2.890372,0.005910,0.002696,0.000015,1.163151,0.003239,-0.740266
2022-11-29 21:00:00,0.000063,0.000126,8.145260,21.809167,0.000498,-0.005267,44.095648,0.959736,0.000267,2.397895,0.006045,0.002995,0.000015,1.098612,0.003220,-0.739333
2022-11-29 22:00:00,0.000252,0.000252,8.193124,21.803334,0.000761,-0.005045,45.136612,1.269906,0.000267,0.693147,0.006274,0.003353,0.000015,1.029619,0.003210,-0.738400
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2025-04-17 05:00:00,0.000842,0.000795,9.610927,31.530001,0.004089,-0.002047,49.869042,6.337219,-0.001159,2.302585,0.000905,-0.016496,-0.000578,0.000000,-0.012086,0.552304
2025-04-17 06:00:00,-0.003551,-0.003458,10.190657,31.450001,0.000675,-0.005536,44.099980,4.656170,-0.001856,3.218876,-0.000666,-0.018552,-0.000578,0.000000,-0.012437,0.546478
2025-04-17 07:00:00,-0.000281,-0.000281,9.840495,30.990000,0.000519,-0.005732,43.670459,3.367148,-0.001627,3.891820,-0.002210,-0.020185,-0.000579,0.000000,-0.012788,0.541652
2025-04-17 08:00:00,0.000844,0.000844,10.709338,31.080000,0.001467,-0.004806,45.388893,2.766646,-0.002561,5.347108,-0.004229,-0.022562,-0.000579,0.000000,-0.013138,0.532826


In [6]:
# Commodites info
# --- Download Gold and Oil futures data ---
gold = yf.Ticker("GC=F").history(period='730d', interval='1h')
oil = yf.Ticker("CL=F").history(period='730d', interval='1h')

# --- Clean timezone ---
gold = gold.tz_localize(None)
oil = oil.tz_localize(None)

# --- Reindex to SPY timestamps ---
gold = gold.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')
oil = oil.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')

# --- VWAP function (already defined above) ---
# def compute_vwap(prices, volumes, window):
#     return (prices * volumes).rolling(window=window).sum() / volumes.rolling(window=window).sum()

# --- Gold features ---
gold['VWAP_24h'] = compute_vwap(gold['Close'], gold['Volume'], window=24)
gold['VWAP_120h'] = compute_vwap(gold['Close'], gold['Volume'], window=120)
gold['pct_close'] = gold['Close'].pct_change()
gold['log_volume'] = np.log1p(gold['Volume'])
gold['pct_above_vwap_24h'] = (gold['Close'] - gold['VWAP_24h']) / gold['VWAP_24h']
gold['pct_above_vwap_120h'] = (gold['Close'] - gold['VWAP_120h']) / gold['VWAP_120h']

# --- Oil features ---
oil['VWAP_24h'] = compute_vwap(oil['Close'], oil['Volume'], window=24)
oil['VWAP_120h'] = compute_vwap(oil['Close'], oil['Volume'], window=120)
oil['pct_close'] = oil['Close'].pct_change()
oil['log_volume'] = np.log1p(oil['Volume'])
oil['pct_above_vwap_24h'] = (oil['Close'] - oil['VWAP_24h']) / oil['VWAP_24h']
oil['pct_above_vwap_120h'] = (oil['Close'] - oil['VWAP_120h']) / oil['VWAP_120h']

# --- Align and clean ---
valid_index = valid_index.intersection(gold.dropna().index).intersection(oil.dropna().index)
embed_df = embed_df.loc[valid_index]
gold = gold.loc[valid_index]
oil = oil.loc[valid_index]

# --- Add to embed_df ---
embed_df['gold_close'] = gold['pct_close']
embed_df['gold_volume'] = gold['log_volume']
embed_df['gold_vwap_24h'] = gold['pct_above_vwap_24h']
embed_df['gold_vwap_120h'] = gold['pct_above_vwap_120h']

embed_df['oil_close'] = oil['pct_close']
embed_df['oil_volume'] = oil['log_volume']
embed_df['oil_vwap_24h'] = oil['pct_above_vwap_24h']
embed_df['oil_vwap_120h'] = oil['pct_above_vwap_120h']

# Final NaN cleanup just in case
embed_df = embed_df.dropna()

  gold = gold.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')
  oil = oil.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')


In [7]:
embed_df

Unnamed: 0_level_0,spy_close,spy_open,spy_volume,vix_close,pct_above_vwap_24h,pct_above_vwap_120h,rsi,macd_hist,ten_yr_close,ten_yr_volume,...,two_yr_vwap_24h,yield_spread_10y_2y,gold_close,gold_volume,gold_vwap_24h,gold_vwap_120h,oil_close,oil_volume,oil_vwap_24h,oil_vwap_120h
Datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2022-11-29 18:00:00,-0.000946,-0.000946,0.000000,21.826667,-0.001298,-0.007022,37.257193,-0.399690,-0.006926,0.000000,...,0.003309,-0.761133,-0.000624,0.000000,-0.002462,0.008161,-0.002281,0.000000,0.001233,0.004530
2022-11-29 19:00:00,0.001263,0.001263,9.016634,21.820834,-0.000036,-0.005761,42.185776,0.102505,0.005097,2.397895,...,0.003268,-0.742200,0.001305,7.596894,-0.001142,0.009452,0.000762,6.825460,0.001991,0.005292
2022-11-29 20:00:00,0.000442,0.000378,9.093694,21.815001,0.000416,-0.005321,43.848321,0.595475,0.000534,2.890372,...,0.003239,-0.740266,0.000737,7.906179,-0.000465,0.010156,-0.001396,8.260493,0.000429,0.003879
2022-11-29 21:00:00,0.000063,0.000126,8.145260,21.809167,0.000498,-0.005267,44.095648,0.959736,0.000267,2.397895,...,0.003220,-0.739333,0.000227,7.633370,-0.000292,0.010339,0.004321,7.823246,0.004417,0.008205
2022-11-29 22:00:00,0.000252,0.000252,8.193124,21.803334,0.000761,-0.005045,45.136612,1.269906,0.000267,0.693147,...,0.003210,-0.738400,0.000793,7.366445,0.000443,0.011089,-0.000126,7.469654,0.004108,0.008069
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2025-04-17 05:00:00,0.000842,0.000795,9.610927,31.530001,0.004089,-0.002047,49.869042,6.337219,-0.001159,2.302585,...,-0.012086,0.552304,-0.000749,8.611230,-0.001063,0.025540,-0.000158,6.963190,0.010273,0.034779
2025-04-17 06:00:00,-0.003551,-0.003458,10.190657,31.450001,0.000675,-0.005536,44.099980,4.656170,-0.001856,3.218876,...,-0.012437,0.546478,0.001799,8.761550,0.000507,0.026971,0.001585,7.631432,0.011398,0.036380
2025-04-17 07:00:00,-0.000281,-0.000281,9.840495,30.990000,0.000519,-0.005732,43.670459,3.367148,-0.001627,3.891820,...,-0.012788,0.541652,0.000030,8.869961,0.000393,0.026649,0.001107,7.341484,0.012067,0.037562
2025-04-17 08:00:00,0.000844,0.000844,10.709338,31.080000,0.001467,-0.004806,45.388893,2.766646,-0.002561,5.347108,...,-0.013138,0.532826,0.000090,9.521128,0.000167,0.025992,0.002845,9.003193,0.013982,0.040165


In [8]:
# --- Download FX Data ---
eur_usd = yf.Ticker("EURUSD=X").history(period='730d', interval='1h')
usd_jpy = yf.Ticker("USDJPY=X").history(period='730d', interval='1h')

eur_usd = eur_usd.tz_localize(None)
usd_jpy = usd_jpy.tz_localize(None)

# --- Reindex and interpolate ---
eur_usd = eur_usd.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')
usd_jpy = usd_jpy.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')

# --- FX Proxy Features ---
for df in [eur_usd, usd_jpy]:
    df['pct_close'] = df['Close'].pct_change()
    df['pseudo_volume'] = (df['Close'] - df['Open']) / df['Open']
    df['sma_24h'] = df['Close'].rolling(window=24).mean()
    df['sma_120h'] = df['Close'].rolling(window=120).mean()
    df['pct_above_sma_24h'] = (df['Close'] - df['sma_24h']) / df['sma_24h']
    df['pct_above_sma_120h'] = (df['Close'] - df['sma_120h']) / df['sma_120h']

# --- Clean up and align ---
valid_index = valid_index \
    .intersection(eur_usd.dropna().index) \
    .intersection(usd_jpy.dropna().index)

eur_usd = eur_usd.loc[valid_index]
usd_jpy = usd_jpy.loc[valid_index]
embed_df = embed_df.loc[valid_index]

# --- Add features to embed_df ---
embed_df['eurusd_close'] = eur_usd['pct_close']
embed_df['eurusd_pseudo_volume'] = eur_usd['pseudo_volume']
embed_df['eurusd_sma_24h'] = eur_usd['pct_above_sma_24h']
embed_df['eurusd_sma_120h'] = eur_usd['pct_above_sma_120h']

embed_df['usdjpy_close'] = usd_jpy['pct_close']
embed_df['usdjpy_pseudo_volume'] = usd_jpy['pseudo_volume']
embed_df['usdjpy_sma_24h'] = usd_jpy['pct_above_sma_24h']
embed_df['usdjpy_sma_120h'] = usd_jpy['pct_above_sma_120h']

# Final NaN check
embed_df = embed_df.dropna()





  eur_usd = eur_usd.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')
  usd_jpy = usd_jpy.reindex(spy.index).interpolate(method='time').fillna(method='bfill').fillna(method='ffill')


In [9]:
embed_df.shape

(13582, 32)

In [10]:
class MarketContextWindowReturnDataset(Dataset):
    def __init__(self, embed_df, raw_close_prices, context_length=512, return_horizon=24):
        """
        embed_df: cleaned + normalized features
        raw_close_prices: Series of raw close prices aligned with embed_df
        """
        self.context_length = context_length
        self.return_horizon = return_horizon
        self.features = torch.tensor(embed_df.values, dtype=torch.float32)
        self.prices = torch.tensor(raw_close_prices.reset_index(drop=True).values, dtype=torch.float32)
        self.total_sequences = len(self.features) - context_length - return_horizon

    def __len__(self):
        return self.total_sequences

    def __getitem__(self, idx):
        x = self.features[idx : idx + self.context_length]  # [512, 8]

        # Compute 24h return for every hour in the context window
        current_prices = self.prices[idx : idx + self.context_length]
        future_prices = self.prices[idx + self.return_horizon : idx + self.return_horizon + self.context_length]

        y = (future_prices - current_prices) / current_prices  # [512]
        y = y.unsqueeze(-1)  # [512, 1]

        return x, y

raw_close_prices = spy['Close'].loc[embed_df.index]
dataset = MarketContextWindowReturnDataset(embed_df, raw_close_prices)

x, y = dataset[0]
print(x.shape)  # torch.Size([512, 8])
print(y.shape)  # torch.Size([512, 1])

# Wrap in DataLoader if needed
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True)

torch.Size([512, 32])
torch.Size([512, 1])


In [11]:


class PositionalEncoding(nn.Module):
    def __init__(self, context_length, d_model):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, context_length, d_model))

    def forward(self, x):
        return x + self.pos_embedding
    

class FeedForward(nn.Module):
    def __init__(self, d_model, expansion=4, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * expansion),
            nn.ReLU(),
            nn.Linear(d_model * expansion, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)


class MultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.n_heads = num_heads
        self.h_dim = embed_dim // num_heads

        self.q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v = nn.Linear(embed_dim, embed_dim, bias=False)

        self.proj_out = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        # (B, T, C) -> (B, T, n_heads, h_dim) -> (B, n_heads, T, h_dim)
        q = self.q(x).view(B, T, self.n_heads, self.h_dim).transpose(1, 2)
        k = self.k(x).view(B, T, self.n_heads, self.h_dim).transpose(1, 2)
        v = self.v(x).view(B, T, self.n_heads, self.h_dim).transpose(1, 2)

        x = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # (B, n_heads, T, h_dim) -> (B, T, n_heads, h_dim) -> (B, T, C)
        x = x.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj_out(x)



class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads=2, ff_expansion=4, dropout=0.1):
        super().__init__()
        self.attn = MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.ff = FeedForward(d_model, expansion=ff_expansion, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Self-attention block
        x = x + self.attn(self.norm1(x))

        # Feedforward block
        x = x + self.ff(self.norm2(x))

        return x

class SimpleTransformer(nn.Module):
    def __init__(self, context_length=512, d_model=8, num_layers=4, n_heads=4, ff_expansion=4):
        super().__init__()
        self.pos_encoder = PositionalEncoding(context_length, d_model)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model=d_model, n_heads=n_heads, ff_expansion=ff_expansion)
            for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(d_model, 1)  # Predicting 1 value per timestep (e.g. 24h future return)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model] — [B, 512, 8]
        x = self.pos_encoder(x)
        for layer in self.layers:
            x = layer(x)

        out = self.output_layer(x)  # [B, 512, 1]
        return out

In [12]:
model = SimpleTransformer(context_length=512, d_model=32)
dummy_input = torch.randn(32, 512, 32)  # [batch, seq_len, features]
out = model(dummy_input)
print(out.shape)  # torch.Size([32, 512, 1])

torch.Size([32, 512, 1])


In [13]:


def train_model(model, dataset, epochs=10, batch_size=128, lr=1e-4, weight_decay=1e-4, device='mps'):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print (device)
    model = model.to(device)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            x, y = batch  # x: [B, 512, 8], y: [B, 512, 1]
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            output = model(x)  # [B, 512, 1]
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} — Avg MSE Loss: {avg_loss:.6f}")

In [14]:
model = SimpleTransformer(context_length=512, d_model=32)
num_params = sum([p.numel() for p in model.parameters()])/1.0e3
num_data = embed_df.size/1.0e3
print (f'This model has {num_params}k parameters, vs. {num_data} items of data. The saftey factor is {num_data/num_params:.2f}')

train_model(model, dataset, epochs=30, batch_size=128)

This model has 66.721k parameters, vs. 434.624 items of data. The saftey factor is 6.51
mps


Epoch 1/30: 100%|██████████| 101/101 [00:14<00:00,  6.93it/s]


Epoch 1 — Avg MSE Loss: 68.249312


Epoch 2/30: 100%|██████████| 101/101 [00:14<00:00,  7.18it/s]


Epoch 2 — Avg MSE Loss: 5.045169


Epoch 3/30: 100%|██████████| 101/101 [00:14<00:00,  7.15it/s]


Epoch 3 — Avg MSE Loss: 2.676895


Epoch 4/30: 100%|██████████| 101/101 [00:14<00:00,  7.13it/s]


Epoch 4 — Avg MSE Loss: 1.712132


Epoch 5/30: 100%|██████████| 101/101 [00:14<00:00,  7.19it/s]


Epoch 5 — Avg MSE Loss: 1.179763


Epoch 6/30: 100%|██████████| 101/101 [00:14<00:00,  7.12it/s]


Epoch 6 — Avg MSE Loss: 0.821447


Epoch 7/30: 100%|██████████| 101/101 [00:14<00:00,  7.19it/s]


Epoch 7 — Avg MSE Loss: 0.552434


Epoch 8/30: 100%|██████████| 101/101 [00:14<00:00,  7.16it/s]


Epoch 8 — Avg MSE Loss: 0.365586


Epoch 9/30: 100%|██████████| 101/101 [00:14<00:00,  7.21it/s]


Epoch 9 — Avg MSE Loss: 0.251282


Epoch 10/30: 100%|██████████| 101/101 [00:13<00:00,  7.23it/s]


Epoch 10 — Avg MSE Loss: 0.179174


Epoch 11/30: 100%|██████████| 101/101 [00:14<00:00,  7.19it/s]


Epoch 11 — Avg MSE Loss: 0.135097


Epoch 12/30: 100%|██████████| 101/101 [00:14<00:00,  7.21it/s]


Epoch 12 — Avg MSE Loss: 0.106727


Epoch 13/30: 100%|██████████| 101/101 [00:14<00:00,  7.21it/s]


Epoch 13 — Avg MSE Loss: 0.086131


Epoch 14/30: 100%|██████████| 101/101 [00:14<00:00,  7.20it/s]


Epoch 14 — Avg MSE Loss: 0.070768


Epoch 15/30: 100%|██████████| 101/101 [00:14<00:00,  7.16it/s]


Epoch 15 — Avg MSE Loss: 0.058923


Epoch 16/30: 100%|██████████| 101/101 [00:14<00:00,  7.13it/s]


Epoch 16 — Avg MSE Loss: 0.049925


Epoch 17/30: 100%|██████████| 101/101 [00:14<00:00,  7.15it/s]


Epoch 17 — Avg MSE Loss: 0.042997


Epoch 18/30: 100%|██████████| 101/101 [00:14<00:00,  7.14it/s]


Epoch 18 — Avg MSE Loss: 0.037701


Epoch 19/30: 100%|██████████| 101/101 [00:14<00:00,  7.20it/s]


Epoch 19 — Avg MSE Loss: 0.033549


Epoch 20/30: 100%|██████████| 101/101 [00:14<00:00,  7.19it/s]


Epoch 20 — Avg MSE Loss: 0.030220


Epoch 21/30: 100%|██████████| 101/101 [00:14<00:00,  7.20it/s]


Epoch 21 — Avg MSE Loss: 0.027470


Epoch 22/30: 100%|██████████| 101/101 [00:14<00:00,  7.13it/s]


Epoch 22 — Avg MSE Loss: 0.025242


Epoch 23/30: 100%|██████████| 101/101 [00:14<00:00,  7.17it/s]


Epoch 23 — Avg MSE Loss: 0.023338


Epoch 24/30: 100%|██████████| 101/101 [00:14<00:00,  7.09it/s]


Epoch 24 — Avg MSE Loss: 0.021583


Epoch 25/30: 100%|██████████| 101/101 [00:14<00:00,  7.09it/s]


Epoch 25 — Avg MSE Loss: 0.020064


Epoch 26/30: 100%|██████████| 101/101 [00:14<00:00,  7.09it/s]


Epoch 26 — Avg MSE Loss: 0.018774


Epoch 27/30: 100%|██████████| 101/101 [00:14<00:00,  7.09it/s]


Epoch 27 — Avg MSE Loss: 0.017617


Epoch 28/30: 100%|██████████| 101/101 [00:14<00:00,  7.11it/s]


Epoch 28 — Avg MSE Loss: 0.016577


Epoch 29/30: 100%|██████████| 101/101 [00:14<00:00,  7.09it/s]


Epoch 29 — Avg MSE Loss: 0.015589


Epoch 30/30: 100%|██████████| 101/101 [00:14<00:00,  7.15it/s]

Epoch 30 — Avg MSE Loss: 0.014675





In [15]:
x, y = next(iter(loader))
device = 'mps'
x, y = x.to(device), y.to(device)

x.requires_grad = True
output = model(x)
criterion = nn.MSELoss()

loss = criterion(output, y)
loss.backward()

# x.grad has the gradients w.r.t. each input value
feature_importance = x.grad.abs().mean(dim=[0, 1])  # Mean over batch and time

In [16]:
imps = feature_importance/feature_importance.max()

for column, imp in zip(embed_df.columns, imps.tolist()):
    
    print (f'{column:20s}:\t {imp:.3f}')
    
#feature_importance/feature_importance.max()

spy_close           :	 0.446
spy_open            :	 0.486
spy_volume          :	 0.147
vix_close           :	 0.323
pct_above_vwap_24h  :	 0.290
pct_above_vwap_120h :	 0.507
rsi                 :	 0.162
macd_hist           :	 0.226
ten_yr_close        :	 0.347
ten_yr_volume       :	 0.134
ten_yr_vwap_24h     :	 0.392
ten_yr_vwap_120h    :	 0.822
two_yr_close        :	 0.411
two_yr_volume       :	 0.240
two_yr_vwap_24h     :	 0.635
yield_spread_10y_2y :	 0.512
gold_close          :	 0.492
gold_volume         :	 0.289
gold_vwap_24h       :	 0.740
gold_vwap_120h      :	 0.583
oil_close           :	 0.429
oil_volume          :	 0.160
oil_vwap_24h        :	 1.000
oil_vwap_120h       :	 0.579
eurusd_close        :	 0.717
eurusd_pseudo_volume:	 0.452
eurusd_sma_24h      :	 0.636
eurusd_sma_120h     :	 0.590
usdjpy_close        :	 0.588
usdjpy_pseudo_volume:	 0.645
usdjpy_sma_24h      :	 0.359
usdjpy_sma_120h     :	 0.681


In [17]:
def compute_feature_impact(model, dataset, feature_index, device):
    model.eval()
    model.to(device)

    dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=True)
    criterion = nn.MSELoss()

    total_loss = 0.0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            x[:, :, feature_index] = 0.0  # zero out one feature
            output = model(x)
            loss = criterion(output, y)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

for i in range(16):
    loss = compute_feature_impact(model, dataset, i, device)
    print(f"Feature {embed_df.columns[i]} zeroed out → Loss: {loss:.6f}")

Feature spy_close zeroed out → Loss: 0.010153
Feature spy_open zeroed out → Loss: 0.010153
Feature spy_volume zeroed out → Loss: 0.013096
Feature vix_close zeroed out → Loss: 0.062055
Feature pct_above_vwap_24h zeroed out → Loss: 0.010153
Feature pct_above_vwap_120h zeroed out → Loss: 0.010152
Feature rsi zeroed out → Loss: 0.072310
Feature macd_hist zeroed out → Loss: 0.010433
Feature ten_yr_close zeroed out → Loss: 0.010154
Feature ten_yr_volume zeroed out → Loss: 0.010321
Feature ten_yr_vwap_24h zeroed out → Loss: 0.010154
Feature ten_yr_vwap_120h zeroed out → Loss: 0.010151
Feature two_yr_close zeroed out → Loss: 0.010154
Feature two_yr_volume zeroed out → Loss: 0.010087
Feature two_yr_vwap_24h zeroed out → Loss: 0.010153
Feature yield_spread_10y_2y zeroed out → Loss: 0.010192


In [18]:
embed_df.columns[0]



'spy_close'

In [19]:
two_year

Unnamed: 0_level_0,Open,High,Low,Close,Volume,Dividends,Stock Splits,VWAP_24h,VWAP_120h,pct_close,log_volume,pct_above_vwap_24h,pct_above_vwap_120h
Datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
2022-11-29 18:00:00,4.476133,4.489133,4.476133,4.489133,2.6,0.0,0.0,4.474327,4.484715,0.000030,1.280934,0.003309,0.000985
2022-11-29 19:00:00,4.477200,4.489200,4.477200,4.489200,2.4,0.0,0.0,4.474577,4.484719,0.000015,1.223775,0.003268,0.000999
2022-11-29 20:00:00,4.478267,4.489266,4.478267,4.489266,2.2,0.0,0.0,4.474774,4.484722,0.000015,1.163151,0.003239,0.001013
2022-11-29 21:00:00,4.479333,4.489333,4.479333,4.489333,2.0,0.0,0.0,4.474924,4.484726,0.000015,1.098612,0.003220,0.001027
2022-11-29 22:00:00,4.480400,4.489400,4.480400,4.489400,1.8,0.0,0.0,4.475034,4.484474,0.000015,1.029619,0.003210,0.001098
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2025-04-17 05:00:00,3.758696,3.758696,3.758696,3.758696,0.0,0.0,0.0,3.804681,3.841787,-0.000578,0.000000,-0.012086,-0.021628
2025-04-17 06:00:00,3.756522,3.756522,3.756522,3.756522,0.0,0.0,0.0,3.803830,3.841597,-0.000578,0.000000,-0.012437,-0.022146
2025-04-17 07:00:00,3.754348,3.754348,3.754348,3.754348,0.0,0.0,0.0,3.802979,3.841411,-0.000579,0.000000,-0.012788,-0.022664
2025-04-17 08:00:00,3.752174,3.752174,3.752174,3.752174,0.0,0.0,0.0,3.802128,3.841229,-0.000579,0.000000,-0.013138,-0.023184


In [20]:
oil = yf.Ticker("EURUSD=X").history(period='730d', interval='1h')
oil
oil = yf.Ticker("USDJPY=X").history(period='730d', interval='1h')
oil

Unnamed: 0_level_0,Open,High,Low,Close,Volume,Dividends,Stock Splits
Datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2022-07-04 00:00:00+01:00,135.141006,135.255005,134.960007,135.001999,0,0.0,0.0
2022-07-04 01:00:00+01:00,135.003006,135.108994,134.850006,134.916000,0,0.0,0.0
2022-07-04 02:00:00+01:00,134.919006,135.136993,134.759995,135.000000,0,0.0,0.0
2022-07-04 03:00:00+01:00,135.011002,135.072998,134.899994,135.009995,0,0.0,0.0
2022-07-04 04:00:00+01:00,135.009995,135.179993,134.979996,135.169006,0,0.0,0.0
...,...,...,...,...,...,...,...
2025-04-18 14:00:00+01:00,142.324997,142.356995,142.283005,142.347000,0,0.0,0.0
2025-04-18 15:00:00+01:00,142.343994,142.356003,142.227005,142.289993,0,0.0,0.0
2025-04-18 16:00:00+01:00,142.289001,142.306000,142.195007,142.246002,0,0.0,0.0
2025-04-18 17:00:00+01:00,142.238007,142.251007,142.104996,142.203995,0,0.0,0.0
