# PN networks for comparison with LB

In [3]:
# import Python packages
import numpy as np
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
import numpy as np
from sklearn.metrics import roc_auc_score
import scipy
from utils.LB_utils import * 
from utils.load_not_MNIST import notMNIST
import os
import time
import matplotlib.pyplot as plt
from laplace import Laplace
import utils.scoring as scoring
from matplotlib import pyplot as plt
%matplotlib inline

# import libraries that we implemented
#from utils import data, measures, models, plot, run

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_status = torch.cuda.is_available()
print("device: ", device)
print("cuda status: ", cuda_status)

s = 1
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device:  cuda
cuda status:  True


In [5]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [6]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

MNIST_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

MNIST_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [7]:
### define network
class ConvNet(nn.Module):
    
    def __init__(self, num_classes=10, alpha_0=1.):
        super(ConvNet, self).__init__()
        
        self.alpha_0 = alpha_0
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(16, 32, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Flatten(),
            torch.nn.Linear(4 * 4 * 32, num_classes)
        )
    def forward(self, x):
        logits = self.net(x)
        
        assert_no_nan_no_inf(logits)

        concentrations = torch.exp(logits) + 1
        assert_no_nan_no_inf(concentrations)

        mean = concentrations / concentrations.sum(dim=1).unsqueeze(dim=1)
        assert_no_nan_no_inf(mean)

        precision = torch.sum(concentrations)
        assert_no_nan_no_inf(precision)

        y_pred = F.softmax(concentrations / self.alpha_0, dim=1)
        assert_no_nan_no_inf(y_pred)

        model_outputs = {
            'logits': logits,
            'mean': mean,
            'concentrations': concentrations,
            'precision': precision,
            'y_pred': y_pred
        }
        return model_outputs

In [8]:
import numpy as np
import torch
from torch.distributions import Categorical, Dirichlet
from torch.distributions.kl import _kl_dirichlet_dirichlet
from torch.nn import NLLLoss
from scipy import special


def assert_no_nan_no_inf(x):
    assert not torch.isnan(x).any()
    assert not torch.isinf(x).any()


def kl_divergence(model_concentrations,
                  target_concentrations,
                  mode='reverse'):
    """
    Input: Model concentrations, target concentrations parameters.
    Output: Average of the KL between the two Dirichlet.
    """
    assert torch.all(model_concentrations > 0)
    assert torch.all(target_concentrations > 0)

    target_dirichlet = Dirichlet(target_concentrations)
    model_dirichlet = Dirichlet(model_concentrations)
    kl_divergences = _kl_dirichlet_dirichlet(
        p=target_dirichlet if mode == 'forward' else model_dirichlet,
        q=model_dirichlet if mode == 'forward' else target_dirichlet)
    assert_no_nan_no_inf(kl_divergences)
    mean_kl = torch.mean(kl_divergences)
    assert_no_nan_no_inf(mean_kl)
    return mean_kl


def kl_loss_fn(loss_input,
               mode='reverse'):

    model_concentrations = loss_input['model_outputs']['concentrations']
    target_concentrations = loss_input['y_concentrations_batch']
    loss = kl_divergence(
        model_concentrations=model_concentrations,
        target_concentrations=target_concentrations,
        mode=mode)
    assert_no_nan_no_inf(loss)
    return loss


def neg_log_likelihood(input,
                       target):

    nll_fn = NLLLoss()
    nll = nll_fn(input=input, target=target)
    assert_no_nan_no_inf(nll)
    return nll


def nll_loss_fn(loss_inputs):
    y_pred_batch = loss_inputs['model_outputs']['y_pred']
    y_batch = loss_inputs['y_batch']
    loss = neg_log_likelihood(
        input=y_pred_batch,
        target=y_batch)
    assert_no_nan_no_inf(loss)
    return loss


def entropy_categorical(categorical_parameters):
    entropy = Categorical(categorical_parameters).entropy()
    # TODO: discuss whether we want numpy in these functions
    assert_no_nan_no_inf(entropy)
    entropy = entropy.detach().numpy()
    return entropy


def entropy_dirichlet(dirichlet_concentrations):
    entropy = Dirichlet(dirichlet_concentrations).entropy()
    # TODO: discuss whether we want numpy in these functions
    entropy = entropy.detach().numpy()
    assert_no_nan_no_inf(entropy)
    return entropy


def mutual_info_dirichlet(dirichlet_concentrations):
    # TODO: discuss whether we want numpy in these functions
    dirichlet_concentrations = dirichlet_concentrations.detach().numpy()
    dirichlet_concentrations_sum = dirichlet_concentrations.sum()
    res = (1.0/dirichlet_concentrations_sum)*dirichlet_concentrations*(np.log(dirichlet_concentrations*1.0/dirichlet_concentrations_sum)-special.digamma(dirichlet_concentrations+1)+special.digamma(dirichlet_concentrations_sum+1))
    final_res = res.sum() * (-1.0)
    #assert_no_nan_no_inf(final_res)
    return final_res


def create_loss_fn(loss_fn_str,
                   args):

    if loss_fn_str == 'nll':
        loss_fn = nll_loss_fn
    elif loss_fn_str == 'kl':
        loss_fn = kl_loss_fn
    else:
        raise NotImplementedError('Loss function {} not implemented!'.format(loss_fn_str))
    return loss_fn

In [9]:
mnist_model = ConvNet().to(device)

mnist_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_PN_s{}.pth".format(s)
loss_fn = create_loss_fn(loss_fn_str='kl', args={})

In [10]:
#Training routine
def get_accuracy(output, targets):
    """Helper function to print the accuracy"""
    predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
    return predictions.eq(targets).float().mean().item()

def assert_no_nan_no_inf(x):
    assert not torch.isnan(x).any()
    assert not torch.isinf(x).any()
    
def concentrations_from_labels(y, num_classes=10):
    len_ = y.size(0)
    #baseline 1
    base = torch.ones((len_, num_classes))
    
    #add onehot vectors from labels
    onehots = torch.zeros((len_, num_classes))
    rows = np.arange(len_)
    onehots[rows, y] = 1
    return(base + onehots)
    

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)
    model.train()

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            model_outputs = model(x)
            y_concentrations_batch = concentrations_from_labels(y, num_classes=10).to(device)
            #print(y)
            #print(y_concentrations_batch)
            
            loss_inputs = {
                'model_outputs': model_outputs,
                'x_batch': x,
                'y_batch': y,
                'y_concentrations_batch': y_concentrations_batch,
            }

            batch_loss = loss_fn(loss_inputs)
            assert_no_nan_no_inf(batch_loss)
            batch_loss.backward()
            optimizer.step()

            m = nn.Softmax(dim=1)
            accuracy = get_accuracy(m(model_outputs["logits"]), y)

            if verbose and batch_idx % 50 == 0:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [11]:
#train(mnist_model, MNIST_train_loader, mnist_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [12]:
#predict in distribution
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_PN_s{}.pth".format(s)

mnist_model = ConvNet().to(device)
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(MNIST_test_loader)
for batch_idx, (x, y) in enumerate(MNIST_test_loader):

    x, y = x.to(device), y.to(device)
    model_outputs = mnist_model(x)

    m = nn.Softmax(dim=1)
    accuracy = get_accuracy(m(model_outputs["logits"]), y)
    
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: pretrained_weights/MNIST_pretrained_10_classes_last_layer_PN_s1.pth
Batch 0/79 	Accuracy 99%
Batch 10/79 	Accuracy 98%
Batch 20/79 	Accuracy 98%
Batch 30/79 	Accuracy 95%
Batch 40/79 	Accuracy 100%
Batch 50/79 	Accuracy 99%
Batch 60/79 	Accuracy 100%
Batch 70/79 	Accuracy 96%
overall test accuracy on MNIST: 98.56 %


### OOD data

In [13]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TEST_KMNIST = 128

In [14]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [15]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [16]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
notMNIST_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


# Predictions for PN network

In [17]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [18]:
@torch.no_grad()
def predict_PN(model, test_loader, device='cuda'):
    py = []

    for batch_idx, (x, y) in enumerate(test_loader):

        x, y = x.to(device), y.to(device)

        py_ = torch.softmax(model(x)["logits"], 1)

        py.append(py_)
    return torch.cat(py, dim=0)

In [19]:
MNIST_test_in_PN = predict_PN(mnist_model, MNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_FMNIST_PN = predict_PN(mnist_model, FMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_notMNIST_PN = predict_PN(mnist_model, notMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_KMNIST_PN = predict_PN(mnist_model, KMNIST_test_loader, device=device).cpu().numpy()

In [20]:
# compute average log-likelihood for Diag
MNIST_LLH_in_PN = -torch.distributions.Categorical(torch.tensor(MNIST_test_in_PN)).log_prob(torch.tensor(targets)).mean().item()
MNIST_LLH_out_FMNIST_PN = -torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_PN)).log_prob(torch.tensor(targets_FMNIST)).mean().item()
MNIST_LLH_out_notMNIST_PN = -torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_PN)).log_prob(torch.tensor(targets_notMNIST)).mean().item()
MNIST_LLH_out_KMNIST_PN = -torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_PN)).log_prob(torch.tensor(targets_KMNIST)).mean().item()

print(MNIST_LLH_in_PN)
print(MNIST_LLH_out_FMNIST_PN)
print(MNIST_LLH_out_notMNIST_PN)
print(MNIST_LLH_out_KMNIST_PN)

0.24606139957904816
2.659008264541626
2.962023973464966
3.141713857650757


In [21]:
#compute the Expected confidence estimate
MNIST_ECE_in_PN = scoring.expected_calibration_error(targets, MNIST_test_in_PN)
MNIST_ECE_out_FMNIST_PN = scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_PN)
MNIST_ECE_out_notMNIST_PN = scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_PN)
MNIST_ECE_out_KMNIST_PN = scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_PN)
print(MNIST_ECE_in_PN)
print(MNIST_ECE_out_FMNIST_PN)
print(MNIST_ECE_out_notMNIST_PN)
print(MNIST_ECE_out_KMNIST_PN)

0.18369689898989894
0.21209345454545458
0.3142469341999357
0.26133215151515155


In [22]:
## Brier score
MNIST_brier_in_PN = get_brier(MNIST_test_in_PN, targets, n_classes=10)
MNIST_brier_out_FMNIST_PN = get_brier(MNIST_test_out_FMNIST_PN, targets_FMNIST, n_classes=10)
MNIST_brier_out_notMNIST_PN = get_brier(MNIST_test_out_notMNIST_PN, targets_notMNIST, n_classes=10)
MNIST_brier_out_KMNIST_PN = get_brier(MNIST_test_out_KMNIST_PN, targets_KMNIST, n_classes=10)
print(MNIST_brier_in_PN)
print(MNIST_brier_out_FMNIST_PN)
print(MNIST_brier_out_notMNIST_PN)
print(MNIST_brier_out_KMNIST_PN)

0.007529875263571739
0.09756391495466232
0.10524796694517136
0.10414869338274002


In [23]:
acc_in_PN, prob_correct_in_PN, ent_in_PN, MMC_in_PN = get_in_dist_values(MNIST_test_in_PN, targets)
acc_out_FMNIST_PN, prob_correct_out_FMNIST_PN, ent_out_FMNIST_PN, MMC_out_FMNIST_PN, auroc_out_FMNIST_PN = get_out_dist_values(MNIST_test_in_PN, MNIST_test_out_FMNIST_PN, targets_FMNIST)
acc_out_notMNIST_PN, prob_correct_out_notMNIST_PN, ent_out_notMNIST_PN, MMC_out_notMNIST_PN, auroc_out_notMNIST_PN = get_out_dist_values(MNIST_test_in_PN, MNIST_test_out_notMNIST_PN, targets_notMNIST)
acc_out_KMNIST_PN, prob_correct_out_KMNIST_PN, ent_out_KMNIST_PN, MMC_out_KMNIST_PN, auroc_out_KMNIST_PN = get_out_dist_values(MNIST_test_in_PN, MNIST_test_out_KMNIST_PN, targets_KMNIST)

In [24]:
print_in_dist_values(acc_in_PN, prob_correct_in_PN, ent_in_PN, MMC_in_PN, 'MNIST', 'PN')
print_out_dist_values(acc_out_FMNIST_PN, prob_correct_out_FMNIST_PN, ent_out_FMNIST_PN, MMC_out_FMNIST_PN, auroc_out_FMNIST_PN, 'MNIST', test='FMNIST', method='PN')
print_out_dist_values(acc_out_notMNIST_PN, prob_correct_out_notMNIST_PN, ent_out_notMNIST_PN, MMC_out_notMNIST_PN, auroc_out_notMNIST_PN, 'MNIST', test='notMNIST', method='PN')
print_out_dist_values(acc_out_KMNIST_PN, prob_correct_out_KMNIST_PN, ent_out_KMNIST_PN, MMC_out_KMNIST_PN, auroc_out_KMNIST_PN, 'MNIST', test='KMNIST', method='PN')

[In, PN, MNIST] Accuracy: 0.985; average entropy: 0.785;     MMC: 0.802; Prob @ correct: 0.100
[Out-FMNIST, PN, MNIST] Accuracy: 0.063; Average entropy: 2.014;    MMC: 0.273; AUROC: 0.995; Prob @ correct: 0.100
[Out-notMNIST, PN, MNIST] Accuracy: 0.164; Average entropy: 1.600;    MMC: 0.447; AUROC: 0.938; Prob @ correct: 0.100
[Out-KMNIST, PN, MNIST] Accuracy: 0.111; Average entropy: 1.757;    MMC: 0.372; AUROC: 0.976; Prob @ correct: 0.100


### create a table

In [29]:
import pandas as pd
pd.options.display.float_format = '{:,.3f}'.format

In [28]:
MMC_PN = [MMC_in_PN, MMC_out_FMNIST_PN, MMC_out_notMNIST_PN, MMC_out_KMNIST_PN]
AUROC_PN = [0, auroc_out_FMNIST_PN, auroc_out_notMNIST_PN, auroc_out_KMNIST_PN]
ECE_PN = [MNIST_ECE_in_PN, MNIST_ECE_out_FMNIST_PN, MNIST_ECE_out_notMNIST_PN, MNIST_ECE_out_KMNIST_PN]
LLH_PN = [MNIST_LLH_in_PN, MNIST_LLH_out_FMNIST_PN, MNIST_LLH_out_notMNIST_PN, MNIST_LLH_out_KMNIST_PN]
Brier_PN = [MNIST_brier_in_PN, MNIST_brier_out_FMNIST_PN, MNIST_brier_out_notMNIST_PN, MNIST_brier_out_KMNIST_PN]

In [32]:
df_PN = pd.DataFrame({
    "MMC":MMC_PN,
    "AUROC":AUROC_PN,
    "ECE":ECE_PN,
    "LLH":LLH_PN,
    "Brier":Brier_PN
})

In [33]:
print(df_PN.to_latex(index=False))

\begin{tabular}{rrrrr}
\toprule
  MMC &  AUROC &   ECE &   LLH &  Brier \\
\midrule
0.802 &  0.000 & 0.184 & 0.246 &  0.008 \\
0.273 &  0.995 & 0.212 & 2.659 &  0.098 \\
0.447 &  0.938 & 0.314 & 2.962 &  0.105 \\
0.372 &  0.976 & 0.261 & 3.142 &  0.104 \\
\bottomrule
\end{tabular}

