# Bayessian NN with gaussian variational learning on MNIST

In the example we will learn ordinary and bayessian NNs on MNIST dataset, compare their performances, evaluate pruning effectiveness and model's uncertainty

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import shutil
from datetime import datetime
from functools import partial
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

from classification_nn import ClassificationNet
from metrics import log_metrics
from variational_gaussian.bayes_nn import MakeModuleBayessian
from variational_gaussian.bayessian_loss import MakeLossBayessian

from torchvision.datasets import MNIST
from torchvision.transforms.functional import pil_to_tensor

In [3]:
import yaml

# load experiment's config
with open("params.yaml") as f:
    exp_params = yaml.full_load(f)

In [4]:
# cleaning cwd
if Path("runs").exists():
    shutil.rmtree(Path("runs"))

artifacts_dir = Path("artifacts/")
if artifacts_dir.exists():
    shutil.rmtree(artifacts_dir)
artifacts_dir.mkdir()

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Device:", device)

Device: cpu


## Loading MNIST dataset

In [6]:
# load dataset

def img_transform(img):
    return pil_to_tensor(img).float()

train_dataset = MNIST("./data", train=True, download=True, transform=img_transform)
test_dataset = MNIST("./data", train=False, download=True, transform=img_transform)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=exp_params["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=exp_params["batch_size"], shuffle=False)

In [8]:
IMG_SIZE = train_dataset[0][0].shape
print(IMG_SIZE)
NUM_CLASSES = 10

torch.Size([1, 28, 28])


## Learn Basic NN

In [9]:
basic_nn = ClassificationNet(IMG_SIZE, NUM_CLASSES, exp_params["num_layers"], exp_params["hidden_size"])
basic_nn = basic_nn.float().to(device)

for module in basic_nn.children():
    print(module)

Sequential(
  (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=784, out_features=128, bias=True)
  (3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): Linear(in_features=128, out_features=128, bias=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=128, bias=True)
  (7): ReLU()
  (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): Linear(in_features=128, out_features=128, bias=True)
  (10): ReLU()
  (11): Linear(in_features=128, out_features=10, bias=True)
)


In [11]:
# define yours optimization

N_EPOCHS = 4
LR = 1e-3
# add L_2 regularization
WEIGHT_DECAY = 1e-3

optimizer = optim.Adam(basic_nn.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [12]:
writer = SummaryWriter()

In [13]:
global_step = 0

for epoch in tqdm(range(N_EPOCHS), desc="Epochs"):
    # training
    basic_nn.train()

    for imgs, targets in tqdm(train_loader, desc="Train batches", leave=False):
        imgs = imgs.to(device, torch.float32)
        targets = targets.to(device)

        logits = basic_nn(imgs)
        loss = F.cross_entropy(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar("Train/Cross_entropy", loss.item(), global_step)
        global_step += 1

    # testing
    basic_nn.eval()
    with torch.no_grad():
        log_metrics(global_step, basic_nn, test_loader, device, writer)

    # saving model
    torch.save(basic_nn.state_dict(), artifacts_dir / "basic_nn.pkl")

writer.close()

Epochs:   0%|          | 0/4 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

In [14]:
with torch.no_grad():
    basic_nn_metrics = log_metrics(0, basic_nn, test_loader, device)
print(basic_nn_metrics)

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

{'Test/Cross_entropy': 0.10904046893119812, 'Test/Accuracy': 0.9693, 'Test/Precision_macro': np.float64(0.9695484847300767), 'Test/Recall_macro': np.float64(0.9691272930732874)}


## Learn Bayessian NN

In [15]:
bayes_nn = MakeModuleBayessian(
    ClassificationNet(IMG_SIZE, NUM_CLASSES, exp_params["num_layers"], exp_params["hidden_size"])
)

for module in bayes_nn.children():
    print(module)

BayessianSequential(
  (bayes_0): BayessianBatchNorm2d()
  (bayes_1): BayessianFlatten()
  (bayes_2): BayessianLinear()
  (bayes_3): BayessianBatchNorm1d()
  (bayes_4): BayessianLinear()
  (bayes_5): BayessianReLU()
  (bayes_6): BayessianLinear()
  (bayes_7): BayessianReLU()
  (bayes_8): BayessianBatchNorm1d()
  (bayes_9): BayessianLinear()
  (bayes_10): BayessianReLU()
  (bayes_11): BayessianLinear()
)


In [16]:
bayes_loss = MakeLossBayessian(nn.CrossEntropyLoss(), bayes_nn)

In [17]:
# define yours optimization

N_ESTIMATES = 10

N_EPOCHS = 4
LR = 1e-3
# add L_2 regularization
WEIGHT_DECAY = 1e-5

optimizer = optim.Adam(bayes_nn.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [18]:
writer = SummaryWriter("runs/bayes_nn/" + str(datetime.now()))

In [19]:
global_step = 0

for epoch in tqdm(range(N_EPOCHS), desc="Epochs"):
    # training
    bayes_nn.train()

    for imgs, targets in tqdm(train_loader, desc="Train batches", leave=False):
        batch_size = imgs.shape[0]

        imgs = imgs.to(device, torch.float32)
        targets = targets.to(device)

        # get list of samples
        logits = bayes_nn(imgs)
        loss = bayes_loss(logits, targets, alpha_KL=1 / len(train_dataset))

        loss["full_loss"].backward()
        optimizer.step()

        for loss_name, loss_val in loss.items():
            writer.add_scalar(f"Train/{loss_name}", loss_val.item(), global_step)

        optimizer.zero_grad()
        global_step += 1

    # testing
    bayes_nn.eval()
    with torch.no_grad():
        log_metrics(global_step, bayes_nn, test_loader, device, writer)

    # saving model
    torch.save(bayes_nn.state_dict(), artifacts_dir / "bayes_nn.pkl")

writer.close()

Epochs:   0%|          | 0/4 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

Train batches:   0%|          | 0/1875 [00:00<?, ?it/s]

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

In [None]:
for p_name, param in bayes_nn.named_parameters():
    display(p_name)
    display(param)

In [21]:
with torch.no_grad():
    bayes_nn_metrics = log_metrics(0, bayes_nn, test_loader, device)
print(bayes_nn_metrics)

Test batchs:   0%|          | 0/313 [00:00<?, ?it/s]

{'Test/Cross_entropy': 0.68235182762146, 'Test/Accuracy': 0.8938, 'Test/Precision_macro': np.float64(0.8940054198161194), 'Test/Recall_macro': np.float64(0.8926892517466392)}


In [22]:
map_net = bayes_nn.get_map_module(prune=True)

weight tensor(0)
bias tensor(0)
weight tensor(2)
bias tensor(0)
weight tensor(0)
bias tensor(0)
weight tensor(0)
bias tensor(0)
weight tensor(0)
bias tensor(0)
weight tensor(0)
bias tensor(0)
weight tensor(0)
bias tensor(0)
weight tensor(1)
bias tensor(0)
