In [1]:
from src.priors import *
from src.base_net import *
import torch.nn as nn
import torch.nn.functional as F
tod = torch.distributions
import copy
import numpy as np

### Entire structure of BNN with Bayes Backprop

In [2]:
def isotropic_gauss_loglike(x, mu, sigma, do_sum=True):
    cte_term = -(0.5) * np.log(2 * np.pi)
    det_sig_term = -torch.log(sigma)
    inner = (x - mu) / sigma
    dist_term = -(0.5) * (inner ** 2)

    if do_sum:
        out = (cte_term + det_sig_term + dist_term).sum()  # sum over all weights
    else:
        out = (cte_term + det_sig_term + dist_term)
    return out    

"Sample according to parameterized posterior distribution of weights"
"posterior is assumed to be element-wise normal distribuiton for the weights"
def sample_weights(W_mu, b_mu, W_p, b_p):
    "Quick method for sampling weights and exporting weights"
    ".data.new is just to create same typed tensor of certain shape"
    "Q1. what is the type of W_mu, and why is tensor type failed for the command .normal_()?"
    """
    A1. W_mu is a nn.Paramter which accepts .normal_(), this essentially create N(0,1) sampled values on
    each of its entries
    A2. W is (entry-wise) normal distributed sample with mean W_mu, (entry-wise) standard deviation std_W.
    A3. W_p is used to model std_W, through a softplus function added with 1e-6.
    A4. Similarly for b, too. Except that we can use no bias, which is the case when b_mu is None.
    A5. A softplus function always keeps output positive, making std reasonable, at the same time, its input
        which is the W_p & b_p is not constrained at all, which is why we use it.
    """
    eps_W = W_mu.new(W_mu.shape).normal_()
    std_W = 1e-6 + F.softplus(W_p, beta=1, threshold=20)
    W = W_mu + 1 * std_W * eps_W
    
    if b_mu is not None:
        std_b = 1e-6 + F.softplus(b_p, beta=1, threshold=20)
        eps_b = b_mu.new(b_mu.shape).normal_()        
        b = b_mu + 1 * std_b * eps_b
    else:
        b = None
        
    return W, b


class BayesLinear_Normalq(torch.nn.Module):
    """
    Linear Layer where weights are sampled from a fully 
    factorised Normal with learnable parameters. 
    The likelihood of the weight samples under the prior
    and the approximate posterior are returned with each
    forward pass in order to estimate the KL term in 
    the ELBO.
    """
    def __init__(self, n_in, n_out, prior_class):
        super(BayesLinear_Normalq, self).__init__()
        self.n_in = n_in
        self.n_out= n_out
        self.prior = prior_class
        
        # Learnable parameters
        self.W_mu = nn.Parameter(torch.Tensor(self.n_in, self.n_out).uniform_(-0.1,0.1))
        self.W_p = nn.Parameter(torch.Tensor(self.n_in, self.n_out).uniform_(-3,-2))
        
        self.b_mu = nn.Parameter(torch.Tensor(self.n_out).uniform_(-0.1, 0.1))
        self.b_p = nn.Parameter(torch.Tensor(self.n_out).uniform_(-3, -2))
        
        self.lpw = 0
        self.lqw = 0
        
    """
    X shape (batch_size, n_in)
    """    
    def forward(self, X, sample=False):
        "the self.training is True by default so it doesn't really matter what sample is here"
        if not self.training and not sample:
            "Expand simply copies and broadcast along first axis to shape (batch_size, n_out)"
            output = X @ self.W_mu + self.b_mu.unsqueeze(0)
            return output, 0, 0
        else:
            # the same random sample is used for every element in the minibatch
            "Source of randomness, in fact a Gaussian Noise Here"
            "Note that the way we generate W & b is in align with the choice of the approximate posterior"
            "And it has nothing to do with the prior"
            eps_W = self.W_mu.new(self.W_mu.shape).normal_()
            eps_b = self.b_mu.new(self.b_mu.shape).normal_()
            
            # sample parameters
            std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
            std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)
            
            W = self.W_mu + 1 * std_w * eps_W
            b = self.b_mu + 1 * std_b * eps_b
            
            output = X @ W + b.unsqueeze(0) # (batch_size, n_output)
            "approximate posterior: isotropic_gauss_loglike is in the 'prior.py' file "
            lqw = isotropic_gauss_loglike(W, self.W_mu, std_w) + isotropic_gauss_loglike(b, self.b_mu, std_b)
            "log-likelihood of the parameters (weights and biases) under the prior"
            lpw = self.prior.loglike(W) + self.prior.loglike(b)
            return output, lqw, lpw
        
class bayes_linear_2L(nn.Module):
    """2 hidden layer Bayes By Backprop (VI) Network"""
    def __init__(self, input_dim, output_dim, n_hid, prior_instance):
        super(bayes_linear_2L, self).__init__()

        # prior_instance = isotropic_gauss_prior(mu=0, sigma=0.1)
        # prior_instance = spike_slab_2GMM(mu1=0, mu2=0, sigma1=0.135, sigma2=0.001, pi=0.5)
        # prior_instance = isotropic_gauss_prior(mu=0, sigma=0.1)
        self.prior_instance = prior_instance

        self.input_dim = input_dim
        self.output_dim = output_dim
        "n_hid: input dim and output dim for the hidden layers (2nd layer)"
        self.bfc1 = BayesLinear_Normalq(input_dim, n_hid, self.prior_instance)
        self.bfc2 = BayesLinear_Normalq(n_hid, n_hid, self.prior_instance)
        self.bfc3 = BayesLinear_Normalq(n_hid, output_dim, self.prior_instance)

        # choose your non linearity
        # self.act = nn.Tanh()
        # self.act = nn.Sigmoid()
        self.act = nn.ReLU(inplace=True)
        # self.act = nn.ELU(inplace=True)
        # self.act = nn.SELU(inplace=True)
        
    def forward(self, x, sample=False):
        "note that sample is just a bool type with True or False value"
        tlqw = 0
        tlpw = 0
        "Q2. what does view do? The answer is with the type of input x"
        "A2: this is the key which reshaped input into (batch_size, input_dim)"
        x = x.view(-1, self.input_dim)  # view(batch_size, input_dim)
        # -----------------
        "Calling the Forward Pass of first layer, essentially"
        x, lqw, lpw = self.bfc1(x, sample)
        tlqw = tlqw + lqw
        tlpw = tlpw + lpw
        # -----------------
        x = self.act(x)
        # -----------------
        x, lqw, lpw = self.bfc2(x, sample)
        tlqw = tlqw + lqw
        tlpw = tlpw + lpw
        # -----------------
        x = self.act(x)
        # -----------------
        y, lqw, lpw = self.bfc3(x, sample)
        tlqw = tlqw + lqw
        tlpw = tlpw + lpw

        return y, tlqw, tlpw
    
    
    def sample_predict(self, x, Nsamples):
        """
        Not sure what it means: Used for estimating the data's likelihood by approximately marginalising the weights with MC
        
        Take a number (Nsamples) of samples of weight sets according to the posterior distribution assumption, compute
        the output or predictions of the network based on those values of weights and biases.
        """
        # Just copies type from x, initializes new vector
        predictions = x.new(Nsamples, x.shape[0], self.output_dim)
        tlqw_vec = np.zeros(Nsamples)
        tlpw_vec = np.zeros(Nsamples)

        for i in range(Nsamples):
            y, tlqw, tlpw = self.forward(x, sample=True)
            predictions[i] = y
            tlqw_vec[i] = tlqw
            tlpw_vec[i] = tlpw

        return predictions, tlqw_vec, tlpw_vec
    

class BaseNet(object):
    def __init__(self):
        cprint('c', '\nNet:')

    def get_nb_parameters(self):
        return sum(p.numel() for p in self.model.parameters())

    def set_mode_train(self, train=True):
        if train:
            "essentially call the nn.Module.train() which sets the training mode for the model"
            self.model.train()
        else:
            "set the nn.Module to eval mode"
            self.model.eval()

    def update_lr(self, epoch, gamma=0.99):
        self.epoch += 1
        if self.schedule is not None:
            if len(self.schedule) == 0 or epoch in self.schedule:
                self.lr *= gamma
                print('learning rate: %f  (%d)\n' % self.lr, epoch)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.lr

    def save(self, filename):
        cprint('c', 'Writting %s\n' % filename)
        torch.save({
            'epoch': self.epoch,
            'lr': self.lr,
            'model': self.model,
            'optimizer': self.optimizer}, filename)

    def load(self, filename):
        cprint('c', 'Reading %s\n' % filename)
        state_dict = torch.load(filename)
        self.epoch = state_dict['epoch']
        self.lr = state_dict['lr']
        self.model = state_dict['model']
        self.optimizer = state_dict['optimizer']
        print('  restoring epoch: %d, lr: %f' % (self.epoch, self.lr))
        return self.epoch
    
    
    
class BBP_Bayes_Net(BaseNet):
    """
    Full network wrapper for Bayes By Backprop nets with methods for training, 
    prediction and weight prunning
    """
    eps = 1e-6

    def __init__(self, lr=1e-3, channels_in=5, side_in=1, cuda=True, classes=5, batch_size=5, Nbatches=1,
                 nhid=10, prior_instance=laplace_prior(mu=0, b=0.1)):
        super(BBP_Bayes_Net, self).__init__()
        cprint('y', ' Creating Net!! ')
        self.lr = lr
        self.schedule = None  # [] #[50,200,400,600]
        self.cuda = cuda
        self.channels_in = channels_in
        self.classes = classes
        "entire number of data is the product of batch_size and Nbatches"
        "more like size of each minibatch"
        self.batch_size = batch_size
        "more like no. of minibatches"
        self.Nbatches = Nbatches
        self.prior_instance = prior_instance
        self.nhid = nhid
        self.side_in = side_in
        self.create_net()
        self.create_opt()
        self.epoch = 0

        self.test = False
    
    def create_net(self):
        torch.manual_seed(42)
        if self.cuda:
            torch.cuda.manual_seed(42)
        "Q3: why is the input dim equals the product and not just channels_in ?"
        "A3: We input image data, which has side_length * side_length * no. layers(RGB) number of data points"
        self.model = bayes_linear_2L(input_dim=self.channels_in * self.side_in * self.side_in,
                                     output_dim=self.classes, n_hid=self.nhid, prior_instance=self.prior_instance)
        if self.cuda:
            self.model.cuda()
        #             cudnn.benchmark = True

        print('    Total params: %.2fM' % (self.get_nb_parameters() / 1000000.0))

    def create_opt(self):
        #         self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08,
        #                                           weight_decay=0)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0)

    #         self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
    #         self.sched = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=10, last_epoch=-1)
    
    def fit(self, x, y, samples=1):
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)

        self.optimizer.zero_grad()

        if samples == 1:
            "model used here is STOCHASTIC but with only 1 samples used"
            out, tlqw, tlpw = self.model(x)
            """
            1. mean log probability of data given weights: mlpdw
            2. out shape: (batch_size, out_dim), y shape: (batch_size)
            3. This is a classification task, y takes value {1,2,...,out_dim}, out_dim is actually
            the number of classes also. For the F.cross_entropy function, it first convert 'out',
            whose value can be any real number, into a valid discrete distribution mass function,
            then compute - log(p_{out}(y)), then summed over all the batches, which we can use
            to construct MC estimate of Cross Entropy (Likelihood Cost) by taking average over batch
            size.
            """
            mlpdw = F.cross_entropy(out, y, reduction='sum')
            "expected KL divergence"
            "Note that this term get scaled furthur by Nbatches"
            Edkl = (tlqw - tlpw) / self.Nbatches

        elif samples > 1:
            mlpdw_cum = 0
            Edkl_cum = 0

            for i in range(samples):
                out, tlqw, tlpw = self.model(x, sample=True)
                mlpdw_i = F.cross_entropy(out, y, reduction='sum')
                Edkl_i = (tlqw - tlpw) / self.Nbatches
                mlpdw_cum = mlpdw_cum + mlpdw_i
                Edkl_cum = Edkl_cum + Edkl_i

            mlpdw = mlpdw_cum / samples
            Edkl = Edkl_cum / samples
        "loss function we wish to minimize, negative ELBO"
        loss = Edkl + mlpdw
        loss.backward()
        self.optimizer.step()

        # out: (batch_size, out_channels, out_caps_dims)
        pred = out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()

        return Edkl.data, mlpdw.data, err

    
        
    def eval(self, x, y, train=False):
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)

        out, _, _ = self.model(x)

        loss = F.cross_entropy(out, y, reduction='sum')

        probs = F.softmax(out, dim=1).data.cpu()

        pred = out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()

        return loss.data, err, probs

    def sample_eval(self, x, y, Nsamples, logits=True, train=False):
        """Prediction, only returining result with weights marginalised"""
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)

        out, _, _ = self.model.sample_predict(x, Nsamples)

        """
        1. Cross-entropy in torch is basically a softmax (to get estimate prob) + 
           negative log-likelihood loss
        2. NLLLoss in torch assumes the input to be already log-scale
        3. Considering above two points, the main difference for the if/else below is that
           we take mean over samples and softmanx to get prob or we take probs for each sample
           and then average to get mean probs
        """
        
        if logits:
            mean_out = out.mean(dim=0, keepdim=False)
            loss = F.cross_entropy(mean_out, y, reduction='sum')
            probs = F.softmax(mean_out, dim=1).data.cpu()

        else:
            mean_out = F.softmax(out, dim=2).mean(dim=0, keepdim=False)
            probs = mean_out.data.cpu()

            log_mean_probs_out = torch.log(mean_out)
            loss = F.nll_loss(log_mean_probs_out, y, reduction='sum')

        pred = mean_out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()

        return loss.data, err, probs

    def all_sample_eval(self, x, y, Nsamples):
        """Returns predictions for each MC sample"""
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)

        out, _, _ = self.model.sample_predict(x, Nsamples)

        prob_out = F.softmax(out, dim=2)
        prob_out = prob_out.data

        return prob_out

    "samples weights, flatten and record it but not the bias"
    def get_weight_samples(self, Nsamples=10):
        
        state_dict = self.model.state_dict()
        weight_vec = []
        Nsamples=10

        for i in range(Nsamples):
            for key in state_dict.keys():
                "each key for loop is a 'str' class object"
                "the split('.') split up the str object according to position of '.' " 
                "and divide into list, weight_name is 'W_mu', 'W_p', 'b_mu' and 'b_p' "
                weight_dict = {}
                weight_name = key.split('.')[1]
                weight_dict[weight_name] = state_dict[key].cpu().data

                if weight_name == 'b_p':
                    W, b = sample_weights(W_mu=W_mu, b_mu=b_mu, W_p=W_p, b_p=b_p)
                    
                    for weight in W.cpu().view(-1):
                        weight_vec.append(weight)

        return np.array(weight_vec)
    
    """
    Record here the value of absolute value of mean divided by std for weights (elementwise)
    posterior distribuiton, probably useful for reparameterization

    1. With thresh, then present element-wise whether threshold is exceeded
    2. Without thresh, report the element-wise _SNR value
    """
    def get_weight_SNR(self, thresh=None):
        state_dict = self.model.state_dict()
        weight_SNR_vec = []

        if thresh is not None:
            mask_dict = {}
        
        weight_dict = {}
        for key in state_dict.keys():
            weight_name = key.split('.')[1]
            layer_name = key.split('.')[0]
            weight_dict[weight_name] = state_dict[key].data
            if weight_name == 'b_p':
                W_mu, W_p, b_mu, b_p = weight_dict.values()
                "compute elementwise posterior std"
                sig_W = 1e-6 + F.softplus(W_p, beta=1, threshold=20)
                sig_b = 1e-6 + F.softplus(b_p, beta=1, threshold=20)
                "element-wise posterior absolute mean divided by std"
                W_snr = (torch.abs(W_mu) / sig_W)
                b_snr = (torch.abs(b_mu) / sig_b)
                "if thresh exist, return element-wise True/False: whether _snr > thresh" 
                if thresh is not None:
                    mask_dict[layer_name + '.W'] = (W_snr > thresh)
                    mask_dict[layer_name + '.b'] = (b_snr > thresh)
                "if no thresh, record the _snr value"
                if thresh is None:
                    for weight_SNR in W_snr.cpu().view(-1):
                        weight_SNR_vec.append(weight_SNR)

                    for weight_SNR in b_snr.cpu().view(-1):
                        weight_SNR_vec.append(weight_SNR)

        if thresh is not None:
            return mask_dict
        else:
            return np.array(weight_SNR_vec)

        
    """
    Sample independetly Nsamples of weights and compute element-wise KL divergence 
    between approximate posterior and the prior distributions
    1. With thresh, then present element-wise whether threshold is exceeded
    2. Without thresh, report the element-wise KL divergence value
    """    
    def get_weight_KLD(self, Nsamples=20, thresh=None):
        state_dict = self.model.state_dict()
        weight_KLD_vec = []

        if thresh is not None:
            mask_dict = {}

        weight_dict = {}
        for key in state_dict.keys():
            weight_name = key.split('.')[1]
            layer_name = key.split('.')[0]
            weight_dict[weight_name] = state_dict[key].data
            if weight_name == 'b_p':
                W_mu, W_p, b_mu, b_p = weight_dict.values()
                "compute elementwise posterior std"
                std_W = 1e-6 + F.softplus(W_p, beta=1, threshold=20)
                std_b = 1e-6 + F.softplus(b_p, beta=1, threshold=20)

                KL_W = W_mu.new(W_mu.shape).zero_()
                KL_b = b_mu.new(b_mu.shape).zero_()

                for i in range(Nsamples):
                    W, b = sample_weights(W_mu=W_mu, b_mu=b_mu, W_p=W_p, b_p=b_p)
                    # Note that this will currently not work with slab and spike prior
                    "posterior element-wise log like minus prior element-wise log like"
                    KL_W += isotropic_gauss_loglike(W, W_mu, std_W,
                                                    do_sum=False) - self.model.prior_instance.loglike(W,
                                                                                                      do_sum=False)
                    "posterior element-wise log like minus prior element-wise log like"
                    KL_b += isotropic_gauss_loglike(b, b_mu, std_b,
                                                    do_sum=False) - self.model.prior_instance.loglike(b,
                                                                                                      do_sum=False)
                "average over number of samples"
                KL_W /= Nsamples
                KL_b /= Nsamples

                "thresh here is for the KL divergence value specifically"
                if thresh is not None:
                    mask_dict[layer_name + '.W'] = KL_W > thresh
                    mask_dict[layer_name + '.b'] = KL_b > thresh

                else:

                    for weight_KLD in KL_W.cpu().view(-1):
                        weight_KLD_vec.append(weight_KLD)

                    for weight_KLD in KL_b.cpu().view(-1):
                        weight_KLD_vec.append(weight_KLD)
    

        if thresh is not None:
            return mask_dict
        else:
            return np.array(weight_KLD_vec)

        
    """
    Masking the model's parameter, if threshold is not exceeded, reset the mean/std for
    posterior distribuiton for the weights to zero/~0.01.
    """
    def mask_model(self, Nsamples=0, thresh=0):
        '''
        Nsamples is used to select SNR (0) or KLD (>0) based masking
        '''
        original_state_dict = copy.deepcopy(self.model.state_dict())
        "Note that = means that changing value of state_dict will also change values for RHS"
        state_dict = self.model.state_dict()

        if Nsamples == 0:
            mask_dict = self.get_weight_SNR(thresh=thresh)
        else:
            mask_dict = self.get_weight_KLD(Nsamples=Nsamples, thresh=thresh)

        n_unmasked = 0

        previous_layer_name = ''
        for key in state_dict.keys():
            layer_name = key.split('.')[0]
            if layer_name != previous_layer_name:
                previous_layer_name = layer_name
                "if element value below threshold, reset to mean zero and small std (~0.01)"
                "this procedure is called masking, put a maks on these values"
                state_dict[layer_name + '.W_mu'][~mask_dict[layer_name + '.W']] = 0
                state_dict[layer_name + '.W_p'][~mask_dict[layer_name + '.W']] = -1000
                state_dict[layer_name + '.b_mu'][~mask_dict[layer_name + '.b']] = 0
                state_dict[layer_name + '.b_p'][~mask_dict[layer_name + '.b']] = -1000
                "number of un-masked weight values"
                n_unmasked += mask_dict[layer_name + '.W'].sum()
                n_unmasked += mask_dict[layer_name + '.b'].sum()

        return original_state_dict, n_unmasked

### Training on MNIST dataset

In [3]:
import time
import torch.utils.data
from torchvision import transforms, datasets
import matplotlib

In [25]:
models_dir = 'test/models_weight_uncertainty_MC_MNIST_gaussian'
results_dir = 'test/results_weight_uncertainty_MC_MNIST_gaussian'

"""
create folder(s) in current location named with element in paths, 
after converting to list object
"""
def mkdir(paths):
    "if not already a list, make it a list"
    if not isinstance(paths, (list, tuple)):
        paths = [paths]
    ""
    for path in paths:
        "if not a directory, make it a directory"
        if not os.path.isdir(path):
            "creates a folder named path in the current location"
            os.makedirs(path)
            
mkdir(models_dir)
mkdir(results_dir)

In [26]:
# train config
NTrainPointsMNIST = 60000
batch_size = 100
nb_epochs = 160
log_interval = 1

savemodel_its = [20, 50, 80, 120]
save_dicts = []

The MNIST dataset is comprised of 70,000 handwritten numeric digit images and their respective labels.

There are 60,000 training images and 10,000 test images, all of which are 28 pixels by 28 pixels.

In [46]:
cprint('c', '\nData:')

# load data

"FIX for data loading issue: 503"
"seems to be a server issue which is fixed by using another amazon server link"
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in datasets.MNIST.resources
]

# data augmentation
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])


trainset = datasets.MNIST("./data", train=True, download=True, transform=transform_train)
valset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

use_cuda = torch.cuda.is_available()

if use_cuda:
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=3)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=3)

else:
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=False,
                                              num_workers=3)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=False,
                                            num_workers=3)

[36m
Data:[0m


### Create BNN Network

In [52]:
# net dims
cprint('c', '\nNetwork:')

lr = 1e-3
nsamples = 3
########################################################################################
net = BBP_Bayes_Net(lr=lr, channels_in=1, side_in=28, cuda=use_cuda, classes=10, batch_size=batch_size,
          Nbatches=(NTrainPointsMNIST/batch_size))

epoch = 0

[36m
Network:[0m
[36m
Net:[0m
[33m Creating Net!! [0m
    Total params: 0.02M


## Training

In [103]:
# train
cprint('c', '\nTrain:')

print('  init cost variables:')
kl_cost_train = np.zeros(nb_epochs)
pred_cost_train = np.zeros(nb_epochs)
err_train = np.zeros(nb_epochs)

cost_dev = np.zeros(nb_epochs)
err_dev = np.zeros(nb_epochs)
# best_cost = np.inf
best_err = np.inf

nb_its_dev = 1
tic0 = time.time()


for i in range(epoch, nb_epochs):
    
    "the object class BaseNet grants the sub-object class BBP_Bayes_Net a .set_mode_train function"
    net.set_mode_train(True)

    tic = time.time()
    nb_samples = 0

    for x, y in trainloader:
        ".fit trains the model with one step of optimization"
        "x shape: (batch_size, channels_in, side_in, side_in)"
        "Q4: why is shape of x not (batch_size, input_dim) and still work?" 
        "A4: See A2, bayes_linear_2L explicitly reshapes the input"
        "y shape: (batch_size)"
        "nsamples is MC samples for weight used to evaluate loss/err, here set to 3"
        cost_dkl, cost_pred, err = net.fit(x, y, samples=nsamples)

        err_train[i] += err
        kl_cost_train[i] += cost_dkl
        pred_cost_train[i] += cost_pred
        "batch_size: len(x) / number of training points accumulate: nb_samples"
        nb_samples += len(x)

    kl_cost_train[i] /= nb_samples
    pred_cost_train[i] /= nb_samples
    err_train[i] /= nb_samples

    toc = time.time()
    net.epoch = i
    
    # ---- print
    print("it %d/%d, Jtr_KL = %f, Jtr_pred = %f, err = %f, " % (i, nb_epochs, kl_cost_train[i], pred_cost_train[i], err_train[i]), end="")
    cprint('r', '   time: %f seconds\n' % (toc - tic))
    
    # Save state dict
    
    if i in savemodel_its:
        save_dicts.append(copy.deepcopy(net.model.state_dict()))
        
    # ---- dev
    "compute the test error among validation data set, save the best model / final model"
    if i % nb_its_dev == 0:
        net.set_mode_train(False)
        nb_samples = 0
        for j, (x, y) in enumerate(valloader):

            cost, err, probs = net.eval(x, y)

            cost_dev[i] += cost
            err_dev[i] += err
            nb_samples += len(x)

        cost_dev[i] /= nb_samples
        err_dev[i] /= nb_samples

        cprint('g', '    Jdev = %f, err = %f\n' % (cost_dev[i], err_dev[i]))

        if err_dev[i] < best_err:
            best_err = err_dev[i]
            cprint('b', 'best test error')
            net.save(models_dir+'/theta_best.dat')
            
toc0 = time.time()
runtime_per_it = (toc0 - tic0) / float(nb_epochs)
cprint('r', '   average time: %f seconds\n' % runtime_per_it)

net.save(models_dir+'/theta_last.dat')

## Result

In [None]:
## ---------------------------------------------------------------------------------------------------------------------
# results
cprint('c', '\nRESULTS:')
nb_parameters = net.get_nb_parameters()
best_cost_dev = np.min(cost_dev)
best_cost_train = np.min(pred_cost_train)
err_dev_min = err_dev[::nb_its_dev].min()

print('  cost_dev: %f (cost_train %f)' % (best_cost_dev, best_cost_train))
print('  err_dev: %f' % (err_dev_min))
print('  nb_parameters: %d (%s)' % (nb_parameters, humansize(nb_parameters)))
print('  time_per_it: %fs\n' % (runtime_per_it))



## Save results for plots
# np.save('results/test_predictions.npy', test_predictions)
np.save(results_dir + '/cost_train.npy', kl_cost_train)
np.save(results_dir + '/cost_train.npy', pred_cost_train)
np.save(results_dir + '/cost_dev.npy', cost_dev)
np.save(results_dir + '/err_train.npy', err_train)
np.save(results_dir + '/err_dev.npy', err_dev)

## ---------------------------------------------------------------------------------------------------------------------
# fig cost vs its

textsize = 15
marker=5

plt.figure(dpi=100)
fig, ax1 = plt.subplots()
ax1.plot(pred_cost_train, 'r--')
ax1.plot(range(0, nb_epochs, nb_its_dev), cost_dev[::nb_its_dev], 'b-')
ax1.set_ylabel('Cross Entropy')
plt.xlabel('epoch')
plt.grid(b=True, which='major', color='k', linestyle='-')
plt.grid(b=True, which='minor', color='k', linestyle='--')
lgd = plt.legend(['test error', 'train error'], markerscale=marker, prop={'size': textsize, 'weight': 'normal'})
ax = plt.gca()
plt.title('classification costs')
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
    ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(textsize)
    item.set_weight('normal')
plt.savefig(results_dir + '/cost.png', bbox_extra_artists=(lgd,), bbox_inches='tight')

plt.figure()
fig, ax1 = plt.subplots()
ax1.plot(kl_cost_train, 'r')
ax1.set_ylabel('nats?')
plt.xlabel('epoch')
plt.grid(b=True, which='major', color='k', linestyle='-')
plt.grid(b=True, which='minor', color='k', linestyle='--')
ax = plt.gca()
plt.title('DKL (per sample)')
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
    ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(textsize)
    item.set_weight('normal')


plt.figure(dpi=100)
fig2, ax2 = plt.subplots()
ax2.set_ylabel('% error')
ax2.semilogy(range(0, nb_epochs, nb_its_dev), 100 * err_dev[::nb_its_dev], 'b-')
ax2.semilogy(100 * err_train, 'r--')
plt.xlabel('epoch')
plt.grid(b=True, which='major', color='k', linestyle='-')
plt.grid(b=True, which='minor', color='k', linestyle='--')
ax2.get_yaxis().set_minor_formatter(matplotlib.ticker.ScalarFormatter())
ax2.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
lgd = plt.legend(['test error', 'train error'], markerscale=marker, prop={'size': textsize, 'weight': 'normal'})
ax = plt.gca()
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
    ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(textsize)
    item.set_weight('normal')
plt.savefig(results_dir + '/err.png',  bbox_extra_artists=(lgd,), box_inches='tight')