In [None]:
from deepuq.models.models import model_setup_DE
from deepuq.data.data import DataPreparation
from deepuq.train import train
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# necessary in order to display matplotlib plots in a notebook environment
# when also running python scripts inline
%matplotlib inline

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_type = 'bnll_loss'
model, lossFn = model_setup_DE(loss_type, DEVICE)

In [None]:
uniform = True
norm = False
verbose = False
val_prop = 0.1
rs_prior = 42
rs_uniform = 42
rs_train_val = 42
BATCH_SIZE = 100
lr = 0.001
n_models = 16
n_epochs = 100
out_dir = '../DeepUQResources/'

Create a dictionary with all the information for all the experiments from the paper.

In [None]:
experiments_df = {
    "0D, input, low": {
        "size_df": 100000,
        "noise": "low",
        "dim": "0D",
        "injection": "input"
    },
    "0D, input, medium": {
        "size_df": 100000,
        "noise": "medium",
        "dim": "0D",
        "injection": "input"
    },
    "0D, input, high": {
        "size_df": 100000,
        "noise": "high",
        "dim": "0D",
        "injection": "input"
    },
    "0D, output, low": {
        "size_df": 100000,
        "noise": "low",
        "dim": "0D",
        "injection": "output"
    },
    "0D, output, medium": {
        "size_df": 100000,
        "noise": "medium",
        "dim": "0D",
        "injection": "output"
    },
    "0D, output, high": {
        "size_df": 100000,
        "noise": "high",
        "dim": "0D",
        "injection": "output"
    },
    "2D, input, low": {
        "size_df": 5000,
        "noise": "low",
        "dim": "2D",
        "injection": "input"
    },
    "2D, input, medium": {
        "size_df": 5000,
        "noise": "medium",
        "dim": "2D",
        "injection": "input"
    },
    "2D, input, high": {
        "size_df": 5000,
        "noise": "high",
        "dim": "2D",
        "injection": "input"
    },
    "2D, output, low": {
        "size_df": 5000,
        "noise": "low",
        "dim": "2D",
        "injection": "output"
    },
    "2D, output, medium": {
        "size_df": 5000,
        "noise": "medium",
        "dim": "2D",
        "injection": "output"
    },
    "2D, output, high": {
        "size_df": 5000,
        "noise": "high",
        "dim": "2D",
        "injection": "output"
    },
}

In [None]:
#counter = 0
for experiment_name, params in experiments_df.items():
    print(f"Experiment: {experiment_name}")
    for key, value in params.items():
        print(f"  {key}: {value}")
    data = DataPreparation()
    model_inputs, model_outputs = data.generate_df(
                params["size_df"], params["noise"],
                params["dim"], params["injection"], uniform, verbose,
                rs_prior=rs_prior, rs_uniform=rs_uniform)
    model_inputs, model_outputs, norm_params = data.normalize(
        model_inputs, model_outputs, norm
    )
    x_train, x_val, y_train, y_val = data.train_val_split(
        model_inputs,
        model_outputs,
        val_proportion=val_prop,
        random_state=rs_train_val,
    )
    trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
    trainDataLoader = DataLoader(
        trainData, batch_size=BATCH_SIZE, shuffle=True
    )
    train.train_DE(
            trainDataLoader,
            x_val,
            y_val,
            lr,
            DEVICE,
            'bnll_loss',
            n_models,
            norm_params,
            model_name='DE',
            BETA=0.5,
            EPOCHS=n_epochs,
            out_dir=out_dir,
            inject_type=params["injection"],
            data_dim=params["dim"],
            noise_level=params["noise"],
            save_all_checkpoints=True,
            save_final_checkpoint=True,
            overwrite_model=False,
            plot_inline=False,
            plot_savefig=False,
            rs_list=[42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57],
            set_and_save_rs=True,
        )