## Hierarchical Reasoning Model (HRM) for Next-Price Prediction

This note adapts the Hierarchical Reasoning Model (HRM) to short-horizon market forecasting (e.g., next-tick or next-bar midprice change). HRM performs latent multi-step reasoning in a single forward pass by coupling two recurrent modules operating at different timescales: a slow high-level planner (H) and a fast low-level solver (L). It avoids chain-of-thought tokens and expensive BPTT by using a 1-step gradient approximation with deep supervision and adaptive computation time (ACT).

### Why HRM for markets
- **Latent multi-step reasoning**: Markets often require iterative hypothesis–test–refine loops (e.g., infer regime → test microstructure cues → revise). HRM executes these loops internally without emitting token-level thoughts.
- **Depth without BPTT**: HRM achieves large effective compute depth (N cycles × T low-level steps) with O(1) memory backprop, fitting long contexts and large batches.
- **Adaptive runtime (ACT)**: Variable difficulty across timesteps/instruments; HRM learns when to “ponder” more or halt early based on uncertainty and expected reward.
- **Data efficiency**: Works well with small supervised datasets by reusing state across cycles and applying deep supervision.

### Problem definition
- **Goal**: Predict short-horizon next price behavior for execution or alpha generation.
- **Targets** (choose per use case):
  - Regression: next midprice change Δp_{t+1} or log-return r_{t+1}.
  - Classification: direction sign(Δp_{t+1}) or thresholded move (e.g., |Δp| > κ).
  - Optional auxiliary: volatility σ̂_{t+1}, spread class, liquidity regime.
- **Latency budget**: Sub-millisecond to milliseconds depending on venue; ACT lets us trade accuracy for compute at inference.

### Inputs and feature design
- **Core time-series window**: length L_in (e.g., 256–1024):
  - OHLCV, mid/bid/ask, spread, depth snapshots (top-K levels), imbalance, order flow (initiator, trade size), cancel/modify counts.
  - Returns and signed returns at multiple scales; realized vol, microprice, queue imbalance, Kyle’s λ proxies.
  - Calendar/time-of-day, session flags, volatility/regime embeddings.
  - Instrument meta: tick size, lot size, exchange, sector.
- **Preprocessing**:
  - Robust scaling per instrument (rolling z-score, median/MAD), ceiling outliers, log-transform volumes.
  - Align asynchronous events; pad/pack sequences; mask missing.
  - Label smoothing for classification; log-cosh/Huber targets for regression.

### HRM architecture (market adaptation)
- Input network `fI`: embeds continuous feature vectors at each timestep (linear + RMSNorm + GLU). Optionally include learned positional encodings.
- Low-level recurrent module `fL`: encoder-only Transformer block stack (2–8 blocks) updated every low-level step; receives current H-state, prior L-state, and embedded inputs.
- High-level recurrent module `fH`: encoder-only Transformer block stack (2–8 blocks) updated once per cycle using the final L-state; evolves a slower plan/regime representation.
- Output heads `fO` on `z_H`:
  - Regression head (log-cosh/Huber); classification head (BCE/focal); optional volatility head; confidence head (e.g., temperature or predictive std).
  - Q-head for ACT to decide halt/continue.
- Effective depth: N high-level cycles × T low-level steps per cycle.

### Training mechanics
- **1-step gradient approximation**: Backprop only through the most recent states of L and H, treating earlier steps as constants. O(1) memory, no BPTT unroll.
- **Deep supervision over segments**:
  - Run M segments; after each segment, compute loss on `fO(z_H)`; detach state before next segment.
  - Acts like iterative refinement with frequent feedback; stabilizes training similar to DEQ deep supervision.
- **Adaptive Computation Time (ACT)**:
  - Q-head predicts `Q(halt), Q(continue)` from `z_H` each segment.
  - Choose halt if `Q(halt) ≥ Q(continue)` and minimum segment threshold `M_min` reached; force halt at `M_max`.
  - Train Q-head with episodic Q-learning targets; add optional ponder-cost penalty λ_p for each extra segment.
- **Loss** (example multi-task):
  - L_reg = logcosh(ŷ_reg, y_reg) or Huber; L_dir = BCEWithLogits(ŷ_cls, y_dir); calibration (ECE/temperature) as post-hoc or auxiliary; L_Q = BCE(Q̂, Ĝ). Total: L = w_r L_reg + w_d L_dir + w_q L_Q.
- **Optimization & init**:
  - AdamW (scale-invariant Adam-atan2 variant works well), constant LR with warmup, weight decay.
  - Post-Norm, RMSNorm, GLU FFN, Rotary Positional Embeddings.
  - Truncated LeCun normal initialization; gradient clip (1.0); bf16/AMP.
- **Batching**:
  - Pack sequences by length; randomize instruments/time; stratify by regimes to reduce drift; reset hidden states between samples.

### Inference & deployment
- **Inference-time scaling**: Increase `M_max` on hard cases (low confidence, large residuals) for extra accuracy; keep small `M_max` for easy cases.
- **Decision policy**:
  - Predict ŷ and confidence c; if c < τ, raise `M_max` or abstain; route to slower model if needed.
  - Optional ensemble across light augmentations (feature dropout, small jitter) and average.
- **Latency controls**: Cap T and N; set `M_max` per venue; prune heads not used online.
- **Calibration**: Temperature scaling or isotonic on a validation stream; monitor drift.

### Metrics
- Pointwise: MAE/MSE of Δp or r; directional accuracy; AUC/PR for hit-rate; calibration (ECE, NLL).
- Trading: PnL with simple execution model (crossing spread or half-spread), turnover, drawdown, IR; latency SLA.
- Stability: Error by regime, by spread/vol bucket, and across time.

### Recommended starting configuration
- d_model: 256–512; heads: 8; FFN multiplier: 2–4×; blocks per module: 4.
- N cycles: 2–4; T steps: 4–8; M_max (train): 2–4; M_max (inference): 2 for easy, up to 8 on demand.
- Sequence length L_in: 512; horizon: next tick/bar; batch size: as fits GPU with bf16.
- Loss: log-cosh for regression, optional direction BCE; w_r = 1.0, w_d = 0.25, w_q = 0.1.
- Optimizer: AdamW/Adam-atan2, LR 1e-3–2e-4 with 2k warmup; weight decay 0.01.

### Minimal PyTorch-style sketch (pseudo-code)
```python
class HRM(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.input_net = InputEmbed(cfg)
        self.low = TransformerStack(cfg.low)
        self.high = TransformerStack(cfg.high)
        self.reg_head = nn.Linear(cfg.d_model, 1)
        self.cls_head = nn.Linear(cfg.d_model, 1)
        self.q_head = nn.Linear(cfg.d_model, 2)

    def forward_segment(self, zH, zL, x, N, T):
        x_emb = self.input_net(x)
        with torch.no_grad():
            for i in range(N * T - 1):
                zL = self.low(zL, zH, x_emb)
                if (i + 1) % T == 0:
                    zH = self.high(zH, zL)
        # 1-step grad
        zL = self.low(zL, zH, x_emb)
        zH = self.high(zH, zL)
        reg = self.reg_head(zH)
        cls = self.cls_head(zH)
        q = self.q_head(zH)
        return (zH, zL), reg, cls, q
```

```python
# Training loop with deep supervision + ACT (sketch)
for batch in loader:
    x, y_reg, y_dir = batch
    zH = init_state(batch_size, d_model)
    zL = init_state(batch_size, d_model)
    m = 0
    halted = torch.zeros(batch_size, dtype=torch.bool)
    while m < M_max:
        (zH, zL), reg, cls, q = model.forward_segment(zH, zL, x, N, T)
        # losses for non-halted samples
        mask = ~halted
        L = w_r * log_cosh(reg[mask], y_reg[mask])
        if use_dir:
            L = L + w_d * bce_with_logits(cls[mask], y_dir[mask])
        # Q-learning targets Ĝ for halt/continue (episodic); add ponder-cost if desired
        L = L + w_q * bce(q[mask], Ghat[mask])
        L.backward(); opt.step(); opt.zero_grad()
        # decide halting per sample
        act = (q[...,0] >= q[...,1]) & (m + 1 >= M_min)
        halted = halted | act
        # detach state across segments
        zH = zH.detach(); zL = zL.detach()
        m += 1
        if halted.all():
            break
```

### Practical tips & pitfalls
- Prevent leakage: build labels strictly forward-only; respect exchange calendar and microstructure delays.
- Normalize per instrument with rolling stats; re-estimate periodically to track drift.
- Use robust losses (log-cosh/Huber) and clip grads to handle spikes.
- Start without ACT to validate pipeline; then enable ACT with small `M_max` and modest ponder-cost.
- Monitor calibration; miscalibrated confidence can harm ACT decisions.
- Ablate: HRM vs same-parameter Transformer (no recursion); −ACT; vary N, T.

### Milestones
1) Data/label pipeline + baseline Transformer. 2) HRM (no ACT) with deep supervision. 3) Enable ACT and halting. 4) Inference-time scaling policy by confidence. 5) Calibration + live metrics (PnL, latency). 6) Hyperparameter sweep and ablations.

References: HRM core ideas adapted from Hierarchical Reasoning Model (2025), with market-specific design choices for time-series forecasting.



In [1]:
# Minimal HRM implementation + synthetic test
# Imports and config
import math
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)


@dataclass
class Config:
    # data
    seq_len: int = 256
    feature_dim: int = 16
    train_size: int = 2000
    val_size: int = 500
    batch_size: int = 64

    # model
    d_model: int = 256
    n_cycles: int = 2  # N
    t_steps: int = 4   # T
    n_segments: int = 2  # M (deep supervision segments); set >1 to mimic refinement

    # optimization
    lr: float = 2e-3
    weight_decay: float = 1e-2
    max_epochs: int = 3
    grad_clip: float = 1.0


cfg = Config()
print({"device": str(device)})


In [2]:
# Synthetic dataset for next-step prediction
class SyntheticMarketDataset(Dataset):
    """
    Generates synthetic multivariate sequences and next-step return targets.

    Latent AR(2) price process with regime switching and nonlinear microstructure noise.
    Features include lagged returns, rolling stats, and exogenous noise proxies.
    """
    def __init__(self, num_samples: int, seq_len: int, feature_dim: int):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.feature_dim = feature_dim
        self.X, self.y = self._generate()

    def _generate(self):
        X = np.zeros((self.num_samples, self.seq_len, self.feature_dim), dtype=np.float32)
        y = np.zeros((self.num_samples,), dtype=np.float32)

        for i in range(self.num_samples):
            # regime: mean-reverting vs trending
            regime = np.random.choice([0, 1])
            if regime == 0:
                a1, a2 = 0.6, -0.2
            else:
                a1, a2 = 1.1, -0.3

            price = np.zeros(self.seq_len + 2, dtype=np.float32)
            ret = np.zeros_like(price)

            # generate AR(2) returns with heteroscedastic noise
            price[0] = 0.0
            price[1] = np.random.randn() * 0.01
            for t in range(2, self.seq_len + 2):
                eps = np.random.randn() * (0.005 + 0.01 * np.random.beta(2, 5))
                ret[t] = a1 * (price[t - 1] - price[t - 2]) + a2 * (price[t - 2] - price[t - 3] if t >= 3 else 0.0) + eps
                price[t] = price[t - 1] + ret[t]

            seq_price = price[2:self.seq_len + 2]
            seq_ret = np.diff(seq_price, prepend=seq_price[0])

            # features
            feat = np.zeros((self.seq_len, self.feature_dim), dtype=np.float32)
            feat[:, 0] = seq_ret
            # multi-scale returns
            for k, win in enumerate([2, 4, 8, 16], start=1):
                r = np.convolve(seq_ret, np.ones(win, dtype=np.float32) / win, mode="same")
                feat[:, k] = r
            # realized volatility proxy
            rv = np.sqrt(np.convolve(seq_ret ** 2, np.ones(8, dtype=np.float32) / 8, mode="same") + 1e-8)
            feat[:, 5] = rv
            # imbalance / microprice proxies (synthetic)
            feat[:, 6] = np.tanh(np.convolve(seq_ret, np.array([1, -1, 1, -1], dtype=np.float32), mode="same"))
            feat[:, 7] = (rv > np.median(rv)).astype(np.float32)
            # exogenous noise features
            noise = np.random.randn(self.seq_len, self.feature_dim - 8).astype(np.float32) * 0.1
            feat[:, 8:] = noise

            # next-step target (regression): next return
            target = ret[self.seq_len + 1]

            X[i] = feat
            y[i] = target

        # robust scaling per feature
        med = np.median(X, axis=(0, 1), keepdims=True)
        mad = np.median(np.abs(X - med), axis=(0, 1), keepdims=True) + 1e-6
        X = np.clip((X - med) / (1.4826 * mad), -5.0, 5.0)

        return X, y

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# Build loaders
train_ds = SyntheticMarketDataset(cfg.train_size, cfg.seq_len, cfg.feature_dim)
val_ds = SyntheticMarketDataset(cfg.val_size, cfg.seq_len, cfg.feature_dim)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False)

len(train_loader), len(val_loader)


In [None]:
# HRM model modules
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0), persistent=False)  # (1, max_len, d)

    def forward(self, x):
        # x: (B, L, d)
        L = x.size(1)
        return x + self.pe[:, :L, :]


class InputEmbed(nn.Module):
    def __init__(self, feature_dim: int, d_model: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(feature_dim, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model)
        )
        self.norm = nn.LayerNorm(d_model)
        self.pos = PositionalEncoding(d_model)

    def forward(self, x):
        # x: (B, L, F)
        h = self.proj(x)
        h = self.norm(h)
        h = self.pos(h)
        return h


class TransformerStack(nn.Module):
    def __init__(self, d_model: int, n_heads: int = 8, n_layers: int = 4, dim_ff: int = None, dropout: float = 0.0):
        super().__init__()
        if dim_ff is None:
            dim_ff = d_model * 2
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

    def forward(self, x):
        # x: (B, L, d)
        return self.encoder(x)


class HRM(nn.Module):
    def __init__(self, feature_dim: int, d_model: int = 256, n_heads: int = 8, n_layers_low: int = 4, n_layers_high: int = 4, dim_ff: int = None):
        super().__init__()
        self.d_model = d_model
        self.input_net = InputEmbed(feature_dim, d_model)
        self.low = TransformerStack(d_model, n_heads=n_heads, n_layers=n_layers_low, dim_ff=dim_ff)
        self.high = TransformerStack(d_model, n_heads=n_heads, n_layers=n_layers_high, dim_ff=dim_ff)
        self.reg_head = nn.Linear(d_model, 1)

    def init_state(self, batch_size: int, seq_len: int, device):
        zH = torch.zeros(batch_size, seq_len, self.d_model, device=device)
        zL = torch.zeros_like(zH)
        return zH, zL

    def forward_segment(self, zH, zL, x, N: int, T: int):
        # x: (B, L, F)
        x_emb = self.input_net(x)
        # No-grad for N*T - 1 steps
        with torch.no_grad():
            for i in range(N * T - 1):
                # Low module updates fast within cycle
                l_in = x_emb + zL + zH
                zL = self.low(l_in)
                # High module updates every T steps
                if (i + 1) % T == 0:
                    h_in = zH + zL
                    zH = self.high(h_in)
        # 1-step grad-enabled update
        l_in = x_emb + zL + zH
        zL = self.low(l_in)
        h_in = zH + zL
        zH = self.high(h_in)
        # Heads on pooled high state
        pooled = zH.mean(dim=1)
        reg = self.reg_head(pooled).squeeze(-1)
        return (zH, zL), reg


model = HRM(feature_dim=cfg.feature_dim, d_model=cfg.d_model).to(device)
model


In [4]:
# Training loop with deep supervision (fixed segments), no ACT
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)


def log_cosh_loss(pred, target):
    x = pred - target
    return torch.log(torch.cosh(x + 1e-12)).mean()


def train_one_epoch(epoch):
    model.train()
    total = 0.0
    n = 0
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)
        # initialize states
        zH, zL = model.init_state(xb.size(0), xb.size(1), device)
        seg_losses = []
        for m in range(cfg.n_segments):
            (zH, zL), reg = model.forward_segment(zH, zL, xb, cfg.n_cycles, cfg.t_steps)
            loss = log_cosh_loss(reg, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            opt.zero_grad(set_to_none=True)
            # detach state before next segment
            zH = zH.detach(); zL = zL.detach()
            seg_losses.append(loss.item())
        total += float(np.mean(seg_losses))
        n += 1
    print(f"epoch {epoch} train loss: {total / max(n,1):.4f}")


def evaluate():
    model.eval()
    losses = []
    maes = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            zH, zL = model.init_state(xb.size(0), xb.size(1), device)
            for m in range(cfg.n_segments):
                (zH, zL), reg = model.forward_segment(zH, zL, xb, cfg.n_cycles, cfg.t_steps)
                # no deep supervision updates; just get last segment pred
                zH = zH.detach(); zL = zL.detach()
            loss = log_cosh_loss(reg, yb).item()
            mae = (reg - yb).abs().mean().item()
            losses.append(loss)
            maes.append(mae)
    print(f"val loss: {np.mean(losses):.4f} | val mae: {np.mean(maes):.5f}")


for epoch in range(1, cfg.max_epochs + 1):
    train_one_epoch(epoch)
    evaluate()


In [None]:
# BTC-USD data ingestion, feature engineering, sequence datasets, and model reinit
import pandas as pd

# 1) Fetch BTC-USD data (1h)
try:
    import yfinance as yf
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "yfinance"])  # silent install
    import yfinance as yf

symbol = "BTC-USD"
interval = "1h"
start_date = "2020-01-01"
print({"fetch": symbol, "interval": interval, "start": start_date})

df = yf.download(symbol, start=start_date, interval=interval, auto_adjust=True, progress=False)
assert not df.empty, "No BTC data fetched. Check internet or adjust date/interval."

df = df.dropna().copy()
# Ensure columns
cols = ["Open", "High", "Low", "Close", "Volume"]
for c in cols:
    assert c in df.columns, f"Missing column {c} in BTC data"

# 2) Feature engineering
close = df["Close"].astype(np.float32).values
high = df["High"].astype(np.float32).values
low = df["Low"].astype(np.float32).values
vol = df["Volume"].astype(np.float32).values

# log returns
eps = 1e-12
ret1 = np.log(np.clip(close[1:] / np.clip(close[:-1], eps, None), eps, None)).astype(np.float32)
ret1 = np.concatenate([[0.0], ret1])

# rolling utilities
def rolling_mean(x, w):
    k = np.ones(w, dtype=np.float32) / float(w)
    return np.convolve(x, k, mode="same")

def rolling_std(x, w):
    m = rolling_mean(x, w)
    m2 = rolling_mean(x * x, w)
    v = np.clip(m2 - m * m, 0.0, None)
    return np.sqrt(v + 1e-8)

# EMAs
def ema(x, span):
    a = 2.0 / (span + 1.0)
    y = np.zeros_like(x, dtype=np.float32)
    y[0] = x[0]
    for i in range(1, len(x)):
        y[i] = a * x[i] + (1 - a) * y[i - 1]
    return y

# RSI
def rsi(x, period=14):
    dx = np.diff(x, prepend=x[0])
    up = np.clip(dx, 0, None)
    dn = np.clip(-dx, 0, None)
    up_ema = ema(up, period)
    dn_ema = ema(dn, period)
    rs = up_ema / (dn_ema + 1e-8)
    r = 100.0 - (100.0 / (1.0 + rs))
    return r.astype(np.float32)

# Features (aim ~16 dims)
feat_list = []
feat_list.append(ret1)                                                    # 0: ret1
for w in [2, 4, 8, 16]:
    feat_list.append(rolling_mean(ret1, w))                               # 1..4: mean returns
rv = rolling_std(ret1, 16)                                                # 5: rolling vol (returns std)
feat_list.append(rv)
feat_list.append(((high - low) / np.clip(close, eps, None)))              # 6: HL range / price
ema8 = ema(close, 8)
ema21 = ema(close, 21)
feat_list.append((close - ema8) / np.clip(close, eps, None))              # 7: deviation from EMA8
feat_list.append((close - ema21) / np.clip(close, eps, None))             # 8: deviation from EMA21
feat_list.append(rsi(close, 14) / 100.0)                                  # 9: RSI scaled
feat_list.append(np.log1p(vol))                                           # 10: log volume
feat_list.append(rolling_std(rv, 16))                                     # 11: vol of vol

# time features (UTC index)
idx = df.index
if isinstance(idx, pd.DatetimeIndex):
    hour = idx.hour.values.astype(np.float32)
    dow = idx.dayofweek.values.astype(np.float32)
else:
    # fallback to zeros
    hour = np.zeros(len(df), dtype=np.float32)
    dow = np.zeros(len(df), dtype=np.float32)
feat_list.append(np.sin(2 * np.pi * hour / 24.0))                         # 12: sin hour
feat_list.append(np.cos(2 * np.pi * hour / 24.0))                         # 13: cos hour
feat_list.append(np.sin(2 * np.pi * dow / 7.0))                           # 14: sin dow
feat_list.append(np.cos(2 * np.pi * dow / 7.0))                           # 15: cos dow

features = np.stack(feat_list, axis=1).astype(np.float32)                 # (T, F)
Fdim = features.shape[1]

# robust scale per feature
med = np.median(features, axis=0, keepdims=True)
mad = np.median(np.abs(features - med), axis=0, keepdims=True) + 1e-6
features = np.clip((features - med) / (1.4826 * mad), -5.0, 5.0)

# Targets: next-step log return
targets = np.roll(ret1, -1).astype(np.float32)
features = features[:-1]
targets = targets[:-1]

print({"T": len(features), "F": Fdim})

# 3) Build windowed sequences
class WindowedSeqDataset(Dataset):
    def __init__(self, feats: np.ndarray, targs: np.ndarray, seq_len: int):
        self.X = feats
        self.y = targs
        self.seq_len = seq_len
        self.N = max(0, len(self.X) - seq_len)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        x = self.X[idx: idx + self.seq_len]
        y = self.y[idx + self.seq_len - 1]  # next-step after window end
        return x.astype(np.float32), np.float32(y)

# time-based split (80/20)
T_total = len(features)
T_train = int(T_total * 0.8)
train_feats = features[:T_train]
train_targs = targets[:T_train]
val_feats = features[T_train - cfg.seq_len:]  # ensure enough context
val_targs = targets[T_train - cfg.seq_len:]

btc_train_ds = WindowedSeqDataset(train_feats, train_targs, cfg.seq_len)
btc_val_ds = WindowedSeqDataset(val_feats, val_targs, cfg.seq_len)

train_loader = DataLoader(btc_train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(btc_val_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False)

print({"train_batches": len(train_loader), "val_batches": len(val_loader)})

# 4) Reinitialize model and optimizer for BTC features
cfg.feature_dim = Fdim
model = HRM(feature_dim=cfg.feature_dim, d_model=cfg.d_model).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
model, cfg.feature_dim


In [None]:
# ACT and loss weights configuration
# Extend cfg with ACT params and loss weights (can be tuned)
setattr(cfg, "w_r", 1.0)          # regression weight
setattr(cfg, "w_d", 0.25)         # direction BCE weight
setattr(cfg, "w_q", 0.1)          # Q-head BCE weight
setattr(cfg, "act_Mmax", 4)       # maximum segments per episode
setattr(cfg, "act_eps", 0.2)      # epsilon prob to sample larger Mmin
setattr(cfg, "act_ponder", 0.001) # ponder cost per extra segment
setattr(cfg, "act_use_dir", True) # use direction label for reward

cfg.w_r, cfg.w_d, cfg.w_q, cfg.act_Mmax, cfg.act_eps, cfg.act_ponder, cfg.act_use_dir


In [None]:
# Extend HRM with Q-head and direction head
class HRM_ACT(HRM):
    def __init__(self, feature_dim: int, d_model: int = 256, n_heads: int = 8, n_layers_low: int = 4, n_layers_high: int = 4, dim_ff: int = None):
        super().__init__(feature_dim, d_model, n_heads, n_layers_low, n_layers_high, dim_ff)
        self.dir_head = nn.Linear(d_model, 1)  # direction logits
        self.q_head = nn.Linear(d_model, 2)    # Q(halt), Q(continue)

    def forward_segment(self, zH, zL, x, N: int, T: int):
        x_emb = self.input_net(x)
        with torch.no_grad():
            for i in range(N * T - 1):
                zL = self.low(x_emb + zL + zH)
                if (i + 1) % T == 0:
                    zH = self.high(zH + zL)
        # 1-step grad
        zL = self.low(x_emb + zL + zH)
        zH = self.high(zH + zL)
        pooled = zH.mean(dim=1)
        reg = self.reg_head(pooled).squeeze(-1)
        dir_logit = self.dir_head(pooled).squeeze(-1)
        q = self.q_head(pooled)
        return (zH, zL), reg, dir_logit, q


# Reinit model as HRM_ACT
model = HRM_ACT(feature_dim=cfg.feature_dim, d_model=cfg.d_model).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
model


In [None]:
# Training with ACT (Q-learning targets and halting)

def bce_logits(logits, target):
    return F.binary_cross_entropy_with_logits(logits, target)


def train_one_epoch_act(epoch):
    model.train()
    total_loss = 0.0
    count = 0
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)
        y_dir = (yb >= 0).float()
        B = xb.size(0)

        # Choose M_min stochastically per paper
        if np.random.rand() < cfg.act_eps:
            M_min = np.random.randint(2, cfg.act_Mmax + 1)
        else:
            M_min = 1

        # Initialize hidden states per sample
        zH, zL = model.init_state(B, xb.size(1), device)
        halted = torch.zeros(B, dtype=torch.bool, device=device)
        last_reg = torch.zeros(B, device=device)
        last_dir_logit = torch.zeros(B, device=device)

        for m in range(cfg.act_Mmax):
            (zH, zL), reg, dir_logit, q = model.forward_segment(zH, zL, xb, cfg.n_cycles, cfg.t_steps)

            # Deep supervision regression & direction losses on active samples
            mask = ~halted
            L = torch.tensor(0.0, device=device)
            if mask.any():
                L = L + cfg.w_r * log_cosh_loss(reg[mask], yb[mask])
                if cfg.act_use_dir:
                    L = L + cfg.w_d * bce_logits(dir_logit[mask], y_dir[mask])

            # Q-learning targets
            # Compute episodic rewards: if halt now, reward = 1{correct direction}; else 0
            with torch.no_grad():
                correct = ((dir_logit >= 0).float() == y_dir).float()
                G_halt = correct  # binary reward
                # continue target will be bootstrapped with next q (set later)

            # Decide actions (halt/continue)
            q_halt = q[:, 0]
            q_cont = q[:, 1]
            choose_halt = (q_halt >= q_cont) & (m + 1 >= M_min)
            # Force halt at last step
            if m == cfg.act_Mmax - 1:
                choose_halt = torch.ones_like(choose_halt, dtype=torch.bool)

            # Build Q targets
            # For samples that halt now: target is immediate reward
            # For continue: bootstrap with max next-step Q on next segment (approx via current q to keep simple)
            with torch.no_grad():
                G = torch.zeros_like(q)
                # Halt targets
                G[:, 0] = G_halt
                # Continue targets: approximate with current next-step estimate using max(q)
                # (In practice, you might compute q_next from next segment. Here we keep 1-step target simple.)
                G[:, 1] = torch.maximum(q_halt, q_cont).detach()

            L = L + cfg.w_q * F.binary_cross_entropy_with_logits(q, G)

            # Ponder cost for continued, unhalted samples
            if m + 1 > M_min:
                L = L + cfg.act_ponder * (~choose_halt).float().mean()

            L.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step(); opt.zero_grad(set_to_none=True)

            # Update halting
            halted = halted | choose_halt
            zH = zH.detach(); zL = zL.detach()
            last_reg = reg.detach()
            last_dir_logit = dir_logit.detach()

            if halted.all():
                break

        total_loss += L.item()
        count += 1

    print(f"epoch {epoch} train(ACT) loss: {total_loss / max(count,1):.4f}")


def evaluate_act():
    model.eval()
    losses = []
    maes = []
    accs = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            y_dir = (yb >= 0).float()
            zH, zL = model.init_state(xb.size(0), xb.size(1), device)
            halted = torch.zeros(xb.size(0), dtype=torch.bool, device=device)
            reg = torch.zeros_like(yb)
            dir_logit = torch.zeros_like(yb)
            for m in range(cfg.act_Mmax):
                (zH, zL), reg_m, dir_logit_m, q = model.forward_segment(zH, zL, xb, cfg.n_cycles, cfg.t_steps)
                # ACT halting rule at eval (greedy)
                q_halt = q[:, 0]
                q_cont = q[:, 1]
                choose_halt = (q_halt >= q_cont)
                # commit predictions when halting and not already halted
                commit = choose_halt & (~halted)
                reg[commit] = reg_m[commit]
                dir_logit[commit] = dir_logit_m[commit]
                halted = halted | choose_halt
                zH = zH.detach(); zL = zL.detach()
                if halted.all():
                    break
            # fallback for any that never halted
            still = ~halted
            if still.any():
                reg[still] = reg_m[still]
                dir_logit[still] = dir_logit_m[still]
            loss = log_cosh_loss(reg, yb).item()
            mae = (reg - yb).abs().mean().item()
            acc = (((dir_logit >= 0).float() == y_dir).float().mean().item())
            losses.append(loss); maes.append(mae); accs.append(acc)
    print(f"val(ACT) loss: {np.mean(losses):.4f} | MAE: {np.mean(maes):.6f} | Dir Acc: {np.mean(accs):.4f}")


for epoch in range(1, cfg.max_epochs + 1):
    train_one_epoch_act(epoch)
    evaluate_act()


In [None]:
# Performance evaluation & visualization (val set)
import matplotlib.pyplot as plt

@torch.no_grad()
def collect_val_predictions(max_batches=None):
    model.eval()
    preds = []
    dirs = []
    trues = []
    steps_used = []
    for bi, (xb, yb) in enumerate(val_loader):
        xb = xb.to(device)
        yb = yb.to(device)
        zH, zL = model.init_state(xb.size(0), xb.size(1), device)
        halted = torch.zeros(xb.size(0), dtype=torch.bool, device=device)
        reg = torch.zeros_like(yb)
        dir_logit = torch.zeros_like(yb)
        steps = torch.zeros_like(yb, dtype=torch.long)
        for m in range(cfg.act_Mmax):
            (zH, zL), reg_m, dir_logit_m, q = model.forward_segment(zH, zL, xb, cfg.n_cycles, cfg.t_steps)
            q_halt = q[:, 0]
            q_cont = q[:, 1]
            choose_halt = (q_halt >= q_cont)
            commit = choose_halt & (~halted)
            reg[commit] = reg_m[commit]
            dir_logit[commit] = dir_logit_m[commit]
            steps[commit] = m + 1
            halted = halted | choose_halt
            zH = zH.detach(); zL = zL.detach()
            if halted.all():
                break
        still = ~halted
        if still.any():
            reg[still] = reg_m[still]
            dir_logit[still] = dir_logit_m[still]
            steps[still] = cfg.act_Mmax
        preds.append(reg.cpu())
        dirs.append(dir_logit.cpu())
        trues.append(yb.cpu())
        steps_used.append(steps.cpu())
        if max_batches is not None and (bi + 1) >= max_batches:
            break
    return torch.cat(preds), torch.cat(dirs), torch.cat(trues), torch.cat(steps_used)


preds, dir_logits, trues, steps_used = collect_val_predictions()
y_dir = (trues >= 0).float()
acc = ((dir_logits >= 0).float() == y_dir).float().mean().item()
mae = (preds - trues).abs().mean().item()
rmse = torch.sqrt(((preds - trues) ** 2).mean()).item()
cor = np.corrcoef(preds.numpy(), trues.numpy())[0,1]

print({"val_dir_acc": round(acc, 4), "val_mae": round(mae, 6), "val_rmse": round(rmse, 6), "val_corr": float(cor)})
print({"mean_steps": float(steps_used.float().mean()), "pct_halt_1": float((steps_used == 1).float().mean())})

# Simple PnL proxy: position = sign(pred), PnL = position * true_return - spread_cost
# Spread cost proxy (very small for crypto hourly): c per trade when position changes
pos = torch.sign(preds).numpy()
ret = trues.numpy()
trade = np.abs(np.diff(pos, prepend=0)) > 0
c = 0.0  # set to small cost if desired, e.g., 1e-5
pnl = pos * ret - c * trade.astype(np.float32)
cum_pnl = pnl.cumsum()

fig, axs = plt.subplots(3, 1, figsize=(10, 10), constrained_layout=True)
axs[0].plot(trues.numpy(), label="true ret", alpha=0.7)
axs[0].plot(preds.numpy(), label="pred ret", alpha=0.7)
axs[0].legend(); axs[0].set_title("Returns: true vs pred (val)")
axs[1].plot(cum_pnl, color="tab:green"); axs[1].set_title("Cumulative PnL (val)")
axs[2].hist(steps_used.numpy(), bins=np.arange(0.5, cfg.act_Mmax + 1.5), rwidth=0.8)
axs[2].set_xticks(list(range(1, cfg.act_Mmax + 1)))
axs[2].set_title("ACT steps used distribution (val)")
plt.show()
