# Uncertainty estimation using ensembles of partly independent MLP models

**Model description:**

-- Network outputs two values in the final layer, corresponding to the predicted **mean** and **variance** by treating the observed value as a sample from a Gaussian distribution

-- Ensemble members are trained on different bootstrap samples of the original training set; the mean and variance of a mixture are given by $$\mu_{*}(\mathbf{x})=M^{-1} \sum_{m} \mu_{\theta_{m}}(\mathbf{x}),$$  $$\sigma_{*}^{2}(\mathbf{x})=M^{-1} \sum_{m}\left(\sigma_{\theta_{m}}^{2}(\mathbf{x})+\mu_{\theta_{m}}^{2}(\mathbf{x})\right)-\mu_{*}^{2}(\mathbf{x}),$$ respectively.



In [None]:
import sys, os, glob
sys.path.append('..')
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
from tqdm import tqdm
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from astropy.io import fits
import numpy.ma as ma
from inverse_problem import SpectrumDataset, PregenSpectrumDataset, make_loader
from inverse_problem.nn_inversion.models_mlp import MlpPartlyIndepNet
from inverse_problem.nn_inversion.transforms import normalize_output
from inverse_problem.nn_inversion.posthoc import compute_metrics, open_param_file, plot_params, plot_pred_vs_refer
from inverse_problem.nn_inversion.transforms import transform_dist
from inverse_problem.nn_inversion.posthoc import plot_hist_params_comparison
from inverse_problem.nn_inversion.posthoc import plot_analysis_hist2d_unc
from inverse_problem.nn_inversion.posthoc import plot_spectrum, plot_model_spectrum, read_spectrum_for_refer
from inverse_problem.nn_inversion import mlp_transform_rescale, normalize_spectrum

### Define ensemble size

In [None]:
ensemble_size = 6

### Load data

In [None]:
filename = '../data/parameters_base_new.fits'
transform = None
sobj = SpectrumDataset(param_path=filename, source='database', transform=transform)
sample = sobj[1]

In [None]:
line_type = ['I','Q','U','V']
line_arg = 1000 * (np.linspace(6302.0692255, 6303.2544205, 56)) - 6302.5
fig, ax = plt.subplots(2,2, figsize = (10,5))
for i in range(4):
    ax[i//2][i%2].plot(line_arg, sample['X'][0][:,i]); ax[i//2][i%2].set_title(f'Spectral line {line_type[i]}')
fig.set_tight_layout(tight = True)

### Prepare data for training

Options:
-- angle transformation
-- log transformation

In [None]:
params = fits.open(filename)[0].data
def params_masked_rows(pars_arr):
    max_par_values = np.array([par_arr.max() for par_arr in pars_arr.T])
    min_par_values = np.array([par_arr.min() for par_arr in pars_arr.T])
    bool_arr = (min_par_values + 1e-3 < pars_arr) & (pars_arr < max_par_values - 1e-3)
    return np.all(bool_arr, axis=1)


def create_masked_array(pars_arr):
    rows_mask = params_masked_rows(pars_arr)
    array_mask = rows_mask[:, np.newaxis] | np.zeros_like(pars_arr, dtype=bool)
    return ma.masked_array(pars_arr, mask=~array_mask)
rows_mask_params = params_masked_rows(params)
filtered_params = params[rows_mask_params, :]

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device.type)

In [None]:
factors, cont_scale = [1, 1000, 1000, 1000], 40000
angle_transformation, logB = True, True

transform_name = "conv1d_transform_rescale"

batch_size = 128
num_workers = 1 if 'cuda' in device.type else 0

In [None]:
train_loader, val_loader = make_loader(data_arr=filtered_params, transform_name=transform_name,
                                       factors=factors, cont_scale=cont_scale,
                                       logB=logB, angle_transformation=angle_transformation,
                                       batch_size=batch_size, num_workers=num_workers)

sample_batch = next(iter(train_loader))

print('Size of spectrum batch: ', sample_batch['X'][0].shape)
print('Size of cont batch: ', sample_batch['X'][1].shape)
print('Size of true params batch: ', sample_batch['Y'].shape)

print(f'\nNumber of batches for train: {len(train_loader)}, for validation: {len(val_loader)}')

### Create path for saving

In [None]:
model_name = 'conv_ens'
current_time = str(datetime.now().strftime('%m-%d_%H-%M'))
save_path = '../' + model_name + '_' + current_time + '/'
Path(save_path).mkdir(parents=True, exist_ok=True)

save_path

### Define ensemble

output_dim=22 for uncertainty estimation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ensemble = [MlpPartlyIndepNet(input_dim=224, output_dim=22, hidden_dims=[200, 200, 200],
                              activation='elu', batch_norm=True, dropout=0.05, number_readout_layers=2).to(device) for _ in range(ensemble_size)]

In [None]:
criterion = nn.MSELoss()
optimizers = [torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99)) for model in ensemble]

In [None]:
def mdn_cost(mu, sigma, y):
    dist = torch.distributions.Normal(mu, sigma)
    return torch.mean(-dist.log_prob(y))


def fit_step(model, optimizer, dataloader, max_steps=None):
    train_loss = 0.0
    train_it = 0
    if max_steps is None:
        max_steps = float('inf')
    total = min(max_steps, len(dataloader))

    with tqdm(desc="batch", total=total, position=0, leave=True) as pbar_outer:
        for i, inputs in enumerate(dataloader):
            if i == total:
                break
            x = [inputs['X'][0].to(device), inputs['X'][1].to(device)]
            y = inputs['Y'].to(device)
            outputs = model(x)
            outputs_mean = outputs[:, :11]
            outputs_sigma = torch.exp(outputs[:, 11:])
            optimizer.zero_grad()

            losses = [mdn_cost(outputs_mean[:, ind], outputs_sigma[:, ind], y[:, ind])
                      for ind in range(11)]
            loss = torch.stack(losses).mean()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_it += 1
            if train_it % 10 == 0:
                pbar_outer.update(10)
        return train_loss / train_it


def eval_step(model, dataloader, max_steps = None):
    if max_steps is None:
        max_steps = float('inf')
    total = min(max_steps,len(dataloader))
    model.eval()
    val_loss = 0.0
    val_it = 0
    for i, inputs in enumerate(dataloader):
        if i==total:
            break
        x = [inputs['X'][0].to(device), inputs['X'][1].to(device)]
        y = inputs['Y'].to(device)

        with torch.no_grad():
            outputs = model(x)
            outputs_mean = outputs[:, :11]
            outputs_sigma = torch.exp(outputs[:, 11:])
            losses = [mdn_cost(outputs_mean[:, ind], outputs_sigma[:, ind], y[:, ind])
                      for ind in range(11)]
            loss = torch.stack(losses).msean()
            val_loss += loss.item()
            val_it += 1
    return val_loss / val_it

In [None]:
def save_model(model, optimizer, epoch, loss, path='../'):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss},
        path + f'ep{epoch}.pt')

### Train model

In [19]:
%%time
max_train_steps = None
max_val_steps = None
best_valid_loss = float('inf')
history = []
loss_history = []
log_template = "\nEpoch {ep:03d} train_loss: {t_loss:0.4f} val_loss {v_loss:0.4f}"
n_epochs = 3
path_to_save = save_path
log_dir=path_to_save
with tqdm(desc="epoch", total = n_epochs, position=0, leave=True) as pbar_outer:
    for epoch in range(n_epochs):
        train_loss = fit_step(train_loader, max_train_steps)
        val_loss = eval_step(val_loader, max_val_steps)
        history.append((train_loss, val_loss))
        if val_loss < best_valid_loss:
            best_valid_loss = val_loss
            save_model(path_to_save, epoch, val_loss)
        pbar_outer.update(1)
        tqdm.write(log_template.format(ep=epoch + 1, t_loss=train_loss,
                                               v_loss=val_loss))

batch: 100%|█████████▉| 31240/31243 [3:51:52<00:01,  2.25it/s]  
batch:   0%|          | 0/31243 [00:00<?, ?it/s]461.27s/it]


Epoch 001 train_loss: -0.1612 val_loss -1.1058


batch: 100%|█████████▉| 31240/31243 [3:29:15<00:01,  2.49it/s]  
batch:   0%|          | 0/31243 [00:00<?, ?it/s]663.19s/it]


Epoch 002 train_loss: -2.3502 val_loss -2.5064


batch: 100%|█████████▉| 31240/31243 [3:27:31<00:01,  2.51it/s]  
epoch: 100%|██████████| 3/3 [11:15:37<00:00, 13512.61s/it] 


Epoch 003 train_loss: -2.5075 val_loss -2.5039
CPU times: user 2d 6min 8s, sys: 11h 44min 16s, total: 2d 11h 50min 24s
Wall time: 11h 15min 37s



