In [1]:
from functools import partial
import os
import numpy as np

import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning_uq_box.datamodules import UCIRegressionDatamodule
from lightning_uq_box.datasets import UCIConcrete, UCIEnergy, UCIYacht
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import MVERegression, DeepEnsembleRegression

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
root = "../../data/uci/"
experiment = "concrete"
ckpt_path = "../../results/robustness/checkpoints/"

In [7]:
dm = UCIRegressionDatamodule(dataset_name = experiment, root = root, train_distortion=0, train_size = 0.9, batch_size=500)
sample = next(iter(dm.train_dataloader()))
n_input = sample["input"].shape[-1]
trainer = Trainer()

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [211]:
mlp = MLP(n_inputs = n_input, n_outputs = 2, activation_fn=nn.ReLU(), n_hidden = [256, 512, 1024])
model = MVERegression(
        mlp, optimizer=partial(torch.optim.Adam, lr=1e-2), burnin_epochs=20
    )

In [212]:
ensemble = [{"base_model": model, "ckpt_path": f"{ckpt_path}{experiment}_0.ckpt"},
            {"base_model": model, "ckpt_path": f"{ckpt_path}{experiment}_1.ckpt"},
            {"base_model": model, "ckpt_path": f"{ckpt_path}{experiment}_distorted_2.ckpt"},
            ]

In [213]:
deep_ens_nll = DeepEnsembleRegression(ensemble)

item = next(iter(dm.test_dataloader()))

In [216]:
target = dm.uci_ds.target_scaler.inverse_transform(item["target"])
pred = deep_ens_nll.predict_step(item["input"])["pred"].detach()
pred = dm.uci_ds.target_scaler.inverse_transform(pred)

In [217]:
np.sqrt(np.mean(np.power(pred- target,2)))

np.float64(5.637772853174725)

In [218]:
pred = deep_ens_nll.predict_step(item["input"])["samples"].detach()
log_sigma_2 = pred[:,1]
eps = torch.ones_like(log_sigma_2) * 1e-6
std = torch.sqrt(eps + np.exp(log_sigma_2))
std.mean(axis = 0)

  std = torch.sqrt(eps + np.exp(log_sigma_2))


tensor([0.0033, 0.0033, 0.0692])

In [220]:
pred[:,0,0:2].mean(axis = 1)

tensor([-0.7029, -1.1830,  1.4391, -0.5282,  1.0318,  0.5841,  0.6093,  0.6669,
         0.6980,  0.1360, -1.5733,  1.9835, -1.0913,  1.9996,  1.9056,  0.4169,
        -1.1498,  0.5451, -0.8140, -0.9725, -0.9618,  1.0318, -1.1560,  1.6428,
        -1.5965, -0.7835, -0.9430, -0.8765, -1.1554,  1.7323, -0.9333, -0.9431,
         1.9445, -0.9308,  0.6257,  1.4289, -0.7856,  1.4113, -0.4723, -1.1802,
        -1.0730,  1.7552, -0.7634,  0.3810, -0.3412,  0.1939, -1.1414, -0.9568,
         0.6545,  0.9987,  1.0161, -0.9291,  1.0824, -0.7346,  1.4318, -1.1859,
        -0.6343,  0.3931,  1.6592, -0.7521, -0.7390, -0.9741,  0.9946, -1.1995,
        -0.9556, -1.0089,  1.6793, -1.1092,  0.3545,  1.7113,  1.0369, -1.1532,
         0.1209, -0.5311, -0.6876,  0.3360,  1.0462])