### Configuration

In [1]:

CSV_PATH = "mnq_complete_dataset.csv"  #@param {type:"string"}

# Features to use (OHLCV + 7 TA indicators; all lowercase)
INDICATORS_TO_USE = ['open', 'high', 'low', 'close', 'volume','atr_14', 'adx_14', 'ema_9', 'ema_21', 'vwap', 'rsi_21', 'stochk_14_3_3']  #@param {type:"raw"}

# Normalization: 'StandardScaler' or 'MinMaxScaler'
NORMALIZATION_TYPE = "StandardScaler"  #@param ["StandardScaler", "MinMaxScaler"]

# Chronological split ratios (must sum to 1.0)
TRAIN_VALID_TEST_SPLIT = [0.7, 0.15, 0.15]  #@param {type:"raw"}

# Sequence/window length for model input
SEQ_LEN = 512  #@param {type:"integer"}

# Step size between windows (1 = full overlap)
WINDOW_STRIDE = 64  #@param {type:"integer"}

# TimesBlock
TOP_K_PERIODS = 3  #@param {type:"integer"}  # k in paper
EMBED_DIM = 256     # 2D backbone output channels
EMBED_H = 8
EMBED_W = 8
DROPOUT_RATE = 0.2  #@param {type:"number"}

# Training
NUM_EPOCHS = 100  #@param {type:"integer"}
BATCH_SIZE = 2048  #@param {type:"integer"}
LEARNING_RATE = 1e-3  # scheduler max_lr
WEIGHT_DECAY = 1e-4  #@param {type:"number"}
PATIENCE = 30  #@param {type:"integer"}
GRAD_ACCUM_STEPS = 1  #@param {type:"integer"}

# LR Scheduler (OneCycle)
USE_SCHEDULER = True  #@param {type:"boolean"}
ONECYCLE_PCT_START = 0.15  #@param {type:"number"}
ONECYCLE_DIV_FACTOR = 25.0  #@param {type:"number"}
ONECYCLE_FINAL_DIV = 100.0  #@param {type:"number"}

# Gradient clipping
CLIP_GRAD_NORM = 1.0  #@param {type:"number"}

# Train-only augmentation
AUG_NOISE_STD = 0.01  #@param {type:"number"}

# DataLoader performance
DATALOADER_WORKERS = 8  #@param {type:"integer"}
PIN_MEMORY = True  #@param {type:"boolean"}
PERSISTENT_WORKERS = True  #@param {type:"boolean"}
PREFETCH_FACTOR = 2  #@param {type:"integer"}

# Precision & memory format (A100)
USE_BF16 = True  #@param {type:"boolean"}
CHANNELS_LAST = True  #@param {type:"boolean"}
PRINT_GPU_MEM = True  #@param {type:"boolean"}

# Checkpointing
CHECKPOINT_PATH = "/content/drive/MyDrive/timesnet_mnq/checkpoints/best.pt"  #@param {type:"string"}

# Logs e histórico
LOG_DIR = "/content/drive/MyDrive/timesnet_mnq/logs"  #@param {type:"string"}
HISTORY_PATH = LOG_DIR + "/training_history.jsonl"

# Resume training if checkpoint exists
RESUME_TRAINING = True  #@param {type:"boolean"}

# Device
import os
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Config -> seq_len={SEQ_LEN}, batch_size={BATCH_SIZE}, k={TOP_K_PERIODS}, embed_dim={EMBED_DIM}, HxW={EMBED_H}x{EMBED_W}, dropout={DROPOUT_RATE}")
print(f"Optim -> AdamW max_lr={LEARNING_RATE}, wd={WEIGHT_DECAY}, patience={PATIENCE}, grad_accum={GRAD_ACCUM_STEPS}")
print(f"Loader -> workers={DATALOADER_WORKERS}, pin_memory={PIN_MEMORY}, persistent={PERSISTENT_WORKERS}, prefetch={PREFETCH_FACTOR}")
print(f"Device -> {DEVICE} | bf16={USE_BF16} | channels_last={CHANNELS_LAST}")


Config -> seq_len=512, batch_size=2048, k=3, embed_dim=256, HxW=8x8, dropout=0.2
Optim -> AdamW max_lr=0.001, wd=0.0001, patience=30, grad_accum=1
Loader -> workers=8, pin_memory=True, persistent=True, prefetch=2
Device -> cuda | bf16=True | channels_last=True


### Setup & Imports

In [2]:
import os
import math
import json
import sys, subprocess
import time
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from datetime import datetime

# Ensure pandas-ta is available early
try:
    import pandas_ta as ta
except Exception:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'pandas-ta'])
    import pandas_ta as ta

import contextlib

def _fmt_gb(bytes_val):
    try:
        return f"{bytes_val/1e9:.2f} GB"
    except Exception:
        return str(bytes_val)

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU -> {torch.cuda.get_device_name(0)} | VRAM total: {_fmt_gb(props.total_memory)}")

try:
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('high')
except Exception:
    pass


GPU -> NVIDIA A100-SXM4-40GB | VRAM total: 42.47 GB


### Mount Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)
print(f"Checkpoints dir -> {os.path.dirname(CHECKPOINT_PATH)}")
print(f"Checkpoint file -> {CHECKPOINT_PATH}")


Mounted at /content/drive
Checkpoints dir -> /content/drive/MyDrive/timesnet_mnq/checkpoints
Checkpoint file -> /content/drive/MyDrive/timesnet_mnq/checkpoints/best.pt


### Data Loading & Chronological Splits (no leakage)



In [4]:
def _find_datetime_column(df: pd.DataFrame) -> Optional[str]:
    candidates = ['datetime', 'date', 'time', 'timestamp', 'ts']
    cols = [c for c in df.columns]
    for c in cols:
        if c.lower() in candidates:
            return c
    return None

def _find_col_ci(df: pd.DataFrame, name: str) -> Optional[str]:
    for c in df.columns:
        if c.lower() == name.lower():
            return c
    return None

def _has_col_ci(df: pd.DataFrame, name: str) -> bool:
    return any(c.lower() == name.lower() for c in df.columns)

def add_ta_indicators(df: pd.DataFrame) -> pd.DataFrame:
    # Normalize base column names to lowercase for consistency
    rename_map = {}
    for nm in ['Open','High','Low','Close','Volume']:
        c = _find_col_ci(df, nm)
        if c is not None:
            rename_map[c] = nm.lower()
    if rename_map:
        df = df.rename(columns=rename_map)

    required = ['high','low','close','volume']
    if any(not _has_col_ci(df, r) for r in required):
        raise ValueError('Missing OHLCV columns (high, low, close, volume) to compute pandas-ta indicators.')

    # Ensure datetime index for VWAP
    time_col = _find_datetime_column(df)
    if time_col is not None:
        try:
            df[time_col] = pd.to_datetime(df[time_col])
            df = df.set_index(time_col)
        except Exception:
            pass # Keep going if datetime conversion fails, but VWAP might fail

    # Compute indicators only if a lowercase target is not already present
    if not _has_col_ci(df, 'atr_14'):
        df.ta.atr(high='high', low='low', close='close', length=14, append=True)
    if not _has_col_ci(df, 'adx_14'):
        df.ta.adx(high='high', low='low', close='close', length=14, append=True)
    if not _has_col_ci(df, 'ema_9'):
        df.ta.ema(close='close', length=9, append=True)
    if not _has_col_ci(df, 'ema_21'):
        df.ta.ema(close='close', length=21, append=True)
    if not _has_col_ci(df, 'vwap'):
        df.ta.vwap(high='high', low='low', close='close', volume='volume', append=True)
    if not _has_col_ci(df, 'rsi_21'):
        df.ta.rsi(close='close', length=21, append=True)
    if not _has_col_ci(df, 'stochk_14_3_3'):
        df.ta.stoch(high='high', low='low', close='close', k=14, d=3, smooth_k=3, append=True)

    # Rename known pandas-ta outputs to lowercase canonical names if needed
    lower_map = {}
    # ATR - pandas-ta outputs ATRr_14
    if 'ATRr_14' in df.columns and 'atr_14' not in df.columns:
        lower_map['ATRr_14'] = 'atr_14'
    # ADX
    if 'ADX_14' in df.columns and 'adx_14' not in df.columns:
        lower_map['ADX_14'] = 'adx_14'
    # EMAs
    if 'EMA_9' in df.columns and 'ema_9' not in df.columns:
        lower_map['EMA_9'] = 'ema_9'
    if 'EMA_21' in df.columns and 'ema_21' not in df.columns:
        lower_map['EMA_21'] = 'ema_21'
    # VWAP
    if 'VWAP' in df.columns and 'vwap' not in df.columns:
        lower_map['VWAP'] = 'vwap'
    if 'VWAP_D' in df.columns and 'vwap' not in df.columns: # Added to handle VWAP_D
        lower_map['VWAP_D'] = 'vwap'
    # RSI
    if 'RSI_21' in df.columns and 'rsi_21' not in df.columns:
        lower_map['RSI_21'] = 'rsi_21'
    # STOCH K/D (we'll expose K by default)
    if 'STOCHk_14_3_3' in df.columns and 'stochk_14_3_3' not in df.columns:
        lower_map['STOCHk_14_3_3'] = 'stochk_14_3_3'
    if lower_map:
        df = df.rename(columns=lower_map)
    return df

def load_mnq_csv(csv_path: str, indicators: List[str]) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    time_col = _find_datetime_column(df)
    if time_col is not None:
        try:
            df[time_col] = pd.to_datetime(df[time_col])
            df = df.sort_values(by=time_col, ascending=True).reset_index(drop=True)
        except Exception:
            df = df.reset_index(drop=True)
    else:
        df = df.reset_index(drop=True)

    # Create TA indicators and ensure lowercase canon names
    df = add_ta_indicators(df)
    # Select requested features (lowercase)
    missing = [c for c in indicators if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns after TA creation: {missing}. Available: {list(df.columns)}")
    xdf = df[indicators].copy()
    # Clean: numeric, remove inf, drop NaNs to avoid leakage via backward fill
    for c in xdf.columns:
        xdf[c] = pd.to_numeric(xdf[c], errors='coerce')
    xdf = xdf.replace([np.inf, -np.inf], np.nan)
    xdf = xdf.dropna()
    return xdf

def chronological_split(arr: np.ndarray, ratios: List[float]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    assert abs(sum(ratios) - 1.0) < 1e-6, "TRAIN_VALID_TEST_SPLIT must sum to 1.0"
    T = len(arr)
    n_train = int(T * ratios[0])
    n_val = int(T * ratios[1])
    n_test = T - n_train - n_val
    train = arr[:n_train]
    val = arr[n_train:n_train+n_val]
    test = arr[n_train+n_val:]
    return train, val, test

def make_scaler(norm_type: str):
    if norm_type == "StandardScaler":
        return StandardScaler()
    elif norm_type == "MinMaxScaler":
        return MinMaxScaler()
    else:
        raise ValueError("NORMALIZATION_TYPE must be 'StandardScaler' or 'MinMaxScaler'.")

def fit_transform_splits(train_arr: np.ndarray, val_arr: np.ndarray, test_arr: np.ndarray, scaler):
    # Fit only on training data
    scaler.fit(train_arr)
    train_scaled = scaler.transform(train_arr)
    val_scaled = scaler.transform(val_arr)
    test_scaled = scaler.transform(test_arr)
    return train_scaled, val_scaled, test_scaled

# Load CSV
print(f'Loading CSV: {CSV_PATH}')
df = load_mnq_csv(CSV_PATH, INDICATORS_TO_USE)
print(f'Features selected: {len(INDICATORS_TO_USE)} | {INDICATORS_TO_USE}')
data = df.values.astype(np.float32)
print(f'Total rows after TA + cleanup: {len(df)} | Feature dim: {data.shape[1]}')

# Chronological split (no shuffling)
train_raw, val_raw, test_raw = chronological_split(data, TRAIN_VALID_TEST_SPLIT)
print(f'Split -> train: {train_raw.shape}, val: {val_raw.shape}, test: {test_raw.shape}')

# Train-only normalization
scaler = make_scaler(NORMALIZATION_TYPE)
train_scaled, val_scaled, test_scaled = fit_transform_splits(train_raw, val_raw, test_raw, scaler)
if NORMALIZATION_TYPE == 'StandardScaler':
    means = scaler.mean_
    stds = scaler.scale_ if hasattr(scaler, 'scale_') else np.sqrt(scaler.var_)
    print(f'Scaler(Standard) -> mean range [{means.min():.4f}, {means.max():.4f}] | std range [{stds.min():.4f}, {stds.max():.4f}]')
else:
    mins = scaler.data_min_
    maxs = scaler.data_max_
    print(f'Scaler(MinMax) -> min range [{mins.min():.4f}, {mins.max():.4f}] | max range [{maxs.min():.4f}, {maxs.max():.4f}]')

def _count_windows(n, L, s):
    return max(0, (n - L) // s + 1)
nw_train = _count_windows(len(train_scaled), SEQ_LEN, WINDOW_STRIDE)
nw_val = _count_windows(len(val_scaled), SEQ_LEN, WINDOW_STRIDE)
nw_test = _count_windows(len(test_scaled), SEQ_LEN, WINDOW_STRIDE)
print(f'Windows -> train: {nw_train}, val: {nw_val}, test: {nw_test} | stride={WINDOW_STRIDE}, seq_len={SEQ_LEN}')

train_scaled.shape, val_scaled.shape, test_scaled.shape

Loading CSV: mnq_complete_dataset.csv
Features selected: 12 | ['open', 'high', 'low', 'close', 'volume', 'atr_14', 'adx_14', 'ema_9', 'ema_21', 'vwap', 'rsi_21', 'stochk_14_3_3']
Total rows after TA + cleanup: 513211 | Feature dim: 12
Split -> train: (359247, 12), val: (76981, 12), test: (76983, 12)
Scaler(Standard) -> mean range [0.1939, 1653.4210] | std range [0.1121, 2725.6734]
Windows -> train: 5606, val: 1195, test: 1195 | stride=64, seq_len=512


((359247, 12), (76981, 12), (76983, 12))

### Dataset & DataLoaders (windowed, no shuffle)



In [5]:
class MNQ_Dataset(Dataset):
    def __init__(self, arr_2d: np.ndarray, seq_len: int, stride: int = 1):
        super().__init__()
        self.x = arr_2d
        self.seq_len = int(seq_len)
        self.stride = int(stride)
        self.T = len(arr_2d)
        self.C = arr_2d.shape[1]
        if self.T < self.seq_len:
            raise ValueError(f"Not enough timesteps ({self.T}) for seq_len={self.seq_len}")
        # Number of windows using stride
        self.idxs = list(range(0, self.T - self.seq_len + 1, self.stride))

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        start = self.idxs[idx]
        end = start + self.seq_len
        window = self.x[start:end]  # shape [seq_len, C]
        # model expects [L, C], training target == input window (self-supervised)
        return torch.from_numpy(window).float()

def make_loader(arr: np.ndarray, seq_len: int, stride: int, batch_size: int) -> DataLoader:
    ds = MNQ_Dataset(arr, seq_len=seq_len, stride=stride)
    # No shuffling to avoid any perceived leakage; data is chronologically windowed already

    # Ajuste automático de batch para evitar época vazia
    bs = int(batch_size)
    nwin = len(ds)
    if nwin < bs:
        print(f'WARN: batch_size {bs} > windows {nwin}; ajustando para evitar época vazia.')
        bs = max(1, nwin)

    return DataLoader(ds, batch_size=bs, shuffle=False, drop_last=False, num_workers=DATALOADER_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=(PERSISTENT_WORKERS and DATALOADER_WORKERS>0), prefetch_factor=(PREFETCH_FACTOR if DATALOADER_WORKERS>0 else 2))


train_loader = make_loader(train_scaled, SEQ_LEN, WINDOW_STRIDE, BATCH_SIZE)
val_loader = make_loader(val_scaled, SEQ_LEN, WINDOW_STRIDE, BATCH_SIZE)
test_loader = make_loader(test_scaled, SEQ_LEN, WINDOW_STRIDE, BATCH_SIZE)

tw, vw, tew = len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)
tb, vb, teb = len(train_loader), len(val_loader), len(test_loader)
print(f'DataLoaders -> windows (train/val/test): {tw}/{vw}/{tew} | batches: {tb}/{vb}/{teb} | batch_size={BATCH_SIZE}')
print(f'Window shape -> L={SEQ_LEN}, C={train_scaled.shape[1]}')


WARN: batch_size 2048 > windows 1195; ajustando para evitar época vazia.
WARN: batch_size 2048 > windows 1195; ajustando para evitar época vazia.
DataLoaders -> windows (train/val/test): 5606/1195/1195 | batches: 3/1/1 | batch_size=2048
Window shape -> L=512, C=12


### Model: TimesBlock + Inception + Feature Extractor + Decoder



In [6]:
class InceptionBlock2D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.1):
        super().__init__()
        # Split out_channels across branches
        base, rem = divmod(out_channels, 4)
        branch_sizes = [base + (1 if i < rem else 0) for i in range(4)]
        b1, b2, b3, b4 = branch_sizes

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, b1, kernel_size=1, bias=False),
            nn.BatchNorm2d(b1),
            nn.ReLU(inplace=True),
        )

        red3 = max(in_channels // 2, 8)
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, red3, kernel_size=1, bias=False),
            nn.BatchNorm2d(red3),
            nn.ReLU(inplace=True),
            nn.Conv2d(red3, b2, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(b2),
            nn.ReLU(inplace=True),
        )

        red5 = max(in_channels // 2, 8)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, red5, kernel_size=1, bias=False),
            nn.BatchNorm2d(red5),
            nn.ReLU(inplace=True),
            nn.Conv2d(red5, b3, kernel_size=5, padding=2, bias=False),
            nn.BatchNorm2d(b3),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, b4, kernel_size=1, bias=False),
            nn.BatchNorm2d(b4),
            nn.ReLU(inplace=True),
        )

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        o1 = self.branch1x1(x)
        o2 = self.branch3x3(x)
        o3 = self.branch5x5(x)
        o4 = self.branch_pool(x)
        out = torch.cat([o1, o2, o3, o4], dim=1)
        return self.dropout(out)

class TimesBlock(nn.Module):
    """
    FFT-based period discovery + 1D->2D folding + 2D Inception backbone + weighted fusion.
    Input:  x [B, L, C]
    Output: embedding [B, E, H, W]
    """
    def __init__(self, in_channels: int, k: int, embed_dim: int, embed_h: int, embed_w: int, dropout: float = 0.1):
        super().__init__()
        self.in_channels = in_channels
        self.k = k
        self.embed_dim = embed_dim
        self.embed_h = embed_h
        self.embed_w = embed_w
        self.backbone = InceptionBlock2D(in_channels, embed_dim, dropout=dropout)

    @torch.no_grad()
    def _find_topk_periods(self, x_bc_l: torch.Tensor, k: int) -> Tuple[List[int], torch.Tensor]:
        # x_bc_l: [B, C, L]
        B, C, L = x_bc_l.shape
        xf = torch.fft.rfft(x_bc_l, dim=-1)  # [B, C, L//2 + 1]
        amp = xf.abs().mean(dim=(0, 1))      # [L//2 + 1], averaged over batch & channels
        if amp.shape[0] <= 1:
            return [L], torch.tensor([1.0], device=x_bc_l.device)

        amp[0] = 0.0  # ignore DC
        k_eff = min(k, amp.shape[0]-1)
        vals, idxs = torch.topk(amp, k=k_eff, largest=True, sorted=True)
        periods = []
        for idx in idxs.tolist():
            p = int(round(L / max(idx, 1)))
            p = max(p, 2)
            periods.append(p)
        # Softmax weights from amplitudes
        w = torch.softmax(vals, dim=0)
        return periods, w

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, C]
        B, L, C = x.shape
        x_bc_l = x.permute(0, 2, 1).contiguous()  # [B, C, L]

        periods, weights = self._find_topk_periods(x_bc_l, self.k)
        feats = None
        for i, p in enumerate(periods):
            pad_len = (p - (L % p)) % p
            x_pad = F.pad(x_bc_l, (0, pad_len), mode='constant', value=0.0)  # [B,C,Lp]
            Lp = x_pad.shape[-1]
            w_ = Lp // p
            # Fold: [B, C, p, w_]
            x_2d = x_pad.view(B, C, w_, p).transpose(2, 3).contiguous()

            z = self.backbone(x_2d)  # [B, E, h, w]
            z = F.adaptive_avg_pool2d(z, (self.embed_h, self.embed_w))  # [B,E,H,W]
            z = z * weights[i].view(1, 1, 1, 1)  # weight this period

            feats = z if feats is None else (feats + z)

        return feats  # [B, E, H, W]

class TimesNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels: int, k: int, embed_dim: int, embed_h: int, embed_w: int, dropout: float=0.1):
        super().__init__()
        self.block = TimesBlock(in_channels, k, embed_dim, embed_h, embed_w, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, C]
        emb = self.block(x)
        return self.dropout(emb)  # [B, E, H, W]

class ModelWithDecoder(nn.Module):
    """
    Wrapper for self-supervised training (reconstruction).
    During inference, use extractor only.
    """
    def __init__(self, in_channels: int, seq_len: int, k: int, embed_dim: int, embed_h: int, embed_w: int, dropout: float=0.1):
        super().__init__()
        self.in_channels = in_channels
        self.seq_len = seq_len

        self.extractor = TimesNetFeatureExtractor(
            in_channels=in_channels,
            k=k,
            embed_dim=embed_dim,
            embed_h=embed_h,
            embed_w=embed_w,
            dropout=dropout
        )
        flat_size = embed_dim * embed_h * embed_w
        hidden = max(512, flat_size // 2)

        # Using Sequential with explicit layers instead of Flatten
        self.decoder = nn.Sequential(
            nn.Linear(flat_size, hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hidden, seq_len * in_channels)
        )

    def forward(self, x: torch.Tensor):
        # x: [B, L, C]
        emb = self.extractor(x)  # [B, E, H, W]
        # Explicitly permute to channels_first before flattening with view
        if emb.dim() == 4:
             # Assuming channels_last if not channels_first or original channels
             emb = emb.permute(0, 1, 2, 3).contiguous() # Ensure channels_first (no-op if already)

        # Debug: print(f"Shape before flattening: {emb.shape}")
        emb_flat = emb.view(emb.size(0), -1) # Explicit flatten
        rec = self.decoder(emb_flat).view(x.shape[0], self.seq_len, self.in_channels)
        return rec, emb

    def strip_decoder(self):
        self.decoder = None

### Training Utilities: EarlyStopping, Checkpointing, Train/Eval

In [11]:
from contextlib import nullcontext
USE_AMP = bool(torch.cuda.is_available() and (DEVICE == 'cuda') and USE_BF16 and getattr(torch.cuda, 'is_bf16_supported', lambda: False)())
AMP_DTYPE = torch.bfloat16 if USE_AMP else None

class EarlyStopping:
    def __init__(self, patience: int = 7, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best = None
        self.num_bad = 0

    def step(self, value: float) -> bool:
        if self.best is None or value < self.best - self.min_delta:
            self.best = value
            self.num_bad = 0
            return False  # no stop
        else:
            self.num_bad += 1
            return self.num_bad >= self.patience

def save_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer, epoch: int, best_val: float, extra: dict=None):
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "best_val_loss": best_val,
    }
    if extra:
        state["extra"] = extra
    torch.save(state, path)

def load_checkpoint_if_any(path: str, model: nn.Module, optimizer: torch.optim.Optimizer, resume: bool):
    start_epoch = 1
    best_val = float("inf")
    if resume and os.path.exists(path):
        ckpt = torch.load(path, map_location="cpu")
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt.get("epoch", 1) + 1
        best_val = ckpt.get("best_val_loss", float("inf"))
        print(f"Resuming from epoch {start_epoch-1} with best_val={best_val:.6f}")
    return start_epoch, best_val

def train_one_epoch(model, loader, optimizer, device, scheduler=None, clip_grad_norm: float=None, aug_noise_std: float=0.0):
    model.train()
    mse = nn.MSELoss()
    total_loss = 0.0
    n = 0
    optimizer.zero_grad(set_to_none=True)
    for i, xb in enumerate(loader):
        xb = xb.to(device, non_blocking=True)
        if aug_noise_std and aug_noise_std > 0:
            noise = torch.randn_like(xb) * float(aug_noise_std)
            xb = xb + noise
        ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if USE_AMP else nullcontext()
        with ctx:
            rec, emb = model(xb)
            loss = mse(rec, xb) / max(int(GRAD_ACCUM_STEPS), 1)
        loss.backward()
        if (i + 1) % max(int(GRAD_ACCUM_STEPS), 1) == 0:
            if clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(clip_grad_norm))
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            optimizer.zero_grad(set_to_none=True)
        # Accumulate reporting with true scale
        total_loss += (loss.item() * max(int(GRAD_ACCUM_STEPS), 1)) * xb.size(0)
        n += xb.size(0)
    return total_loss / max(n, 1)

@torch.no_grad()
def evaluate(model, loader, device, aug_noise_std: float=0.0):
    model.eval()
    mse = nn.MSELoss()
    total_loss = 0.0
    n = 0
    for xb in loader:
        xb = xb.to(device, non_blocking=True)
        # Augmentation should only be applied during training, remove this line
        # if aug_noise_std and aug_noise_std > 0:
        #     noise = torch.randn_like(xb) * float(aug_noise_std)
        #     xb = xb + noise
        ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if USE_AMP else nullcontext()
        with ctx:
            rec, emb = model(xb)
            loss = mse(rec, xb)
        total_loss += loss.item() * xb.size(0)
        n += xb.size(0)
    return total_loss / max(n, 1)

### Initialize Model, Optimizer, and Train

In [12]:
in_channels = train_scaled.shape[1]
print(f'In-channels (features): {in_channels}')
model = ModelWithDecoder(
    in_channels=in_channels,
    seq_len=SEQ_LEN,
    k=TOP_K_PERIODS,
    embed_dim=EMBED_DIM,
    embed_h=EMBED_H,
    embed_w=EMBED_W,
    dropout=DROPOUT_RATE
).to(DEVICE)
if CHANNELS_LAST and (DEVICE == 'cuda'):
    model = model.to(memory_format=torch.channels_last)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model params: {params/1e6:.2f}M')

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

start_epoch, best_val = load_checkpoint_if_any(CHECKPOINT_PATH, model, optimizer, resume=RESUME_TRAINING)

# Ensure the optimizer's initial LR is set before creating the scheduler when resuming
if RESUME_TRAINING and os.path.exists(CHECKPOINT_PATH):
    for param_group in optimizer.param_groups:
        param_group['lr'] = LEARNING_RATE

scheduler = None
if USE_SCHEDULER:
    steps_per_epoch = max(1, (len(train_loader) + max(int(GRAD_ACCUM_STEPS),1) - 1) // max(int(GRAD_ACCUM_STEPS),1))
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LEARNING_RATE, steps_per_epoch=steps_per_epoch, epochs=NUM_EPOCHS,
        pct_start=ONECYCLE_PCT_START, div_factor=ONECYCLE_DIV_FACTOR, final_div_factor=ONECYCLE_FINAL_DIV,
        anneal_strategy='cos'
    )


os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
training_history = []
early = EarlyStopping(patience=PATIENCE, min_delta=0.0)

tb, vb = len(train_loader), len(val_loader)
tw, vw = len(train_loader.dataset), len(val_loader.dataset)
est_opt_steps = (tb + max(int(GRAD_ACCUM_STEPS),1) - 1) // max(int(GRAD_ACCUM_STEPS),1)
print(f"Starting training at epoch {start_epoch} on {DEVICE} | train windows={tw}, batches/epoch={tb}, grad_accum={GRAD_ACCUM_STEPS}, est_opt_steps/epoch={est_opt_steps}")
for epoch in range(start_epoch, NUM_EPOCHS + 1):
    if DEVICE == 'cuda':
        torch.cuda.reset_peak_memory_stats()
    ep_start = time.perf_counter()
    train_loss = train_one_epoch(model, train_loader, optimizer, DEVICE, scheduler=scheduler, clip_grad_norm=CLIP_GRAD_NORM, aug_noise_std=AUG_NOISE_STD)
    val_loss = evaluate(model, val_loader, DEVICE, aug_noise_std=0.0) # Pass 0.0 for augmentation noise during evaluation
    ep_time = time.perf_counter() - ep_start
    throughput = len(train_loader.dataset) / max(ep_time, 1e-9)

    improved = val_loss < best_val - 1e-12
    if improved:
        best_val = val_loss
        save_checkpoint(
            CHECKPOINT_PATH, model, optimizer, epoch, best_val,
            extra={
                "INDICATORS_TO_USE": INDICATORS_TO_USE,
                "NORMALIZATION_TYPE": NORMALIZATION_TYPE,
                "SEQ_LEN": SEQ_LEN,
                "TOP_K_PERIODS": TOP_K_PERIODS,
                "EMBED_DIM": EMBED_DIM,
                "EMBED_H": EMBED_H,
                "EMBED_W": EMBED_W,
            }
        )

    stop = early.step(val_loss)
    lr = optimizer.param_groups[0].get('lr', None)
    msg = f"Epoch {epoch:03d} | train {train_loss:.6f} | val {val_loss:.6f} | {throughput:.0f} samp/s"
    if lr is not None:
        msg += f" | lr={lr:.2e}"
    if improved:
        msg += " | [saved]"
    if PRINT_GPU_MEM and (DEVICE == 'cuda'):
        peak = torch.cuda.max_memory_allocated() / 1e9
        msg += f" | GPU peak mem: {peak:.2f} GB"
    print(msg + f" | patience {early.num_bad}/{PATIENCE}")
    # Log epoch metrics
    try:
        import json as _json
        rec = {
            'epoch': int(epoch),
            'train_loss': float(train_loss),
            'val_loss': float(val_loss),
            'lr': float(lr) if lr is not None else None,
            'throughput': float(throughput),
            'epoch_time': float(ep_time),
            'gpu_peak_gb': float(peak) if (PRINT_GPU_MEM and (DEVICE=='cuda')) else None,
            'improved': bool(improved),
            'best_val': float(best_val)
        }
        training_history.append(rec)
        with open(HISTORY_PATH, 'a') as f:
            f.write(_json.dumps(rec) + '\n')
    except Exception as e:
        print('WARN: falha ao gravar histórico:', e)
    if stop:
        print(f"Early stopping at epoch {epoch}. Best val: {best_val:.6f}")
        break

In-channels (features): 12
Model params: 184.58M
Resuming from epoch 19 with best_val=0.000000
Starting training at epoch 20 on cuda | train windows=5606, batches/epoch=3, grad_accum=1, est_opt_steps/epoch=3
Epoch 020 | train 1.007655 | val 6.017887 | 797 samp/s | lr=5.10e-05 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 021 | train 1.009290 | val 5.907022 | 10548 samp/s | lr=8.34e-05 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 022 | train 1.007252 | val 5.746515 | 11012 samp/s | lr=1.36e-04 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 023 | train 1.001725 | val 5.484633 | 11449 samp/s | lr=2.06e-04 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 024 | train 0.988819 | val 5.020320 | 11039 samp/s | lr=2.90e-04 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 025 | train 0.964965 | val 4.136224 | 10845 samp/s | lr=3.85e-04 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 026 | train 0.930967 | val 2.563833 | 11413 samp/s | lr=4.86e-04 | GPU peak mem: 15.57 GB | patience 0/30
Epoch 027 

### Verification: Embedding-only forward

In [None]:

# Load best checkpoint
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
model = ModelWithDecoder(
    in_channels=in_channels,
    seq_len=SEQ_LEN,
    k=TOP_K_PERIODS,
    embed_dim=EMBED_DIM,
    embed_h=EMBED_H,
    embed_w=EMBED_W,
    dropout=DROPOUT_RATE
)
model.load_state_dict(ckpt["model_state_dict"], strict=True)
model = model.to(DEVICE)
model.eval()

# Remove temporary reconstruction decoder
model.strip_decoder()
assert model.decoder is None

# Take a sample batch from test set and compute embeddings
xb = next(iter(test_loader))
xb = xb.to(DEVICE, non_blocking=True)
with torch.no_grad():
    ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if ('AMP_DTYPE' in globals() and AMP_DTYPE is not None) else nullcontext()
    with ctx:
        # Use the extractor directly for inference-only embeddings
        emb = model.extractor(xb)  # [B, E, H, W]

print("Embedding tensor shape:", tuple(emb.shape))


### Probing: Configuração

In [None]:

# === Probing configuration ===
PROBE_HORIZON = 1  # steps ahead for label (adjust as needed)
PROBE_RETURN_COL = 'close'  # which column to compute returns from
PROBE_POOLING = 'avg'  # 'avg' over HxW or 'flatten'
PROBE_REG_C = 1.0
PROBE_MAX_ITER = 1000
PROBE_RANDOM_STATE = 42

# Regime detection windows (based on returns of PROBE_RETURN_COL)
REG_TREND_WIN = 128  # steps for trend proxy
REG_VOL_WIN = 128    # steps for realized volatility


### Probing: Funções Utilitárias


In [None]:

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.stats import spearmanr


def compute_forward_returns(raw_split: np.ndarray, close_idx: int, horizon: int) -> np.ndarray:
    close = raw_split[:, close_idx].astype(np.float64)
    # log returns for stability
    lr = np.zeros_like(close)
    if horizon >= 1:
        lr[:-horizon] = np.log(close[horizon:] / close[:-horizon])
        lr[-horizon:] = np.nan
    return lr


def ds_valid_range(ds, horizon: int, total_len: int):
    # windows whose end + horizon is within array
    val_mask = []
    for s in ds.idxs:
        e = s + ds.seq_len
        val_mask.append(e + horizon <= total_len)
    return np.array(val_mask, dtype=bool)


def window_labels_for_ds(ds, raw_split: np.ndarray, close_idx: int, horizon: int):
    fwd_ret = compute_forward_returns(raw_split, close_idx, horizon)
    mask = ds_valid_range(ds, horizon, len(raw_split))
    labels = []
    rets = []
    for ok, s in zip(mask, ds.idxs):
        if not ok:
            break
        e = s + ds.seq_len
        r = fwd_ret[e-1]  # return immediately after the window end
        labels.append(1 if r > 0 else 0)
        rets.append(r)
    return np.array(labels, dtype=np.int64), np.array(rets, dtype=np.float64), mask.sum()


def extract_embeddings(loader, model, device, pooling: str = 'avg'):
    model.eval()
    vecs = []
    with torch.no_grad():
        ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if ('AMP_DTYPE' in globals() and AMP_DTYPE is not None) else nullcontext()
        for xb in loader:
            xb = xb.to(device, non_blocking=True)
            with ctx:
                emb = model.extractor(xb)  # [B,E,H,W]
            if pooling == 'avg':
                v = emb.mean(dim=(2,3))  # [B,E]
            else:
                v = emb.view(emb.size(0), -1)
            vecs.append(v.detach().to('cpu').numpy())
    return np.concatenate(vecs, axis=0)


def make_regimes(raw_split: np.ndarray, close_idx: int, trend_win: int, vol_win: int):
    close = raw_split[:, close_idx].astype(np.float64)
    lr = np.zeros_like(close)
    lr[1:] = np.log(close[1:] / close[:-1])
    # trend proxy: rolling mean return
    trend = pd.Series(lr).rolling(trend_win, min_periods=trend_win//2).mean().to_numpy()
    # vol proxy: rolling std of returns
    vol = pd.Series(lr).rolling(vol_win, min_periods=vol_win//2).std(ddof=0).to_numpy()
    # quantile thresholds
    t_lo, t_hi = np.nanquantile(trend, [0.33, 0.67])
    v_lo, v_hi = np.nanquantile(vol, [0.33, 0.67])
    trend_reg = np.where(trend <= t_lo, -1, np.where(trend >= t_hi, 1, 0))  # -1 bear, 0 neutral, 1 bull
    vol_reg = np.where(vol <= v_lo, 0, np.where(vol >= v_hi, 2, 1))  # 0 low,1 mid,2 high
    return trend_reg, vol_reg


def eval_probe(y_true, y_score, fwd_returns):
    out = {}
    out['roc_auc'] = roc_auc_score(y_true, y_score)
    out['pr_auc'] = average_precision_score(y_true, y_score)
    # IC: Spearman between predicted prob and realized return
    vmask = ~np.isnan(fwd_returns)
    ic, _ = spearmanr(y_score[vmask], fwd_returns[vmask])
    out['ic'] = float(ic)
    # ICIR: mean/std of rolling-50 ICs for stability
    if vmask.sum() > 100:
        import numpy as _np
        wins = 50
        ics = []
        for i in range(0, vmask.sum()-wins+1):
            seg = slice(i, i+wins)
            c, _ = spearmanr(y_score[vmask][seg], fwd_returns[vmask][seg])
            if _np.isfinite(c):
                ics.append(c)
        if len(ics) >= 2:
            out['icir'] = float(_np.mean(ics) / ( _np.std(ics) + 1e-12))
        else:
            out['icir'] = float('nan')
    else:
        out['icir'] = float('nan')
    return out


### Probing: Pipeline e Métricas


In [None]:

# === Probing pipeline: fit on val, eval on test ===
from sklearn.metrics import classification_report

# Find column index for PROBE_RETURN_COL
try:
    close_idx = INDICATORS_TO_USE.index(PROBE_RETURN_COL)
except ValueError:
    raise ValueError(f"PROBE_RETURN_COL={PROBE_RETURN_COL} não está em INDICATORS_TO_USE: {INDICATORS_TO_USE}")

# Labels aligned to windows
y_val, r_val, n_val = window_labels_for_ds(val_loader.dataset, val_raw, close_idx, PROBE_HORIZON)
y_test, r_test, n_test = window_labels_for_ds(test_loader.dataset, test_raw, close_idx, PROBE_HORIZON)

# Embeddings extraídos e alinhados (mantendo apenas janelas válidas para o horizonte)
X_val_full = extract_embeddings(val_loader, model, DEVICE, pooling=PROBE_POOLING)
X_test_full = extract_embeddings(test_loader, model, DEVICE, pooling=PROBE_POOLING)
X_val = X_val_full[:n_val]
X_test = X_test_full[:n_test]

print(f"Embeddings -> val: {X_val.shape}, test: {X_test.shape}")
print(f"Labels -> val: {y_val.shape}, test: {y_test.shape}")

# Probe: logistic regression balanceada
probe = Pipeline([
    ('scaler', StandardScaler()),
    ('clf', LogisticRegression(
        C=PROBE_REG_C,
        max_iter=PROBE_MAX_ITER,
        class_weight='balanced',
        solver='lbfgs',
        random_state=PROBE_RANDOM_STATE
    ))
])
probe.fit(X_val, y_val)

# Scores em probabilidade (classe positiva = retorno>0)
proba_test = probe.predict_proba(X_test)[:,1]
metrics_all = eval_probe(y_test, proba_test, r_test)
print('--- Probe (TEST) - métricas gerais ---')
for k, v in metrics_all.items():
    print(f"{k}: {v:.6f}")

# Métricas por regime
trend_reg_test, vol_reg_test = make_regimes(test_raw, close_idx, REG_TREND_WIN, REG_VOL_WIN)
# Alinhar regimes ao final da janela
reg_aligned = trend_reg_test[SEQ_LEN-1:SEQ_LEN-1+n_test]
vol_aligned = vol_reg_test[SEQ_LEN-1:SEQ_LEN-1+n_test]

print('
--- Por regime de tendência (bear=-1, neutral=0, bull=1) ---')
for lbl in [-1,0,1]:
    idx = np.where(reg_aligned == lbl)[0]
    if len(idx) < 200:
        continue
    m = eval_probe(y_test[idx], proba_test[idx], r_test[idx])
    print(f"trend={lbl} | n={len(idx)} | auc={m['roc_auc']:.4f} | pr={m['pr_auc']:.4f} | ic={m['ic']:.4f} | icir={m['icir']:.4f}")

print('
--- Por regime de volatilidade (0=low,1=mid,2=high) ---')
for lbl in [0,1,2]:
    idx = np.where(vol_aligned == lbl)[0]
    if len(idx) < 200:
        continue
    m = eval_probe(y_test[idx], proba_test[idx], r_test[idx])
    print(f"vol={lbl} | n={len(idx)} | auc={m['roc_auc']:.4f} | pr={m['pr_auc']:.4f} | ic={m['ic']:.4f} | icir={m['icir']:.4f}")


### Embeddings: Export para PPO Agent


In [None]:

# === Exportar embeddings para consumo por PPO Agent ===
import os, json as _json
from pathlib import Path as _Path

# Diretório de saída (ajuste se necessário)
EMBED_SAVE_DIR = "/content/drive/MyDrive/timesnet_mnq/embeddings"  #@param {type:"string"}
EMBED_POOLING_EXPORT = 'avg'  #@param ["avg", "flatten"]

os.makedirs(EMBED_SAVE_DIR, exist_ok=True)
print(f"Salvando em: {EMBED_SAVE_DIR}")

# Função local de extração para evitar dependências de ordem de execução
@torch.no_grad()
def _extract_embeddings(loader, model, device, pooling: str = 'avg'):
    model.eval()
    vecs = []
    ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if ('AMP_DTYPE' in globals() and AMP_DTYPE is not None) else nullcontext()
    for xb in loader:
        xb = xb.to(device, non_blocking=True)
        with ctx:
            emb = model.extractor(xb)  # [B,E,H,W]
        if pooling == 'avg':
            v = emb.mean(dim=(2,3))  # [B,E]
        else:
            v = emb.view(emb.size(0), -1)
        vecs.append(v.detach().to('cpu').numpy())
    import numpy as _np
    return _np.concatenate(vecs, axis=0)

# Helper para alinhar janelas válidas (DataLoader usa drop_last=True)
def _align_by_batches(loader, ds):
    total = len(loader) * loader.batch_size
    idxs = ds.idxs[:total]
    import numpy as _np
    start_idx = _np.array(idxs, dtype=_np.int64)
    end_idx = start_idx + int(SEQ_LEN) - 1
    return start_idx, end_idx

splits = {
    'train': (train_loader, train_raw),
    'val':   (val_loader,   val_raw),
    'test':  (test_loader,  test_raw),
}

saved = {}
for name, (loader, raw) in splits.items():
    X = _extract_embeddings(loader, model, DEVICE, pooling=EMBED_POOLING_EXPORT)
    # Cortar para múltiplo do batch (drop_last)
    N = len(loader) * loader.batch_size
    X = X[:N]
    start_idx, end_idx = _align_by_batches(loader, loader.dataset)
    out_path = os.path.join(EMBED_SAVE_DIR, f"{name}_emb_{EMBED_POOLING_EXPORT}.npz")
    import numpy as _np
    _np.savez_compressed(out_path, X=X, start_idx=start_idx, end_idx=end_idx)
    saved[name] = out_path
    print(f"{name}: salvo {X.shape} -> {out_path}")

# Metadados úteis para consumidores downstream
meta = {
    'seq_len': int(SEQ_LEN),
    'window_stride': int(WINDOW_STRIDE),
    'features': INDICATORS_TO_USE,
    'normalization': NORMALIZATION_TYPE,
    'embed_dim': int(EMBED_DIM),
    'embed_hw': [int(EMBED_H), int(EMBED_W)],
    'pooling': EMBED_POOLING_EXPORT,
    'checkpoint_path': CHECKPOINT_PATH,
    'files': saved,
}
meta_path = os.path.join(EMBED_SAVE_DIR, 'embeddings_meta.json')
with open(meta_path, 'w') as f:
    _json.dump(meta, f, indent=2)
print(f"Meta salvo em: {meta_path}")


### Visualização Dinâmica


In [None]:

# === Visualização dinâmica do histórico de treino ===
import os, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import ipywidgets as w

HIST_PATH = HISTORY_PATH
assert os.path.exists(HIST_PATH), f"Histórico não encontrado em {HIST_PATH}. Treine primeiro."

# Carrega JSON Lines
rows = []
with open(HIST_PATH, 'r') as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            rows.append(json.loads(line))
        except Exception:
            pass

df = pd.DataFrame(rows)
if df.empty:
    raise SystemExit('Histórico vazio.')

# Widgets
smoothing = w.IntSlider(description='Smoothing', min=1, max=10, value=1)
show_lr = w.Checkbox(value=True, description='Mostrar LR')
show_throughput = w.Checkbox(value=False, description='Mostrar Throughput')
show_gpu = w.Checkbox(value=False, description='Mostrar GPU Peak')
epoch_range = w.IntRangeSlider(description='Épocas', min=int(df.epoch.min()), max=int(df.epoch.max()), value=[int(df.epoch.min()), int(df.epoch.max())], step=1)
refresh = w.Button(description='Recarregar', button_style='')

out = w.Output()

# Plot function

def _plot(*args):
    with out:
        clear_output(wait=True)
        lo, hi = epoch_range.value
        d = df[(df.epoch>=lo)&(df.epoch<=hi)].copy()
        if smoothing.value>1:
            d['train_s'] = d['train_loss'].rolling(smoothing.value, min_periods=1).mean()
            d['val_s'] = d['val_loss'].rolling(smoothing.value, min_periods=1).mean()
        else:
            d['train_s'] = d['train_loss']
            d['val_s'] = d['val_loss']
        best_ep = int(df.loc[df.val_loss.idxmin(),'epoch'])
        fig, ax1 = plt.subplots(1,1, figsize=(10,5))
        ax1.plot(d['epoch'], d['train_s'], label='train (smoothed)')
        ax1.plot(d['epoch'], d['val_s'], label='val (smoothed)')
        ax1.axvline(best_ep, color='g', linestyle='--', alpha=0.5, label=f'best@{best_ep}')
        ax1.set_xlabel('epoch')
        ax1.set_ylabel('loss')
        ax1.grid(True, alpha=0.2)
        lines, labels = ax1.get_legend_handles_labels()
        # Optional axes
        if show_lr.value:
            ax2 = ax1.twinx()
            ax2.plot(d['epoch'], d['lr'], color='tab:purple', alpha=0.4, label='lr')
            ax2.set_ylabel('lr')
            l2, lab2 = ax2.get_legend_handles_labels()
            lines += l2; labels += lab2
        if show_throughput.value:
            ax3 = ax1.twinx()
            ax3.spines.right.set_position(("axes", 1.1))
            ax3.plot(d['epoch'], d['throughput'], color='tab:orange', alpha=0.4, label='samp/s')
            l3, lab3 = ax3.get_legend_handles_labels()
            lines += l3; labels += lab3
        if show_gpu.value and 'gpu_peak_gb' in d.columns:
            ax4 = ax1.twinx()
            ax4.spines.right.set_position(("axes", 1.2))
            ax4.plot(d['epoch'], d['gpu_peak_gb'], color='tab:red', alpha=0.4, label='gpu peak GB')
            l4, lab4 = ax4.get_legend_handles_labels()
            lines += l4; labels += lab4
        ax1.legend(lines, labels, loc='best')
        plt.show()

# Refresh handler

def _reload(_btn):
    global df
    rows = []
    with open(HIST_PATH, 'r') as f:
        for line in f:
            line=line.strip()
            if not line: continue
            try: rows.append(json.loads(line))
            except Exception: pass
    df = pd.DataFrame(rows)
    epoch_range.max = int(max(epoch_range.max, df.epoch.max()))
    epoch_range.value = [int(df.epoch.min()), int(df.epoch.max())]
    _plot()

refresh.on_click(_reload)
for wdg in [smoothing, show_lr, show_throughput, show_gpu, epoch_range]:
    wdg.observe(_plot, names='value')

controls = w.HBox([smoothing, show_lr, show_throughput, show_gpu])
display(controls, epoch_range, refresh, out)
_plot()
