In [1]:
%matplotlib notebook
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
from tqdm import tqdm
from torch.utils.data.sampler import SubsetRandomSampler
from data_generator_boneage import BoneAgeDataset
from models import BreastPathQModel
from uce import uceloss
from calibration_plots import plot_uncert, plot_frequency
from utils import nll_criterion_gaussian
from glob import glob
from test import AuxModel, train_aux
%matplotlib notebook

In [2]:
base_model = 'densenet201'

In [3]:
assert base_model in ['resnet101', 'densenet201', 'efficientnetb4']
device = torch.device("cuda:1")
resize_to = (256, 256)

In [4]:
batch_size = 16

data_set = BoneAgeDataset(data_dir='/media/fastdata/laves/rsna-bone-age/', augment=False, resize_to=resize_to)
assert len(data_set) > 0

calib_indices = torch.load('./data_indices/boneage_valid_indices.pth')
test_indices = torch.load('./data_indices/boneage_test_indices.pth')

print(calib_indices.shape)
print(test_indices.shape)

calib_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
                                           sampler=SubsetRandomSampler(calib_indices))
test_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
                                          sampler=SubsetRandomSampler(test_indices))

torch.Size([2000])
torch.Size([4000])


In [5]:
model = BreastPathQModel(base_model, in_channels=1, out_channels=1).to(device)

checkpoint_path = glob(f"/media/fastdata/laves/regression_snapshots/{base_model}_gaussian_boneage_2.pth.tar")[0]

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
print("Loading previous weights at epoch " + str(checkpoint['epoch']))

Loading previous weights at epoch 485


In [6]:
model.eval()
y_p_calib = []
vars_calib = []
logvars_calib = []
targets_calib = []

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tqdm(calib_loader)):
        data, target = data.to(device), target.to(device)

        y_p, logvar, var_bayesian = model(data, dropout=True, mc_dropout=True, test=True)

        y_p_calib.append(y_p.detach())
        vars_calib.append(var_bayesian.detach())
        logvars_calib.append(logvar.detach())
        targets_calib.append(target.detach())

100%|██████████| 125/125 [04:36<00:00,  2.21s/it]


In [7]:
y_p_calib = torch.cat(y_p_calib, dim=1).clamp(0, 1).permute(1,0,2)
mu_calib = y_p_calib.mean(dim=1)
var_calib = torch.cat(vars_calib, dim=0)
logvars_calib = torch.cat(logvars_calib, dim=1).permute(1,0,2)
logvar_calib = logvars_calib.mean(dim=1)
target_calib = torch.cat(targets_calib, dim=0)

In [8]:
err_calib = (target_calib-mu_calib).pow(2).mean(dim=1, keepdim=True).sqrt()
errvar_calib = (y_p_calib-target_calib.unsqueeze(1).repeat(1,25,1)).pow(2).mean(dim=(1,2)).unsqueeze(-1)

uncertainty = 'total'

uncert_calib_aleatoric = logvar_calib.exp().mean(dim=1, keepdim=True)
uncert_calib_epistemic = var_calib.mean(dim=1, keepdim=True)

if uncertainty == 'aleatoric':
    uncert_calib = uncert_calib_aleatoric.sqrt().clamp(0, 1)
elif uncertainty == 'epistemic':
    uncert_calib = uncert_calib_epistemic.sqrt().clamp(0, 1)
else:
    uncert_calib = (uncert_calib_aleatoric + uncert_calib_epistemic).sqrt().clamp(0, 1)  # total

In [9]:
print((err_calib**2).mean())
print(errvar_calib.mean())
print((uncert_calib**2).mean())
print(uncert_calib_aleatoric.sqrt().mean())
print(uncert_calib_epistemic.sqrt().mean())

err_calib = errvar_calib.sqrt()

tensor(0.0035, device='cuda:1')
tensor(0.0039, device='cuda:1')
tensor(0.0031, device='cuda:1')
tensor(0.0510, device='cuda:1')
tensor(0.0199, device='cuda:1')


In [10]:
fig, ax = plt.subplots(1)
ax.plot(uncert_calib.cpu()**2, err_calib.cpu()[:,0]**2, '.')

max_val = max((err_calib**2).max().item(), (uncert_calib**2).max().item())
ax.plot([0, max_val], [0, max_val], '--')
ax.set_xlabel('uncert')
ax.set_ylabel('err')
plt.show()

<IPython.core.display.Javascript object>

In [11]:
# calculate optimal T
S = (err_calib**2 / uncert_calib**2).mean().sqrt()
print(S)

tensor(1.1150, device='cuda:1')


In [12]:
class Scaler(torch.nn.Module):
    def __init__(self, init_S=1.0):
        super().__init__()
        self.S = torch.nn.Parameter(torch.tensor([init_S]))

    def forward(self, x):
        return self.S.mul(x)

In [13]:
# find optimal S
scaler = Scaler(init_S=S).to(device)
s_opt = torch.optim.LBFGS([scaler.S], lr=3e-2, max_iter=2000)

def closure():
    s_opt.zero_grad()

    loss = nll_criterion_gaussian(mu_calib, scaler(uncert_calib).pow(2).log(), target_calib)

    loss.backward()
    return loss

s_opt.step(closure)
print(scaler.S.item())

1.0570523738861084


In [14]:
aux = AuxModel(1).to(device)
loss = train_aux(aux, nll_criterion_gaussian, mu_calib, uncert_calib, target_calib)
print(loss)

-4.724762916564941


In [15]:
aux.train()
print(nll_criterion_gaussian(mu_calib, uncert_calib.pow(2).log(), target_calib).item())
print(nll_criterion_gaussian(mu_calib, (S*uncert_calib).pow(2).log(), target_calib).item())
print(nll_criterion_gaussian(mu_calib, scaler(uncert_calib).pow(2).log(), target_calib).item())
print(nll_criterion_gaussian(mu_calib, aux(uncert_calib), target_calib).item())
aux.eval()

-4.722522258758545
-4.722490310668945
-4.728422164916992
-4.724762916564941


AuxModel(
  (linear1): Linear(in_features=1, out_features=16, bias=True)
  (fc): Linear(in_features=16, out_features=1, bias=True)
)

In [16]:
print(torch.nn.functional.mse_loss(uncert_calib**2, err_calib**2, reduction='sum').item())
print(torch.nn.functional.mse_loss((S*uncert_calib)**2, err_calib**2, reduction='sum').item())
print(torch.nn.functional.mse_loss(scaler(uncert_calib)**2, err_calib**2, reduction='sum').item())
print(torch.nn.functional.mse_loss(aux(uncert_calib)**2, err_calib**2, reduction='sum').item())

0.10146330296993256
0.09993088990449905
0.10039717704057693
0.10099232196807861


In [17]:
uce, err_in_bin, avg_sigma_in_bin, freq_in_bin = uceloss(err_calib**2, uncert_calib**2)
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce)
plt.show()

<IPython.core.display.Javascript object>

tensor([0.0009], device='cuda:1')


In [18]:
uce, err_in_bin, avg_sigma_in_bin, freq_in_bin = uceloss(err_calib**2, (S*uncert_calib)**2)
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce.item())
plt.show()

<IPython.core.display.Javascript object>

0.0004252669750712812


In [19]:
uce, err_in_bin, avg_uncert_in_bin, freq_in_bin = uceloss(err_calib**2, scaler(uncert_calib)**2)
plot_uncert(err_in_bin.cpu(), avg_uncert_in_bin.cpu())
plt.show()
print(uce.item())
#fig, ax = plot_frequency(scaler(uncert_calib).cpu(), freq_in_bin.cpu())
#fig.show()

<IPython.core.display.Javascript object>

0.0006010259967297316


In [20]:
uce, err_in_bin, avg_uncert_in_bin, freq_in_bin = uceloss(err_calib**2, aux(uncert_calib)**2)
plot_uncert(err_in_bin.cpu(), avg_uncert_in_bin.cpu())
plt.show()
print(uce.item())

<IPython.core.display.Javascript object>

0.0006580170011147857


In [22]:
y_p_test = []
mus_test = []
vars_test = []
logvars_test = []
targets_test = []

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
        data, target = data.to(device), target.to(device)

        y_p, logvar, var_bayesian = model(data, dropout=True, mc_dropout=True, test=True)

        y_p_test.append(y_p.detach())
        vars_test.append(var_bayesian.detach())
        logvars_test.append(logvar.detach())
        targets_test.append(target.detach())

    y_p_test = torch.cat(y_p_test, dim=1).clamp(0, 1).permute(1,0,2)
    mu_test = y_p_test.mean(dim=1)
    var_test = torch.cat(vars_test, dim=0)
    logvars_test = torch.cat(logvars_test, dim=1).permute(1,0,2)
    logvar_test = logvars_test.mean(dim=1)
    target_test = torch.cat(targets_test, dim=0)

    y_p_test_list.append(y_p_test)
    mu_test_list.append(mu_test)
    var_test_list.append(var_test)
    logvars_test_list.append(logvars_test)
    logvar_test_list.append(logvar_test)
    target_test_list.append(target_test)


  0%|          | 0/250 [00:00<?, ?it/s][A
  0%|          | 1/250 [00:02<09:02,  2.18s/it][A
  1%|          | 2/250 [00:04<09:00,  2.18s/it][A
  1%|          | 3/250 [00:06<08:56,  2.17s/it][A
  2%|▏         | 4/250 [00:08<08:50,  2.16s/it][A
  2%|▏         | 5/250 [00:10<08:47,  2.15s/it][A
  2%|▏         | 6/250 [00:12<08:41,  2.14s/it][A
  3%|▎         | 7/250 [00:14<08:35,  2.12s/it][A
  3%|▎         | 8/250 [00:17<08:32,  2.12s/it][A
  4%|▎         | 9/250 [00:19<08:37,  2.15s/it][A
  4%|▍         | 10/250 [00:21<08:32,  2.14s/it][A
  4%|▍         | 11/250 [00:23<08:30,  2.14s/it][A
  5%|▍         | 12/250 [00:25<08:25,  2.12s/it][A
  5%|▌         | 13/250 [00:27<08:19,  2.11s/it][A
  6%|▌         | 14/250 [00:29<08:17,  2.11s/it][A
  6%|▌         | 15/250 [00:31<08:15,  2.11s/it][A
  6%|▋         | 16/250 [00:33<08:09,  2.09s/it][A
  7%|▋         | 17/250 [00:36<08:05,  2.08s/it][A
  7%|▋         | 18/250 [00:38<08:07,  2.10s/it][A
  8%|▊         | 19/250 [00:4

 62%|██████▏   | 156/250 [05:36<03:20,  2.14s/it][A
 63%|██████▎   | 157/250 [05:39<03:23,  2.18s/it][A
 63%|██████▎   | 158/250 [05:41<03:19,  2.17s/it][A
 64%|██████▎   | 159/250 [05:43<03:14,  2.13s/it][A
 64%|██████▍   | 160/250 [05:45<03:13,  2.15s/it][A
 64%|██████▍   | 161/250 [05:47<03:11,  2.16s/it][A
 65%|██████▍   | 162/250 [05:49<03:11,  2.18s/it][A
 65%|██████▌   | 163/250 [05:52<03:09,  2.18s/it][A
 66%|██████▌   | 164/250 [05:54<03:08,  2.19s/it][A
 66%|██████▌   | 165/250 [05:56<03:07,  2.21s/it][A
 66%|██████▋   | 166/250 [05:58<03:02,  2.17s/it][A
 67%|██████▋   | 167/250 [06:00<02:59,  2.16s/it][A
 67%|██████▋   | 168/250 [06:02<02:55,  2.14s/it][A
 68%|██████▊   | 169/250 [06:05<02:53,  2.14s/it][A
 68%|██████▊   | 170/250 [06:07<02:49,  2.12s/it][A
 68%|██████▊   | 171/250 [06:09<02:48,  2.14s/it][A
 69%|██████▉   | 172/250 [06:11<02:51,  2.20s/it][A
 69%|██████▉   | 173/250 [06:13<02:52,  2.23s/it][A
 70%|██████▉   | 174/250 [06:16<02:47,  2.21s/

NameError: name 'y_p_test_list' is not defined

In [None]:
err_test = (target_test-mu_test).pow(2).mean(dim=1, keepdim=True).sqrt()
errvar_test = (y_p_test-target_test.unsqueeze(1).repeat(1,25,1)).pow(2).mean(dim=(1,2)).unsqueeze(-1)

uncert_aleatoric_test = logvar_test.exp().mean(dim=1, keepdim=True)
uncert_epistemic_test = var_test.mean(dim=1, keepdim=True)

if uncertainty == 'aleatoric':
    uncert_test = uncert_aleatoric_test.sqrt().clamp(0, 1)
elif uncertainty == 'epistemic':
    uncert_test = uncert_epistemic_test.sqrt().clamp(0, 1)
else:
    uncert_test = (u_a_t + u_e_t).sqrt().clamp(0, 1)

In [None]:
print(err_test_s.mean())
print(uncert_test_s.mean())

In [None]:
aux.train()
print(nll_criterion_gaussian(mu_test_s, uncert_test_s.pow(2).log(), target_test_s).item())
print(nll_criterion_gaussian(mu_test_s, (S*uncert_test_s).pow(2).log(), target_test_s).item())
print(nll_criterion_gaussian(mu_test_s, scaler(uncert_test_s).pow(2).log(), target_test_s).item())
print(nll_criterion_gaussian(mu_test_s, aux(uncert_test_s), target_test_s).item())
aux.eval()

In [None]:
print(torch.nn.functional.mse_loss(uncert_test_s**2, err_test_s**2, reduction='sum').item())
print(torch.nn.functional.mse_loss((S*uncert_test_s)**2, err_test_s**2, reduction='sum').item())
print(torch.nn.functional.mse_loss(scaler(uncert_test_s)**2, err_test_s**2, reduction='sum').item())
print(torch.nn.functional.mse_loss(aux(uncert_test_s)**2, err_test_s**2, reduction='sum').item())

In [None]:
n_bins = 15
uce_uncal, _, _, _ = uceloss(err_test_s**2, uncert_test_s**2, n_bins=n_bins)
_, err_uncal, sigma_uncal, _ = uceloss(err_test_s**2, uncert_test_s**2, n_bins=n_bins, range=[0.0, 0.0075])
plot_uncert(err_uncal.cpu(), sigma_uncal.cpu())
print(uce_uncal.item())

In [None]:
uce, _, _, _ = uceloss(err_test**2, (S*uncert_test)**2, n_bins=n_bins)
_, err_in_bin, avg_sigma_in_bin, _ = uceloss(err_test**2, (S*uncert_test)**2, n_bins=n_bins, range=[0.0, 0.0075])
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce.item())

In [None]:
uce_cal, _, _, _ = uceloss(err_test**2, scaler(uncert_test)**2, n_bins=n_bins)
_, err_cal, sigma_cal, in_bin = uceloss(err_test**2, scaler(uncert_test)**2, n_bins=n_bins, range=[0.0, 0.0075])
plot_uncert(err_cal.cpu(), sigma_cal.cpu())
plt.show()
print(uce_cal.item())
fig, ax = plot_frequency(scaler(uncert_test).cpu(), in_bin.cpu(), n_bins=n_bins)
fig.show()

In [None]:
uce_aux, _, _, _ = uceloss(err_test**2, aux(uncert_test)**2, n_bins=n_bins)
_, err_aux, sigma_aux, freq_in_bin = uceloss(err_test**2, aux(uncert_test)**2, n_bins=n_bins, range=[0.0, 0.0075])
plot_uncert(err_aux.cpu(), sigma_aux.cpu())
plt.show()
print(uce_aux.item())

# Unreliable Predictions

The subsequent figure shows the mean MSE after rejecting all predictions, where uncert > uncert_max. The shadow width visualizes the percentage of rejected samples. On the very right side of the plot, the width represents 100 % of the test set samples.

In [None]:
import seaborn as sns

sns.set()
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['font.size'] = 8
matplotlib.rcParams['text.latex.preamble'] = [
    r'\usepackage{bm}']

In [None]:
e_s_list = []
u_s_list = []
t_s_list = []
num_rejected_s = []
uncert_s = scaler(uncert_test)**2

for thresh in np.linspace(uncert_s.max().item(), uncert_s.min().item(), 100)[:-1]:
    e = (err_test**2)[torch.where(uncert_s < thresh)]
    u = uncert_s[torch.where(uncert_s < thresh)]

    t_s_list.append(thresh)
    e_s_list.append(e.mean().item())
    u_s_list.append(u.mean().item())
    num_rejected_s.append((err_test.shape[0]-e.shape[0])/err_test.shape[0])

t_s_list = np.array(t_s_list)
e_s_list = np.array(e_s_list)
u_s_list = np.array(u_s_list)
num_rejected_s = np.array(num_rejected_s)

In [None]:
fig, ax = plt.subplots(1, figsize=(3.0, 2.0))

ax.plot([0.0008, e_s_list.max()], [0.0008, e_s_list.max()], 'k--')
ax.plot(t_s_list, e_s_list)
ax.fill_between(t_s_list, e_s_list-num_rejected_s/1000, e_s_list+num_rejected_s/1000, alpha=0.5)
ax.set_xlim(
    #ax[0].get_xlim()[1],
    0.0065,
    0.0005)
ax.set_xlabel(r'uncertainty threshold $ \Sigma^2_{\mathsf{max}} $')
ax.set_ylabel(r'MSE')
ax.set_title(r'$ \sigma $ scaling')

fig.tight_layout()
fig.show()
fig.savefig(f"rejection_s_{base_model}.pdf", bbox_inches='tight', pad_inches=0.01)

# OOD Detection

In [None]:
mus_ood = []
uncerts_ood = []
errs_ood = []
repetitions = 1

with torch.no_grad():
    for s in [0.0, 0.1, 0.2]:
        mu_ood = []
        uncert_ood = []
        err_ood = []
        for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
                data, target = data.to(device), target.to(device)
                img = data + torch.randn_like(data)*s + s
                y_p, logvar, var_bayesian = model(img, dropout=True, mc_dropout=True)
                
                

                mu_ood.append(mu.cpu().squeeze().numpy())
                uncert = (uncert_aleatoric_test + uncert_epistemic_test).clamp(0, 1)
                uncert_ood.append(uncert.cpu().squeeze().numpy())
                err_ood.append(err.cpu().squeeze().numpy())

        mus_ood.append(mu_ood)
        uncerts_ood.append(uncert_ood)
        errs_ood.append(err_ood)
        
y_p_test_list = []
mu_test_list = []
var_test_list = []
logvars_test_list = []
logvar_test_list = []
target_test_list = []

for s in [0.0, 0.1, 0.2]:
    y_p_test = []
    mus_test = []
    vars_test = []
    logvars_test = []
    targets_test = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
            data, target = data.to(device), target.to(device)

            y_p, logvar, var_bayesian = model(data, dropout=True, mc_dropout=True, test=True)

            y_p_test.append(y_p.detach())
            vars_test.append(var_bayesian.detach())
            logvars_test.append(logvar.detach())
            targets_test.append(target.detach())

        y_p_test = torch.cat(y_p_test, dim=1).clamp(0, 1).permute(1,0,2)
        mu_test = y_p_test.mean(dim=1)
        var_test = torch.cat(vars_test, dim=0)
        logvars_test = torch.cat(logvars_test, dim=1).permute(1,0,2)
        logvar_test = logvars_test.mean(dim=1)
        target_test = torch.cat(targets_test, dim=0)

        y_p_test_list.append(y_p_test)
        mu_test_list.append(mu_test)
        var_test_list.append(var_test)
        logvars_test_list.append(logvars_test)
        logvar_test_list.append(logvar_test)
        target_test_list.append(target_test)

In [None]:
err_test = [(target_test-mu_test).pow(2).mean(dim=1, keepdim=True).sqrt() for target_test, mu_test in zip(target_test_list, mu_test_list)]
errvar_test = [(y_p_test-target_test.unsqueeze(1).repeat(1,25,1)).pow(2).mean(dim=(1,2)).unsqueeze(-1) for target_test, y_p_test in zip(target_test_list, y_p_test_list)]

uncert_aleatoric_test = [logvar_test.exp().mean(dim=1, keepdim=True) for logvar_test in logvar_test_list]
uncert_epistemic_test = [var_test.mean(dim=1, keepdim=True) for var_test in var_test_list]

if uncertainty == 'aleatoric':
    uncert_test = uncert_aleatoric_test.sqrt().clamp(0, 1)
elif uncertainty == 'epistemic':
    uncert_test = uncert_epistemic_test.sqrt().clamp(0, 1)
else:
    uncert_test = [(u_a_t + u_e_t).sqrt().clamp(0, 1) for u_a_t, u_e_t in zip(uncert_aleatoric_test, uncert_epistemic_test)]
    
uncerts_ood.append(uncert_test.detach().cpu().numpy())
errs_ood.append(errvar_test.detach().cpu().numpy())

In [None]:
uncerts_ood_np = np.array(uncerts_ood).reshape(3,-1)
errs_ood_np = np.array(errs_ood).reshape(3,-1)

In [None]:
fig, ax = plt.subplots(figsize=(3.3, 2.0))

labels = [r'$c=0.0$', r'$c=0.1$', r'$c=0.2$']

for i in range(3):
    data = uncerts_ood_np[i][np.where(uncerts_ood_np[i] < 0.01)]
    sns.distplot(data, hist_kws={'stacked': True}, kde=True, hist=False, norm_hist=True, label=labels[i], ax=ax)

ax.legend(prop={'size': 10})
#ax.set_ylim([-30, 430])
#ax.set_xticks([0, 0.005, 0.01])
ax.set_xlabel(r'uncertainty')
ax.set_ylabel(f'frequency')
ax.set_title(f'$ \sigma $ scaling')
fig.tight_layout()
fig.show()
fig.savefig(f"ood_boneage_s_{base_model}.pdf", bbox_inches='tight', pad_inches=0.01)