# Bayessian NN with gaussian variational learning on MNIST

In the example we will learn ordinary and bayessian NNs on MNIST dataset, compare their performances, evaluate pruning effectiveness and model's uncertainty

In [34]:
%load_ext autoreload
%autoreload 2

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

from classification_nn import ClassificationNet
from metrics import log_metrics
from variational_gaussian.bayes_nn import MakeModuleBayessian
from variational_gaussian.bayessian_loss import MakeLossBayessian

from torchvision.datasets import MNIST
from torchvision.transforms.functional import pil_to_tensor

In [36]:
import yaml

# load experiment's config
with open("params.yaml") as f:
    exp_params = yaml.full_load(f)

In [37]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Device:", device)

Device: cpu


## Loading MNIST dataset

In [38]:
# load dataset

def img_transform(img):
    return pil_to_tensor(img).float()

train_dataset = MNIST("./data", train=True, download=True, transform=img_transform)
test_dataset = MNIST("./data", train=False, download=True, transform=img_transform)

In [39]:
train_loader = DataLoader(train_dataset, batch_size=exp_params["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=exp_params["batch_size"], shuffle=False)

In [40]:
IMG_SIZE = train_dataset[0][0].shape
print(IMG_SIZE)
NUM_CLASSES = 10

torch.Size([1, 28, 28])


## Learn Basic NN

In [42]:
basic_nn = ClassificationNet(IMG_SIZE, NUM_CLASSES, exp_params["num_layers"], exp_params["hidden_size"])
basic_nn = basic_nn.float().to(device)

In [43]:
# define yours optimization

N_EPOCHS = 10
LR = 1e-3
# add L_2 regularization
WEIGHT_DECAY = 1e-3

optimizer = optim.Adam(basic_nn.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [44]:
writer = SummaryWriter()

In [None]:
for epoch in tqdm(range(N_EPOCHS), desc="Epochs"):
    # training
    basic_nn.train()

    losses = []
    for imgs, targets in tqdm(train_loader, desc="Train batchs", leave=True):
        imgs = imgs.to(device)
        targets = targets.to(device)

        logits = basic_nn(imgs)
        loss = F.cross_entropy(logits, targets)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    writer.add_scalar("Train/Cross_entropy", torch.Tensor(losses).mean(), epoch)

    # testing
    basic_nn.eval()
    with torch.no_grad():
        log_metrics(epoch, basic_nn, test_loader, writer)