In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import matplotlib
from ece import eceloss, uceloss
from tqdm import tqdm
from utils import accuracy, nentr
from models import BayesianNet
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['font.size'] = 8

In [None]:
model = 'densenet169'

In [None]:
batch_size = 128
valid_size = 5000

if model == 'resnet18':
    mean = [0.4914, 0.48216, 0.44653]
    std = [0.2470, 0.2435, 0.26159]
    valid_set = datasets.CIFAR10('../data', train=True, download=True,
                                 transform=transforms.Compose([
                                     transforms.RandomCrop(32, padding=4),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=mean, std=std)]))

    test_set = datasets.CIFAR10('../data', train=False, download=False,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=mean, std=std)]))
    valid_indices = torch.load('./valid_indices_cifar10.pth')
else:
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    valid_set = datasets.CIFAR100('../data', train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=mean, std=std)]))

    test_set = datasets.CIFAR100('../data', train=False, download=False,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=mean, std=std)]))
    valid_indices = torch.load('./valid_indices_cifar100.pth')

valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, pin_memory=True,
                                           sampler=SubsetRandomSampler(valid_indices))
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, pin_memory=True, num_workers=4)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
if model == 'resnet18':
    net = BayesianNet(num_classes=10, model=model).to(device)
else:
    net = BayesianNet(num_classes=100, model=model).to(device)

# load weights for flow estimation from best last stage
checkpoint = torch.load(f'../snapshots/{model}_499.pth.tar', map_location=device)
print("Loading previous weights at epoch " + str(checkpoint['epoch']))
net.load_state_dict(checkpoint['state_dict'])
print('T =', net.T.item())

optimizer_temp = optim.Adam([net.T], lr=1e-2)
lr_scheduler_temp = optim.lr_scheduler.ReduceLROnPlateau(optimizer_temp, patience=5, factor=0.1)

In [None]:
train_losses = []
train_accuracies = []
epochs = 30
net.eval()

for e in range(epochs):
    print("lr =", optimizer_temp.param_groups[0]['lr'])
    
    epoch_train_loss = []
    epoch_train_acc = []
    is_best = False
    
    for batch_idx, (data, target) in enumerate(tqdm(valid_loader)):
        data, target = data.to(device), target.to(device)
        optimizer_temp.zero_grad()
        output = net(data, temp_scale=True, bayesian=False)
        loss = F.cross_entropy(output, target)
        loss.backward()
        epoch_train_loss.append(loss.item())
        epoch_train_acc.append(accuracy(output, target))
        optimizer_temp.step()

    epoch_train_loss = np.mean(epoch_train_loss)
    epoch_train_acc = np.mean(epoch_train_acc)
    lr_scheduler_temp.step(epoch_train_loss)

    # save epoch losses
    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_acc)
    
    print("Epoch {:d}: loss: {:4f}, acc: {:4f}"
          .format(e,
                  epoch_train_loss,
                  epoch_train_acc,
                  ))
    
    print('T =', net.T.item())

In [None]:
def test(temp_scale, bayesian):
    logits = []
    labels = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
            data, target = data.to(device), target.to(device)
            output = net(data, temp_scale=temp_scale, bayesian=bayesian)
            logits.append(output.detach())
            labels.append(target.detach())
    return torch.cat(logits, dim=0), torch.cat(labels, dim=0)

In [None]:
def plot_reliability(acc, conf, err, entr):
    fig, ax = plt.subplots(1, 2, figsize=(4.5,2.25), sharey=True)
    ax[0].plot([0,1], [0,1], 'k--')
    ax[0].plot(conf.data.cpu().numpy(), acc.data.cpu().numpy(), marker='.')
    ax[0].set_xlabel(r'Confidence')
    ax[0].set_ylabel(r'Accuracy')
    ax[0].set_xticks((np.arange(0, 1.1, step=0.2)))
    ax[0].set_yticks((np.arange(0, 1.1, step=0.2)))

    ax[1].plot([0,1], [0,1], 'k--')
    ax[1].plot(entr.data.cpu().numpy(), err.data.cpu().numpy(), marker='.')
    ax[1].set_xticks((np.arange(0, 1.1, step=0.2)))
    #ax[1].set_yticklabels([])
    ax[1].set_ylabel(r'Error')
    ax[1].set_xlabel(r'Uncertainty')

    return fig, ax

In [None]:
logits, labels = test(temp_scale=False, bayesian=False)
print("acc =", accuracy(logits, labels))

ece, acc, conf = eceloss(logits, labels)
uce, err, entr = uceloss(logits, labels)
print('ece =', ece.item())
print("uce =", uce.item())

fig1, ax = plot_reliability(acc, conf, err, entr)

textstr = r'ECE\,=\,{:.2f}'.format(ece.item()*100)
props = dict(boxstyle='round', facecolor='white', alpha=0.75)
ax[0].text(0.075, 0.925, textstr, transform=ax[0].transAxes, fontsize=14,
           verticalalignment='top',
           horizontalalignment='left',
           bbox=props
          )

textstr = r'UCE\,=\,{:.2f}'.format(uce.item()*100)
props = dict(boxstyle='round', facecolor='white', alpha=0.75)
ax[1].text(0.925, 0.075, textstr, transform=ax[1].transAxes, fontsize=14,
           verticalalignment='bottom',
           horizontalalignment='right',
           bbox=props
          )
fig1.tight_layout()
fig1.show()

In [None]:
fig1.savefig(f'{model}_uncalib.pdf', dpi=300)

In [None]:
logits, labels = test(temp_scale=True, bayesian=False)
print("acc =", accuracy(logits, labels))

ece, acc, conf = eceloss(logits, labels)
uce, err, entr = uceloss(logits, labels)
print('ece =', ece.item())
print("uce =", uce.item())

fig2, ax = plot_reliability(acc, conf, err, entr)

textstr = r'ECE\,=\,{:.2f}'.format(ece.item()*100)
props = dict(boxstyle='round', facecolor='white', alpha=0.75)
ax[0].text(0.075, 0.925, textstr, transform=ax[0].transAxes, fontsize=14,
           verticalalignment='top',
           horizontalalignment='left',
           bbox=props
          )

textstr = r'UCE\,=\,{:.2f}'.format(uce.item()*100)
props = dict(boxstyle='round', facecolor='white', alpha=0.75)
ax[1].text(0.925, 0.075, textstr, transform=ax[1].transAxes, fontsize=14,
           verticalalignment='bottom',
           horizontalalignment='right',
           bbox=props
          )
fig2.tight_layout()
fig2.show()

In [None]:
fig2.savefig(f'{model}_calib.pdf', dpi=300)