# Bayesian NN

This notebook implements a simple, fully connected Bayesian neural network using `pyro` and `torch`.


The network is trained on the _MNIST_ dataset. Network parameters are exported during training.<br>
After training, sampled outputs are exported (for different training stages) to be used in an external visualization tool.

## Import dependecies

Required modules:
- `torch` for neural net functionality (including tensors)
- `torchvision` to obtain the MNIST dataset and enable some tensor transformations
- `pyro` for variational inference – this makes the network Bayesian

Additional modules:
- `time` for timing operations (could be raplayed by `%time` in notebook)
- `numpy` and `matplotlib.pyplot` to post-process and plot the data
- `os` to parse directory names and list files
- `csv` and/or `json` for data export

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions.constraints as constraints

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

import numpy as np
import matplotlib.pyplot as plt

import time
import os
import csv
import json

## Define a neural network class and instantiate

Fully connected, one hidden layer, relu activation function

In [17]:
class NN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size, use_cuda=False):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        if use_cuda:
            self.cuda()
            
        self.use_cuda = use_cuda

    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.out(output)
        return output

In [18]:
net = NN(28*28, 1024, 10, use_cuda=True)

## Define some activation functions

- log softmax
- softplus

In [19]:
log_softmax = nn.LogSoftmax(dim=1)

softplus = torch.nn.Softplus()

## Define priors ond place over NN module

In [20]:
def model(x_data, y_data):

    fc1w_prior = dist.Normal(
        loc=torch.zeros_like(net.fc1.weight),
        scale=torch.ones_like(net.fc1.weight),
    )
    fc1b_prior = dist.Normal(
        loc=torch.zeros_like(net.fc1.bias),
        scale=torch.ones_like(net.fc1.bias),
    )

    outw_prior = dist.Normal(
        loc=torch.zeros_like(net.out.weight),
        scale=torch.ones_like(net.out.weight),
    )
    outb_prior = dist.Normal(
        loc=torch.zeros_like(net.out.bias),
        scale=torch.ones_like(net.out.bias),
    )

    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias'  : fc1b_prior,
        'out.weight': outw_prior,
        'out.bias'  : outb_prior
    }

    # lift module parameters to random variables sampled from the priors
    lifted_model = pyro.random_module('module', net, priors)
    # sample a regressor (which also samples w and b)
    lifted_reg_model = lifted_model()

    lhat = lifted_reg_model(x_data)

    pyro.sample('obs', dist.Categorical(logits=lhat), obs=y_data)

## Define guide and initialize randomly

_TODO:_ find out what `softplus` accomplishes here

In [25]:
def guide(x_data, y_data):
    
    # First layer weight distribution priors
    fc1w_mean = torch.randn_like(net.fc1.weight)
    fc1w_std = torch.abs(torch.randn_like(net.fc1.weight))
    fc1w_mean_param = pyro.param('fc1w_mean', fc1w_mean)
    fc1w_std_param = pyro.param('fc1w_std', fc1w_std, constraint=constraints.positive)
    fc1w_prior = dist.Normal(
        loc=fc1w_mean_param,
        scale=fc1w_std_param
    )

    # First layer bias distribution priors
    fc1b_mean = torch.randn_like(net.fc1.bias)
    fc1b_std = torch.abs(torch.randn_like(net.fc1.bias))
    fc1b_mean_param = pyro.param('fc1b_mean', fc1b_mean)
    fc1b_std_param = pyro.param('fc1b_std', fc1b_std, constraint=constraints.positive)
    fc1b_prior = dist.Normal(
        loc=fc1b_mean_param,
        scale=fc1b_std_param
    )

    # Output layer weight distribution priors
    outw_mean = torch.randn_like(net.out.weight)
    outw_std = torch.abs(torch.randn_like(net.out.weight))
    outw_mean_param = pyro.param('outw_mean', outw_mean)
    outw_std_param = pyro.param('outw_std', outw_std, constraint=constraints.positive)
    outw_prior = dist.Normal(
        loc=outw_mean_param,
        scale=outw_std_param
    )

    # Output layer bias distribution priors
    outb_mean = torch.randn_like(net.out.bias)
    outb_std = torch.abs(torch.randn_like(net.out.bias))
    outb_mean_param = pyro.param('outb_mean', outb_mean)
    outb_std_param = pyro.param('outb_std', outb_std, constraint=constraints.positive)
    outb_prior = dist.Normal(
        loc=outb_mean_param,
        scale=outb_std_param
    )

    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias'  : fc1b_prior,
        'out.weight': outw_prior,
        'out.bias'  : outb_prior
    }

    lifted_module = pyro.random_module('module', net, priors)

    return lifted_module()

## Choose optimization parameters

- Adam optimizer ([Kingma & Ba, 2014](https://arxiv.org/abs/1412.6980))
    - Learning rate: 0.001
- Stochastic variational inference
    - Expectation lower bound (ELBO) as loss function

In [26]:
optim = pyro.optim.Adam({'lr': 0.001})
svi = pyro.infer.SVI(model, guide, optim, loss=pyro.infer.Trace_ELBO())

## Construct loaders for training and test data

- Dataset: MNIST
- Batch size: 128

In [27]:
batch_size = 128
use_cuda = True

kwargs = {'num_workers': 1, 'pin_memory': use_cuda, 'batch_size': batch_size}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist-data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ])
    ),
    shuffle=True,
    **kwargs
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist-data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ])
    ),
    shuffle=False,
    **kwargs
)

## Perform training

If experiments are rerun, the parameter store has to be cleared first (uncomment first line).<br>
Possibly uncomment lines for exporting network data and adapt export granularity.<br>
Set number of epochs via `num_iterations`.

In [29]:
# pyro.clear_param_store()
# pyro.get_param_store().load('saved-params-bs128/params_ep6_batch00450')

num_iterations = 15
loss = 0

start_time = time.time()
for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        images = data[0].view(-1,28*28)
        labels = data[1]
        
        if use_cuda:
            images = images.cuda()
            labels = labels.cuda()
        loss += svi.step(images, labels)
        if batch_id % 30 == 0:
            print('.', end='')
    
    #filename = 'saved-params-bs{}_NEW/params_ep{:02}'.format(batch_size, j)
    #pyro.get_param_store().save(filename)
    #print('saved parameter store to {}'.format(filename))
    
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
        
    print('\nEpoch ', j, ' Loss ', total_epoch_loss_train)
    print('Time taken: {} s'.format(time.time()-start_time))
    
#filename = 'saved-params-bs%d_NEW/params_after_ep%d' % (batch_size, 11 + num_iterations - 1)
#pyro.get_param_store().save(filename)

................
Epoch  0  Loss  6493.183268421574
Time taken: 10.154435157775879 s
................
Epoch  1  Loss  4713.582906363217
Time taken: 19.771262645721436 s
................
Epoch  2  Loss  3484.526981917675
Time taken: 29.354207038879395 s
................
Epoch  3  Loss  2584.2780761197646
Time taken: 39.05123591423035 s
................
Epoch  4  Loss  1923.5260506713312
Time taken: 48.690467834472656 s
................
Epoch  5  Loss  1436.5318848463216
Time taken: 58.78489065170288 s
................
Epoch  6  Loss  1080.2607213371198
Time taken: 69.09088921546936 s
................
Epoch  7  Loss  818.5022287346601
Time taken: 78.94313597679138 s
................
Epoch  8  Loss  626.825169472154
Time taken: 88.77480125427246 s
................
Epoch  9  Loss  486.18120333329836
Time taken: 98.71620869636536 s
................
Epoch  10  Loss  382.3653835990429
Time taken: 108.69900894165039 s
................
Epoch  11  Loss  305.231566237402
Time taken: 119.3854873180

## Test/Predict

Prediction is performed by simply choosing the class with the highest score (from the score average over drawn samples).<br>
`num_samples` governs the number of samples to be drawn for prediction.

In [33]:
def sample_outputs(x, num_samples=10):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    return torch.stack(yhats)

# average logits, then take argmax
def predict(x, num_samples=10):
    yhats = sample_outputs(x, num_samples=num_samples)
    mean = torch.mean(yhats, 0)
    if use_cuda:
        mean = mean.cpu()
    return np.argmax(mean.numpy(), axis=1)

Perform test while forcing the network to predict:

In [34]:
num_samples = 20

print('Prediction when network is forced to predict')
start_time = time.time()
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    if use_cuda:
        images = images.view(-1,28*28).cuda()
    else:
        images = images.view(-1,28*28)
    outputs = predict(images, num_samples=num_samples)
    total += labels.size(0)
    correct += (np.asarray(outputs) == np.asarray(labels)).sum().item()
    if j % 3 == 0: print('.', end='')

print('\nlength of test set: {}'.format(total))
print('accuracy for num_samples = %d: %d %%' % (num_samples, 100 * correct / total))
print('Time taken for predictions: {}'.format(time.time()-start_time))

Prediction when network is forced to predict
...........................
length of test set: 10000
accuracy for num_samples = 20: 91 %
Time taken for predictions: 5.1039416790008545


***

Alternatively, predict by averaging softmaxes:

In [45]:
num_samples = 100

start_time = time.time()

for j , data in enumerate(test_loader):
    images, labels = data
    if use_cuda:
        images = images.view(-1,28*28).cuda()
    else:
        images = images.view(-1,28*28)
    yhats = sample_outputs(images, num_samples=num_samples)
    softmax = nn.Softmax(dim=2)(yhats)
    mean = torch.mean(softmax, 0)
    if use_cuda:
        mean = mean.cpu()
    prediction = np.argmax(mean.numpy(), axis=1)
    total += labels.size(0)
    correct += (np.asarray(prediction) == np.asarray(labels)).sum().item()
    if j % 3 == 0: print('.', end='')
        
print('\nlength of test set: {}'.format(total))
print('accuracy for num_samples = %d: %d %%' % (num_samples, 100 * correct / total))
print('Time taken for predictions: {}'.format(time.time()-start_time))

...........................
length of test set: 89952
accuracy for num_samples = 100: 64 %
Time taken for predictions: 23.941030502319336


## Obtain "nice" lists of trianing and test examples

Define some custom data loaders for obtaining single training and test instances:

In [2]:
single_train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist-data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ])
    ),
    batch_size=1,
    shuffle=False
)

single_test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './mnist-data',
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ])
    ),
    batch_size=1,
    shuffle=False
)

For each class, obtain a certain number of training and test examples:

In [60]:
samples_per_label = {
    'train': 1000,
    'test': 500
}

In [61]:
training_images = [[] for _ in range(10)]
index_counts = [0 for _ in range(10)]

for index, batch in enumerate(single_train_loader):
    image, label = batch
    if index_counts[label] < samples_per_label['train']:
        # sample_indices[label][index_counts[label]] = index
        training_images[label].append(image)
        index_counts[label] += 1

test_images = [[] for _ in range(10)]
index_counts = [0 for _ in range(10)]

for index, batch in enumerate(single_test_loader):
    image, label = batch
    if index_counts[label] < samples_per_label['test']:
        # sample_indices[label][index_counts[label]] = index
        test_images[label].append(image)
        index_counts[label] += 1

Reshape the training and test image lists. After this step, they each have shape `(10 * samples_per_label['train'|'test'], 28*28)`:

In [62]:
training_images = torch.stack([image.view(-1,28*28) for digit_class in training_images for image in digit_class])

test_images = torch.stack([image.view(-1,28*28) for digit_class in test_images for image in digit_class])

## Sample some ouputs

Sample some outputs for the previously obtained training and test images:

In [74]:
num_samples = 20

start_time = time.time()

training_outputs = []
for _ in range(num_samples):
    model = guide(None, None)
    training_outputs.append(model(training_images))
training_outputs = torch.stack(training_outputs).transpose(0,1).flatten(start_dim=2)

print('\nElapsed time: {} s'.format(time.time() - start_time))

start_time = time.time()

test_outputs = []
for _ in range(num_samples):
    model = guide(None, None)
    test_outputs.append(model(test_images))
test_outputs = torch.stack(test_outputs).transpose(0,1).flatten(start_dim=2)

print('\nElapsed time: {} s'.format(time.time() - start_time))


Elapsed time: 2.370715618133545 s

Elapsed time: 1.735379695892334 s


## Obtain epoch-wise data

Select directory with saved parameters (possibly adapt `batch_size` before) and select only those at end of epochs.

In [78]:
directory = '.\saved-params-bs' + str(batch_size) + '_NEW'
parsed_dir = os.fsencode(directory)

# epoch_endstate_files = [s for s in os.listdir(parsed_dir) if '450' in s.decode()]
epoch_endstate_files = [s for s in os.listdir(parsed_dir) if 'before' in s.decode()]

In [79]:
epoch_endstate_files

[b'params_before_ep00',
 b'params_before_ep01',
 b'params_before_ep02',
 b'params_before_ep03',
 b'params_before_ep04',
 b'params_before_ep05',
 b'params_before_ep06',
 b'params_before_ep07',
 b'params_before_ep08',
 b'params_before_ep09',
 b'params_before_ep10',
 b'params_before_ep11',
 b'params_before_ep12',
 b'params_before_ep13',
 b'params_before_ep14',
 b'params_before_ep15']

For each saved network state in `epoch_endstate_files`, pass all `training_images` through the Bayesian NN, each time sampling `num_samples` outputs.

In [80]:
num_samples = 20

start_time = time.time()
epoch_training_outputs = []
for index, file in enumerate(epoch_endstate_files):
    filename = os.fsdecode(file)
    pyro.get_param_store().load(directory + '/' + filename)
    yhats = sample_outputs(training_images, num_samples=num_samples)
    epoch_training_outputs.append(yhats)
    print('.', end='')
print('\nElapsed time: {} s'.format(time.time() - start_time))

................
Elapsed time: 42.52355217933655 s


Reshape epoch-wise data.<br>
New shape is `N_images * N_epochs * num_samples * dim_output`.

In [81]:
epoch_training_outputs = torch.stack(epoch_training_outputs)

In [82]:
epoch_training_outputs = epoch_training_outputs.transpose(2,0).transpose(2,1).flatten(start_dim=3)

## Export everything

Save the outputs using `torch.save`.

In [75]:
with open('exported-data/training_outputs_NEW.npy', 'wb') as file:
    np.save(file, training_outputs.detach().numpy())

In [77]:
with open('exported-data/test_outputs_NEW.npy', 'wb') as file:
    np.save(file, test_outputs.detach().numpy())

In [83]:
with open('exported-data/training_outputs_epoch_NEW.npy', 'wb') as file:
    np.save(file, epoch_training_outputs.detach().numpy())

In [12]:
with open('exported-data/training_inputs.npy', 'wb') as file:
    np.save(file, training_images.detach().numpy())

In [13]:
with open('exported-data/test_inputs.npy', 'wb') as file:
    np.save(file, test_images.detach().numpy())

In [285]:
epoch_training_data = np.load('training_outputs_epoch.npy')

In [286]:
epoch_training_data_for_json = []
for index, image in enumerate(epoch_training_data):
    entry = dict()
    entry['index'] = index
    # entry['image'] = training_images[index].tolist()
    entry['data'] = []
    for epoch, outputs in enumerate(image):
        subentry = dict()
        subentry['epoch'] = epoch
        subentry['outputs'] = outputs.tolist()
        entry['data'].append(subentry)
    epoch_training_data_for_json.append(entry)

In [12]:
with open('NEW_complete_training_data.json', 'w') as outfile:
    json.dump(epoch_training_data_for_json, outfile)