# Bayes by Backprop

APPLICATION TO MedMNIST DATASETS  https://medmnist.com


In [None]:
%matplotlib inline
import math
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tensorboardX import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange


sns.set()
sns.set_style("dark")
sns.set_palette("muted")
sns.set_color_codes("muted")

In [None]:
import torch.utils.data as data
import medmnist
from medmnist import INFO, Evaluator
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

In [None]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("device", DEVICE)
LOADER_KWARGS = {'num_workers': 1, 'pin_memory': True} if torch.backends.mps.is_available() else {}
print(torch.backends.mps.is_available())

## Data Preparation

In [None]:

data_flag = 'dermamnist'
download = True

NUM_EPOCHS = 200
BATCH_SIZE = 91
TEST_BATCH_SIZE = 401
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])


# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)


# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)

In [None]:
print(train_dataset)
print("===================")
print(test_dataset)

In [None]:
TRAIN_SIZE = len(train_dataset)
TEST_SIZE = len(test_loader.dataset)
NUM_BATCHES = len(train_loader)
NUM_TEST_BATCHES = len(test_loader)

print(TRAIN_SIZE)
print(TEST_SIZE)

CLASSES = 7
SAMPLES = 2
TEST_SAMPLES = 10

assert (TRAIN_SIZE % BATCH_SIZE) == 0
assert (TEST_SIZE % TEST_BATCH_SIZE) == 0

In [None]:
class Gaussian(object):
    def __init__(self, mu, rho):
        super().__init__()
        self.mu = mu
        self.rho = rho
        self.normal = torch.distributions.Normal(0,1)
    
    @property
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))
    
    def sample(self):
        epsilon = self.normal.sample(self.rho.size()).to(DEVICE)
        return self.mu + self.sigma * epsilon
    
    def log_prob(self, input):
        return (-math.log(math.sqrt(2 * math.pi))
                - torch.log(self.sigma)
                - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()

In [None]:
class ScaleMixtureGaussian(object):
    def __init__(self, pi, sigma1, sigma2):
        super().__init__()
        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.gaussian1 = torch.distributions.Normal(0,sigma1)
        self.gaussian2 = torch.distributions.Normal(0,sigma2)
    
    def log_prob(self, input):
        prob1 = torch.exp(self.gaussian1.log_prob(input))
        prob2 = torch.exp(self.gaussian2.log_prob(input))
        return (torch.log(self.pi * prob1 + (1-self.pi) * prob2)).sum()

In [None]:
#PI = 0.5
#SIGMA_1 = torch.FloatTensor([math.exp(-0)]).to(DEVICE)
#SIGMA_2 = torch.FloatTensor([math.exp(-6)]).to(DEVICE)

# BBB prior for 400 units
PI = 0.25
SIGMA_1 = torch.FloatTensor([math.exp(-1)]).to(DEVICE)
SIGMA_2 = torch.FloatTensor([math.exp(-6)]).to(DEVICE)



In [None]:
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        # Weight parameters
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        #self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5,-4))

        # BBB
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).normal_(-8, .05))
        
        self.weight = Gaussian(self.weight_mu, self.weight_rho)
        
        # Bias parameters
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        #self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5,-4))

        # BBB
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(-8, .05))
    

        self.bias = Gaussian(self.bias_mu, self.bias_rho)

        # Prior distributions
        self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, input, sample=False, calculate_log_probs=False):
        if self.training or sample:
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            weight = self.weight.mu
            bias = self.bias.mu
        if self.training or calculate_log_probs:
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0

        return F.linear(input, weight, bias)

In [None]:
class BayesianNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # for 28*28 RGB (3) IMAGES
        self.l1 = BayesianLinear(3*28*28, 400)
        self.l2 = BayesianLinear(400, 400)
        self.l3 = BayesianLinear(400, 7)
    
    def forward(self, x, sample=False):
        x = x.view(-1, 3*28*28)
        x = F.relu(self.l1(x, sample))
        x = F.relu(self.l2(x, sample))
        x = F.log_softmax(self.l3(x, sample), dim=1)
        return x
    
    def log_prior(self):
        return self.l1.log_prior \
               + self.l2.log_prior \
               + self.l3.log_prior
    
    def log_variational_posterior(self):
        return self.l1.log_variational_posterior \
               + self.l2.log_variational_posterior \
               + self.l3.log_variational_posterior
    
    def sample_elbo(self, input, target, samples=SAMPLES):
        
        outputs = torch.zeros(samples, BATCH_SIZE, CLASSES).to(DEVICE)
        log_priors = torch.zeros(samples).to(DEVICE)
        log_variational_posteriors = torch.zeros(samples).to(DEVICE)
        for i in range(samples):
            outputs[i] = self(input, sample=True)
            log_priors[i] = self.log_prior()
            log_variational_posteriors[i] = self.log_variational_posterior()
        log_prior = log_priors.mean()
        log_variational_posterior = log_variational_posteriors.mean()
        negative_log_likelihood = F.nll_loss(outputs.mean(0), target, size_average=False)
        loss = (log_variational_posterior - log_prior)/NUM_BATCHES + negative_log_likelihood
        return loss, log_prior, log_variational_posterior, negative_log_likelihood

net = BayesianNetwork().to(DEVICE)

## Training

In [None]:
def train(net, optimizer, epoch):
    net.train()

    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        target = target.squeeze().long()
        data, target = data.to(DEVICE), target.to(DEVICE)
        net.zero_grad()
        loss, log_prior, log_variational_posterior, negative_log_likelihood = net.sample_elbo(data, target)
        loss.backward()
        optimizer.step()

In [None]:
def evaluate(model, loader, samples):
    correct = 0
    for idx, (data, target) in enumerate(loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        if samples == 1:
            output = model(data, sample=False).to(DEVICE)
        else:
            outputs = torch.zeros(samples, TEST_BATCH_SIZE, CLASSES).to(DEVICE)
            for i in range(samples):
                outputs[i] = model(data, sample=True)
            output = outputs.mean(0)

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
    return correct

In [None]:
import csv
#optimizer = optim.Adam(net.parameters())
optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, momentum=0.95)

test_accs_ens = np.zeros(200)
test_accs_mean = np.zeros(200)

for epoch in range(200):
    train(net, optimizer, epoch)
    if epoch%5 == 0:
        test_acc_ens = evaluate(net, test_loader, samples=10)
        test_acc_mean = evaluate(net, test_loader, samples=1)
        test_accs_ens[epoch] = test_acc_ens
        test_accs_mean[epoch] = test_acc_mean
        print('Epoch: ', epoch)
        print('Test acc ens: ', test_acc_ens)
        print('Test acc mean: ', test_acc_mean)

    if epoch%50 == 0:
        path = 'Results/med/BBB_derma_400_0.0001_ID0_notebook_epoch_' + str(epoch)
        torch.save(net.state_dict(), path + '.pth')

path = 'Results/med/BBB_derma_400_0.0001_ID0_notebook_epoch_' + str(epoch)
torch.save(net.state_dict(), path + '.pth')



In [None]:
path = 'Results/med/BBB_derma_400_0.0001_ID0_notebook'
wr = csv.writer(open(path + '.csv', 'w'), delimiter=',', lineterminator='\n')
wr.writerow(['epoch', 'test_acc_ens', 'test_acc_mean'])

for i in range(200):
    wr.writerow((i + 1, test_accs_ens[i] ,test_accs_mean[i] ))

## Evaluation

### Model Ensemble

We are building an enseble of 10 samples. Each component is a sample and we resample weights and biases. For each component we evaluate accuracy. what we obtain as posterior mean accuracy is not resampling weights and biases (used from eman weights like effectively feeding data to our model). Ensemble accuracy is the mean



In [None]:
modelpath = "Results/med/BBB_derma_400_0.0001_ID0_notebook_epoch_199.pth"

model = BayesianNetwork().to(DEVICE)

model.load_state_dict(torch.load(modelpath, map_location='cpu'))

In [None]:
def test_ensemble(model):
    model.eval()
    correct = 0
    corrects = np.zeros(TEST_SAMPLES+1, dtype=int)
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            outputs = torch.zeros(TEST_SAMPLES+1, TEST_BATCH_SIZE, CLASSES).to(DEVICE)
            for i in range(TEST_SAMPLES):
                # for BBB is infer for notebook is sample
                outputs[i] = model(data, sample=True)
            outputs[TEST_SAMPLES] = model(data, sample=False)
            output = outputs.mean(0)
            preds = preds = outputs.max(2, keepdim=True)[1]
            pred = output.max(1, keepdim=True)[1] # index of max log-probability
            corrects += preds.eq(target.view_as(pred)).sum(dim=1).squeeze().cpu().numpy()
            correct += pred.eq(target.view_as(pred)).sum().item()
    for index, num in enumerate(corrects):
        if index < TEST_SAMPLES:
            print('Component {} Accuracy: {}/{}'.format(index, num, TEST_SIZE))
        else:
            print('Posterior Mean Accuracy: {}/{}'.format(num, TEST_SIZE))
    print('Ensemble Accuracy: {}/{}'.format(correct, TEST_SIZE))

test_ensemble(model)

#### Evaluating ECE

In [None]:

def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape)+[nb_classes])

def forward_ece( probs, labels):
        pred_class = np.argmax(probs, axis=1)
        print(pred_class)

        # make confidence, preds, labels one-hot
        expanded_preds = np.reshape(probs, -1)
        print(expanded_preds)
        pred_class_OH = np.reshape(get_one_hot(pred_class, 7), -1)
       
        target_class_OH = np.reshape(get_one_hot(labels, 7), -1)
        correct_vec = (target_class_OH*(pred_class_OH == target_class_OH)).astype(int)

        # generate bins
        bins = np.arange(0, 1.1, 0.1)
        bin_idxs = np.digitize(abs(expanded_preds), bins, right=True)
        print(bin_idxs)
        bin_idxs = bin_idxs - 1
        
        bin_centers = bins[1:] - 10/2
        print(bin_centers)
        bin_counts = np.ones(len(bin_centers))
        bin_corrects = np.zeros(len(bin_centers))
        bin_confidence = np.zeros(len(bin_centers))

        min_idx = 10
        if min(bin_idxs) < min_idx:
            min_idx = min(bin_idxs)
        
        for nbin in range(len(bin_centers)):
            bin_counts[nbin] = np.sum((bin_idxs==nbin).astype(int))
            bin_corrects[nbin] = np.sum(correct_vec[bin_idxs==nbin])
            bin_confidence[nbin] = np.mean(expanded_preds[bin_idxs==nbin])

        have_data = bin_counts > 0  
        bin_acc = bin_corrects[have_data] / bin_counts[have_data]

        ece = 0
        for i in range(len(bin_acc)):
            ece += np.absolute(bin_confidence[i]-bin_acc[i]) * bin_counts[i]/np.sum(bin_counts)

        return ece, bin_centers[have_data], bin_acc

In [None]:
eces = []
confidence_lists = []
accuracy_lists =[]

TEST_SAMPLES = 20

with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            outputs = torch.zeros(TEST_SAMPLES+1, TEST_BATCH_SIZE, CLASSES).to(DEVICE)
            for i in range(TEST_SAMPLES):
                # for BBB is infer for notebook is sample
                outputs[i] = model(data, sample=True)
            outputs[TEST_SAMPLES] = model(data, sample=False)
            output = outputs.mean(0)
            ece, confidence_list, accuracy_list = forward_ece(np.array(output.cpu().numpy()), np.array(target.cpu().numpy()))
            eces.append(ece)
            confidence_lists.append(confidence_list)
            accuracy_lists.append(accuracy_list)

mean_acc = np.array(accuracy_lists).mean(0)

In [None]:
plt.style.use('seaborn-colorblind')
fig = plt.figure(figsize=(8, 6))

plt.plot(np.arange(0.05, 1.0, 0.1), mean_acc, marker='o', linewidth=2, color="#4C0099" )
plt.plot([0.05, 0.95], [0.05, 0.95], '--', linewidth=2, label = 'Ideal', color = 'red')
plt.legend(loc=2, prop={'size': 15})
plt.xlabel('Confidence', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.yticks(fontsize=10)
plt.xticks(np.arange(0.05, 1.0, 0.1), fontsize=10)
plt.savefig('reliability_diagram.png', bbox_inches='tight')

### Model Uncertainty

In [None]:
modelpath = "Results/med/BBB_derma_1200_0.0001_ID1_notebook_new_epoch_299.pth"

model = BayesianNetwork().to(DEVICE)

model.load_state_dict(torch.load(modelpath, map_location='cpu'))

In [None]:
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.axis('off')

In [None]:
dataiter = iter(test_loader)
sample = next(dataiter)
print(sample[1][10:15])

sample = sample[0][10:15].to(DEVICE)


sns.set_style("white")
Grid = make_grid(sample.cpu(), padding = 5, nrows = 1)
show(Grid)

In [None]:

#{'0': actinic keratoses and intraepithelial carcinoma, '1': 'basal cell carcinoma', '2': 'benign keratosis-like lesions', 
# '3': 'dermatofibroma', '4': 'melanoma', '5': 'melanocytic nevi', '6': 'vascular lesions'}


net.eval()
outputs = model(sample, True).max(1, keepdim=True)[1].detach().cpu().numpy()
for _ in range(999):
    outputs = np.append(outputs, model(sample, True).max(1, keepdim=True)[1].detach().cpu().numpy(), axis=1)

sns.set_style("darkgrid")
fig, ax = plt.subplots(5,1,figsize=(4,3))

for i in range(5):
    plt.subplot(5,1,i+1)
    
    plt.xlabel("Label")
    if i ==2:
        plt.ylabel("Confidence")
    plt.yticks([])
    plt.xticks(range(7), ["0", "1", "2 ","3", "4", "5", "6"])
    plt.hist(outputs[i], np.arange(-0.5, 7, 1), color="#4C0099")
