# Bayes by Backprop


In [None]:
%matplotlib inline
import math
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import csv

from tensorboardX import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange

writer = SummaryWriter()
sns.set()
sns.set_style("dark")
sns.set_palette("muted")
sns.set_color_codes("muted")

In [None]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("device", DEVICE)
LOADER_KWARGS = {'num_workers': 1, 'pin_memory': True} if torch.backends.mps.is_available() else {}
print(torch.backends.mps.is_available())

## Data Preparation

In [None]:
BATCH_SIZE = 120
TEST_BATCH_SIZE = 1000

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist', train=True, download=True,
        transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **LOADER_KWARGS)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist', train=False, download=True,
        transform=transforms.ToTensor()),
    batch_size=TEST_BATCH_SIZE, shuffle=False, **LOADER_KWARGS)

TRAIN_SIZE = len(train_loader.dataset)
TEST_SIZE = len(test_loader.dataset)
NUM_BATCHES = len(train_loader)
NUM_TEST_BATCHES = len(test_loader)

CLASSES = 10
TRAIN_EPOCHS = 600
SAMPLES = 2
TEST_SAMPLES = 10

assert (TRAIN_SIZE % BATCH_SIZE) == 0
assert (TEST_SIZE % TEST_BATCH_SIZE) == 0

In [None]:
class Gaussian(object):
    def __init__(self, mu, rho):
        super().__init__()
        self.mu = mu
        self.rho = rho
        self.normal = torch.distributions.Normal(0,1)
    
    @property
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))
    
    def sample(self):
        epsilon = self.normal.sample(self.rho.size()).to(DEVICE)
        return self.mu + self.sigma * epsilon
    
    def log_prob(self, input):
        return (-math.log(math.sqrt(2 * math.pi))
                - torch.log(self.sigma)
                - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()

In [None]:
class ScaleMixtureGaussian(object):
    def __init__(self, pi, sigma1, sigma2):
        super().__init__()
        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.gaussian1 = torch.distributions.Normal(0,sigma1)
        self.gaussian2 = torch.distributions.Normal(0,sigma2)
    
    def log_prob(self, input):
        prob1 = torch.exp(self.gaussian1.log_prob(input))
        prob2 = torch.exp(self.gaussian2.log_prob(input))
        return (torch.log(self.pi * prob1 + (1-self.pi) * prob2)).sum()



class ScaleGaussian(object):
    def __init__(self, pi, sigma1):
        super().__init__()
        self.pi = pi
        self.sigma1 = sigma1
        self.gaussian1 = torch.distributions.Normal(0,sigma1)

    def log_prob(self, input):
        prob1 = torch.exp(self.gaussian1.log_prob(input))
        return (torch.log(prob1)).sum()
    


$$\pi = \frac{1}{2}$$
$$-\ln{\sigma_1} = 0$$
$$-\ln{\sigma_2} = 6$$

In [None]:
#PI = 0.5
#SIGMA_1 = torch.FloatTensor([math.exp(-0)]).to(DEVICE)
#SIGMA_2 = torch.FloatTensor([math.exp(-6)]).to(DEVICE)

# BBB prior for 400 units
PI = 0.25
SIGMA_1 = torch.FloatTensor([math.exp(-1)]).to(DEVICE)
SIGMA_2 = torch.FloatTensor([math.exp(-6)]).to(DEVICE)


In [None]:
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, mixture):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.mixture = mixture
        # Weight parameters
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        #self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5,-4))

        # BBB
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).normal_(-8, .05))
        
        self.weight = Gaussian(self.weight_mu, self.weight_rho)
        
        # Bias parameters
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        #self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5,-4))

        # BBB
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(-8, .05))
    

        self.bias = Gaussian(self.bias_mu, self.bias_rho)

        # Prior distributions
        if self.mixture:
            self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
            self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        else:
            self.weight_prior = ScaleGaussian(PI, SIGMA_1)
            self.bias_prior = ScaleGaussian(PI, SIGMA_1)
        
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, input, sample=False, calculate_log_probs=False):
        if self.training or sample:
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            weight = self.weight.mu
            bias = self.bias.mu
        if self.training or calculate_log_probs:
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0

        return F.linear(input, weight, bias)

In [None]:
class BayesianNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # for breast 28*28 and OUT 2
        # for MNIST and FMINSt 28*28 and OUT 10
        self.l1 = BayesianLinear(28*28, 400, True)
        self.l2 = BayesianLinear(400, 400, True)
        self.l3 = BayesianLinear(400, 10, True)
    
    def forward(self, x, sample=False):
        x = x.view(-1, 28*28)
        x = F.relu(self.l1(x, sample))
        x = F.relu(self.l2(x, sample))
        x = F.log_softmax(self.l3(x, sample), dim=1)
        return x
    
    def log_prior(self):
        return self.l1.log_prior \
               + self.l2.log_prior \
               + self.l3.log_prior
    
    def log_variational_posterior(self):
        return self.l1.log_variational_posterior \
               + self.l2.log_variational_posterior \
               + self.l3.log_variational_posterior
    
    def sample_elbo(self, input, target, samples=SAMPLES):
        
        outputs = torch.zeros(samples, BATCH_SIZE, CLASSES).to(DEVICE)
        log_priors = torch.zeros(samples).to(DEVICE)
        log_variational_posteriors = torch.zeros(samples).to(DEVICE)
        for i in range(samples):
            outputs[i] = self(input, sample=True)
            log_priors[i] = self.log_prior()
            log_variational_posteriors[i] = self.log_variational_posterior()
        log_prior = log_priors.mean()
        log_variational_posterior = log_variational_posteriors.mean()
        negative_log_likelihood = F.nll_loss(outputs.mean(0), target, size_average=False)
        loss = (log_variational_posterior - log_prior)/NUM_BATCHES + negative_log_likelihood
        return loss, log_prior, log_variational_posterior, negative_log_likelihood

net = BayesianNetwork().to(DEVICE)

## Training

In [None]:
def train(net, optimizer, epoch):
    net.train()
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        target = target.squeeze().long()
        data, target = data.to(DEVICE), target.to(DEVICE)
        net.zero_grad()
        loss, log_prior, log_variational_posterior, negative_log_likelihood = net.sample_elbo(data, target)
        loss.backward()
        optimizer.step()
        

## Evaluation

In [None]:
def evaluate(model, loader, samples):
    correct = 0
    for idx, (data, target) in enumerate(loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        if samples == 1:
            output = model(data, sample=False).to(DEVICE)
        else:
            outputs = torch.zeros(samples, TEST_BATCH_SIZE, CLASSES).to(DEVICE)
            for i in range(samples):
                outputs[i] = model(data, sample=True)
            output = outputs.mean(0)

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
    return correct

In [None]:
#optimizer = optim.Adam(net.parameters())
optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, momentum=0.95)

test_accs_ens = np.zeros(300)
test_accs_mean = np.zeros(300)

for epoch in range(300):
    train(net, optimizer, epoch)

    test_acc_ens = evaluate(net, test_loader, samples=10)
    test_acc_mean = evaluate(net, test_loader, samples=1)
    test_accs_ens[epoch] = test_acc_ens
    test_accs_mean[epoch] = test_acc_mean
    print('Epoch: ', epoch)
    print('Test acc ens: ', test_acc_ens)
    print('Test acc mean: ', test_acc_mean)

    if epoch%25 == 0:
        path = 'Results/BBB-mixture/BBB_mnist_400_0.0001_ID0_notebook_epoch_' + str(epoch)
        torch.save(net.state_dict(), path + '.pth')

path = 'Results/BBB-mixture/BBB_mnist_400_0.0001_ID0_notebook_epoch_' + str(epoch)
torch.save(net.state_dict(), path + '.pth')

 
path = 'Results/BBB-mixture/BBB_mnist_400_0.0001_ID0_notebook2'
wr = csv.writer(open(path + '.csv', 'w'), delimiter=',', lineterminator='\n')
wr.writerow(['epoch', 'test_acc_ens', 'test_acc_mean'])

for i in range(300):
    wr.writerow((i + 1, test_accs_ens[i],test_accs_mean[i]))


In [None]:
path = 'Results/BBB-mixture/BBB_mnist_400_0.0001_ID0_notebook2'
wr = csv.writer(open(path + '.csv', 'w'), delimiter=',', lineterminator='\n')
wr.writerow(['epoch', 'test_acc_ens', 'test_acc_mean'])

for i in range(200):
    wr.writerow((i + 1, test_accs_ens[i],test_accs_mean[i]))

# Weight Pruning

In [None]:
HIDDEN = 1200
modelpath = "Results/all-in-bbb/BBB_mnist_1200_0.0001_ID0_notebook_epoch_299.pth"

model = BayesianNetwork()

model.load_state_dict(torch.load(modelpath, map_location='cpu'))
model.eval()

In [None]:
def getThreshold(model,buckets):
    sigmas = []
    mus = []

    sigmas.append(model.state_dict()['l1.weight_rho'].view(-1).cpu().detach().numpy())
    sigmas.append(model.state_dict()['l2.weight_rho'].view(-1).cpu().detach().numpy())
    sigmas.append(model.state_dict()['l3.weight_rho'].view(-1).cpu().detach().numpy())

    mus.append(model.state_dict()['l1.weight_mu'].view(-1).cpu().detach().numpy())
    mus.append(model.state_dict()['l2.weight_mu'].view(-1).cpu().detach().numpy())
    mus.append(model.state_dict()['l3.weight_mu'].view(-1).cpu().detach().numpy())
    

    sigmas = np.concatenate(sigmas).ravel()
    mus = np.concatenate(mus).ravel()
    sigmas = np.log(1. + np.exp(sigmas))
    sign_to_noise = np.abs(mus) / sigmas
    p = np.percentile(sign_to_noise, buckets)
    
    s = np.log10(sign_to_noise)/10
    hist, bin_edges = np.histogram(s, bins='auto')
    hist = hist / s.size
    X =[]
    for i in range(hist.size):
        X.append((bin_edges[i]+bin_edges[i+1])*0.5)
    
    plt.plot(X,hist)
    plt.axvline(x= np.log10(p[4])/10, color='red')
    plt.ylabel('Density')
    plt.xlabel('Signal−to−Noise Ratio (dB)')
    plt.savefig('./Results/SignalToNoiseRatioDensity_BBB_mnist_1200_0.0001_ID0_notebook_epoch_299.png')
    plt.savefig('./Results/SignalToNoiseRatioDensity_BBB_mnist_1200_0.0001_ID0_notebook_epoch_299.eps', format='eps', dpi=1000)

    plt.figure(2)
    Y = np.cumsum(hist)
    plt.plot(X, Y)
    plt.axvline(x= np.log10(p[4])/10, color='red')
    plt.hlines(y= 0.75, xmin=np.min(s),xmax=np.max(s),colors='red')
    plt.ylabel('CDF')
    plt.xlabel('Signal−to−Noise Ratio (dB)')
    plt.savefig('./Results/SignalToNoiseRatioDensity_CDF_BBB_mnist_1200_0.0001_ID0_notebook_epoch_299.png')
    plt.savefig('./Results/SignalToNoiseRatioDensity_CDF_BBB_mnist_1200_0.0001_ID0_notebook_epoch_299.eps', format='eps', dpi=1000)
    
    return p

In [None]:
buckets = np.asarray([0,10,25,50,75,95,98])
thresholds = getThreshold(model,buckets)

In [None]:
import copy
from torch.autograd import Variable

model_name = "BBB_mnist_1200_0.0001_ID0_notebook_epoch_299"

for index in range(buckets.size):
    print(buckets[index],'-->',thresholds[index])
    t = Variable(torch.Tensor([thresholds[index]]))
    model1 = copy.deepcopy(model)
    for i in range(1, 4):
        rho = model.state_dict()['l'+str(i)+'.weight_rho']
        mu = model.state_dict()['l'+str(i)+'.weight_mu'] 
        sigma = np.log(1. + np.exp(rho.cpu().numpy()))
        signalRatio = np.abs(mu.cpu().numpy()) / sigma
        signalRatio = (torch.from_numpy(signalRatio) > t).float() * 1
        model1.state_dict()['l'+str(i)+'.weight_rho'].data.copy_(rho * signalRatio)
        model1.state_dict()['l'+str(i)+'.weight_mu'].data.copy_(mu * signalRatio)

    torch.save(model1.state_dict(), 'Models/' + model_name + '_Pruned_'+str(buckets[index])+'.pth')

## Evaluate pruned models

In [None]:
import os
for root, dirs, files in os.walk("Models/all-in-bbb-pruned"):
    for file in files:
        if file.startswith('BBB_mnist_1200') and file.endswith(".pth"):
            print(file)
            pruned_model = BayesianNetwork().to(DEVICE)
            pruned_model.load_state_dict(torch.load('Models/all-in-bbb-pruned/' + file))
            pruned_model.eval()

            correct = 0
            corrects = np.zeros(TEST_SAMPLES+1, dtype=int)
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    outputs = torch.zeros(TEST_SAMPLES+1, TEST_BATCH_SIZE, CLASSES).to(DEVICE)
                    for i in range(TEST_SAMPLES):
                        outputs[i] = pruned_model(data, sample=True)
                    outputs[TEST_SAMPLES] = pruned_model(data, sample=False)
                    output = outputs.mean(0)
                    preds = preds = outputs.max(2, keepdim=True)[1]
                    pred = output.max(1, keepdim=True)[1] # index of max log-probability
                    corrects += preds.eq(target.view_as(pred)).sum(dim=1).squeeze().cpu().numpy()
                    correct += pred.eq(target.view_as(pred)).sum().item()
            for index, num in enumerate(corrects):
                if index < TEST_SAMPLES:
                    print('Component {} Accuracy: {}/{}'.format(index, num, TEST_SIZE))
                else:
                    print('Posterior Mean Accuracy: {}/{}'.format(num, TEST_SIZE))
            print('Ensemble Accuracy: {}/{}'.format(correct, TEST_SIZE))

## Pruning SGD

In [None]:
import torch.nn.utils.prune as prune
from SGD import *

In [None]:
TEST_BATCH_SIZE = 1000


test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist', train=False, download=True,
        transform=transforms.ToTensor()),
    batch_size=TEST_BATCH_SIZE, shuffle=False)


In [None]:
model3 = ModelMLPDropout(400, n_input=28*28, n_ouput=10)
modelpath3 = 'Results/SGD/SGD_mnist_dropout_400_0.001_0.95.pth'
model3.load_state_dict(torch.load(modelpath3, map_location='cpu'))
#model3.eval()

In [None]:

parameters_to_prune = (
    (model3.fc0, 'weight'),
    (model3.fc1, 'weight'),
    (model3.fc2, 'weight'),
)


prune.global_unstructured(
    parameters_to_prune,
    #pruning_method=prune.L1Unstructured,
    pruning_method=prune.random_unstructured,
    amount=0.1,
)


print(
    "Sparsity in fc0.weight: {:.2f}%".format(
        100. * float(torch.sum(model3.fc0.weight == 0))
        / float(model3.fc0.weight.nelement())
    )
)

print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model3.fc1.weight == 0))
        / float(model3.fc1.weight.nelement())
    )
)

print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model3.fc2.weight == 0))
        / float(model3.fc2.weight.nelement())
    )
)

In [None]:
def evaluate(model, loader):
    model.eval()
    loss_sum = 0
    acc_sum = 0
    #DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for idx, (data, target) in enumerate(loader):
        #data, target = data.to(DEVICE), target.to(DEVICE)
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss_sum += loss.item()

        predict = output.data.max(1)[1]
        acc = predict.eq(target.data).cpu().sum().item()
        acc_sum += acc
    return loss_sum / len(loader), acc_sum / len(loader)

test_loss, test_acc = evaluate(model3, test_loader)

# histograms

In [None]:
def collect_weights(model, bnn=False, dvi = False, lrt=False):
    '''Collect all weights from model in a list'''
    mus = []
    rhos = []
    weights = []
    stds=[]
    if lrt:
        for name, param in model.net.named_parameters():
            if 'mu' in name:
                mus.append(param.flatten().tolist())
            elif 'rho' in name:
                rhos.append(param.flatten().tolist())
            else:
                weights.append(param.flatten().tolist())
    else:
        for name, param in model.named_parameters():
            if 'mu' in name:
                mus.append(param.flatten().tolist())
            elif 'rho' in name:
                rhos.append(param.flatten().tolist())

            elif 'W' in name:
                if 'log' in name:
                    stds.append(param.flatten().tolist())
                else:
                    mus.append(param.flatten().tolist())
            else:
                weights.append(param.flatten().tolist())
    
    # flatten nested lists
    mus = [item for sublist in mus for item in sublist]
    rhos = [item for sublist in rhos for item in sublist]
    weights = [item for sublist in weights for item in sublist]
    stds = [item for sublist in stds for item in sublist]

    if bnn:
        sigmas = [rho_to_sigma(rho) for rho in rhos]
        weights = [mus, sigmas]

    if dvi:
        weights = [mus, stds]

    return weights

def rho_to_sigma(rho): 
    return np.log(1 + np.exp(rho))

def sample_bnn_weights(mu, sigma):
    return np.random.normal(mu, sigma)

In [None]:
from SGD import *


models = []


model0 = BayesianNetwork()
modelpath0 = 'mnist_400_sgd_constant_xavier_neg2_epoch_400.pth'
model0.load_state_dict(torch.load(modelpath0, map_location='cpu'))
model0.eval()
models.append(model0)


model2 = ModelMLP(400, n_input=28*28, n_ouput=10)
modelpath2 = 'Results/SGD/SGD_mnist_mlp_400_0.001_0.95.pth'
model2.net.state_dict(torch.load(modelpath2, map_location='cpu'))
model2.eval()
models.append(model2)


In [None]:
bnn_mus, bnn_sigmas = collect_weights(model0, bnn=True)
bnn_weights = [sample_bnn_weights(mu, sigma) for mu, sigma in zip(bnn_mus, bnn_sigmas)]

mlp_weights = collect_weights(model2, lrt=True)


In [None]:
def plot_histogram(weights_list, labels):
    plt.style.use('seaborn-colorblind')
    fig = plt.figure(figsize=(8, 6))
    colors = ['cornflowerblue',  '#ffb266', '#7f00ff' ]
    index = 0

    for weights, label in zip(weights_list, labels):
        sns.kdeplot(weights, label=label, fill=True, clip=[-0.7, 0.7], color = colors[index])
        index+=1
    plt.xlim(-0.7, 0.7)
    plt.ylabel('Probability Density', fontsize=20)
    plt.xlabel('Weight', fontsize=20)
    plt.yticks(fontsize=20)
    plt.xticks(fontsize=20)
    plt.legend(loc=2, prop={'size': 18})
    #plt.savefig('weight_histogram.png')

In [None]:
plot_histogram(
        [ mlp_weights,  bnn_weights], 
        ['Vanilla SGD',  'BBB']
    )

# evolution

In [None]:
import pandas as pd
from matplotlib import pyplot as plt

columns = ["test_acc"]
columns2 = ["test_acc_ens"]
df_1 = pd.read_csv("Results/SGD/SGD_mnist_dropout_400_0.001_0.95.csv", usecols=columns)
df_2 = pd.read_csv("Results/SGD/SGD_mnist_mlp_400_0.001_0.95.csv", usecols=columns)
df_3 = pd.read_csv("Results/BBB_mnist_400_ 2 copy.csv", usecols=["test_acc_new"])

colors = ['#7f00ff','cornflowerblue',  '#ffb266' ]
plt.style.use('seaborn-colorblind')
fig = plt.figure(figsize=(8, 6))

plt.plot(df_2.test_acc*100, label = 'Vanilla SGD', linewidth=2, color = 'cornflowerblue')
plt.plot(df_1.test_acc*100, label = 'Dropout SGD', linewidth=2, color = '#ffb266')
plt.plot((1-df_3.test_acc_new)*100, label = 'BBB', linewidth=2, color = '#7f00ff')




plt.ylabel('Test error (%)', fontsize=20)

plt.xlabel('Epoch', fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
#plt.ylim((1,3))
plt.legend(loc=1, prop={'size': 18})
plt.savefig('epoch evolution.png')
#plt.show()