In [1]:
%matplotlib notebook
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
from tqdm import tqdm
from torch.utils.data.sampler import SubsetRandomSampler
from data_generator_oct import OCTDataset
from models import BreastPathQModel
from uce import uceloss
from calibration_plots import plot_uncert, plot_frequency
from utils import nll_criterion_laplacian

matplotlib.rcParams['font.size'] = 8

In [2]:
base_model = 'densenet121'

In [3]:
assert base_model in ['resnet50', 'resnet101', 'densenet121', 'densenet201', 'efficientnetb0', 'efficientnetb4']
device = torch.device("cuda:1")
resize_to = (256, 256)

In [4]:
batch_size = 16
resize_to = (256, 256)

data_dir = '/media/fastdata/laves/oct_data_needle/data'
data_set = OCTDataset(data_dir=data_dir, augment=False, resize_to=resize_to)
assert len(data_set) > 0

calib_indices = torch.load('./oct_valid_indices.pth')
test_indices = torch.load('./oct_test_indices.pth')

#calib_indices = test_indices[:len(test_indices)//2]
#test_indices = test_indices[len(test_indices)//2:]

print(calib_indices.shape)
print(test_indices.shape)

calib_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
                                           sampler=SubsetRandomSampler(calib_indices))
test_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
                                          sampler=SubsetRandomSampler(test_indices))

torch.Size([850])
torch.Size([850])


In [5]:
from glob import glob
model = BreastPathQModel(base_model, out_channels=6).to(device)

checkpoint_path = glob(f"/media/fastdata/laves/regression_snapshots/{base_model}_laplacian_oct_399.pth.tar")[0]

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
print("Loading previous weights at epoch " + str(checkpoint['epoch']) + " from\n" + checkpoint_path)

Loading previous weights at epoch 399 from
/media/fastdata/laves/regression_snapshots/densenet121_laplacian_oct_399.pth.tar


In [6]:
model.eval()
mus_calib = []
vars_calib = []
logvars_calib = []
targets_calib = []

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tqdm(calib_loader)):
        data, target = data.to(device), target.to(device)

        mu, logvar, var_bayesian = model(data, dropout=True, mc_dropout=True)

        mus_calib.append(mu.detach())
        vars_calib.append(var_bayesian.detach())
        logvars_calib.append(logvar.detach())
        targets_calib.append(target.detach())

mu_calib = torch.cat(mus_calib, dim=0).clamp(0, 1)
var_calib = torch.cat(vars_calib, dim=0)
logvar_calib = torch.cat(logvars_calib, dim=0)
target_calib = torch.cat(targets_calib, dim=0)

print("l1 =", torch.nn.functional.l1_loss(mu_calib, target_calib).item())

100%|██████████| 54/54 [01:15<00:00,  1.13it/s]

l1 = 0.03203185275197029





In [7]:
err_calib = torch.pow(target_calib-mu_calib, 2).sqrt()

uncertainty = 'total'

uncert_calib_aleatoric = logvar_calib.exp()
uncert_calib_epistemic = var_calib.sqrt()

if uncertainty == 'aleatoric':
    uncert_calib = uncert_calib_aleatoric.clamp(0, 1)
elif uncertainty == 'epistemic':
    uncert_calib = uncert_calib_epistemic.clamp(0, 1)
else:
    uncert_calib = (uncert_calib_aleatoric + uncert_calib_epistemic).clamp(0, 1)  # total

In [8]:
fig, ax = plt.subplots(1)
ax.plot(uncert_calib.cpu(), err_calib.cpu()[:,0], '.')

max_val = max(err_calib.max().item(), uncert_calib.max().item())
ax.plot([0, max_val], [0, max_val], '--')
ax.set_xlabel('uncert')
ax.set_ylabel('err')
plt.show()

<IPython.core.display.Javascript object>

In [9]:
# calculate optimal T
S = err_calib.sum() / uncert_calib.sum()
print(S)

tensor(0.5856, device='cuda:1')


In [10]:
class Scaler(torch.nn.Module):
    def __init__(self, init_S=1.0):
        super().__init__()
        self.S = torch.nn.Parameter(torch.tensor([init_S]))
        
    def forward(self, x):
        return self.S.mul(x)

In [11]:
# find optimal S
scaler = Scaler(init_S=S).to(device)
s_opt = torch.optim.LBFGS([scaler.S], lr=3e-4, max_iter=100)

def closure():
    s_opt.zero_grad()
    
    loss = nll_criterion_laplacian(mu_calib, scaler(uncert_calib).log(), target_calib)

    loss.backward()
    return loss

s_opt.step(closure)
print(scaler.S.item())

0.585617184638977


In [12]:
print(nll_criterion_laplacian(mu_calib, uncert_calib.log(), target_calib).item())
print(nll_criterion_laplacian(mu_calib, (S*uncert_calib).log(), target_calib).item())
print(nll_criterion_laplacian(mu_calib, scaler(uncert_calib).log(), target_calib).item())

-2.323625326156616
-2.4293508529663086
-2.4293532371520996


In [13]:
print(torch.nn.functional.l1_loss(uncert_calib, err_calib, reduction='sum').item())
print(torch.nn.functional.l1_loss((S*uncert_calib), err_calib, reduction='sum').item())
print(torch.nn.functional.l1_loss(scaler(uncert_calib), err_calib, reduction='sum').item())

160.08973693847656
107.79620361328125
107.79857635498047


In [14]:
outlier = 0.025
n_bins = 51

In [15]:
uce, err_in_bin, avg_sigma_in_bin, prop_in_bin = uceloss(err_calib, uncert_calib)
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce.item())
plt.show()

<IPython.core.display.Javascript object>

0.02266969531774521


In [16]:
uce, err_in_bin, avg_sigma_in_bin, num_in_bin = uceloss(err_calib, (S*uncert_calib))
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce.item())
plt.show()

<IPython.core.display.Javascript object>

0.0019648924935609102


In [17]:
uce, err_in_bin, avg_uncert_in_bin, in_bin = uceloss(err_calib, scaler(uncert_calib))
plot_uncert(err_in_bin.cpu(), avg_uncert_in_bin.cpu())
plt.show()
print(uce.item())
fig, ax = plot_frequency(scaler(uncert_calib).cpu(), in_bin.cpu())
fig.show()

<IPython.core.display.Javascript object>

0.0019635853823274374


<IPython.core.display.Javascript object>

In [18]:
mus_test = []
vars_test = []
logvars_test = []
targets_test = []

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
        data, target = data.to(device), target.to(device)

        mu, logvar, var_bayesian = model(data, dropout=True, mc_dropout=True)

        mus_test.append(mu.detach())
        vars_test.append(var_bayesian.detach())
        logvars_test.append(logvar.detach())
        targets_test.append(target.detach())

mu_test = torch.cat(mus_test, dim=0).clamp(0, 1)
var_test = torch.cat(vars_test, dim=0)
logvar_test = torch.cat(logvars_test, dim=0)
target_test = torch.cat(targets_test, dim=0)

print("l1 =", torch.nn.functional.l1_loss(mu_test, target_test).item())

100%|██████████| 54/54 [00:58<00:00,  1.15it/s]

l1 = 0.03335777297616005





In [19]:
err_test = torch.pow(target_test-mu_test, 2).sqrt()

uncert_aleatoric_test = logvar_test.exp()
uncert_epistemic_test = var_test.sqrt()

if uncertainty == 'aleatoric':
    uncert_test = uncert_aleatoric_test.clamp(0, 1)
elif uncertainty == 'epistemic':
    uncert_test = uncert_epistemic_test.clamp(0, 1)
else:
    uncert_test = (uncert_aleatoric_test + uncert_epistemic_test).clamp(0, 1)  # total

In [20]:
print(nll_criterion_laplacian(mu_test, uncert_test.log(), target_test).item())
print(nll_criterion_laplacian(mu_test, S*uncert_test.log(), target_test).item())
print(nll_criterion_laplacian(mu_test, scaler(uncert_test).log(), target_test).item())

-2.299947738647461
-1.5266326665878296
-2.396505355834961


In [21]:
print(torch.nn.functional.l1_loss(uncert_test, err_test, reduction='sum').item())
print(torch.nn.functional.l1_loss((S*uncert_test), err_test, reduction='sum').item())
print(torch.nn.functional.l1_loss((scaler.S*uncert_test), err_test, reduction='sum').item())

161.80563354492188
110.91647338867188
110.91874694824219


In [22]:
outlier = 0.0
n_bins = 21

In [23]:
uce, err_in_bin, avg_sigma_in_bin, _ = uceloss(err_test, uncert_test, outlier=outlier, n_bins=n_bins)
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce)

<IPython.core.display.Javascript object>

tensor([0.0226], device='cuda:1')


In [24]:
uce, err_in_bin, avg_sigma_in_bin, _ = uceloss(err_test, S*uncert_test,  outlier=outlier, n_bins=n_bins)
plot_uncert(err_in_bin.cpu(), avg_sigma_in_bin.cpu())
print(uce)

<IPython.core.display.Javascript object>

tensor([0.0024], device='cuda:1')


In [25]:
uce, err_in_bin, uncert_in_bin, in_bin = uceloss(err_test, scaler(uncert_test), outlier=outlier, n_bins=n_bins)
plot_uncert(err_in_bin.cpu(), uncert_in_bin.cpu())
plt.show()
print(uce)
fig, ax = plot_frequency(scaler(uncert_test).cpu(), in_bin.cpu(), n_bins=n_bins)
fig.show()

<IPython.core.display.Javascript object>

tensor([0.0024], device='cuda:1', grad_fn=<AddBackward0>)


<IPython.core.display.Javascript object>