### Configuration

In [None]:

CSV_PATH = "mnq_complete_dataset.csv"  #@param {type:"string"}

# Features to use (OHLCV + 7 TA indicators; all lowercase)
INDICATORS_TO_USE = [
    'open', 'high', 'low', 'close', 'volume',
    'atr_14', 'adx_14', 'ema_9', 'ema_21', 'vwap', 'rsi_21', 'stochk_14_3_3'
]  #@param {type:"raw"}

# Normalization: 'StandardScaler' or 'MinMaxScaler'
NORMALIZATION_TYPE = "StandardScaler"  #@param ["StandardScaler", "MinMaxScaler"]

# Chronological split ratios (must sum to 1.0)
TRAIN_VALID_TEST_SPLIT = [0.7, 0.15, 0.15]  #@param {type:"raw"}

# Sequence/window length for model input
SEQ_LEN = 512  #@param {type:"integer"}

# Step size between windows (1 = full overlap)
WINDOW_STRIDE = 1  #@param {type:"integer"}

# TimesBlock
TOP_K_PERIODS = 3  #@param {type:"integer"}  # k in paper
EMBED_DIM = 192     # 2D backbone output channels
EMBED_H = 8
EMBED_W = 8
DROPOUT_RATE = 0.1  #@param {type:"number"}

# Training
NUM_EPOCHS = 50  #@param {type:"integer"}
BATCH_SIZE = 2048  #@param {type:"integer"}
LEARNING_RATE = 1e-3  #@param {type:"number"}
WEIGHT_DECAY = 1e-4  #@param {type:"number"}
PATIENCE = 7  #@param {type:"integer"}
GRAD_ACCUM_STEPS = 1  #@param {type:"integer"}

# DataLoader performance
DATALOADER_WORKERS = 4  #@param {type:"integer"}
PIN_MEMORY = True  #@param {type:"boolean"}

# Precision & memory format (A100)
USE_BF16 = True  #@param {type:"boolean"}
CHANNELS_LAST = True  #@param {type:"boolean"}
PRINT_GPU_MEM = True  #@param {type:"boolean"}

# Checkpointing
CHECKPOINT_PATH = "/content/drive/MyDrive/timesnet_mnq/checkpoints/best.pt"  #@param {type:"string"}

# Resume training if checkpoint exists
RESUME_TRAINING = True  #@param {type:"boolean"}

# Device
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Config -> seq_len={SEQ_LEN}, batch_size={BATCH_SIZE}, k={TOP_K_PERIODS}, embed_dim={EMBED_DIM}, HxW={EMBED_H}x{EMBED_W}, dropout={DROPOUT_RATE}")
print(f"Optim -> AdamW lr={LEARNING_RATE}, wd={WEIGHT_DECAY}, patience={PATIENCE}, grad_accum={GRAD_ACCUM_STEPS}")
print(f"Loader -> workers={DATALOADER_WORKERS}, pin_memory={PIN_MEMORY}")
print(f"Device -> {DEVICE} | bf16={USE_BF16} | channels_last={CHANNELS_LAST}")


### Setup & Imports

In [None]:
import os
import math
import json
import sys, subprocess
import time
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from datetime import datetime

# Ensure pandas-ta is available early
try:
    import pandas_ta as ta
except Exception:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'pandas-ta'])
    import pandas_ta as ta

import contextlib

def _fmt_gb(bytes_val):
    try:
        return f"{bytes_val/1e9:.2f} GB"
    except Exception:
        return str(bytes_val)

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU -> {torch.cuda.get_device_name(0)} | VRAM total: {_fmt_gb(props.total_memory)}")

try:
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('high')
except Exception:
    pass


### Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)
print(f"Checkpoints dir -> {os.path.dirname(CHECKPOINT_PATH)}")
print(f"Checkpoint file -> {CHECKPOINT_PATH}")


### Data Loading & Chronological Splits (no leakage)



In [None]:
def _find_datetime_column(df: pd.DataFrame) -> Optional[str]:
    candidates = ['datetime', 'date', 'time', 'timestamp', 'ts']
    cols = [c for c in df.columns]
    for c in cols:
        if c.lower() in candidates:
            return c
    return None

def _find_col_ci(df: pd.DataFrame, name: str) -> Optional[str]:
    for c in df.columns:
        if c.lower() == name.lower():
            return c
    return None

def _has_col_ci(df: pd.DataFrame, name: str) -> bool:
    return any(c.lower() == name.lower() for c in df.columns)

def add_ta_indicators(df: pd.DataFrame) -> pd.DataFrame:
    # Normalize base column names to lowercase for consistency
    rename_map = {}
    for nm in ['Open','High','Low','Close','Volume']:
        c = _find_col_ci(df, nm)
        if c is not None:
            rename_map[c] = nm.lower()
    if rename_map:
        df = df.rename(columns=rename_map)

    required = ['high','low','close','volume']
    if any(not _has_col_ci(df, r) for r in required):
        raise ValueError('Missing OHLCV columns (high, low, close, volume) to compute pandas-ta indicators.')

    # Compute indicators only if a lowercase target is not already present
    if not _has_col_ci(df, 'atr_14'):
        df.ta.atr(high='high', low='low', close='close', length=14, append=True)
    if not _has_col_ci(df, 'adx_14'):
        df.ta.adx(high='high', low='low', close='close', length=14, append=True)
    if not _has_col_ci(df, 'ema_9'):
        df.ta.ema(close='close', length=9, append=True)
    if not _has_col_ci(df, 'ema_21'):
        df.ta.ema(close='close', length=21, append=True)
    if not _has_col_ci(df, 'vwap'):
        df.ta.vwap(high='high', low='low', close='close', volume='volume', append=True)
    if not _has_col_ci(df, 'rsi_21'):
        df.ta.rsi(close='close', length=21, append=True)
    if not _has_col_ci(df, 'stochk_14_3_3'):
        df.ta.stoch(high='high', low='low', close='close', k=14, d=3, smooth_k=3, append=True)

    # Rename known pandas-ta outputs to lowercase canonical names if needed
    lower_map = {}
    # ATR
    if 'ATR_14' in df.columns and 'atr_14' not in df.columns:
        lower_map['ATR_14'] = 'atr_14'
    # ADX
    if 'ADX_14' in df.columns and 'adx_14' not in df.columns:
        lower_map['ADX_14'] = 'adx_14'
    # EMAs
    if 'EMA_9' in df.columns and 'ema_9' not in df.columns:
        lower_map['EMA_9'] = 'ema_9'
    if 'EMA_21' in df.columns and 'ema_21' not in df.columns:
        lower_map['EMA_21'] = 'ema_21'
    # VWAP
    if 'VWAP' in df.columns and 'vwap' not in df.columns:
        lower_map['VWAP'] = 'vwap'
    # RSI
    if 'RSI_21' in df.columns and 'rsi_21' not in df.columns:
        lower_map['RSI_21'] = 'rsi_21'
    # STOCH K/D (we'll expose K by default)
    if 'STOCHk_14_3_3' in df.columns and 'stochk_14_3_3' not in df.columns:
        lower_map['STOCHk_14_3_3'] = 'stochk_14_3_3'
    if lower_map:
        df = df.rename(columns=lower_map)
    return df

def load_mnq_csv(csv_path: str, indicators: List[str]) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    time_col = _find_datetime_column(df)
    if time_col is not None:
        try:
            df[time_col] = pd.to_datetime(df[time_col])
            df = df.sort_values(by=time_col, ascending=True).reset_index(drop=True)
        except Exception:
            df = df.reset_index(drop=True)
    else:
        df = df.reset_index(drop=True)

    # Create TA indicators and ensure lowercase canon names
    df = add_ta_indicators(df)
    # Select requested features (lowercase)
    missing = [c for c in indicators if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns after TA creation: {missing}. Available: {list(df.columns)}")
    xdf = df[indicators].copy()
    # Clean: numeric, remove inf, drop NaNs to avoid leakage via backward fill
    for c in xdf.columns:
        xdf[c] = pd.to_numeric(xdf[c], errors='coerce')
    xdf = xdf.replace([np.inf, -np.inf], np.nan)
    xdf = xdf.dropna()
    return xdf

def chronological_split(arr: np.ndarray, ratios: List[float]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    assert abs(sum(ratios) - 1.0) < 1e-6, "TRAIN_VALID_TEST_SPLIT must sum to 1.0"
    T = len(arr)
    n_train = int(T * ratios[0])
    n_val = int(T * ratios[1])
    n_test = T - n_train - n_val
    train = arr[:n_train]
    val = arr[n_train:n_train+n_val]
    test = arr[n_train+n_val:]
    return train, val, test

def make_scaler(norm_type: str):
    if norm_type == "StandardScaler":
        return StandardScaler()
    elif norm_type == "MinMaxScaler":
        return MinMaxScaler()
    else:
        raise ValueError("NORMALIZATION_TYPE must be 'StandardScaler' or 'MinMaxScaler'.")

def fit_transform_splits(train_arr: np.ndarray, val_arr: np.ndarray, test_arr: np.ndarray, scaler):
    # Fit only on training data
    scaler.fit(train_arr)
    train_scaled = scaler.transform(train_arr)
    val_scaled = scaler.transform(val_arr)
    test_scaled = scaler.transform(test_arr)
    return train_scaled, val_scaled, test_scaled

# Load CSV
print(f'Loading CSV: {CSV_PATH}')
df = load_mnq_csv(CSV_PATH, INDICATORS_TO_USE)
print(f'Features selected: {len(INDICATORS_TO_USE)} | {INDICATORS_TO_USE}')
data = df.values.astype(np.float32)
print(f'Total rows after TA + cleanup: {len(df)} | Feature dim: {data.shape[1]}')

# Chronological split (no shuffling)
train_raw, val_raw, test_raw = chronological_split(data, TRAIN_VALID_TEST_SPLIT)
print(f'Split -> train: {train_raw.shape}, val: {val_raw.shape}, test: {test_raw.shape}')

# Train-only normalization
scaler = make_scaler(NORMALIZATION_TYPE)
train_scaled, val_scaled, test_scaled = fit_transform_splits(train_raw, val_raw, test_raw, scaler)
if NORMALIZATION_TYPE == 'StandardScaler':
    means = scaler.mean_
    stds = scaler.scale_ if hasattr(scaler, 'scale_') else np.sqrt(scaler.var_)
    print(f'Scaler(Standard) -> mean range [{means.min():.4f}, {means.max():.4f}] | std range [{stds.min():.4f}, {stds.max():.4f}]')
else:
    mins = scaler.data_min_
    maxs = scaler.data_max_
    print(f'Scaler(MinMax) -> min range [{mins.min():.4f}, {mins.max():.4f}] | max range [{maxs.min():.4f}, {maxs.max():.4f}]')

def _count_windows(n, L, s):
    return max(0, (n - L) // s + 1)
nw_train = _count_windows(len(train_scaled), SEQ_LEN, WINDOW_STRIDE)
nw_val = _count_windows(len(val_scaled), SEQ_LEN, WINDOW_STRIDE)
nw_test = _count_windows(len(test_scaled), SEQ_LEN, WINDOW_STRIDE)
print(f'Windows -> train: {nw_train}, val: {nw_val}, test: {nw_test} | stride={WINDOW_STRIDE}, seq_len={SEQ_LEN}')

train_scaled.shape, val_scaled.shape, test_scaled.shape


### Dataset & DataLoaders (windowed, no shuffle)



In [None]:
class MNQ_Dataset(Dataset):
    def __init__(self, arr_2d: np.ndarray, seq_len: int, stride: int = 1):
        super().__init__()
        self.x = arr_2d
        self.seq_len = int(seq_len)
        self.stride = int(stride)
        self.T = len(arr_2d)
        self.C = arr_2d.shape[1]
        if self.T < self.seq_len:
            raise ValueError(f"Not enough timesteps ({self.T}) for seq_len={self.seq_len}")
        # Number of windows using stride
        self.idxs = list(range(0, self.T - self.seq_len + 1, self.stride))

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        start = self.idxs[idx]
        end = start + self.seq_len
        window = self.x[start:end]  # shape [seq_len, C]
        # model expects [L, C], training target == input window (self-supervised)
        return torch.from_numpy(window).float()

def make_loader(arr: np.ndarray, seq_len: int, stride: int, batch_size: int) -> DataLoader:
    ds = MNQ_Dataset(arr, seq_len=seq_len, stride=stride)
    # No shuffling to avoid any perceived leakage; data is chronologically windowed already
    return DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=DATALOADER_WORKERS, pin_memory=PIN_MEMORY)

train_loader = make_loader(train_scaled, SEQ_LEN, WINDOW_STRIDE, BATCH_SIZE)
val_loader = make_loader(val_scaled, SEQ_LEN, WINDOW_STRIDE, BATCH_SIZE)
test_loader = make_loader(test_scaled, SEQ_LEN, WINDOW_STRIDE, BATCH_SIZE)

tw, vw, tew = len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)
tb, vb, teb = len(train_loader), len(val_loader), len(test_loader)
print(f'DataLoaders -> windows (train/val/test): {tw}/{vw}/{tew} | batches: {tb}/{vb}/{teb} | batch_size={BATCH_SIZE}')
print(f'Window shape -> L={SEQ_LEN}, C={train_scaled.shape[1]}')


### Model: TimesBlock + Inception + Feature Extractor + Decoder



In [None]:
class InceptionBlock2D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.1):
        super().__init__()
        # Split out_channels across branches
        b = out_channels // 4
        r = out_channels - 3*b  # distribute remainder to first branch
        b1 = b + r
        b2 = b
        b3 = b
        b4 = b

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, b1, kernel_size=1, bias=False),
            nn.BatchNorm2d(b1),
            nn.ReLU(inplace=True),
        )

        red3 = max(in_channels // 2, 8)
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, red3, kernel_size=1, bias=False),
            nn.BatchNorm2d(red3),
            nn.ReLU(inplace=True),
            nn.Conv2d(red3, b2, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(b2),
            nn.ReLU(inplace=True),
        )

        red5 = max(in_channels // 2, 8)
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, red5, kernel_size=1, bias=False),
            nn.BatchNorm2d(red5),
            nn.ReLU(inplace=True),
            nn.Conv2d(red5, b3, kernel_size=5, padding=2, bias=False),
            nn.BatchNorm2d(b3),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, b4, kernel_size=1, bias=False),
            nn.BatchNorm2d(b4),
            nn.ReLU(inplace=True),
        )

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        o1 = self.branch1x1(x)
        o2 = self.branch3x3(x)
        o3 = self.branch5x5(x)
        o4 = self.branch_pool(x)
        out = torch.cat([o1, o2, o3, o4], dim=1)
        return self.dropout(out)

class TimesBlock(nn.Module):
    """
    FFT-based period discovery + 1D->2D folding + 2D Inception backbone + weighted fusion.
    Input:  x [B, L, C]
    Output: embedding [B, E, H, W]
    """
    def __init__(self, in_channels: int, k: int, embed_dim: int, embed_h: int, embed_w: int, dropout: float = 0.1):
        super().__init__()
        self.in_channels = in_channels
        self.k = k
        self.embed_dim = embed_dim
        self.embed_h = embed_h
        self.embed_w = embed_w
        self.backbone = InceptionBlock2D(in_channels, embed_dim, dropout=dropout)

    @torch.no_grad()
    def _find_topk_periods(self, x_bc_l: torch.Tensor, k: int) -> Tuple[List[int], torch.Tensor]:
        # x_bc_l: [B, C, L]
        B, C, L = x_bc_l.shape
        xf = torch.fft.rfft(x_bc_l, dim=-1)  # [B, C, L//2 + 1]
        amp = xf.abs().mean(dim=(0, 1))      # [L//2 + 1], averaged over batch & channels
        if amp.shape[0] <= 1:
            return [L], torch.tensor([1.0], device=x_bc_l.device)

        amp[0] = 0.0  # ignore DC
        k_eff = min(k, amp.shape[0]-1)
        vals, idxs = torch.topk(amp, k=k_eff, largest=True, sorted=True)
        periods = []
        for idx in idxs.tolist():
            p = int(round(L / max(idx, 1)))
            p = max(p, 2)
            periods.append(p)
        # Softmax weights from amplitudes
        w = torch.softmax(vals, dim=0)
        return periods, w

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, C]
        B, L, C = x.shape
        x_bc_l = x.permute(0, 2, 1).contiguous()  # [B, C, L]

        periods, weights = self._find_topk_periods(x_bc_l, self.k)
        feats = None
        for i, p in enumerate(periods):
            pad_len = (p - (L % p)) % p
            x_pad = F.pad(x_bc_l, (0, pad_len), mode='constant', value=0.0)  # [B,C,Lp]
            Lp = x_pad.shape[-1]
            w_ = Lp // p
            # Fold: [B, C, p, w_]
            x_2d = x_pad.view(B, C, w_, p).transpose(2, 3).contiguous()

            z = self.backbone(x_2d)  # [B, E, h, w]
            z = F.adaptive_avg_pool2d(z, (self.embed_h, self.embed_w))  # [B,E,H,W]
            z = z * weights[i].view(1, 1, 1, 1)  # weight this period

            feats = z if feats is None else (feats + z)

        return feats  # [B, E, H, W]

class TimesNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels: int, k: int, embed_dim: int, embed_h: int, embed_w: int, dropout: float=0.1):
        super().__init__()
        self.block = TimesBlock(in_channels, k, embed_dim, embed_h, embed_w, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, C]
        emb = self.block(x)
        return self.dropout(emb)  # [B, E, H, W]

class ModelWithDecoder(nn.Module):
    """
    Wrapper for self-supervised training (reconstruction).
    During inference, use extractor only.
    """
    def __init__(self, in_channels: int, seq_len: int, k: int, embed_dim: int, embed_h: int, embed_w: int, dropout: float=0.1):
        super().__init__()
        self.in_channels = in_channels
        self.seq_len = seq_len

        self.extractor = TimesNetFeatureExtractor(
            in_channels=in_channels,
            k=k,
            embed_dim=embed_dim,
            embed_h=embed_h,
            embed_w=embed_w,
            dropout=dropout
        )
        flat_size = embed_dim * embed_h * embed_w
        hidden = max(512, flat_size // 2)

        self.decoder = nn.Sequential(
            nn.Flatten(),                           # [B, E*H*W]
            nn.Linear(flat_size, hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hidden, seq_len * in_channels)
        )

    def forward(self, x: torch.Tensor):
        # x: [B, L, C]
        emb = self.extractor(x)  # [B, E, H, W]
        rec = self.decoder(emb).view(x.shape[0], self.seq_len, self.in_channels)
        return rec, emb

    def strip_decoder(self):
        self.decoder = None


### Training Utilities: EarlyStopping, Checkpointing, Train/Eval

In [None]:
from contextlib import nullcontext
USE_AMP = bool(torch.cuda.is_available() and (DEVICE == 'cuda') and USE_BF16 and getattr(torch.cuda, 'is_bf16_supported', lambda: False)())
AMP_DTYPE = torch.bfloat16 if USE_AMP else None

class EarlyStopping:
    def __init__(self, patience: int = 7, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best = None
        self.num_bad = 0

    def step(self, value: float) -> bool:
        if self.best is None or value < self.best - self.min_delta:
            self.best = value
            self.num_bad = 0
            return False  # no stop
        else:
            self.num_bad += 1
            return self.num_bad >= self.patience

def save_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer, epoch: int, best_val: float, extra: dict=None):
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "best_val_loss": best_val,
    }
    if extra:
        state["extra"] = extra
    torch.save(state, path)

def load_checkpoint_if_any(path: str, model: nn.Module, optimizer: torch.optim.Optimizer, resume: bool):
    start_epoch = 1
    best_val = float("inf")
    if resume and os.path.exists(path):
        ckpt = torch.load(path, map_location="cpu")
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt.get("epoch", 1) + 1
        best_val = ckpt.get("best_val_loss", float("inf"))
        print(f"Resuming from epoch {start_epoch-1} with best_val={best_val:.6f}")
    return start_epoch, best_val

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    mse = nn.MSELoss()
    total_loss = 0.0
    n = 0
    optimizer.zero_grad(set_to_none=True)
    for i, xb in enumerate(loader):
        xb = xb.to(device, non_blocking=True)
        ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if USE_AMP else nullcontext()
        with ctx:
            rec, emb = model(xb)
            loss = mse(rec, xb) / max(int(GRAD_ACCUM_STEPS), 1)
        loss.backward()
        if (i + 1) % max(int(GRAD_ACCUM_STEPS), 1) == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        # Accumulate reporting with true scale
        total_loss += (loss.item() * max(int(GRAD_ACCUM_STEPS), 1)) * xb.size(0)
        n += xb.size(0)
    return total_loss / max(n, 1)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    mse = nn.MSELoss()
    total_loss = 0.0
    n = 0
    for xb in loader:
        xb = xb.to(device, non_blocking=True)
        ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if USE_AMP else nullcontext()
        with ctx:
            rec, emb = model(xb)
            loss = mse(rec, xb)
        total_loss += loss.item() * xb.size(0)
        n += xb.size(0)
    return total_loss / max(n, 1)


### Initialize Model, Optimizer, and Train

In [None]:

in_channels = train_scaled.shape[1]
print(f'In-channels (features): {in_channels}')
model = ModelWithDecoder(
    in_channels=in_channels,
    seq_len=SEQ_LEN,
    k=TOP_K_PERIODS,
    embed_dim=EMBED_DIM,
    embed_h=EMBED_H,
    embed_w=EMBED_W,
    dropout=DROPOUT_RATE
).to(DEVICE)
if CHANNELS_LAST and (DEVICE == 'cuda'):
    model = model.to(memory_format=torch.channels_last)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model params: {params/1e6:.2f}M')

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

start_epoch, best_val = load_checkpoint_if_any(CHECKPOINT_PATH, model, optimizer, resume=RESUME_TRAINING)
early = EarlyStopping(patience=PATIENCE, min_delta=0.0)

tb, vb = len(train_loader), len(val_loader)
tw, vw = len(train_loader.dataset), len(val_loader.dataset)
est_opt_steps = (tb + max(int(GRAD_ACCUM_STEPS),1) - 1) // max(int(GRAD_ACCUM_STEPS),1)
print(f"Starting training at epoch {start_epoch} on {DEVICE} | train windows={tw}, batches/epoch={tb}, grad_accum={GRAD_ACCUM_STEPS}, est_opt_steps/epoch={est_opt_steps}")
for epoch in range(start_epoch, NUM_EPOCHS + 1):
    if DEVICE == 'cuda':
        torch.cuda.reset_peak_memory_stats()
    ep_start = time.perf_counter()
    train_loss = train_one_epoch(model, train_loader, optimizer, DEVICE)
    val_loss = evaluate(model, val_loader, DEVICE)
    ep_time = time.perf_counter() - ep_start
    throughput = len(train_loader.dataset) / max(ep_time, 1e-9)

    improved = val_loss < best_val - 1e-12
    if improved:
        best_val = val_loss
        save_checkpoint(
            CHECKPOINT_PATH, model, optimizer, epoch, best_val,
            extra={
                "INDICATORS_TO_USE": INDICATORS_TO_USE,
                "NORMALIZATION_TYPE": NORMALIZATION_TYPE,
                "SEQ_LEN": SEQ_LEN,
                "TOP_K_PERIODS": TOP_K_PERIODS,
                "EMBED_DIM": EMBED_DIM,
                "EMBED_H": EMBED_H,
                "EMBED_W": EMBED_W,
            }
        )

    stop = early.step(val_loss)
    lr = optimizer.param_groups[0].get('lr', None)
    msg = f"Epoch {epoch:03d} | train {train_loss:.6f} | val {val_loss:.6f} | {throughput:.0f} samp/s"
    if lr is not None:
        msg += f" | lr={lr:.2e}"
    if improved:
        msg += " | [saved]"
    if PRINT_GPU_MEM and (DEVICE == 'cuda'):
        peak = torch.cuda.max_memory_allocated() / 1e9
        msg += f" | GPU peak mem: {peak:.2f} GB"
    print(msg + f" | patience {early.num_bad}/{PATIENCE}")
    if stop:
        print(f"Early stopping at epoch {epoch}. Best val: {best_val:.6f}")
        break


### Verification: Embedding-only forward

In [None]:

# Load best checkpoint
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
model = ModelWithDecoder(
    in_channels=in_channels,
    seq_len=SEQ_LEN,
    k=TOP_K_PERIODS,
    embed_dim=EMBED_DIM,
    embed_h=EMBED_H,
    embed_w=EMBED_W,
    dropout=DROPOUT_RATE
)
model.load_state_dict(ckpt["model_state_dict"], strict=True)
model = model.to(DEVICE)
model.eval()

# Remove temporary reconstruction decoder
model.strip_decoder()
assert model.decoder is None

# Take a sample batch from test set and compute embeddings
xb = next(iter(test_loader))
xb = xb.to(DEVICE, non_blocking=True)
with torch.no_grad():
    ctx = torch.autocast(device_type='cuda', dtype=AMP_DTYPE) if ('AMP_DTYPE' in globals() and AMP_DTYPE is not None) else nullcontext()
    with ctx:
        # Use the extractor directly for inference-only embeddings
        emb = model.extractor(xb)  # [B, E, H, W]

print("Embedding tensor shape:", tuple(emb.shape))
