# Deep Trading Model — PatchTST / TCN (PyTorch • Apple Silicon MPS) - Personal
**Date:** 2025-09-10

End-to-end **daily swing** pipeline with **deep models**: **PatchTST Transformer** or **TCN**. Includes:
- Heavier **feature set** (momentum/vol/trend, VWAP-like proxies, cross-sectional ranks, rolling betas/correlations, Garman–Klass & Yang–Zhang vol).
- **Triple-barrier meta-label** (trade/no-trade) + **return regression** head.
- **Masked-patch pretraining** (optional) to soak compute, then supervised fine-tune.
- **Gradient checkpointing**, **AMP**, **EMA**, **Cosine LR**, **torch.compile** (safe try).
- **Cost-aware backtest** with your risk rules (prob ≥ 0.55, top-K, ATR TP/SL, –2% day, –6% week pause).

> Default data via Yahoo Finance (EOD) for the ETF universe


In [2]:
# Setup
import torch, platform, os, json, random, numpy as np
print('Python:', platform.python_version())
print('PyTorch:', torch.__version__)
has_mps = torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_cuda = torch.cuda.is_available()
device = torch.device('mps' if has_mps else ('cuda' if has_cuda else 'cpu'))
print('Device:', device)
try:
    torch.set_float32_matmul_precision('medium')
except Exception as e:
    print('matmul precision setting skipped:', e)

import warnings; warnings.filterwarnings('ignore')
np.set_printoptions(suppress=True, linewidth=120)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(42)

SAVE_DIR = 'artifacts_pro'
os.makedirs(SAVE_DIR, exist_ok=True)


Python: 3.12.7
PyTorch: 2.8.0+cu128
Device: cpu


In [3]:
CONFIG = {
    # Data
    'UNIVERSE': ['SPY','QQQ','IWM','DIA','XLK','XLF','XLE','XLV','XLY','XLP','XLI','XLB','XLRE','XLU','GLD','TLT','HYG'],
    'START_DATE': '2010-01-01',
    'VAL_START':  '2023-01-01',
    'TEST_START': '2024-01-01',
    'END_DATE':   None,
    'DATA_SOURCE': 'yfinance',         # 'yfinance' or 'local'
    'LOCAL_PATH': 'TODO/path/to/csvs', # if local

    # Labeling / trading
    'RET_HORIZON': 5,           # k-day forward regression label
    'TB_ATR_TP':   2.5,         # triple-barrier TP multiple of ATR%
    'TB_ATR_SL':   1.75,        # triple-barrier SL multiple of ATR%
    'TB_TIMEOUT':  5,           # days

    # Features
    'RSI_PERIODS': [3,14],
    'MA_WINDOWS':  [20,50,200],
    'ATR_WINDOW':  14,
    'RET_LAGS':    [1,3,5,10,20],
    'BETA_WINDOW': 60,
    'CORR_WINDOW': 60,

    # Sequence (for deep models)
    'SEQ_LEN':     384, # total timesteps per sample
    'PATCH_LEN':   16, # for PatchTST
    'PATCH_STRIDE': 8,

    # Model choice
    'MODEL_TYPE':  'patchtst',  # 'patchtst' or 'tcn'

    # PatchTST
    'DMODEL':      320,
    'N_HEADS':     8,
    'N_LAYERS':    6,
    'DROPOUT':     0.15,

    # TCN
    'TCN_CHANNELS': [128,128,256,256,256,256,256,256],
    'TCN_KERNEL':   7,
    'TCN_DROPOUT':  0.1,

    # Train
    'BATCH_SIZE':  128,
    'LR':          1e-3,
    'WEIGHT_DECAY':1e-4,
    'EPOCHS':      40,
    'PATIENCE':    8,
    'GRAD_CLIP':   1.0,
    'ACCUM_STEPS': 1, # gradient accumulation

    # Pretraining
    'DO_PRETRAIN': True,
    'PRE_EPOCHS':  15,
    'MASK_PROB':   0.3, # fraction of patches masked for reconstruction

    # Loss mixing
    'LOSS_META_W': 0.6,
    'LOSS_RET_W':  0.4,

    # Backtest rules
    'PROB_CUTOFF': 0.55,
    'TOP_K':       5,
    'HOLD_MAX_DAYS': 5,
    'ATR_MULT_STOP': 1.75,
    'ATR_MULT_TP':   2.5,
    'START_EQUITY':  25_000.0,
    'MAX_GROSS_LEV': 1.5,
    'PER_NAME_CAP':  0.125,
    'MAX_NET_EXP':   0.80,
    'DAILY_LOSS_CAP': -0.02,
    'WEEKLY_LOSS_CAP': -0.06,
    'PAUSE_DAYS_ON_WEEKLY_BREACH': 3,
    'SLIPPAGE_BPS': 3,
    'FEE_BPS_ROUNDTRIP': 1,
    'COMMISSION_PER_SHARE': 0.0035,
    'COMMISSION_MIN': 0.35,
}
print(json.dumps({k:CONFIG[k] for k in ['MODEL_TYPE','SEQ_LEN','PATCH_LEN','PATCH_STRIDE','DMODEL','N_LAYERS','BATCH_SIZE','EPOCHS','DO_PRETRAIN']}, indent=2))


{
  "MODEL_TYPE": "patchtst",
  "SEQ_LEN": 384,
  "PATCH_LEN": 16,
  "PATCH_STRIDE": 8,
  "DMODEL": 320,
  "N_LAYERS": 6,
  "BATCH_SIZE": 128,
  "EPOCHS": 40,
  "DO_PRETRAIN": true
}


In [4]:
# Imports for data & modeling
import math, time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, mean_absolute_error, mean_squared_error, accuracy_score, f1_score
from sklearn.calibration import calibration_curve

try:
    import yfinance as yf
except Exception as e:
    print('Install yfinance if using DATA_SOURCE=yfinance: pip install yfinance')

from collections import defaultdict, deque
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
import torch.utils.checkpoint as cp


In [5]:
# Data Ingest
def load_data_yf(symbols, start, end=None):
    df = yf.download(symbols, start=start, end=end, auto_adjust=True, progress=False, group_by='ticker')
    frames = []
    for s in symbols:
        sub = df[s].copy()
        sub.columns = [c.capitalize() for c in sub.columns]  # Open, High, Low, Close, Volume
        sub['Symbol'] = s
        frames.append(sub.reset_index().set_index(['Date','Symbol']))
    out = pd.concat(frames).sort_index()
    return out

def load_data_local(folder, symbols):
    import os
    frames = []
    for s in symbols:
        path = os.path.join(folder, f"{s}.csv")
        sub = pd.read_csv(path, parse_dates=['Date'])
        sub['Symbol'] = s
        sub = sub[['Date','Symbol','Open','High','Low','Close','Volume']]
        frames.append(sub.set_index(['Date','Symbol']))
    out = pd.concat(frames).sort_index()
    return out

if CONFIG['DATA_SOURCE']=='yfinance':
    ohlcv = load_data_yf(CONFIG['UNIVERSE'], CONFIG['START_DATE'], CONFIG['END_DATE'])
else:
    ohlcv = load_data_local(CONFIG['LOCAL_PATH'], CONFIG['UNIVERSE'])
print('OHLCV shape:', ohlcv.shape); display(ohlcv.tail(3))


OHLCV shape: (67269, 5)


Unnamed: 0_level_0,Unnamed: 1_level_0,Open,High,Low,Close,Volume
Date,Symbol,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2025-09-25,XLU,86.169998,86.349998,85.419998,85.419998,8971346.0
2025-09-25,XLV,136.125,136.179993,133.725006,134.130005,14085817.0
2025-09-25,XLY,238.110001,238.195007,235.160004,236.220001,5717212.0


In [12]:
# Feature Engineering (heavy but leakage-safe)
import numpy as np
import pandas as pd

def rsi(series, window=14):
    delta = series.diff()
    up = delta.clip(lower=0.0); dn = (-delta).clip(lower=0.0)
    ru = up.rolling(window).mean(); rd = dn.rolling(window).mean()
    rs = ru / (rd + 1e-12)
    return 100 - 100/(1+rs)

def atr(high, low, close, window=14):
    prev_close = close.shift(1)
    tr = pd.concat([high-low, (high-prev_close).abs(), (low-prev_close).abs()], axis=1).max(axis=1)
    return tr.rolling(window).mean()

def garman_klass(o,h,l,c, window=20):
    rs = 0.5*(np.log(h/l))**2 - (2*np.log(2)-1)*(np.log(c/o))**2
    return pd.Series(rs, index=c.index).rolling(window).sum().pow(0.5)

def yang_zhang(o,h,l,c, window=20):
    k = 0.34/(1.34+(window+1)/(window-1))
    oc = np.log(o.shift(1)/c.shift(1))
    co = np.log(c/o)
    hl = np.log(h/l)
    rs = oc.rolling(window).var() + k*hl.rolling(window).mean() + (1-k)*co.rolling(window).var()
    return rs.pow(0.5)

def realized_vol(ret, window=10):
    return ret.rolling(window).std()*np.sqrt(252)

def build_features(df, cfg):
    # Expect MultiIndex index: ['Date','Symbol'] and columns: Open, High, Low, Close
    x = df.copy().sort_index()

    # 1) Simple returns and lags
    x['Ret1'] = x['Close'].groupby(level='Symbol').pct_change()
    for L in cfg['RET_LAGS']:
        x[f'Ret_{L}'] = x['Close'].groupby(level='Symbol').pct_change(L)

    # 2) RSI family
    for R in cfg['RSI_PERIODS']:
        x[f'RSI_{R}'] = (
            x.groupby(level='Symbol')['Close']
              .transform(lambda s: rsi(s, R))
        )

    # 3) ATR (percent-like raw ATR — divide by Close if desired)
    x['ATR'] = (
        x.groupby(level='Symbol')
          .apply(lambda g: atr(g['High'], g['Low'], g['Close'], cfg['ATR_WINDOW']))
          .droplevel(0)  # drop Symbol level added by groupby-apply
    )

    # 4) Moving-average distances & slopes
    for W in cfg['MA_WINDOWS']:
        ma = x.groupby(level='Symbol')['Close'].transform(lambda s: s.rolling(W).mean())
        x[f'DistMA_{W}']  = (x['Close'] - ma) / (ma + 1e-12)
        x[f'SlopeMA_{W}'] = ma.groupby(level='Symbol').diff()

    # 5) GK & YZ volatility (window=20)
    x['GK20'] = (
        x.groupby(level='Symbol')
         .apply(lambda g: garman_klass(g['Open'], g['High'], g['Low'], g['Close'], 20))
         .reset_index(level=0, drop=True)
    )
    x['YZ20'] = (
        x.groupby(level='Symbol')
         .apply(lambda g: yang_zhang(g['Open'], g['High'], g['Low'], g['Close'], 20))
         .reset_index(level=0, drop=True)
    )

    # 6) Realized vol on Ret1
    x['RV10'] = x.groupby(level='Symbol')['Ret1'].transform(lambda s: realized_vol(s, 10))

    # 7) Day of week dummies
    xi = x.reset_index()
    xi['DOW'] = xi['Date'].dt.dayofweek
    dummies = pd.get_dummies(xi['DOW'], prefix='DOW', drop_first=True)
    xi = pd.concat([xi.drop(columns=['DOW']), dummies], axis=1).set_index(['Date','Symbol']).sort_index()
    x = xi

    # 8) Market regime from SPY (join on Date to all symbols)
    spy = x.xs('SPY', level='Symbol', drop_level=False).copy()
    spy_ma50 = spy['Close'].rolling(50).mean()
    regime = (spy['Close'] >= spy_ma50).astype(int).rename('SPY_Regime')
    # regime index is (Date, Symbol='SPY'); reduce to Date, then join
    regime_by_date = regime.reset_index('Symbol', drop=True)
    x = x.join(regime_by_date, on='Date')

    # 9) Rolling beta & corr to SPY (cross-sectional, per date)
    close = x['Close'].unstack('Symbol')
    ret = close.pct_change()
    spy_ret = ret['SPY']
    def roll_corr_beta(s, ref, win):
        cov = s.rolling(win).cov(ref)
        var = ref.rolling(win).var()
        beta = cov/(var + 1e-12)
        corr = s.rolling(win).corr(ref)
        return beta, corr

    betas = pd.DataFrame(index=ret.index, columns=ret.columns, dtype=float)
    corrs = pd.DataFrame(index=ret.index, columns=ret.columns, dtype=float)
    for sym in ret.columns:
        b, c = roll_corr_beta(ret[sym], spy_ret, cfg['BETA_WINDOW'])
        betas[sym] = b; corrs[sym] = c

    beta_long = betas.stack().rename('BetaSPY')
    corr_long = corrs.stack().rename('CorrSPY')
    x = x.join(beta_long, how='left').join(corr_long, how='left')

    # 10) Cross-sectional ranks per day (ensure the columns exist given cfg)
    def add_cs_ranks(frame, cols):
        frame = frame.copy()
        for col in cols:
            if col not in frame.columns:
                continue
            pivot = frame[col].unstack('Symbol')
            ranks = pivot.rank(axis=1, pct=True)
            frame[f'{col}_Rank'] = ranks.stack()
        return frame

    # Example rank set — adjust to your cfg contents
    rank_cols = []
    if any(L == 10 for L in cfg['RET_LAGS']): rank_cols.append('Ret_10')
    rank_cols += ['RV10']
    if any(W == 20 for W in cfg['MA_WINDOWS']): rank_cols.append('DistMA_20')
    x = add_cs_ranks(x, rank_cols)

    # 11) Final clean
    x = x.dropna().sort_index()
    return x

# Example usage (unchanged)
features = build_features(ohlcv, CONFIG)
print('Features shape:', features.shape)
display(features.tail(3))


Features shape: (62418, 33)


Unnamed: 0_level_0,Unnamed: 1_level_0,Open,High,Low,Close,Volume,Ret1,Ret_1,Ret_3,Ret_5,Ret_10,...,DOW_1,DOW_2,DOW_3,DOW_4,SPY_Regime,BetaSPY,CorrSPY,Ret_10_Rank,RV10_Rank,DistMA_20_Rank
Date,Symbol,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
2025-09-25,XLU,86.169998,86.349998,85.419998,85.419998,8971346.0,-0.009623,-0.009623,0.002582,0.018777,0.010873,...,False,False,True,False,1,0.300225,0.215991,0.764706,0.823529,0.823529
2025-09-25,XLV,136.125,136.179993,133.725006,134.130005,14085817.0,-0.01657,-0.01657,-0.018298,-0.019834,-0.035618,...,False,False,True,False,1,0.504096,0.296005,0.117647,0.588235,0.176471
2025-09-25,XLY,238.110001,238.195007,235.160004,236.220001,5717212.0,-0.014107,-0.014107,-0.016488,-0.015982,-0.003215,...,False,False,True,False,1,1.261185,0.739777,0.529412,0.705882,0.352941


In [17]:
def triple_barrier_labels(df, cfg):
    """
    df: MultiIndex [Date, Symbol] with columns at least: Open, High, Low, Close.
        Optionally has ATR or ATR_Pct.
    cfg: expects keys TB_ATR_TP, TB_ATR_SL, TB_TIMEOUT, RET_HORIZON, ATR_WINDOW
    """
    import pandas as pd
    import numpy as np

    # Ensure sorted for forward slices
    df = df.sort_index()

    # Ensure ATR_Pct exists
    if 'ATR_Pct' not in df.columns:
        if 'ATR' in df.columns:
            atr_pct = df['ATR'] / (df['Close'] + 1e-12)
        else:
            # compute ATR with window from cfg (fallback to 14)
            win = cfg.get('ATR_WINDOW', 14)
            atr_series = (
                df.groupby(level=1)
                  .apply(lambda g: atr(g['High'], g['Low'], g['Close'], win))
                  .droplevel(0)
            )
            atr_pct = atr_series / (df['Close'] + 1e-12)
        df = df.assign(ATR_Pct=atr_pct)

    # k-day forward log return target
    close = df['Close']
    fwd_close = close.groupby(level=1).shift(-cfg['RET_HORIZON'])
    y_ret = np.log(fwd_close / (close + 1e-12)).rename('y_ret')

    tp_mult, sl_mult, timeout = cfg['TB_ATR_TP'], cfg['TB_ATR_SL'], cfg['TB_TIMEOUT']

    def per_symbol(g):
        res = pd.Series(index=g.index, dtype=float)
        atrp = g['ATR_Pct']
        # iterate forward to check hits within timeout window
        for i, (dt, row) in enumerate(g.iterrows()):
            entry = row['Close']
            atr = atrp.iloc[i]
            if pd.isna(entry) or pd.isna(atr):
                res.iloc[i] = np.nan
                continue

            tp = entry * (1.0 + tp_mult * atr)
            sl = entry * (1.0 - sl_mult * atr)

            fut = g.iloc[i+1 : i+1+timeout]
            outcome = np.nan
            for _, r in fut.iterrows():
                if r['Low'] <= sl:
                    outcome = 0.0
                    break
                if r['High'] >= tp:
                    outcome = 1.0
                    break

            if np.isnan(outcome):
                if len(fut) == 0:
                    outcome = np.nan
                else:
                    outcome = 1.0 if fut.iloc[-1]['Close'] > entry else 0.0

            res.iloc[i] = outcome
        return res

    # BUGFIX: use df here (not outer 'features')
    y_tb = df.groupby(level=1, group_keys=False).apply(per_symbol).rename('y_meta')
    return y_tb, y_ret

# Re-run labels
y_meta, y_ret = triple_barrier_labels(features, CONFIG)
data = features.join([y_meta, y_ret], how='left').dropna().copy()
print('Labeled rows:', len(data))
display(data[['y_meta','y_ret']].head(3))


Labeled rows: 62333


Unnamed: 0_level_0,Unnamed: 1_level_0,y_meta,y_ret
Date,Symbol,Unnamed: 2_level_1,Unnamed: 3_level_1
2010-10-19,DIA,1.0,0.017329
2010-10-19,GLD,1.0,0.005901
2010-10-19,HYG,1.0,0.011526


In [None]:
# Time-based splits & scaling
def time_split(df_index, val_start, test_start):
    dates = df_index.get_level_values(0)
    train_idx = dates < pd.to_datetime(val_start)
    val_idx   = (dates >= pd.to_datetime(val_start)) & (dates < pd.to_datetime(test_start))
    test_idx  = dates >= pd.to_datetime(test_start)
    return train_idx, val_idx, test_idx

train_m, val_m, test_m = time_split(data.index, CONFIG['VAL_START'], CONFIG['TEST_START'])

Y_cols = ['y_meta','y_ret']
X_cols = [c for c in data.columns if c not in Y_cols]
X = data[X_cols].astype(float)
Ym = data['y_meta'].astype(int)
Yr = data['y_ret'].astype(float)

X_train, Ym_train, Yr_train = X[train_m], Ym[train_m], Yr[train_m]
X_val,   Ym_val,   Yr_val   = X[val_m],   Ym[val_m],   Yr[val_m]
X_test,  Ym_test,  Yr_test  = X[test_m],  Ym[test_m],  Yr[test_m]

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_s = scaler.fit_transform(X_train)
X_val_s   = scaler.transform(X_val)
X_test_s  = scaler.transform(X_test)

print('Shapes:')
print('Train:', X_train_s.shape, 'Val:', X_val_s.shape, 'Test:', X_test_s.shape)


In [None]:
# Sequence Dataset (sliding windows over each symbol)
from typing import List, Tuple

def build_panel_arrays(X_df, Ym_s, Yr_s, seq_len):
    # Returns dict per symbol: {'X': np.array[T,F], 'Ym': np.array[T], 'Yr': np.array[T]}
    panel = {}
    for sym, gX in X_df.groupby(level=1):
        gX = gX.droplevel(1)
        idx = gX.index
        gYm = Ym_s.loc[(idx, sym)].values if isinstance(Ym_s.index, pd.MultiIndex) else Ym_s.loc[idx].values
        gYr = Yr_s.loc[(idx, sym)].values if isinstance(Yr_s.index, pd.MultiIndex) else Yr_s.loc[idx].values
        panel[sym] = {'X': gX.values, 'Ym': gYm, 'Yr': gYr, 'dates': idx}
    return panel

def make_sequences(panel, seq_len):
    seqs = []
    for sym, d in panel.items():
        Xv, Ymv, Yrv, dates = d['X'], d['Ym'], d['Yr'], d['dates']
        T = len(Xv)
        if T <= seq_len: continue
        for t0 in range(0, T - seq_len):
            t1 = t0 + seq_len
            x = Xv[t0:t1]
            ym = Ymv[t1-1]   # predict at the last step
            yr = Yrv[t1-1]
            dt = dates[t1-1]
            seqs.append((sym, dt, x, ym, yr))
    return seqs

train_panel = build_panel_arrays(X.loc[train_m], Ym, Yr, CONFIG['SEQ_LEN'])
val_panel   = build_panel_arrays(X.loc[val_m], Ym, Yr, CONFIG['SEQ_LEN'])
test_panel  = build_panel_arrays(X.loc[test_m], Ym, Yr, CONFIG['SEQ_LEN'])

train_seqs = make_sequences(train_panel, CONFIG['SEQ_LEN'])
val_seqs   = make_sequences(val_panel,   CONFIG['SEQ_LEN'])
test_seqs  = make_sequences(test_panel,  CONFIG['SEQ_LEN'])

print('Num sequences:', len(train_seqs), len(val_seqs), len(test_seqs))

class SeqDataset(Dataset):
    def __init__(self, seqs, scaler: StandardScaler):
        self.seqs = seqs
        self.scaler = scaler
    def __len__(self): return len(self.seqs)
    def __getitem__(self, i):
        sym, dt, x, ym, yr = self.seqs[i]
        x_s = self.scaler.transform(x) if x.ndim==2 else x  # (L,F)
        return torch.tensor(x_s, dtype=torch.float32), torch.tensor([ym], dtype=torch.float32), torch.tensor([yr], dtype=torch.float32), sym, dt

train_ds = SeqDataset(train_seqs, scaler)
val_ds   = SeqDataset(val_seqs,   scaler)
test_ds  = SeqDataset(test_seqs,  scaler)

from torch.utils.data import DataLoader
train_loader = DataLoader(train_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=4, pin_memory=(device.type!='cpu'), persistent_workers=False)
val_loader   = DataLoader(val_ds,   batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=(device.type!='cpu'))
test_loader  = DataLoader(test_ds,  batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=(device.type!='cpu'))


In [None]:
# Models: PatchTST & TCN (multi-task heads)
import torch.nn as nn
import torch.utils.checkpoint as cp
import torch

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, p=0.1, ff_mult=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=p, batch_first=True)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_mult*d_model),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(ff_mult*d_model, d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(p)
    def forward(self, x):
        def attn_forward(x):
            out,_ = self.attn(x, x, x, need_weights=False)
            return out
        a = cp.checkpoint(attn_forward, x)
        x = self.ln1(x + self.drop(a))
        f = self.ff(x)
        x = self.ln2(x + self.drop(f))
        return x

class PatchTST(nn.Module):
    def __init__(self, in_dim, cfg):
        super().__init__()
        self.patch_len = cfg['PATCH_LEN']; self.stride = cfg['PATCH_STRIDE']
        self.d_model = cfg['DMODEL']; self.n_layers = cfg['N_LAYERS']; self.n_heads = cfg['N_HEADS']
        self.proj = nn.Linear(in_dim*self.patch_len, self.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(self.d_model, self.n_heads, p=cfg['DROPOUT']) for _ in range(self.n_layers)])
        self.norm = nn.LayerNorm(self.d_model)
        self.head_meta = nn.Sequential(nn.Linear(self.d_model, 1))
        self.head_ret  = nn.Sequential(nn.Linear(self.d_model, 1))
        self.recon = nn.Linear(self.d_model, in_dim*self.patch_len)

    def patchify(self, x):  # x: (B, L, F)
        B,L,F = x.shape
        n = 1 + (L - self.patch_len)//self.stride
        patches = []
        for i in range(n):
            s = i*self.stride
            patches.append(x[:, s:s+self.patch_len, :].reshape(B, -1))
        return torch.stack(patches, dim=1)  # (B, n_patches, F*patch_len)

    def forward(self, x, pretrain_mask=None):
        z = self.patchify(x)
        z = self.proj(z)
        for blk in self.blocks:
            z = blk(z)
        h = self.norm(z)[:, -1]
        meta_logit = self.head_meta(h).squeeze(-1)
        ret = self.head_ret(h).squeeze(-1)

        recon_loss = None
        if pretrain_mask is not None:
            B, N, D = z.shape
            masked_tokens = z[pretrain_mask]
            pred_vec = self.recon(masked_tokens)
            patches = self.patchify(x).detach()
            target_vec = patches[pretrain_mask]
            recon_loss = nn.functional.mse_loss(pred_vec, target_vec)
        return meta_logit, ret, recon_loss

class TCNBlock(nn.Module):
    def __init__(self, c_in, c_out, k=7, d=1, p=0.1):
        super().__init__()
        pad = (k-1)*d
        self.net = nn.Sequential(
            nn.utils.weight_norm(nn.Conv1d(c_in, c_out, k, padding=pad, dilation=d)),
            nn.ReLU(), nn.Dropout(p),
            nn.utils.weight_norm(nn.Conv1d(c_out, c_out, k, padding=pad, dilation=d)),
            nn.ReLU(), nn.Dropout(p),
        )
        self.res = nn.Conv1d(c_in, c_out, 1) if c_in!=c_out else nn.Identity()
    def forward(self, x):
        def fn(inp): return self.net(inp)
        out = cp.checkpoint(fn, x)
        if out.size(-1) != x.size(-1):
            out = out[..., :x.size(-1)]
        return out + self.res(x)

class TCN(nn.Module):
    def __init__(self, in_dim, cfg):
        super().__init__()
        chans = cfg['TCN_CHANNELS']; k = cfg['TCN_KERNEL']; p = cfg['TCN_DROPOUT']
        layers=[]; c=in_dim
        for i,ch in enumerate(chans):
            layers += [TCNBlock(c, ch, k=k, d=2**i, p=p)]
            c = ch
        self.tcn = nn.Sequential(*layers)
        self.head_meta = nn.Linear(c, 1)
        self.head_ret  = nn.Linear(c, 1)
    def forward(self, x, pretrain_mask=None):
        z = x.transpose(1,2)
        z = self.tcn(z)
        h = z[:,:,-1]
        meta_logit = self.head_meta(h).squeeze(-1)
        ret = self.head_ret(h).squeeze(-1)
        return meta_logit, ret, None


In [None]:
# Train / Validate with AMP, checkpointing, EMA, early stopping
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
from torch.optim.swa_utils import AveragedModel
import numpy as np
import torch

def make_model(in_dim, cfg):
    if cfg['MODEL_TYPE']=='patchtst':
        model = PatchTST(in_dim, cfg)
    else:
        model = TCN(in_dim, cfg)
    model = model.to(device)
    try:
        model = torch.compile(model)
        print('torch.compile enabled')
    except Exception as e:
        print('torch.compile skipped:', e)
    return model

def run_epoch(model, loader, optimizers=None, cfg=None, pretrain=False):
    is_train = optimizers is not None
    model.train(mode=is_train)

    bce = nn.BCEWithLogitsLoss()
    huber = nn.HuberLoss(delta=0.01)
    total_loss = 0.0
    all_meta, all_meta_p = [], []
    all_ret, all_ret_p = [], []

    for step, (xb, ymb, yrb, _, _) in enumerate(loader):
        xb = xb.to(device)    # (B, L, F)
        ymb = ymb.to(device).squeeze(-1).float()
        yrb = yrb.to(device).squeeze(-1).float()

        pre_mask = None
        if pretrain and isinstance(model, PatchTST):
            with torch.no_grad():
                B,L,F = xb.shape
                N = 1 + (L - CONFIG['PATCH_LEN'])//CONFIG['PATCH_STRIDE']
                mask = torch.rand((B,N), device=device) < CONFIG['MASK_PROB']
                pre_mask = mask

        with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=(device.type in ['cuda','mps'])):
            meta_logit, ret_pred, recon_loss = model(xb, pretrain_mask=(pre_mask if pretrain else None))
            if pretrain and recon_loss is not None:
                loss = recon_loss
            else:
                loss_meta = bce(meta_logit, ymb)
                loss_ret  = huber(ret_pred, yrb)
                loss = CONFIG['LOSS_META_W']*loss_meta + CONFIG['LOSS_RET_W']*loss_ret

        if is_train:
            optimizers['opt'].zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), CONFIG['GRAD_CLIP'])
            optimizers['opt'].step()
            if 'sched' in optimizers and optimizers['sched'] is not None:
                optimizers['sched'].step()

        total_loss += loss.item()*xb.size(0)

        if not pretrain:
            all_meta.append(ymb.detach().cpu().numpy())
            all_meta_p.append(torch.sigmoid(meta_logit).detach().cpu().numpy())
            all_ret.append(yrb.detach().cpu().numpy())
            all_ret_p.append(ret_pred.detach().cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    out = {'loss': avg_loss}
    if not pretrain:
        ym = np.concatenate(all_meta) if len(all_meta)>0 else np.array([])
        yp = np.concatenate(all_meta_p) if len(all_meta_p)>0 else np.array([])
        yr = np.concatenate(all_ret) if len(all_ret)>0 else np.array([])
        yrp= np.concatenate(all_ret_p) if len(all_ret_p)>0 else np.array([])
        if len(ym)>0 and len(np.unique(ym))==2:
            try: auc = roc_auc_score(ym, yp)
            except: auc = float('nan')
        else:
            auc = float('nan')
        mae = mean_absolute_error(yr, yrp) if len(yr)>0 else float('nan')
        rmse = (mean_squared_error(yr, yrp)**0.5) if len(yr)>0 else float('nan')
        out.update({'AUC': auc, 'MAE': mae, 'RMSE': rmse})
    return out

in_dim = X_train.shape[1]
model = make_model(in_dim, CONFIG)
opt = optim.AdamW(model.parameters(), lr=CONFIG['LR'], weight_decay=CONFIG['WEIGHT_DECAY'])
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CONFIG['EPOCHS'])
ema = AveragedModel(model)

# Optional pretraining
if CONFIG['DO_PRETRAIN'] and isinstance(model, PatchTST):
    print('Pretraining (masked patches):', CONFIG['PRE_EPOCHS'], 'epochs')
    for ep in range(1, CONFIG['PRE_EPOCHS']+1):
        tr = run_epoch(model, train_loader, optimizers={'opt':opt, 'sched':None}, cfg=CONFIG, pretrain=True)
        vl = run_epoch(model, val_loader, optimizers=None, cfg=CONFIG, pretrain=True)
        print(f'[PRE] Ep {ep:02d} | train_loss={tr["loss"]:.4f} | val_loss={vl["loss"]:.4f}')

best_val, best_state, patience = -np.inf, None, CONFIG['PATIENCE']
history = defaultdict(list)

# Supervised fine-tune
for ep in range(1, CONFIG['EPOCHS']+1):
    tr = run_epoch(model, train_loader, optimizers={'opt':opt, 'sched':sched}, cfg=CONFIG, pretrain=False)
    vl = run_epoch(model, val_loader, optimizers=None, cfg=CONFIG, pretrain=False)
    ema.update_parameters(model)
    history['train_loss'].append(tr['loss']); history['val_loss'].append(vl['loss'])
    history['train_auc'].append(tr.get('AUC', np.nan)); history['val_auc'].append(vl.get('AUC', np.nan))
    print(f'Ep {ep:02d} | train_loss={tr["loss"]:.4f} AUC={tr.get("AUC",np.nan):.4f} | val_loss={vl["loss"]:.4f} AUC={vl.get("AUC",np.nan):.4f}')
    score = vl.get('AUC', -np.inf)
    if score > best_val + 1e-4:
        best_val = score; best_state = model.state_dict(); patience = CONFIG['PATIENCE']
        torch.save(best_state, os.path.join(SAVE_DIR, 'best_pro.pt'))
    else:
        patience -= 1
        if patience <= 0:
            print('Early stopping.'); break

if best_state is not None:
    model.load_state_dict(best_state)


In [None]:
# Plot training curves
plt.figure(); plt.plot(history['train_loss'], label='train'); plt.plot(history['val_loss'], label='val'); plt.title('Loss'); plt.legend(); plt.show()
plt.figure(); plt.plot(history['train_auc'], label='train'); plt.plot(history['val_auc'], label='val'); plt.title('AUC'); plt.legend(); plt.show()


In [None]:
# Test evaluation
def predict_loader(model, loader):
    model.eval()
    all_meta_p, all_ret_p, metas, rets, idx = [], [], [], [], []
    with torch.no_grad():
        for xb, ymb, yrb, sym, dt in loader:
            xb = xb.to(device)
            ml, rp, _ = model(xb, pretrain_mask=None)
            mp = torch.sigmoid(ml).cpu().numpy()
            all_meta_p.append(mp); all_ret_p.append(rp.cpu().numpy())
            metas.append(ymb.numpy().squeeze()); rets.append(yrb.numpy().squeeze())
            for s,d in zip(sym, dt):
                idx.append((pd.Timestamp(d), s))
    meta_p = np.concatenate(all_meta_p); ret_p = np.concatenate(all_ret_p)
    meta = np.concatenate(metas); ret = np.concatenate(rets)
    idx = pd.MultiIndex.from_tuples(idx, names=['Date','Symbol'])
    return pd.DataFrame({'p_trade': meta_p, 'y_meta': meta, 'y_ret': ret, 'ret_pred': ret_p}, index=idx)

df_test_pred = predict_loader(model, test_loader)
try:
    auc_test = roc_auc_score(df_test_pred['y_meta'], df_test_pred['p_trade'])
except Exception:
    auc_test = float('nan')
mae_test = mean_absolute_error(df_test_pred['y_ret'], df_test_pred['ret_pred'])
rmse_test = mean_squared_error(df_test_pred['y_ret'], df_test_pred['ret_pred'])**0.5
print({'AUC_test': round(auc_test,4), 'MAE_test': round(mae_test,6), 'RMSE_test': round(rmse_test,6)})
display(df_test_pred.head(3))


In [None]:
# Backtest (uses p_trade and ATR TP/SL)
def daily_vol(series, window=10):
    return series.pct_change().rolling(window).std()

def select_positions(prob_df, date, top_k=CONFIG['TOP_K'], cutoff=CONFIG['PROB_CUTOFF']):
    if date not in prob_df.index.get_level_values(0): return []
    day = prob_df.xs(date, level=0, drop_level=False)
    day = day[day['p_trade'] >= cutoff]
    day = day.sort_values('p_trade', ascending=False).head(top_k)
    return list(day.index.get_level_values(1))

def simulate_trade_path(df, symbol, start_date):
    hold_max = CONFIG['HOLD_MAX_DAYS']; tp_mult=CONFIG['ATR_MULT_TP']; sl_mult=CONFIG['ATR_MULT_STOP']
    slippage_bps = CONFIG['SLIPPAGE_BPS']; fee_bps = CONFIG['FEE_BPS_ROUNDTRIP']
    try:
        sub = df.xs(symbol, level=1)
    except Exception:
        return 0.0, start_date
    if start_date not in sub.index:
        return 0.0, start_date
    entry_px = sub.loc[start_date,'Close']; atr_pct = sub.loc[start_date,'ATR_Pct']
    if np.isnan(entry_px) or np.isnan(atr_pct):
        return 0.0, start_date
    tp = entry_px*(1+tp_mult*atr_pct); sl = entry_px*(1-sl_mult*atr_pct)
    fut = sub.loc[start_date:].iloc[1:1+hold_max]
    exit_date = start_date; pnl = None
    for idx,row in fut.iterrows():
        exit_date = idx
        if row['Low'] <= sl: pnl = (sl/entry_px - 1.0); break
        if row['High'] >= tp: pnl = (tp/entry_px - 1.0); break
    if pnl is None:
        if len(fut)==0: return 0.0, start_date
        pnl = (fut.iloc[-1]['Close']/entry_px - 1.0)
    cost = (2*slippage_bps + fee_bps)/10_000.0
    return pnl - cost, exit_date

def backtest(df_feats, preds):
    dates = sorted(preds.index.get_level_values(0).unique())
    equity = CONFIG['START_EQUITY']
    daily_pnl = pd.Series(0.0, index=dates)
    from collections import deque
    pause_days = 0; weekly_returns = deque(maxlen=5); last_week = None

    spy_close = df_feats.xs('SPY', level=1)['Close']
    spy_ma50  = spy_close.rolling(50).mean()
    spy_vol10 = daily_vol(spy_close, 10)

    import math
    for d in dates:
        if last_week is None or d.weekday()==0:
            if len(weekly_returns)==5 and sum(weekly_returns) <= CONFIG['WEEKLY_LOSS_CAP']:
                pause_days = CONFIG['PAUSE_DAYS_ON_WEEKLY_BREACH']
            weekly_returns.clear(); last_week = d

        day_ret = 0.0
        if pause_days>0:
            pause_days -= 1
            daily_pnl.loc[d] = 0.0
            continue

        vol_throttle = 1.0
        if d in spy_close.index:
            bear = spy_close.loc[d] < spy_ma50.loc[d] if not math.isnan(spy_ma50.loc[d]) else False
            hist = spy_vol10.loc[:d].dropna()
            high_vol = (len(hist)>20) and (hist.iloc[-1] >= hist.quantile(0.9))
            if bear and high_vol: vol_throttle = 0.5

        picks = select_positions(preds, d, top_k=int(CONFIG['TOP_K']*vol_throttle))
        if len(picks)>0:
            gross_target = min(1.0, CONFIG['MAX_GROSS_LEV'])*vol_throttle
            w = min(1.0/len(picks), CONFIG['PER_NAME_CAP'], gross_target/len(picks))
            for sym in picks:
                pnl,_ = simulate_trade_path(df_feats, sym, d)
                day_ret += w * pnl

        if day_ret <= CONFIG['DAILY_LOSS_CAP']:
            day_ret = CONFIG['DAILY_LOSS_CAP']; pause_days = 1

        equity *= (1.0 + day_ret)
        daily_pnl.loc[d] = day_ret; weekly_returns.append(day_ret)

    curve = pd.Series(np.cumprod(1+daily_pnl.values)*CONFIG['START_EQUITY'], index=daily_pnl.index, name='equity')
    return curve, daily_pnl

curve, dret = backtest(data, df_test_pred[['p_trade']])
def kpis(curve, dret):
    rets = dret.values
    sharpe = (rets.mean()/(rets.std()+1e-12))*np.sqrt(252)
    sortino = (rets.mean()*np.sqrt(252))/ (np.std(np.minimum(0,rets))*np.sqrt(252)+1e-12)
    roll_max = curve.cummax(); maxdd = (curve/roll_max - 1.0).min()
    wins = rets[rets>0].sum(); losses = -rets[rets<0].sum()
    pf = wins/(losses+1e-12)
    return {'Sharpe':sharpe, 'Sortino':sortino, 'MaxDD':maxdd, 'ProfitFactor':pf}

k = kpis(curve, dret)
print({k_: round(v,3) for k_,v in k.items()})
plt.figure(figsize=(10,4)); plt.plot(curve/curve.iloc[0]); plt.title('Equity Curve (Test)'); plt.show()


In [None]:
# Save artifacts
torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'best_pro.pt'))
with open(os.path.join(SAVE_DIR, 'config.json'),'w') as f: json.dump(CONFIG, f, indent=2)
with open(os.path.join(SAVE_DIR, 'feature_columns.json'),'w') as f: json.dump(X_cols, f, indent=2)
import joblib
joblib.dump(scaler, os.path.join(SAVE_DIR, 'scaler.joblib'))
print('Saved artifacts to', SAVE_DIR)
