<a href="https://colab.research.google.com/github/finardi/tutos/blob/master/Neural_Network_Uncertaincy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This notebooks deeply delves into exploration of neural network uncertaincy. It shows how networks tend to be overconfident and how we can resolve these issues by well known methods like Temperature Scaling. For further study I recommend looking into [this](https://arxiv.org/pdf/1706.04599.pdf) paper intuitively explaining the backgrounds.


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
# Use the GPU provided by Google Colab
device = 'cuda:0'

# Allow reproducability
torch.manual_seed(0)
np.random.seed(0)

In [None]:
# Normalize the images by the imagenet mean/std since the nets are pretrained
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

train_set, val_set = torch.utils.data.random_split(dataset, [45000, 5000])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=128,
                                          shuffle=True, num_workers=2)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=128,
                                          shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128,
                                         shuffle=False, num_workers=2)


In [None]:
import torchvision.models as models
import torch.nn as nn
num_classes = 10
net = models.resnet101(pretrained=True)
net.fc = nn.Linear(2048, num_classes)
net = net.to(device)

In [None]:
# Training loop
import torch.optim as optim
from tqdm.notebook import tqdm as tqdm

NUM_EPOCHS = 2

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)

for epoch in range(NUM_EPOCHS):  # loop over the dataset multiple times

    running_loss = 0.0
    net.train()
    print(f'Training iteration {epoch}')
    for i, data in enumerate(tqdm(train_loader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        

    corrects = []
    net.eval()
    classified_right = 0
    print('Evaluating on validation set')
    for i, data in enumerate(tqdm(val_loader, 0)):
        with torch.no_grad():
          inputs, labels = data[0].to(device), data[1].to(device)
          outputs = net(inputs)
          _, pred_classes = torch.max(outputs, 1)

          loss = criterion(outputs, labels)
          classified_right += (pred_classes == labels).sum().item()
          
    acc = classified_right / len(val_set)

    print(f'Epoch {epoch}  Acc: {acc}')


In [None]:
################### Testing ######################

from tqdm.notebook import tqdm

# Use kwags for calibration method specific parameters
def test(calibration_method=None, **kwargs):
  preds = []
  labels_oneh = []
  correct = 0
  net.eval()
  with torch.no_grad():
      for data in tqdm(test_loader):
          images, labels = data[0].to('cuda:0'), data[1].to('cuda:0')

          pred = net(images)
          
          if calibration_method:
            pred = calibration_method(pred, kwargs)

          # Get softmax values for net input and resulting class predictions
          sm = nn.Softmax(dim=1)
          pred = sm(pred)

          _, predicted_cl = torch.max(pred.data, 1)
          pred = pred.cpu().detach().numpy()

          # Convert labels to one hot encoding
          label_oneh = torch.nn.functional.one_hot(labels, num_classes=num_classes)
          label_oneh = label_oneh.cpu().detach().numpy()

          preds.extend(pred)
          labels_oneh.extend(label_oneh)

          # Count correctly classified samples for accuracy
          correct += sum(predicted_cl == labels).item()

  preds = np.array(preds).flatten()
  labels_oneh = np.array(labels_oneh).flatten()

  correct_perc = correct / len(test_set)
  print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct_perc))
  print(correct_perc)
  
  return preds, labels_oneh

preds, labels_oneh = test()

In [None]:
def calc_bins(preds):
  # Assign each prediction to a bin
  num_bins = 10
  bins = np.linspace(0.1, 1, num_bins)
  binned = np.digitize(preds, bins)

  # Save the accuracy, confidence and size of each bin
  bin_accs = np.zeros(num_bins)
  bin_confs = np.zeros(num_bins)
  bin_sizes = np.zeros(num_bins)

  for bin in range(num_bins):
    bin_sizes[bin] = len(preds[binned == bin])
    if bin_sizes[bin] > 0:
      bin_accs[bin] = (labels_oneh[binned==bin]).sum() / bin_sizes[bin]
      bin_confs[bin] = (preds[binned==bin]).sum() / bin_sizes[bin]

  return bins, binned, bin_accs, bin_confs, bin_sizes



# Visualizations and metrics

The Reliability Diagram as figured below intuitively show the relation between expected sample accuracy per bin and confidence.

$acc(B_m) = \frac{1}{|B_m|} \sum_{i \in B_m} \mathbf{1}(\hat{y}_i = y_i)$

$conf(B_m) = \frac{1}{|B_m|} \sum_{i \in B_m} \hat{p}_i$

For the figure I chose $M=10$, describing the number of seperate bins where the predictions are put in their respective bin based on their magnitude ($Bin 1 = [0.0, 0.1), Bin 2 = [0.1, 0.2), ...)$).

For a perfectly calibrated model is defined as $P(\hat{Y} = Y | \hat{P} = p) = p, \forall p \in [0, 1]$. For Example: Given 100 predictions, each with confidence of 0.8, we expect 80 to be correctly classified. Bars under the identity line show underconfident behavior while bars above it signal overconfidence.


The Expected Calibration Error (ECE) is taking the weighted average of the bins' accuracy/cofidence differences.

$ECE = \sum_{m=1}^{M} \frac{|B_m|}{n} |acc(B_m) + conf(B_m)|$

The Maximum Calibration Error (MCE) focuses more on high risk applications where the maximum accuracy/confidence difference is more important than just the average.

$MCE = \max_m |acc(B_m) + conf(B_m)|$

In [None]:
def get_metrics(preds):
  ECE = 0
  MCE = 0
  bins, _, bin_accs, bin_confs, bin_sizes = calc_bins(preds)

  for i in range(len(bins)):
    abs_conf_dif = abs(bin_accs[i] - bin_confs[i])
    ECE += (bin_sizes[i] / sum(bin_sizes)) * abs_conf_dif
    MCE = max(MCE, abs_conf_dif)

  return ECE, MCE

In [None]:
import matplotlib.patches as mpatches

def draw_reliability_graph(preds):
  ECE, MCE = get_metrics(preds)
  bins, _, bin_accs, _, _ = calc_bins(preds)

  fig = plt.figure(figsize=(8, 8))
  ax = fig.gca()

  # x/y limits
  ax.set_xlim(0, 1.05)
  ax.set_ylim(0, 1)

  # x/y labels
  plt.xlabel('Confidence')
  plt.ylabel('Accuracy')

  # Create grid
  ax.set_axisbelow(True) 
  ax.grid(color='gray', linestyle='dashed')

  # Error bars
  plt.bar(bins, bins,  width=0.1, alpha=0.3, edgecolor='black', color='r', hatch='\\')

  # Draw bars and identity line
  plt.bar(bins, bin_accs, width=0.1, alpha=1, edgecolor='black', color='b')
  plt.plot([0,1],[0,1], '--', color='gray', linewidth=2)

  # Equally spaced axes
  plt.gca().set_aspect('equal', adjustable='box')

  # ECE and MCE legend
  ECE_patch = mpatches.Patch(color='green', label='ECE = {:.2f}%'.format(ECE*100))
  MCE_patch = mpatches.Patch(color='red', label='MCE = {:.2f}%'.format(MCE*100))
  plt.legend(handles=[ECE_patch, MCE_patch])

  #plt.show()
  
  plt.savefig('calibrated_network.png', bbox_inches='tight')

#draw_reliability_graph(preds)


In [None]:
def T_scaling(logits, args):
  temperature = args.get('temperature', None)
  return torch.div(logits, temperature)

# Temperature Scaling

Temperature Scaling is a parametric calibration approach on the validation set using the Negative-Log-Likelihood (NLL) los. It learns a single parameter $T$ for all classes to update the confidences to $\hat{q}_i = max_k  \sigma_{SM}(z_i/T)^{(k)}$



More sample code can be found in [this](https://github.com/gpleiss/temperature_scaling) awesome GitHub repository by "gpleis".

In [None]:
temperature = nn.Parameter(torch.ones(1).cuda())
args = {'temperature': temperature}
criterion = nn.CrossEntropyLoss()

# Removing strong_wolfe line search results in jump after 50 epochs
optimizer = optim.LBFGS([temperature], lr=0.001, max_iter=10000, line_search_fn='strong_wolfe')

logits_list = []
labels_list = []
temps = []
losses = []

for i, data in enumerate(tqdm(val_loader, 0)):
    images, labels = data[0].to(device), data[1].to(device)

    net.eval()
    with torch.no_grad():
      logits_list.append(net(images))
      labels_list.append(labels)

# Create tensors
logits_list = torch.cat(logits_list).to(device)
labels_list = torch.cat(labels_list).to(device)

def _eval():
  loss = criterion(T_scaling(logits_list, args), labels_list)
  loss.backward()
  temps.append(temperature.item())
  losses.append(loss)
  return loss


optimizer.step(_eval)

print('Final T_scaling factor: {:.2f}'.format(temperature.item()))

plt.subplot(121)
plt.plot(list(range(len(temps))), temps)

plt.subplot(122)
plt.plot(list(range(len(losses))), losses)
plt.show()



In [None]:
preds_original, _ = test()
preds_calibrated, _ = test(T_scaling, temperature=temperature)

draw_reliability_graph(preds_original)
draw_reliability_graph(preds_calibrated)