In [None]:
%load_ext autoreload
%autoreload 2

%cd ../

In [None]:
import datetime
import itertools
import os
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import distributions as tdist, Tensor
from tqdm.auto import tqdm

from models.gaussian import GaussianRegressor, GaussianRegressorSplit, train_gaussian_regressor_custom
from storage.data import get_loaders, get_tensors, get_train_calib_split

In [None]:
INPUT_DIM = 101
Y_DIM = 24
MAX_EPOCHS = 500
BATCH_SIZE = 256
SHUFFLE = False
SEEDS = range(10)

if SHUFFLE:
    out_dir = 'out/storage_gaussian_custom_shuffle/'
else:
    out_dir = 'out/storage_gaussian_custom/'

In [None]:
def expand_dates_to_hourly(dates: np.ndarray) -> pd.DatetimeIndex:
    oneday = pd.Timedelta(days=1)
    hourly_dts: list[pd.Timestamp] = []
    for date in dates:
        hourly_range = pd.date_range(start=date, end=date + oneday, freq='h', inclusive='left')
        hourly_dts.extend(hourly_range)
    return pd.DatetimeIndex(hourly_dts)

In [None]:
tensors, y_info = get_tensors(shuffle=SHUFFLE, log_prices=False)
assert isinstance(y_info, tuple)
y_mean, y_std = y_info

In [None]:
def plot_model_preds(
    model: GaussianRegressorSplit,
    tensors: dict[str, Tensor | np.ndarray],
    unstandardize: tuple[np.ndarray, np.ndarray] | Literal['log'] | None,
    plot_std: bool = False,
    num_samples: int = 0,
    split: str = 'traincalib',
    date_range: tuple[str, str] | tuple[datetime.date, datetime.date] | None = ('2014-01-01', '2014-04-15'),
) -> plt.Axes:
    X, Y, dates = tensors[f'X_{split}'], tensors[f'Y_{split}'], tensors[f'date_{split}']
    assert isinstance(X, Tensor)
    assert isinstance(Y, Tensor)
    assert isinstance(dates, np.ndarray)

    datetimes = expand_dates_to_hourly(dates)

    model.eval()
    with torch.no_grad():
        loc, scale_tril = model(X)
    pred_dist = tdist.MultivariateNormal(loc=loc, scale_tril=scale_tril)

    samples = None
    if num_samples > 0:
        samples = pred_dist.sample((num_samples,))  # shape [3, N, y_dim]

    if unstandardize is None:
        true_y = Y
        pred_mean = loc
        pred_std = pred_dist.stddev
        pred_lo = pred_mean - 2 * pred_std
        pred_hi = pred_mean + 2 * pred_std
    elif unstandardize == 'log':
        true_y = torch.exp(Y)
        pred_mean = torch.exp(loc)
        pred_lo = torch.exp(loc - 2 * pred_dist.stddev)
        pred_hi = torch.exp(loc + 2 * pred_dist.stddev)
        if samples is not None:
            samples = torch.exp(samples)
    else:
        y_mean, y_std = torch.from_numpy(unstandardize[0]), torch.from_numpy(unstandardize[1])
        true_y = Y * y_std + y_mean
        pred_mean = loc * y_std + y_mean
        pred_std = pred_dist.stddev * y_std
        pred_lo = pred_mean - 2 * pred_std
        pred_hi = pred_mean + 2 * pred_std
        if samples is not None:
            samples = samples * y_std + y_mean

    # reshape from [N, y_dim] to [N*y_dim]
    true_y = true_y.reshape(-1).numpy()
    pred_mean = pred_mean.reshape(-1).numpy()
    pred_lo = pred_lo.reshape(-1).numpy()
    pred_hi = pred_hi.reshape(-1).numpy()
    if samples is not None:
        samples = samples.reshape(3, -1).numpy()

    pred_df = pd.DataFrame(index=datetimes, data={
        'y': true_y,
        'ypred_mean': pred_mean,
        'ypred_lo': pred_lo,
        'ypred_hi': pred_hi,
    })
    if samples is not None:
        for i in range(num_samples):
            pred_df[f'sample_{i}'] = samples[i]

    if date_range is None:
        subdf = pred_df
    elif isinstance(date_range[0], str):
        start_date = pd.Timestamp(date_range[0]).date()
        end_date = pd.Timestamp(date_range[1]).date()
        subdf = pred_df.loc[start_date:end_date]
    else:
        start_date, end_date = date_range
        subdf = pred_df.loc[start_date:end_date]

    _, ax = plt.subplots(figsize=(20, 4), tight_layout=True)
    ax.plot(subdf.index, subdf['y'], label='true', color='black')
    ax.plot(subdf.index, subdf['ypred_mean'], label=f'mean', alpha=0.7)
    if samples is not None:
        for i in range(num_samples):
            ax.plot(subdf.index, subdf[f'sample_{i}'], alpha=0.3)
    if plot_std:
        ax.fill_between(subdf.index, subdf['ypred_lo'], subdf['ypred_hi'],
                        label=r'$\pm$ 2 std', alpha=0.3)
    ax.legend()
    return ax

## Plot pre-trained models

In [None]:
split = 'traincalib'
X, Y, dates = tensors[f'X_{split}'], tensors[f'Y_{split}'], tensors[f'date_{split}']

assert isinstance(X, Tensor)
assert isinstance(Y, Tensor)
assert isinstance(dates, np.ndarray)

In [None]:
datetimes = expand_dates_to_hourly(dates)
true_y = torch.cat(tuple(Y)).numpy()
pred_df = pd.DataFrame(index=datetimes, data={'y': true_y})

In [None]:
for seed in SEEDS:
    model = GaussianRegressor(input_dim=INPUT_DIM, y_dim=Y_DIM)
    ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}.pt')
    model.load_state_dict(torch.load(ckpt_path, weights_only=True))
    model.eval()

    with torch.no_grad():
        loc, scale_tril = model(X)
    pred_dist = tdist.MultivariateNormal(loc=loc, scale_tril=scale_tril)
    pred_df[f'ypred_mean_s{seed}'] = torch.cat(tuple(loc)).numpy()
    pred_df[f'ypred_std_s{seed}'] = torch.cat(tuple(pred_dist.stddev)).numpy()

    nll = -pred_dist.log_prob(Y).mean().item()
    print(f'seed {seed} nll {nll}')

In [None]:
start_date = pd.Timestamp('2014-01-01').date()
end_date = pd.Timestamp('2014-04-15').date()
subdf = pred_df.loc[start_date:end_date]
# subdf = pred_df

fig, ax = plt.subplots(figsize=(12, 4), tight_layout=True)
ax.plot(subdf.index, subdf['y'], label='true', color='black')
for seed in [2]:
    ax.plot(subdf.index, subdf[f'ypred_mean_s{seed}'], label=f's{seed} mean', alpha=0.3)
    ax.plot(subdf.index, subdf[f'ypred_std_s{seed}'], label=f's{seed} std', alpha=0.3)
ax.legend()
plt.show()

## Plot scratch

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
# with torch.no_grad():
#     model.diag_net.weight.fill_(0.)
#     model.diag_net.bias.fill_(1.)
#     model.loc_net.weight.fill_(0.)
#     model.loc_net.bias.fill_(0.)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
ax = plot_model_preds(
    model=model, tensors=tensors, unstandardize=y_info,  # y_info,  # y_info, (y_mean_tch, y_std_tch),
    plot_samples=False, plot_std=True, split='traincalib',
    date_range=('2014-01-01', '2014-02-15')
)
# ax.set(ylim=(-0, 400))

## Train Smart

Option A
1. train with NLL
2. Fix (embed, covariance), train mean
3. Fix (embed, mean), train covariance

Option B
1. train with NLL-diag
2. Fix (embed, mean), train covariance
3. Fine tune NLL

Option C
1. train with NLL-diag
2. Fine tune NLL

In [None]:
lrs = 10. ** np.arange(-4, -1.4, 0.5)
l2regs = [1e-4]
# l2regs = [0, 1e-4, 1e-3, 1e-2]

### Option A
1. train with NLL
2. Fix (embed, covariance), train mean
3. Fix (embed, mean), train covariance

In [None]:
device = 'cpu'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=BATCH_SIZE)

best_model = None
best_hp = None
best_val_loss = np.inf

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
for lr, l2reg in pbar:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='nll', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device, show_pbar=True)
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
# ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll.pt')
# torch.save(best_model.cpu().state_dict(), ckpt_path)
# print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
ax = plot_model_preds(
    model=model, tensors=tensors, unstandardize=y_info,  # y_info, (y_mean_tch, y_std_tch),
    plot_std=True, split='traincalib',
    date_range=('2012-01-01', '2012-09-01')
)
# ax.set(ylim=(-10, 2000))

### Option B

1. train with NLL
2. Fix (embed, covariance), train mean
3. Fix (embed, mean), train covariance

In [None]:
device = 'cuda'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=BATCH_SIZE)

best_model = None
best_hp = None
best_val_loss = np.inf

ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll.pt')

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
l2reg = 0
for lr in [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        model.load_state_dict(torch.load(ckpt_path, weights_only=True))
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='nll', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device, show_pbar=True, freeze=('embed', 'diag_net', 'scale_tril_net'))
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll_mean.pt')
torch.save(best_model.cpu().state_dict(), ckpt_path)
print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll_mean.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
plot_model_preds(
    model=model, tensors=tensors, unstandardize=(y_mean_tch, y_std_tch), plot_std=True, split='traincalib',
    date_range=('2014-01-01', '2014-04-15')
)

In [None]:
device = 'cuda'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=BATCH_SIZE)

best_model = None
best_hp = None
best_val_loss = np.inf

ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll.pt')

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
l2reg = 0
for lr in [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        model.load_state_dict(torch.load(ckpt_path, weights_only=True))
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='mse', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device, show_pbar=True, freeze=('embed', 'diag_net', 'scale_tril_net'))
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll_mse.pt')
torch.save(best_model.cpu().state_dict(), ckpt_path)
print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nll_mse.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
plot_model_preds(
    model=model, tensors=tensors, unstandardize=(y_mean_tch, y_std_tch), plot_std=True, split='traincalib',
    date_range=('2014-01-01', '2014-04-15')
)

In [None]:
loc, scale_tril = model(tensors_cv['X_calib'])
-tdist.MultivariateNormal(loc=loc, scale_tril=scale_tril).log_prob(tensors_cv['Y_calib']).mean().item()

### Option C

1. train mean with MSE
2. Fix (embed, mean), train covariance
3. Optionally fine tune

In [None]:
device = 'device'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=BATCH_SIZE)

best_model = None
best_hp = None
best_val_loss = np.inf

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
for lr, l2reg in pbar:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='mse', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device)
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_mse.pt')
torch.save(best_model.cpu().state_dict(), ckpt_path)
print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_mse.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
plot_model_preds(
    model=model, tensors=tensors, unstandardize=(y_mean_tch, y_std_tch),
    plot_std=False, split='traincalib',
    date_range=('2014-01-01', '2014-04-01')
)

In [None]:
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_mse.pt')
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
model.eval()
loc, _ = model(tensors_cv['X_train'])

In [None]:
device = 'cuda'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=BATCH_SIZE)

best_model = None
best_hp = None
best_val_loss = np.inf

# ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_mse.pt')

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
l2reg = 1e-4
for lr in [1e-5, 1e-4, 1e-3, 1e-2]:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        # model.load_state_dict(torch.load(ckpt_path, weights_only=True))
        # model.initialize_diag_bias(4.)
        # with torch.no_grad():
        #     torch.nn.init.normal_(model.scale_tril_net.weight, mean=0., std=5e-3)
        #     torch.nn.init.normal_(model.scale_tril_net.bias, mean=0., std=5e-3)
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='mse_nll', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device, show_pbar=True)
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_mse_nll.pt')
torch.save(best_model.cpu().state_dict(), ckpt_path)
print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_mse_nll.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
plot_model_preds(
    model=model, tensors=tensors, unstandardize=(y_mean_tch, y_std_tch),
    plot_std=True, split='traincalib',
    date_range=('2014-01-01', '2014-04-01')
)

### Option D
- train with NLL diag
- fine tune with NLL

In [None]:
device = 'cpu'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=BATCH_SIZE)

best_model = None
best_hp = None
best_val_loss = np.inf

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
for lr, l2reg in pbar:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='nll_diag', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device)
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag.pt')
torch.save(best_model.cpu().state_dict(), ckpt_path)
print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
with torch.no_grad():
    model.scale_tril_net.weight.fill_(0.)
    model.scale_tril_net.bias.fill_(0.)
ax = plot_model_preds(
    model=model, tensors=tensors, unstandardize=y_info,  # y_info, (y_mean_tch, y_std_tch),
    plot_std=True, split='traincalib',
    date_range=('2014-03-01', '2014-03-15')
)
ax.set(ylim=(-10, 200))

In [None]:
for seed in range(10):
    ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag.pt')
    model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
    model.load_state_dict(torch.load(ckpt_path, weights_only=True))
    model.zero_scale_tril_net()
    torch.save(model.state_dict(), ckpt_path)

In [None]:
device = 'cpu'

seed = 3
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=-1)

ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag.pt')

best_model = None
best_hp = None
best_val_loss = np.inf

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
for lr, l2reg in pbar:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        model.load_state_dict(torch.load(ckpt_path, weights_only=True))
        model.zero_scale_tril_net()
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='nll', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device, show_pbar=True)
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
# ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag_nll.pt')
# torch.save(best_model.cpu().state_dict(), ckpt_path)
# print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
best_model.scale_tril_net.bias

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag_nll.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
ax = plot_model_preds(
    model=model, tensors=tensors, unstandardize=y_info,  # y_info, (y_mean_tch, y_std_tch),
    plot_std=True, split='traincalib',
    date_range=('2014-03-01', '2014-03-15')
)
ax.set(ylim=(-10, 200))

In [None]:
device = 'cpu'

seed = 1
tensors_cv, _ = get_train_calib_split(tensors, seed=seed)
loaders = get_loaders(tensors_cv, batch_size=1024)

ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag_nll.pt')

best_model = None
best_hp = None
best_val_loss = np.inf

losses = []
pbar = tqdm(itertools.product(lrs, l2regs), total=len(lrs) * len(l2regs))
for lr, l2reg in pbar:
    try:
        model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
        model.load_state_dict(torch.load(ckpt_path, weights_only=True))
        result = train_gaussian_regressor_custom(
            model, loaders, loss_name='nll', max_epochs=MAX_EPOCHS, lr=lr, l2reg=l2reg,
            return_best_model=True, device=device, show_pbar=True, cutoff=50)
        if result['val_loss'] < best_val_loss:
            best_val_loss = result['val_loss']
            best_hp = (lr, l2reg)
            best_model = model
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}) best epoch: {result["best_epoch"]}, '
                   f'val_loss: {result["val_loss"]:.3f}, ')
    except Exception as e:
        tqdm.write(f'(lr {lr:.3g}, l2reg {l2reg:.3g}, seed {seed}) failed: {e}')
        losses.append((lr, l2reg, seed, np.nan))

assert best_model is not None
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag_nll_finetune.pt')
torch.save(best_model.cpu().state_dict(), ckpt_path)
print(f'Saved best model to {ckpt_path}')
print(f'Best hp: lr={best_hp[0]}, l2reg={best_hp[1]}, val_loss={best_val_loss:.3f}')

In [None]:
seed = 1
model = GaussianRegressorSplit(input_dim=INPUT_DIM, y_dim=Y_DIM)
ckpt_path = os.path.join(out_dir, f'gaussian_regressor_s{seed}_nlldiag_nll_finetune.pt')
model.load_state_dict(torch.load(ckpt_path, weights_only=True))
ax = plot_model_preds(
    model=model, tensors=tensors, unstandardize=y_info,  # y_info, (y_mean_tch, y_std_tch),
    plot_std=True, split='traincalib',
    date_range=('2014-02-01', '2014-03-01')
)
ax.set(ylim=(-10, 250))