# Training

This notebook contains the training script for the revised encoder-decoder transformer with handpicked input features.

In [None]:
# imports
import time
from pathlib import Path

import numpy as np

# import os
# os.environ["TORCHDYNAMO_VERBOSE"] = "1"
# os.environ["TORCH_LOGS"] = "+dynamo"

import wandb
import torch
import torch.optim as optim
# import torch.multiprocessing as mp
import torchvision.transforms.v2

from torch import nn, GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

from src.model.loss import ODRMSELoss
from src.model.residual import Transformer
from src.model.scaler import ZTransform

# mp.set_start_method('spawn', force=True)

In [None]:
# paths
dataset_path = Path('../data/preprocessed/dataset-scaled-time.npz')

model_path = Path('../model/')
ground_truth_scaler_p = model_path / 'z-scaler' / 'ground-truth-scaler-scaled-time.npz'
weight_save_path = model_path / 'models'

weight_save_path.mkdir(parents=True, exist_ok=True)

In [None]:
# scaler
ground_truth_scaler = ZTransform.load(ground_truth_scaler_p)  # necessary for calculating the OD-RMSE-Loss

In [None]:
# dataloader constants
DEVICE = 'cuda'
TRAIN_SPLIT = 0.8

TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 512

INPUT_NOISE_STD = 0.01
OUTPUT_NOISE_STD = 0.0005

NUM_WORKERS_TRAIN = 24
NUM_WORKERS_VALID = 8
PERSISTENT_WORKERS = True
PREFETCH_FACTOR = 2

RESIDUAL_TRAINING = True
INCLUDE_NLRMSIS_X = True
USE_LOG = True  # requires rescaling before using with OD-RMSE
USE_LOG_X = USE_LOG or False

In [None]:
# load data
seed = 0xC0FFEEBABE

npz_file_contents = np.load(dataset_path, allow_pickle=True)
x, y, baseline = npz_file_contents['x'], npz_file_contents['y'], npz_file_contents['nrlm']

rng = np.random.default_rng(seed=seed)
idx = rng.permutation(x.shape[0])
x = x[idx]
y = y[idx]
baseline = baseline[idx]

total_sims = x.shape[0]
split_seperator = int(total_sims * TRAIN_SPLIT)
x_train, x_valid = x[:split_seperator], x[split_seperator:]
y_train, y_valid = y[:split_seperator], y[split_seperator:]
baseline_train, baseline_valid = baseline[:split_seperator], baseline[split_seperator:]

In [None]:
baseline_residual = baseline[:, 0].reshape(-1, 1) if RESIDUAL_TRAINING else 0.0
y_rescaled = ground_truth_scaler.reverse_z_transform(y) + baseline

if USE_LOG:
    # we take the log of y not the residual -> can lead to negative values which are not supported by log(x)
    y_rescaled = np.log(y_rescaled) - (np.log(baseline_residual) if RESIDUAL_TRAINING else 0.0)
    baseline = np.log(baseline)
    baseline_train, baseline_valid = baseline[:split_seperator], baseline[split_seperator:]
else:
    y_rescaled -= baseline_residual

scale_one = ZTransform(y_rescaled, axis=(0, 1))
y_new = np.stack([scale_one.z_transform(y_rescaled), baseline], axis=-1)

y_train, y_valid = y_new[:split_seperator], y_new[split_seperator:]

In [None]:
baseline_residual = baseline[:, 0].reshape(-1, 1) if RESIDUAL_TRAINING else 0.0
y_rescaled = ground_truth_scaler.reverse_z_transform(y) + baseline

if USE_LOG:
    # we take the log of y not the residual -> can lead to negative values which are not supported by log(x)
    y_rescaled = np.log(y_rescaled) - (np.log(baseline_residual) if RESIDUAL_TRAINING else 0.0)
    baseline = np.log(baseline)
    baseline_train, baseline_valid = baseline[:split_seperator], baseline[split_seperator:]
else:
    y_rescaled -= baseline_residual

scale_one2 = ZTransform(y_rescaled, axis=(0, 1))
y2_new = np.stack([scale_one2.z_transform(y_rescaled), baseline], axis=-1)
y2 = y_new
y2_train, y2_valid = y2_new[:split_seperator], y2_new[split_seperator:]

In [None]:
# train with baseline[0] feature
if INCLUDE_NLRMSIS_X:
    _x_add_train = baseline_train[:, 0]
    _x_add_valid = baseline_valid[:, 0]

    if USE_LOG_X and not USE_LOG:
        _x_add_train = np.log(_x_add_train)
        _x_add_valid = np.log(_x_add_valid)

    _baseline_mean = np.mean(_x_add_train)
    _baseline_std = np.std(_x_add_train)

    rescaled_baseline_train = (_x_add_train - _baseline_mean) / _baseline_std
    rescaled_baseline_valid = (_x_add_valid - _baseline_mean) / _baseline_std

    x_train = np.concat([x_train, rescaled_baseline_train.reshape(-1, 1)], axis=1)
    x_valid = np.concat([x_valid, rescaled_baseline_valid.reshape(-1, 1)], axis=1)
    x = np.concat([x_train, x_valid], axis=0)

In [None]:
# train with baseline[0] feature
x2 = x
if INCLUDE_NLRMSIS_X:
    _x_add_train = baseline_train[:, 0]
    _x_add_valid = baseline_valid[:, 0]

    if USE_LOG_X and not USE_LOG:
        _x_add_train = np.log(_x_add_train)
        _x_add_valid = np.log(_x_add_valid)

    _baseline_mean = np.mean(_x_add_train)
    _baseline_std = np.std(_x_add_train)

    rescaled_baseline_train = (_x_add_train - _baseline_mean) / _baseline_std
    rescaled_baseline_valid = (_x_add_valid - _baseline_mean) / _baseline_std

    x2_train = np.concat([x_train, rescaled_baseline_train.reshape(-1, 1)], axis=1)
    x2_valid = np.concat([x_valid, rescaled_baseline_valid.reshape(-1, 1)], axis=1)
    x2 = np.concat([x2_train, x2_valid], axis=0)

In [None]:
# dataloader
class CustomDataset(Dataset):
    def __init__(self, _x, _y, in_transforms=None, out_transforms=None):
        self.x = torch.tensor(_x, dtype=torch.float32)
        self.y = torch.tensor(_y, dtype=torch.float32)

        self.in_transforms = in_transforms
        self.out_transforms = out_transforms

    def __getitem__(self, _index):
        _x = self.x[_index]
        if self.in_transforms is not None:
            _x = self.in_transforms(_x)

        _y = self.y[_index]
        if self.out_transforms is not None:
            _y = self.out_transforms(_y)

        return _x, _y

    def __len__(self):
        return self.x.shape[0]


train_dataset = CustomDataset(
    x_train, y_train,  # + ground_truth_scaler.z_transform(baseline[:split_seperator]),
    in_transforms=torchvision.transforms.v2.GaussianNoise(sigma=INPUT_NOISE_STD, clip=False),
    out_transforms=torchvision.transforms.v2.GaussianNoise(sigma=OUTPUT_NOISE_STD, clip=False)
)
valid_dataset = CustomDataset(x_valid, y_valid)  # + ground_truth_scaler.z_transform(baseline[split_seperator:]))

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS_TRAIN,
                          persistent_workers=PERSISTENT_WORKERS, drop_last=False, prefetch_factor=PREFETCH_FACTOR,
                          pin_memory=DEVICE == 'cuda')
valid_loader = DataLoader(valid_dataset, batch_size=VALID_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS_VALID,
                          persistent_workers=PERSISTENT_WORKERS, drop_last=False, prefetch_factor=PREFETCH_FACTOR,
                          pin_memory=DEVICE == 'cuda')

In [None]:
# model constants
N_ENCODER_LAYERS = 2 # 1
N_DECODER_LAYERS = 1
D_MODEL = 112 # 112 # 128 # 96
N_HEADS = 4
D_FEED_FORWARD = 4 * D_MODEL
DROPOUT_RATE = 0.1
MODEL_NAME = "m1200"

In [None]:
# model setup
model = Transformer(
    num_encoder_layers=N_ENCODER_LAYERS,
    num_decoder_layers=N_DECODER_LAYERS,
    d_model=D_MODEL,
    num_heads=N_HEADS,
    d_ff=D_FEED_FORWARD,
    input_features=x_train.shape[1],
    output_sequence_length=y_train.shape[1],
    dropout_rate=DROPOUT_RATE
).to(DEVICE)

# ) # mode='max-autotune', dynamic=True , fullgraph=True)
model.compile(mode='max-autotune', fullgraph=True)

In [None]:
# training constants
EPOCHS = 5000
SCHEDULER_CYCLES = 0.7 # 0.95 # 0.7
EARLY_STOPPING_PATIENCE = EPOCHS
LEARNING_RATE = 5e-5 # 5e-5
WEIGHT_DECAY = 0.01
WARMUP_PERCENTAGE = 0.01
SKIP_VALIDATION_STEPS = 19
CLIP_GRAD_NORM = 1000  # 1.0

In [None]:
# setup training stuff
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

total_steps = int((len(train_loader) * EPOCHS) / SCHEDULER_CYCLES)
warmup_steps = int(total_steps * WARMUP_PERCENTAGE)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

WITH_ODRMSE = True

criterion_mse = nn.MSELoss().to(DEVICE)
criterion_odrmse = ODRMSELoss(
    total_duration=432 * 10 * 60, times= torch.arange(0, 432, device=DEVICE) * 10 * 60,     # 432 time steps which are scaled to seconds
    min_weight_epsilon=1e-5, numerical_stability_delta=5e-7
).to(DEVICE)
scaler = GradScaler(enabled=(DEVICE == 'cuda'))

In [None]:
# wandb constants - for using wandb use `uv run wandb login`
WANDB_PROJECT = "XXX"
WANDB_ENTITY = "XXX"
WANDB_RUN_NAME_PREFIX = f"full-transformer-{MODEL_NAME}"

In [None]:
# wandb setup
config = {
    "model": "TimeSeriesTransformer", "epochs": EPOCHS, "early_stopping_patience": EARLY_STOPPING_PATIENCE,
    "train_batch_size": TRAIN_BATCH_SIZE, "valid_batch_size": VALID_BATCH_SIZE,
    "learning_rate": LEARNING_RATE, "weight_decay": WEIGHT_DECAY, "clip_grad_norm": CLIP_GRAD_NORM,
    "input_noise_std": INPUT_NOISE_STD, "output_noise_std": OUTPUT_NOISE_STD,
    "mse_weight": 1.0, "odrmse_weight": 1.0 if WITH_ODRMSE else 0.0,
    "d_model": D_MODEL, "n_heads": N_HEADS, "num_encoder_layers": N_ENCODER_LAYERS,
    "num_decoder_layers": N_DECODER_LAYERS,
    "d_ff": D_FEED_FORWARD, "dropout": DROPOUT_RATE, "total_duration_sec": 432 * 10 * 60,
    "min_weight_eps": 1e-5, "num_stab_delta": 5e-7,
    "epochal_scaling": False, "skip_validation_steps": SKIP_VALIDATION_STEPS,
}
run = wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, name=f"{WANDB_RUN_NAME_PREFIX}", config=config)

summary(model)

In [None]:
# train helper functions
def _save_model(new_best: bool = True):
    if new_best:
        save_path = weight_save_path / f"{MODEL_NAME}_best_epoch_{epoch + 1}_valloss_{current_val_loss:.4f}.pt"
        artifact = wandb.Artifact(f'{MODEL_NAME}-best', type='model')
        print(f"  Saving new best model to {save_path}...")
    else:
        save_path = weight_save_path / f"{MODEL_NAME}_epoch_{epoch + 1}_valloss_{current_val_loss:.4f}.pt"
        artifact = wandb.Artifact(f'{MODEL_NAME}', type='model')
        print(f"  Saving model to {save_path}...")
    model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model
    torch.save(model_to_save.state_dict(), save_path)

    # Optional: Save to W&B Artifacts
    artifact.add_file(str(save_path))
    wandb.log_artifact(artifact)

In [None]:
# train loop
print(f"\n--- Starting Training: {MODEL_NAME} ---")
wandb.watch(model, log_freq=100)  # Watch model gradients (optional)
best_val_loss = float('inf')
best_epoch = -1
global_step = 0
skipped_validation_steps = 0

model.train()
for epoch in range(EPOCHS):
    epoch_start_time = time.time()

    # --- Training Phase ---
    train_loss_mse_accum = 0.0
    train_loss_odrmse_accum = 0.0
    train_loss_total_accum = 0.0
    processed_samples_train = 0

    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS} [Train]", mininterval=0.5)
    for batch_data in train_pbar:
        # Expects: features, targets
        batch_features = batch_data[0].to(DEVICE)
        batch_target = batch_data[1].to(DEVICE)

        if WITH_ODRMSE:
            batch_baseline = None if RESIDUAL_TRAINING else batch_target[:, 0, 1].reshape(-1, 1)
            batch_target = batch_target[:, :, 0]

        current_batch_size = batch_data[0].size(0)
        optimizer.zero_grad(set_to_none=True)

        # --- Forward Pass ---
        with autocast(enabled=(DEVICE == 'cuda'), device_type=DEVICE, dtype=torch.float16):
            predictions = model(batch_features, batch_target)

            # --- Loss Calculation ---
            loss_mse = criterion_mse(predictions, batch_target)
            if WITH_ODRMSE:
                loss_odrmse = criterion_odrmse(predictions, batch_target, batch_baseline)
                total_loss = loss_mse + loss_odrmse
            else:
                total_loss = loss_mse

        # --- Backward Pass & Optimization ---
        scaler.scale(total_loss).backward()
        if CLIP_GRAD_NORM > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=CLIP_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        # --- Logging ---
        lr = scheduler.get_last_lr()[0]
        wandb.log({
            f"train/batch_loss_total": total_loss.item(),
            # f"train/batch_loss_cross-entropy": loss_mse.item(),
            f"train/batch_loss_mse": loss_mse.item(),
            f"train/batch_loss_odrmse": loss_odrmse.item() if WITH_ODRMSE else 0.0,
            f"train/learning_rate": lr,
        }, step=global_step)

        train_loss_mse_accum += loss_mse.item() * current_batch_size
        train_loss_odrmse_accum += loss_odrmse.item() * current_batch_size if WITH_ODRMSE else 0.0
        train_loss_total_accum += total_loss.item() * current_batch_size
        processed_samples_train += current_batch_size
        global_step += 1

        # Update tqdm progress bar
        train_pbar.set_postfix(loss=f"{total_loss.item():.4f}", lr=f"{lr:.2e}", refresh=False)

    if skipped_validation_steps < SKIP_VALIDATION_STEPS:
        skipped_validation_steps += 1
        continue

    avg_train_loss_mse = train_loss_mse_accum / processed_samples_train
    avg_train_loss_odrmse = train_loss_odrmse_accum / processed_samples_train
    avg_train_loss_total = train_loss_total_accum / processed_samples_train

    # --- Validation Phase ---
    model.eval()
    val_loss_mse_accum = 0.0
    val_loss_odrmse_accum = 0.0
    val_loss_total_accum = 0.0
    processed_samples_val = 0

    skipped_validation_steps -= SKIP_VALIDATION_STEPS  # this allows for fractions

    # Wrap val_loader with tqdm
    val_pbar = tqdm(valid_loader, desc=f"Epoch {epoch + 1}/{EPOCHS} [Validate]")
    with torch.no_grad():
        for batch_data in val_pbar:
            # --- Data Preparation (ASSUMES MODIFIED DATALOADER) ---
            batch_features = batch_data[0].to(DEVICE)
            batch_target = batch_data[1].to(DEVICE)

            if WITH_ODRMSE:
                batch_baseline = None if RESIDUAL_TRAINING else batch_target[:, 0, 1].reshape(-1, 1)
                batch_target = batch_target[:, :, 0]

            if torch.isnan(batch_features).any():
                print("Nan detected in input")

            if torch.isnan(batch_target).any():
                print("Nan detected in target")

            current_batch_size = batch_target.size(0)

            # --- Forward Pass (Inference Mode) ---
            with autocast(enabled=(DEVICE == 'cuda'), device_type=DEVICE, dtype=torch.float16):
                predictions_eval = model.predict(batch_features)

                # --- Loss Calculation (for monitoring) ---
                loss_mse = criterion_mse(predictions_eval, batch_target)
                if WITH_ODRMSE:
                    loss_odrmse = criterion_odrmse(predictions_eval, batch_target, batch_baseline)
                    total_loss = loss_mse + loss_odrmse  # Combined for comparison
                else:
                    total_loss = loss_mse

            val_loss_mse_accum += loss_mse.item() * current_batch_size
            val_loss_odrmse_accum += loss_odrmse.item() * current_batch_size if WITH_ODRMSE else 0.0
            val_loss_total_accum += total_loss.item() * current_batch_size
            processed_samples_val += current_batch_size
            val_pbar.set_postfix(loss=f"{total_loss.item():.4f}")

    avg_val_loss_mse = val_loss_mse_accum / processed_samples_val
    avg_val_loss_odrmse = val_loss_odrmse_accum / processed_samples_val
    avg_val_loss_total = val_loss_total_accum / processed_samples_val

    epoch_duration = time.time() - epoch_start_time

    model.train()

    # --- W&B Epoch Logging ---
    wandb.log({
        "epoch": epoch + 1,
        "train/epoch_loss_total": avg_train_loss_total,
        "train/epoch_loss_mse": avg_train_loss_mse,
        # "train/epoch_loss_cross-entropy": avg_train_loss_mse,
        "train/epoch_loss_odrmse": avg_train_loss_odrmse,
        "val/epoch_loss_total": avg_val_loss_total,
        "val/epoch_loss_mse": avg_val_loss_mse,
        # "val/epoch_loss_cross-entropy": avg_val_loss_mse,
        "val/epoch_loss_odrmse": avg_val_loss_odrmse,
        "epoch_duration_sec": epoch_duration,
    }, step=global_step)  # Log against the last global step of the epoch

    print(f"Epoch {epoch + 1}/{EPOCHS} Summary:")
    print(
        f"  Train Loss: Total={avg_train_loss_total:.4f}, MSE={avg_train_loss_mse:.4f}, OD-RMSE={avg_train_loss_odrmse:.4f}")
    print(
        f"  Valid Loss: Total={avg_val_loss_total:.4f}, MSE={avg_val_loss_mse:.4f}, OD-RMSE={avg_val_loss_odrmse:.4f}")
    # print(f"  Train Loss: Total={avg_train_loss_total:.4f}, Cross-Entropy={avg_train_loss_mse:.4f}, OD-RMSE={avg_train_loss_odrmse:.4f}")
    # print(f"  Valid Loss: Total={avg_val_loss_total:.4f}, Cross-Entropy={avg_val_loss_mse:.4f}, OD-RMSE={avg_val_loss_odrmse:.4f}")
    print(f"  Epoch Duration: {epoch_duration:.2f} seconds")

    # --- Save Best Model ---
    current_val_loss = avg_val_loss_total  # Use combined validation loss

    if current_val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = current_val_loss
        _save_model()

    if best_epoch < epoch - EARLY_STOPPING_PATIENCE:
        print(f"  Early stopping at epoch {epoch + 1} due to lack of improvement in validation loss.")
        break

_save_model(False)
print(f"--- Finished Training: {MODEL_NAME} (Best Val Loss: {best_val_loss:.4f}) ---")

Now do some evaluation...

In [None]:
import matplotlib

model.load_state_dict(torch.load('some-best-model.pt'))
model.eval()

...