In [6]:
#!pip install optuna
#!pip -q install einops

In [7]:
import numpy as np
import torch
import optuna
from torch.utils.data import DataLoader
from set_data import load_real_weather, WeatherPairs, split_windows_train_val_test
from visualization import (plot_complete_data, 
                           plot_weatherpair_panel,
                           visualize_loss, plot_ctx_and_scenarios_panel)
from utils import *
from diffusion_model import DiffusionSchedule, TinyCondUNet1D
from diffusion_model_train import train_one_epoch, eval_epoch, EarlyStopping
from diffusion_model_generate import p_sample_loop
import error_metrics as em
np.set_printoptions(suppress=True)

VARS = 6
EPOCHS = 100
feature_names = ["T2M_MIN", "T2M_MAX", "PRECTOTCORR", "ALLSKY_SFC_SW_DWN", "RH2M", "WS2M"]

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


device: cpu


In [8]:
X, columns = load_real_weather("../../data/data_20150630_to_20250630.csv", 
                      start_date="2015-06-30", end_date="2024-06-30")
print("Real data shape:", X.shape)  # (days, number of features)
CTX = 60
HORIZON = 60
dataset = WeatherPairs(X, ctx=CTX, horizon=HORIZON, stride=1)
train_ds, val_ds, test_ds = split_windows_train_val_test(dataset, test_frac=0.1, val_frac=0.1, train_frac=0.8)
mean, std = fit_standardizer(train_ds)
mean_t = torch.tensor(mean, device=device)
std_t = torch.tensor(std, device=device)
print("mean:", mean, "\nstd:", std)


Real data shape: (3289, 6)
mean: [ 8.279258  23.20681    1.2304815 18.979977  65.8404     1.5848776] 
std: [ 3.7783682   7.657927    4.024448    9.812458   17.534721    0.55773836]


In [9]:

def evaluate_crps_on_val(model, sched, val_ds, device, mean_t, std_t,
                         n_eval=64, n_scen=30):
    model.eval()
    idxs = np.random.choice(len(val_ds), size=min(n_eval, len(val_ds)), replace=False)
    crps_list = []
    for i in idxs:
        ctx, tgt = val_ds[i]
        ctx = ctx.unsqueeze(0).to(device)  # (1,CTX,D)
        y = tgt.detach().cpu().numpy()     # (H,D)

        scens = p_sample_loop(ctx, model, device, sched, HORIZON, VARS,
                              mean_t, std_t, n_scenarios=n_scen).detach().cpu().numpy()  # (S,H,D)

        # your function: returns (H,D) if average=False, then average
        crps_hd = em.crps_ensemble_snd(scens, y, average=False)
        crps_list.append(crps_hd.mean())
    return float(np.mean(crps_list))

def objective(trial):
    batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
    T = trial.suggest_categorical("T", [100, 250, 500])

    lr = trial.suggest_float("lr", 1e-5, 5e-4, log=True)
    wd = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    sched = DiffusionSchedule(T=T)
    model = TinyCondUNet1D(in_vars=VARS, ctx_vars=VARS).to(device)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd
    )

    early = EarlyStopping(patience=15, min_delta=1e-4, mode="min")

    for e in range(1, 101):
        tr = train_one_epoch(opt, model, sched, device, train_loader, mean_t, std_t)
        va = eval_epoch(model, device, val_loader, sched, mean_t, std_t)

        if early.step(va, model, e):
            break

    early.restore_best(model)

    val_crps = evaluate_crps_on_val(
        model, sched, val_ds, device, mean_t, std_t,
        n_eval=64, n_scen=30
    )

    return val_crps


study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=25)

print("Best:", study.best_params, "CRPS:", study.best_value)


[32m[I 2026-01-29 16:59:19,823][0m A new study created in memory with name: no-name-b0a5f0b7-4f1f-417c-b74f-d6b8107df1a2[0m
[33m[W 2026-01-29 17:01:44,941][0m Trial 0 failed with parameters: {'batch_size': 64, 'T': 250, 'lr': 0.00011373636100895947, 'weight_decay': 0.00023505238711996748} because of the following error: KeyboardInterrupt().[0m
Traceback (most recent call last):
  File "C:\Users\fperezg\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\optuna\study\_optimize.py", line 206, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "C:\Users\fperezg\AppData\Local\Temp\ipykernel_10300\1958643744.py", line 49, in objective
    val_crps = evaluate_crps_on_val(
               ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\fperezg\AppData\Local\Temp\ipykernel_10300\1958643744.py", line 11, in evaluate_crps_on_val
    scens = p_sample_loop(ctx, model, device, sched, HORIZON, V

KeyboardInterrupt: 