# Example Experiment
> Experiment using Repeated MNIST and BatchBALD vs BALD vs random sampling

This notebook ties everything together and runs an AL loop.

In [1]:
# experiment

#import blackhc.project.script
from tqdm.auto import tqdm

In [33]:
# experiment

import math
import time
import torch
from torch import nn as nn
from torch.nn import functional as F

from batchbald_redux import (
    active_learning,
    batchbald,
    consistent_mc_dropout,
    joint_entropy,
    repeated_mnist,
)

Let's define our Bayesian CNN model that we will use to train MNIST.

In [8]:
# experiment


class BayesianCNN(consistent_mc_dropout.BayesianModule):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = consistent_mc_dropout.ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = consistent_mc_dropout.ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = consistent_mc_dropout.ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        input = F.log_softmax(input, dim=1)

        return input

In [40]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import wandb

In [21]:
TRAIN_SIZE = 20
VAL_SIZE = 300
MONITOR_SIZE=300
POOL_SIZE = 60000 - TRAIN_SIZE - VAL_SIZE -MONITOR_SIZE


transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

mnist_train = MNIST(".", train=True, download=True, transform=transform)
mnist_test = MNIST(".", train=False, download=True, transform=transform)


train_set, val_set, pool_set,monitor_set = random_split(
    mnist_train, [TRAIN_SIZE, VAL_SIZE, POOL_SIZE,MONITOR_SIZE]
)
train_loader = DataLoader(
    dataset=train_set, batch_size=TRAIN_SIZE, shuffle=True
)
val_loader = DataLoader(dataset=val_set, batch_size=VAL_SIZE, shuffle=True)
pool_loader = DataLoader(
    dataset=pool_set, batch_size=POOL_SIZE, shuffle=True
)

monitor_loader = DataLoader(
    dataset=pool_set, batch_size=MONITOR_SIZE, shuffle=True
)
test_loader = DataLoader(
    dataset=mnist_test, batch_size=312, shuffle=True
)

Grab our dataset, we'll use Repeated-MNIST. We will acquire to samples for each class for our initial training set.

In [41]:
# experiment

train_dataset, test_dataset = repeated_mnist.create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)
train_dataset,monitor_set = random_split(
    train_dataset, [59900,100]
)
num_initial_samples = 20
num_classes = 10

initial_samples = active_learning.get_balanced_sample_indices(
    repeated_mnist.get_targets(train_dataset), num_classes=num_classes, n_per_digit=num_initial_samples / num_classes
)

In [42]:
len(train_dataset)

59900

For this example, we are going to take two shortcuts that will reduce the performance:
* we discard most of the training set and only keep 20k samples; and
* we don't implement early stopping or epoch-wise training.

Instead, we always train by drawing 24576 many samples from the training set. This will overfit in the beginning and underfit later, but it still is sufficient to achieve 90% accuracy with 105 samples in the training set. 

In [43]:
# experiment
max_training_samples = 150
acquisition_batch_size = 5
num_inference_samples = 100
num_test_inference_samples = 5
num_samples = 100000

test_batch_size = 512
batch_size = 64
scoring_batch_size = 128
training_iterations = 4096 * 6

use_cuda = torch.cuda.is_available()

print(f"use_cuda: {use_cuda}")

device = "cuda" if use_cuda else "cpu"

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs)

active_learning_data = active_learning.ActiveLearningData(train_dataset)

# Split off the initial samples first.
active_learning_data.acquire(initial_samples)

# THIS REMOVES MOST OF THE POOL DATA. UNCOMMENT THIS TO TAKE ALL UNLABELLED DATA INTO ACCOUNT!
#active_learning_data.extract_dataset_from_pool(40000)

train_loader = torch.utils.data.DataLoader(
    active_learning_data.training_dataset,
    sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset, training_iterations),
    batch_size=batch_size,
    **kwargs,
)

pool_loader = torch.utils.data.DataLoader(
    active_learning_data.pool_dataset, batch_size=scoring_batch_size, shuffle=False, **kwargs
)


use_cuda: True


In [None]:
wandb.init(project="ActiveLearning_1124")
wandb.run.name = 'BALD_MNIST_200001124'
wandb.run.save()


# Run experiment
test_accs = []
test_loss = []
added_indices = []

pbar = tqdm(initial=len(active_learning_data.training_dataset), total=max_training_samples, desc="Training Set Size")
epoch=0
model = BayesianCNN(num_classes).to(device=device)
optimizer = torch.optim.Adam(model.parameters())
while True:

    epoch+=1
    model.train()
    train_loss=0
    # Train
    for data, target in tqdm(train_loader, desc="Training", leave=False):
        data = data.to(device=device)
        target = target.to(device=device)

        optimizer.zero_grad()

        prediction = model(data, 1).squeeze(1)
        loss = F.nll_loss(prediction, target)
        train_loss+=loss.item()
        loss.backward()
        optimizer.step()
                                                   
    train_loss /= len(train_loader)
    # Test
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Testing", leave=False):
            data = data.to(device=device)
            target = target.to(device=device)

            prediction = torch.logsumexp(model(data, num_test_inference_samples), dim=1) - math.log(
                num_test_inference_samples
            )
            loss += F.nll_loss(prediction, target, reduction="sum")
        
            prediction = prediction.max(1)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()

    loss /= len(test_loader.dataset)

    percentage_correct = 100.0 * correct / len(test_loader.dataset)


                                                   
    if len(active_learning_data.training_dataset) >= max_training_samples:
        break

    # Acquire pool predictions
                                                   
    acquir_start=time.time()
    N = len(active_learning_data.pool_dataset)
    logits_N_K_C = torch.empty((N, num_inference_samples, num_classes), dtype=torch.double, pin_memory=use_cuda)
    
    with torch.no_grad():
        model.eval()
        
        for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)):
            data = data.to(device=device)

            lower = i * pool_loader.batch_size
            upper = min(lower + pool_loader.batch_size, N)
            logits_N_K_C[lower:upper].copy_(model(data, num_inference_samples).double(), non_blocking=True)

            
    with torch.no_grad():
#         candidate_batch = batchbald.get_batchbald_batch(
#             logits_N_K_C, acquisition_batch_size, num_samples, dtype=torch.double, device=device
#         )
        candidate_batch = batchbald.get_bald_batch(
            logits_N_K_C, acquisition_batch_size, dtype=torch.double, device=device
        )
                                                   
    acquire_end=acquir_start-time.time()
    wandb.log({"Epoch":epoch,
               "Num_Pool":len(pool_loader)
               "Num_Train":len(training_dataset)
               "Test Loss": loss,
               'Train Loss': train_loss,
              "Test Accuracy":percentage_correct,
               "Query Time":-acquire_end})

                                                   
    targets = repeated_mnist.get_targets(active_learning_data.pool_dataset)
    dataset_indices = active_learning_data.get_dataset_indices(candidate_batch.indices)

    print("Dataset indices: ", dataset_indices)
    print("Scores: ", candidate_batch.scores)
    print("Labels: ", targets[candidate_batch.indices])

    active_learning_data.acquire(candidate_batch.indices)
    added_indices.append(dataset_indices)
    pbar.update(len(dataset_indices))

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




Training Set Size:  13%|#3        | 20/150 [00:00<?, ?it/s]

Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19880 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19880 [00:00<?, ?it/s]

Dataset indices:  [11504 16265 45267 22626 58655]
Scores:  [1.3371315741670893, 1.2840991833492765, 1.259585630032177, 1.247872067496476, 1.2430764680138355]
Labels:  tensor([2, 6, 5, 6, 9])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19875 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19875 [00:00<?, ?it/s]

Dataset indices:  [47390 10059 40329  1000 25137]
Scores:  [1.3126765715210704, 1.2899330328994778, 1.2839473608201857, 1.271654106883802, 1.267077879773238]
Labels:  tensor([7, 3, 3, 3, 2])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19870 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19870 [00:00<?, ?it/s]

Dataset indices:  [51687 20636 13466 18147 42797]
Scores:  [1.3697880689368809, 1.3375052365825908, 1.329340756522325, 1.322733408558638, 1.3198637128780903]
Labels:  tensor([5, 2, 3, 3, 3])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19865 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19865 [00:00<?, ?it/s]

Dataset indices:  [55396 13127 47377  4079 19955]
Scores:  [1.3310214270691543, 1.3019635570317005, 1.28756920713807, 1.279090520821545, 1.275343704621409]
Labels:  tensor([6, 0, 0, 4, 5])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19860 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19860 [00:00<?, ?it/s]

Dataset indices:  [53486 55131 20808 56443  7731]
Scores:  [1.330598704211786, 1.303957533518792, 1.302388199960565, 1.2810615329865356, 1.27600770829741]
Labels:  tensor([5, 8, 5, 3, 8])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19855 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19855 [00:00<?, ?it/s]

Dataset indices:  [25693 16074   252 17257 54913]
Scores:  [1.3971341488951907, 1.3190735995908967, 1.3125291360637168, 1.3120757085992214, 1.2892533574609386]
Labels:  tensor([9, 2, 4, 8, 0])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19850 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19850 [00:00<?, ?it/s]

Dataset indices:  [22420  4098 42433 41469  3580]
Scores:  [1.3753962463348717, 1.3592637069532345, 1.3115099819548446, 1.308221824362246, 1.2973480618696858]
Labels:  tensor([0, 9, 0, 4, 0])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19845 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19845 [00:00<?, ?it/s]

Dataset indices:  [50678 32447  6296 53170 22518]
Scores:  [1.3373361807478488, 1.278140804499416, 1.2682258135163482, 1.2578451633772576, 1.2462278535485924]
Labels:  tensor([2, 2, 9, 8, 4])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/155 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19840 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19840 [00:00<?, ?it/s]

Dataset indices:  [54765 32899  5526 57516 45945]
Scores:  [1.2704354553099289, 1.2627667336384292, 1.2458695930591912, 1.2411978507607926, 1.2019321650581771]
Labels:  tensor([4, 5, 5, 8, 8])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/155 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19835 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19835 [00:00<?, ?it/s]

Dataset indices:  [42854 23377 39908 50881 25810]
Scores:  [1.3131297043289765, 1.2673920631047837, 1.2667417101898082, 1.2580927065328944, 1.256854244353568]
Labels:  tensor([6, 7, 8, 1, 6])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluating Acquisition Set:   0%|          | 0/155 [00:01<?, ?it/s]

Conditional Entropy:   0%|          | 0/19830 [00:00<?, ?it/s]

Entropy:   0%|          | 0/19830 [00:00<?, ?it/s]

Dataset indices:  [26075 42067  6337 23716 46436]
Scores:  [1.3621845313869982, 1.3410256819863273, 1.3219267223569886, 1.3177197742707314, 1.2987949702224544]
Labels:  tensor([0, 5, 3, 5, 8])


Training:   0%|          | 0/384 [00:01<?, ?it/s]

Testing:   0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
# experiment
import matplotlib.pyplot as plt
import numpy as np

plt.plot(np.arange(num_initial_samples, max_training_samples + 1, acquisition_batch_size), test_accs)
plt.xlabel("# training samples")
plt.ylabel("Accuracy")
plt.hlines(90, 20, 150, linestyles="dashed", color="r")
plt.show()

In [38]:
model = BayesianCNN(num_classes).to(device=device)

with torch.no_grad():
    model.eval()

    for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)):
        data = data.to(device=device)

        lower = i * pool_loader.batch_size
        upper = min(lower + pool_loader.batch_size, N)
        logits_N_K_C[lower:upper].copy_(model(data, num_inference_samples).double(), non_blocking=True)

Evaluating Acquisition Set:   0%|          | 0/156 [00:01<?, ?it/s]

NameError: name 'N' is not defined