In [None]:
import sys

# numpy to manipulate and make computations on arrays
import numpy as np

# libraries to process datasets
#from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset

# importing evaluation metrics and splitting function from sklearn 
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix, classification_report

# importing PyTorch libraries
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import v2

# wrapper allowing to use PyTorch with sklearn
from skorch import NeuralNetClassifier

# Active learning library
from collections.abc import Mapping
from typing import Callable
from sklearn.base import BaseEstimator
from modAL.utils.data import modALinput
from modAL.utils.selection import multi_argmax, shuffled_argmax
from modAL.dropout import default_logits_adaptor, set_dropout_mode, _bald_divergence
from modAL.models import DeepActiveLearner

# libraries for plotting
import matplotlib.pyplot as plt
import seaborn as sns

# import wandb for experiment tracking and login
# instructions to wandb login here: <https://docs.wandb.ai/quickstart#:~:text=Sign%20up%20for%20a%20free,Python%203%20environment%20using%20pip%20>
import wandb
wandb.login()


## Dataset

In [None]:
# download whole dataset 
# change path according to your location
# load datasets
train_data = np.load("kmnist-master/dataset/kmnist-train-imgs.npz")['arr_0']
train_labels = np.load("kmnist-master/dataset/kmnist-train-labels.npz")['arr_0']
test_data = np.load("kmnist-master/dataset/kmnist-test-imgs.npz")['arr_0']
test_labels = np.load("kmnist-master/dataset/kmnist-test-labels.npz")['arr_0']

num_class = len(np.unique(train_labels))

In [None]:
# split training set to have also a validation set

train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size = 0.2, random_state = 42)

In [None]:
# It's already available online for this dataset and it's a default value, but to double check
# Get mean and std of pixel values in dataset
# Then divide those by 255 to normalize values to range [0,1]
# 255 because pixel values range from 0 to 255. 0 representing black and 255 representing white

print(train_data.shape)
mean = np.mean(train_data)
print(mean/255)
std = np.std(train_data)
print(std/255)

In [None]:
# visualize example
ind = np.random.randint(0, len(train_data)-1)
plt.imshow(train_data[ind].reshape((28, 28)), cmap='gray')
plt.title("Label: {}".format(train_labels[ind]))
plt.show()

In [None]:
# dataset class for handling KMNIST, consisting of images and corresponding labels, with optional transformations

class KMNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, n):
        image = self.images[n].reshape((28, 28)).astype(np.uint8)
        label = self.labels[n]
        if self.transform:
            image = self.transform(image)
        return (image, label)


In [None]:
classes = range(10)

# Define transformations that will convert and normalize data

# Training data has also been augmented 
train_transform = v2.Compose([
    v2.ToPILImage(),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.RandomAffine(degrees=20, translate=(0.1,0.1), scale=(0.9, 1.1)),
    v2.ColorJitter(brightness=0.2, contrast=0.2),
    v2.Normalize(mean=[0.1917], std=[0.3483])
])

val_normalization = v2.Compose([
    v2.ToPILImage(),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.1917], std=[0.3082])
])

test_normalization = val_normalization

# Apply transformations and load datasets

train_dataset = KMNISTDataset(train_data, train_labels, transform = train_transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=48000, shuffle = True)

val_dataset = KMNISTDataset(val_data, val_labels, transform = val_normalization)

val_loader = DataLoader(dataset=val_dataset, batch_size=12000,shuffle = True)

test_dataset = KMNISTDataset(test_data, test_labels, transform = test_normalization)

test_loader = DataLoader(dataset=test_dataset, batch_size=12000,shuffle = True)

### A series of sanity checks will follow. This is to make sure data has been preprocessed well. 

Also, next(iter()) is called to iterate over batches of data

Normalization should give a mean close to 0 and a standard deviation close to 1

In [None]:
# sanity check for training data
X_train, y_train = next(iter(train_loader))

X_train[0].data.shape
print(X_train.data.min())
print(X_train.data.max())
# after normalization, mean should be around 0 and std should be close to 1
print("Mean:", X_train.data.mean())
print("Std:", X_train.data.std())
print("Label:", classes[y_train[0]])
plt.imshow(X_train[0].data.reshape((28,28)), cmap="gray")

In [None]:
# sanity check for validation data
X_val, y_val = next(iter(val_loader))

X_val[0].data.shape
print(X_val.data.min())
print(X_val.data.max())
print("Mean:", X_val.data.mean())
print("Std:", X_val.data.std())
print("Label:", classes[y_val[0]])
plt.imshow(X_val[0].data.reshape((28,28)), cmap="gray")

In [None]:
# sanity check for test data
X_test, y_test = next(iter(test_loader))

X_test[0].data.shape
print(X_test.data.min())
print(X_test.data.max())
print("Mean:", X_test.data.mean())
print("Std:", X_test.data.std())
print("Label:", classes[y_test[0]])
plt.imshow(X_test[0].data.reshape((28,28)), cmap="gray")

Another sanity check - showing 5 samples from three sub-datasets

In [None]:
# Function to show data
def show_images(images, labels, title="Images"):
    plt.figure(figsize=(12, 6))
    for i in range(min(5, len(images))):  # Display up to 5 images
        plt.subplot(1, 5, i + 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title(f"Label: {labels[i]}")
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

# Convert tensor to numpy array
X_train_numpy = X_train.numpy()
X_val_numpy = X_val.numpy()
X_test_numpy = X_test.numpy()

# Show samples from the training set
show_images(X_train_numpy, y_train, title="Training Set Samples")

# Show samples from the validation set
show_images(X_val_numpy, y_val, title="Validation Set Samples")

# Show samples from the test set
show_images(X_test_numpy, y_test, title="Test Set Samples")


### Assemble initial data and generate pool

In [None]:
# assemble initial data
n_initial = 1000
initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)
X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

In [None]:
# generate the pool
# remove the initial data from the training dataset
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

## Model

In [None]:
class CNNmodel(nn.Module):
    def __init__(self):
        super(CNNmodel, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.4),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.4),
            nn.Conv2d(32, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.4)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.4),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.4),
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.4)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.4)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128, 10)
        )
                
        
        self.input_dropout = nn.Dropout(0.7)
    
    def forward(self, x):
        x = self.input_dropout(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x

Create an instance of model

In [None]:
model = CNNmodel()

In [None]:
# get overview of model
from torchsummary import summary
summary(model, (1, 28, 28))

Set device to gpu

In [None]:
# Set gpu. Note: mps is Apple M1's "cuda". Please change mps to cuda if you run this code on a non-Apple chip

# Make sure mps is available
torch.backends.mps.is_available()
# set device
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

# if run on a non-Apple chip, substitute the code above with the following:
# device = "cuda" if torch.cuda.is_available() else "cpu"

NNClassifier to wrap PyTorch model and set hyperparameters

In [None]:
classifier = NeuralNetClassifier(model,
                                 max_epochs=80,
                                 criterion=torch.nn.CrossEntropyLoss,
                                 optimizer=torch.optim.Adam,
                                 lr=0.0001,
                                 train_split=None,
                                 verbose=1,
                                 device=device
                                 )

# Query Strategy

In [None]:
def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
                    random_tie_break: bool = False, dropout_layer_indexes: list = [],
                    num_cycles: int = 50, sample_per_forward_pass: int = 1000,
                    logits_adaptor: Callable[[
                        torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
                    **mc_dropout_kwargs,) -> np.ndarray:
    """
        Mc-Dropout bald query strategy. Returns the indexes of the instances with the largest BALD 
        (Bayesian Active Learning by Disagreement) score calculated through the dropout cycles
        and the corresponding bald score. 

        Based on the work of: 
            Deep Bayesian Active Learning with Image Data.
            (Yarin Gal, Riashat Islam, and Zoubin Ghahramani. 2017.)
            Dropout as a Bayesian Approximation: Representing Model Uncer- tainty in Deep Learning.
            (Yarin Gal and Zoubin Ghahramani. 2016.)
            Bayesian Active Learning for Classification and Preference Learning.
            (NeilHoulsby,FerencHusza ́r,ZoubinGhahramani,andMa ́te ́Lengyel. 2011.) 

        Args:
            classifier: The classifier for which the labels are to be queried.
            X: The pool of samples to query from.
            n_instances: Number of samples to be queried.
            random_tie_break: If True, shuffles utility scores to randomize the order. This
                can be used to break the tie when the highest utility score is not unique.
            dropout_layer_indexes: Indexes of the dropout layers which should be activated
                Choose indices from : list(torch_model.modules())
            num_cycles: Number of forward passes with activated dropout
            sample_per_forward_pass: max. sample number for each forward pass. 
                The allocated RAM does mainly depend on this.
                Small number --> small RAM allocation
            logits_adaptor: Callable which can be used to adapt the output of a forward pass 
                to the required vector format for the vectorised metric functions 
            **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
                measure function.

        Returns:
            The indices of the instances from X chosen to be labelled;
            The mc-dropout metric of the chosen instances; 
    """
    predictions = get_predictions(
        classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
    # calculate BALD (Bayesian active learning divergence))

    bald_scores = _bald_divergence(predictions)

    if not random_tie_break:
        return multi_argmax(bald_scores, n_instances=n_instances)

    return shuffled_argmax(bald_scores, n_instances=n_instances)

In [None]:
def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_indexes: list = [],
                    num_predictions: int = 50, sample_per_forward_pass: int = 1000,
                    logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor):
    """
        Runs num_predictions times the prediction of the classifier on the input X 
        and puts the predictions in a list.

        Args:
            classifier: The classifier for which the labels are to be queried.
            X: The pool of samples to query from.
            dropout_layer_indexes: Indexes of the dropout layers which should be activated
                Choose indices from : list(torch_model.modules())
            num_predictions: Number of predictions which should be made
            sample_per_forward_pass: max. sample number for each forward pass. 
                The allocated RAM does mainly depend on this.
                Small number --> small RAM allocation
            logits_adaptor: Callable which can be used to adapt the output of a forward pass 
                to the required vector format for the vectorised metric functions 
        Return: 
            prediction: list with all predictions
    """

    assert num_predictions > 0, 'num_predictions must be larger than zero'
    assert sample_per_forward_pass > 0, 'sample_per_forward_pass must be larger than zero'

    predictions = []
    # set dropout layers to train mode
    set_dropout_mode(classifier.estimator.module_,
                     dropout_layer_indexes, train_mode=True)

    split_args = []

    if isinstance(X, Mapping):  # check for dict
        for k, v in X.items():

            v.detach()
            split_v = torch.split(v, sample_per_forward_pass)
            # create sub-dictionary split for each forward pass with same keys&values
            for split_idx, split in enumerate(split_v):
                if len(split_args) <= split_idx:
                    split_args.append({})
                split_args[split_idx][k] = split

    elif torch.is_tensor(X):  # check for tensor
        X.detach()
        split_args = torch.split(X, sample_per_forward_pass)
    else:
        raise RuntimeError(
            "Error in model data type, only dict or tensors supported")

    for i in range(num_predictions):

        probas = []

        for samples in split_args:
            # call Skorch infer function to perform model forward pass
            # In comparison to: predict(), predict_proba() the infer()
            # does not change train/eval mode of other layers
            with torch.no_grad():
                logits = classifier.estimator.infer(samples)
                prediction = logits_adaptor(logits, samples)
                mask = ~prediction.isnan()
                prediction[mask] = prediction[mask].softmax(-1)
                probas.append(prediction)

        probas = torch.cat(probas)
        
        # mps device wouldn't let me convert a pytorch tensor into a numpy array, because numpy operations on tensors in mps are not supported
        # for this reason I moved the tensor to the CPU before converting it to a numpy array
        if isinstance(probas, torch.Tensor):
            probas = probas.cpu().detach().numpy()
        else:
            raise TypeError("probas must be a torch.Tensor")

        # Append the numpy array to predictions
        predictions.append(probas)

    # set dropout layers to eval
    set_dropout_mode(classifier.estimator.module_,
                     dropout_layer_indexes, train_mode=False)

    return predictions

## Active Learning Loop

In [None]:
# initialize ActiveLearner (Pass the skorch wrapped PyTorch model and select query strategy)
# query strategies:
    # entropy_sampling
    # margin_sampling
    # mc_dropout_bald

learner = DeepActiveLearner(
    estimator=classifier,
    query_strategy=mc_dropout_bald
)

# perform initial training
model.train()
learner.teach(X_initial, y_initial)

print("Score from sklearn: {}".format(learner.score(X_pool, y_pool)))

In [None]:
# choose number of queries
n_queries = 10

X_teach = X_initial
y_teach = y_initial

In [None]:
# confusion matrix before querying and after initial training
before_train_predictions = learner.predict(X_teach)
cm = confusion_matrix(y_teach, before_train_predictions)
print(f"Before training - Confusion Matrix")
print(cm)

# plotting the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=['Class 1', 'Class 2', 'Class 3', 'Class 10'],
            yticklabels=['Class 1', 'Class 2', 'Class 3', 'Class 10'])

plt.xticks([0.5, 1.5, 2.5, 9.5], labels=['Class 1', 'Class 2', 'Class 3', 'Class 10'], rotation=90)
plt.yticks([0.5, 1.5, 2.5, 9.5], labels=['Class 1', 'Class 2', 'Class 3', 'Class 10'], rotation=0)

plt.xlabel("Predicted")
plt.ylabel("True")
plt.title(f"Before trainining - Confusion Matrix")
plt.show()

In [None]:
for run_number in range(10):
    # initialize wandb tracking, set configuration
    wandb.init(
        project="MC_BALD_KMNIST",
        config={
            "model": "CNNmodel",
            "dataset": "KMNIST",
            "split": "48k/12k/10k",
            "batch_size": 60000,
            "loss": "CrossEntropyLoss",
            "optimization": "Adam",
            "learning_rate": 0.0001,
            "dropout": "0.4/0.7",
            "epochs": 80,
            "n_queries": 10,
            "initial": 1000,
            "n_instances": 100,
            "num_cycles" : 50
        },
    )
    wandb_step = 1
        
    for idx in range(n_queries):
        print('Query no. %d' % (idx + 1))
        
        # adjust model.train()/ model.eval() according to the strategy and experimental setup:
        # for non-MC dropout strategies, I set the model to model.eval() before querying
        # BALD MC dropout will automatically turn on training mode before calculating uncertainties
        # model.train will keep dropout layers on during inference, whereas model.eval will deactivate them during inference
        # mc_dropout_bald automatically sets model.train() before selecting instances to label
        
        query_idx, metric_values = learner.query(X_pool, n_instances=100, sample_per_forward_pass=250, num_cycles=50)
        
        # NB: for non-mc_dropout_bald strategies, remove "sample_per_forward_pass, num_cycles"
        # n_instances = n of samples to be queried
        # all dropout layers will be activated on default. 
        # sample_per_forward_pass = max. sample number for each forward pass
        # num_cycles are the number of dropout forward passes that should be performed
        
        # Add queried instances
        X_teach = torch.cat((X_teach, X_pool[query_idx]))
        y_teach = torch.cat((y_teach, y_pool[query_idx]))
        
        # set training mode when updating the model with labelled instances
        model.train()
        learner.teach(X_teach, y_teach)

        # remove queried instance from pool
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        
        # Evaluate and log training metrics
        model.eval()
        train_predictions = learner.predict(X_teach)
        train_accuracy = accuracy_score(y_teach, train_predictions)*100
        
        train_f1 = f1_score(y_teach, train_predictions, average='weighted')*100
        train_precision = precision_score(y_teach, train_predictions, average='weighted')*100
        train_recall = recall_score(y_teach, train_predictions, average='weighted')*100
        
        print("Train Accuracy: {}".format(train_accuracy))
        print("Train F1 Score: {:.4f}".format(train_f1))
        print("Train Precision: {:.4f}".format(train_precision))
        print("Train Recall: {:.4f}".format(train_recall))
        
        wandb.log({"Train Accuracy": train_accuracy, "Train F1 Score": train_f1, "Train Precision": train_precision, "Train Recall": train_recall}, step=wandb_step)
        
        # Evaluate and log validation metrics
        model.eval()
        val_predictions = learner.predict(X_val)
        val_accuracy = accuracy_score(y_val, val_predictions)*100
        val_f1 = f1_score(y_val, val_predictions, average='weighted')*100
        val_precision = precision_score(y_val, val_predictions, average='weighted')*100
        val_recall = recall_score(y_val, val_predictions, average='weighted')*100

        print("Val Accuracy: {}".format(val_accuracy))
        print("Val F1 Score: {:.4f}".format(val_f1))
        print("Val Precision: {:.4f}".format(val_precision))
        print("Val Recall: {:.4f}".format(val_recall))

        wandb.log({"Val Accuracy": val_accuracy, "Val F1 Score": val_f1, "Val Precision": val_precision, "Val Recall": val_recall}, step=wandb_step)
        
        wandb_step += 1
    
    # evaluate on test dataset with dropout off, as they did in Gal et al. (2017)
    model.eval()
    # Evaluate and log test metrics after the loop
    test_predictions = learner.predict(X_test)
    test_accuracy = accuracy_score(y_test, test_predictions)*100
    test_f1 = f1_score(y_test, test_predictions, average='weighted')*100
    test_precision = precision_score(y_test, test_predictions, average='weighted')*100
    test_recall = recall_score(y_test, test_predictions, average='weighted')*100

    print("Test Accuracy: {}".format(test_accuracy))
    print("Test F1 Score: {:.4f}".format(test_f1))
    print("Test Precision: {:.4f}".format(test_precision))
    print("Test Recall: {:.4f}".format(test_recall))

    wandb.log({"Test Accuracy": test_accuracy, "Test F1 Score": test_f1, "Test Precision": test_precision, "Test Recall": test_recall})

    wandb.finish()

    # After running the AL loop, calculate confusion matrix and print report on test set
    
    cm = confusion_matrix(y_test, test_predictions)
    print(f"Confusion Matrix - Run {run_number + 1}:")
    print(cm)

    # Plotting the confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=['Class 1', 'Class 2', 'Class 3', 'Class 10'],
                yticklabels=['Class 1', 'Class 2', 'Class 3', 'Class 10'])

    plt.xticks([0.5, 1.5, 2.5, 9.5], labels=['Class 1', 'Class 2', 'Class 3', 'Class 10'], rotation=90)
    plt.yticks([0.5, 1.5, 2.5, 9.5], labels=['Class 1', 'Class 2', 'Class 3', 'Class 10'], rotation=0)

    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Run {run_number + 1} - Confusion Matrix")
    plt.show()

    # Print classification report
    print(f"Classification Report - Run {run_number + 1}:")
    print(classification_report(y_test, test_predictions))