## Required Packages

In [None]:
%matplotlib inline

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from scipy import stats
from modAL.models import ActiveLearner
from skorch import NeuralNetClassifier

## Class: Convolutional Neural Network Architecture

In [None]:
class ConvNN(nn.Module):
    def __init__(
        self,
        num_filters: int = 32,
        kernel_size: int = 4,
        dense_layer: int = 128,
        img_rows: int = 28,
        img_cols: int = 28,
        maxpool: int = 2,
    ):
        """
        Basic Architecture of CNN proposed in the paper, I have modified softmax
        activation function to logsoftmax to penalise large error and improve training
        efficiency.

        Attributes:
            num_filters: Number of filters, out channel for 1st and 2nd conv layers,
            kernel_size: Kernel size of convolution,
            dense_layer: Dense layer units,
            img_rows: Height of input image,
            img_cols: Width of input image,
            maxpool: Max pooling size
        """
        super(ConvNN, self).__init__()
        self.conv1 = nn.Conv2d(1, num_filters, kernel_size, 1)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(
            num_filters
            * ((img_rows - 2 * kernel_size + 2) // 2)
            * ((img_cols - 2 * kernel_size + 2) // 2),
            dense_layer,
        )
        self.fc2 = nn.Linear(dense_layer, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        out = F.log_softmax(x, dim=1)
        return out

## Class: LoadData - Download and Split Datasets

In [1]:
class LoadData:
    """Download, split and shuffle dataset into train, validate, test and pool"""

    def __init__(self):
        self.mnist_train, self.mnist_test = self.download_dataset()
        (
            self.X_train_All,
            self.y_train_All,
            self.X_val,
            self.y_val,
            self.X_pool,
            self.y_pool,
            self.X_test,
            self.y_test,
        ) = self.split_and_load_dataset()
        self.X_init, self.y_init = self.preprocess_training_data()

    def tensor_to_np(self, tensor_data: torch.Tensor) -> np.ndarray:
        """Since Skorch doesn not support dtype of torch.Tensor, we will modify
        the dtype to numpy.ndarray

        Attribute:
            tensor_data: Data of class type=torch.Tensor
        """
        np_data = tensor_data.detach().numpy()
        return np_data


    def check_MNIST_folder(self) -> bool:
        """Check whether MNIST folder exists, if yes remove and redownload."""
        if os.path.isfile("MNIST/"):
            return False
        return True

    def download_dataset(self):
        """Load MNIST dataset for training and test set."""
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
        )
        download = self.check_MNIST_folder()
        mnist_train = MNIST(".", train=True, download=download, transform=transform)
        mnist_test = MNIST(".", train=False, download=download, transform=transform)
        return mnist_train, mnist_test

    def split_and_load_dataset(
        self,
        train_size: int = 10000,
        val_size: int = 5000,
        pool_size: int = 45000,
    ):
        """Split all training datatset into train, validate, pool sets and load them accordingly.

        Attributes:
            all_training_set: MNIST training dataset,
            test_set: MNIST test dataset
            training_size: Training data size,
            val_size: Validation data size,
            pool_size: Pool set size
        """
        train_set, val_set, pool_set = random_split(
            self.mnist_train, [train_size, val_size, pool_size]
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=train_size, shuffle=True
        )
        val_loader = DataLoader(dataset=val_set, batch_size=val_size, shuffle=True)
        pool_loader = DataLoader(dataset=pool_set, batch_size=pool_size, shuffle=True)
        test_loader = DataLoader(
            dataset=self.mnist_test, batch_size=10000, shuffle=True
        )
        X_train_All, y_train_All = next(iter(train_loader))
        X_val, y_val = next(iter(val_loader))
        X_pool, y_pool = next(iter(pool_loader))
        X_test, y_test = next(iter(test_loader))
        return X_train_All, y_train_All, X_val, y_val, X_pool, y_pool, X_test, y_test

    def preprocess_training_data(self):
        """Setup a random but balanced initial training set of 20 data points

        Attributes:
            X_train_All: X input of training set,
            y_train_All: y input of training set
        """
        initial_idx = np.array([], dtype=np.int)
        for i in range(10):
            idx = np.random.choice(
                np.where(self.y_train_All == i)[0], size=2, replace=False
            )
            initial_idx = np.concatenate((initial_idx, idx))
        X_init = self.X_train_All[initial_idx]
        y_init = self.y_train_All[initial_idx]
        print(f"Initial training data points: {X_init.shape[0]}")
        print(f"Data distribution for each class: {np.bincount(y_init)}")
        return X_init, y_init

    def load_all(self):
        """Load all data"""
        return (
            self.tensor_to_np(self.X_init),
            self.tensor_to_np(self.y_init),
            self.tensor_to_np(self.X_val),
            self.tensor_to_np(self.y_val),
            self.tensor_to_np(self.X_pool),
            self.tensor_to_np(self.y_pool),
            self.tensor_to_np(self.X_test),
            self.tensor_to_np(self.y_test),
        )

## Acquisition Functions

### Uniform (Baseline)

In [None]:
def uniform(model, X_pool: np.ndarray, n_query: int = 10):
    """Baseline acquisition a(x) = unif() with unif() a function
    returning a draw from a uniform distribution over the interval [0,1].
    Using this acquisition function is equivalent to choosing a point
    uniformly at random from the pool.

    Attributes:
        X_pool: Pool set to select uncertainty,
        n_query: Number of points that randomly select from pool set
    """
    query_idx = np.random.choice(range(len(X_pool)), size=n_query, replace=False)
    return query_idx, X_pool[query_idx]

### Max Entropy

In [None]:
def max_entropy(model, X_pool: np.ndarray, n_query: int = 10, T: int = 100):
    """Choose pool points that maximise the predictive entropy.
    Using Shannon entropy function.

    Attributes:
        model: Model that is ready to measure uncertainty after training,
        X_pool: Pool set to select uncertainty,
        n_query: Number of points that maximise max_entropy a(x) from pool set,
        T: Number of MC dropout iterations aka training iterations
    """
    acquisition, random_subset = shannon_entropy_function(model, X_pool, T)
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]

### BALD

In [None]:
def bald(model, X_pool: np.ndarray, n_query: int = 10, T: int = 100):
    """Choose pool points that are expected to maximise the information
    gained about the model parameters, i.e. maximise the mutal information
    between predictions and model posterior. Given
    I[y,w|x,D_train] = H[y|x,D_train] - E_{p(w|D_train)}[H[y|x,w]]
    with w the model parameters (H[y|x,w] is the entropy of y given w).
    Points that maximise this acquisition function are points on which the
    model is uncertain on average but there exist model parameters that produce
    disagreeing predictions with high certainty. This is equivalent to points
    with high variance in th einput to the softmax layer

    Attributes:
        model: Model that is ready to measure uncertainty after training,
        X_pool: Pool set to select uncertainty,
        n_query: Number of points that maximise bald a(x) from pool set,
        T: Number of MC dropout iterations aka training iterations
    """
    H, E_H, random_subset = shannon_entropy_function(model, X_pool, T, E_H=True)
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]

### Var Ratios

In [None]:
def var_ratios(model, X_pool: np.ndarray, n_query: int = 10, T: int = 100):
    """Like Max Entropy but Variational Ratios measures lack of confidence.
    Given: variational_ratio[x] := 1 - max_{y} p(y|x,D_{train})

    Attributes:
        model: Model that is ready to measure uncertainty after training,
        X_pool: Pool set to select uncertainty,
        n_query: Number of points that maximise var_ratios a(x) from pool set,
        T: Number of MC dropout iterations aka training iterations
    """
    outputs, random_subset = predictions_from_pool(model, X_pool, T)
    preds = np.argmax(outputs, axis=2)
    _, count = stats.mode(preds, axis=0)
    acquisition = (1 - count / preds.shape[1]).reshape((-1,))
    idx = (-acquisition).argsort()[:n_query]
    query_idx = random_subset[idx]
    return query_idx, X_pool[query_idx]

## Functions

In [None]:
def predictions_from_pool(model, X_pool: np.ndarray, T: int = 100):
    """Run random_subset prediction on model and return the output

    Attributes:
        X_pool: Pool set to select uncertainty,
        T: Number of MC dropout iterations aka training iterations,
    """
    random_subset = np.random.choice(range(len(X_pool)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack(
            [
                torch.softmax(
                    model.estimator.forward(X_pool[random_subset], training=True),
                    dim=-1,
                )
                .cpu()
                .numpy()
                for t in range(T)
            ]
        )
    return outputs, random_subset

def shannon_entropy_function(
    model, X_pool: np.ndarray, T: int = 100, E_H: bool = False
):
    """H[y|x,D_train] := - sum_{c} p(y=c|x,D_train)log p(y=c|x,D_train)

    Attributes:
        model: Model that is ready to measure uncertainty after training,
        X_pool: Pool set to select uncertainty,
        T: Number of MC dropout iterations aka training iterations,
        E_H: If True, compute H and EH for BALD (default: False)
    """
    outputs, random_subset = predictions_from_pool(model, X_pool, T)
    pc = outputs.mean(axis=0)
    H = (-pc * np.log(pc + 1e-10)).sum(axis=-1)
    if E_H:
        E = -np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)
        return H, E, random_subset
    return H, random_subset

def load_CNN_model(lr, batch_size, epochs, device):
    """Load new model each time for different acqusition function
    each experiments"""
    model = ConvNN().to(device)
    cnn_classifier = NeuralNetClassifier(
        module=model,
        lr=lr,
        batch_size=batch_size,
        max_epochs=epochs,
        criterion=nn.CrossEntropyLoss,
        optimizer=torch.optim.Adam,
        train_split=None,
        verbose=0,
        device=device,
    )
    return cnn_classifier

def select_acq_function(acq_func: int = 0) -> list:
    """Choose types of acqusition function

    Attributes:
        acq_func: 0-all(unif, max_entropy, bald), 1-unif, 2-maxentropy, 3-bald, \
                  4-var_ratios
    """
    acq_func_dict = {
        0: [uniform, max_entropy, bald, var_ratios],
        1: [uniform],
        2: [max_entropy],
        3: [bald],
        4: [var_ratios],
    }
    return acq_func_dict[acq_func]

def active_learning_procedure(
    query_strategy,
    X_val: np.ndarray,
    y_val: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    X_pool: np.ndarray,
    y_pool: np.ndarray,
    X_init: np.ndarray,
    y_init: np.ndarray,
    estimator,
    T: int = 100,
    n_query: int = 10,
) -> list:
    """Active Learning Procedure

    Attributes:
        query_strategy: Choose between Uniform(baseline), max_entropy, bald,
        X_val, y_val: Validation dataset,
        X_test, y_test: Test dataset,
        X_pool, y_pool: Query pool set,
        X_init, y_init: Initial training set data points,
        estimator: Neural Network architecture, e.g. CNN,
        T: Number of MC dropout iterations (repeat acqusition process T times),
        n_query: Number of points to query from X_pool
    """
    learner = ActiveLearner(
        estimator=estimator,
        X_training=X_init,
        y_training=y_init,
        query_strategy=query_strategy,
    )
    perf_hist = [learner.score(X_test, y_test)]
    for index in range(T):
        query_idx, query_instance = learner.query(X_pool, n_query)
        learner.teach(X_pool[query_idx], y_pool[query_idx], only_new=True)
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy_val = learner.score(X_val, y_val)
        if (index + 1) % 5 == 0:
            print(f"Val Accuracy after query {index+1}: {model_accuracy_val:0.4f}")
        perf_hist.append(model_accuracy_val)
    model_accuracy_test = learner.score(X_test, y_test)
    print(f"********** Test Accuracy per experiment: {model_accuracy_test} **********")
    return perf_hist, model_accuracy_test

def plot_results(data: dict):
    for key in data.keys():
        plt.plot(data[key], label=key)
    plt.legend()
    plt.show()
    
    
def train_active_learning(lr, batch_size, epochs, acq_funcs, experiments, dropout_iter, query, device, datasets):
    """Start training process"""
    acq_functions = select_acq_function(acq_funcs)
    results = dict()
    for i, acq_func in enumerate(acq_functions):
        avg_hist = []
        test_scores = []
        acq_func_name = str(acq_func).split(" ")[1]
        print(f"\n---------- Start {acq_func_name} training! ----------")
        for e in range(experiments):
            estimator = load_CNN_model(lr, batch_size, epochs, device)
            print(
                f"********** Experiment Iterations: {e+1}/{experiments} **********"
            )
            training_hist, test_score = active_learning_procedure(
                acq_func,
                datasets["X_val"],
                datasets["y_val"],
                datasets["X_test"],
                datasets["y_test"],
                datasets["X_pool"],
                datasets["y_pool"],
                datasets["X_init"],
                datasets["y_init"],
                estimator,
                dropout_iter,
                query,
            )
            avg_hist.append(training_hist)
            test_scores.append(test_score)
        avg_hist = np.average(np.array(avg_hist), axis=0)
        avg_test = sum(test_scores) / len(test_scores)
        print(f"Average Test score for {acq_func_name}: {avg_test}")
        results[acq_func_name] = avg_hist.tolist()
    print("--------------- Done Training! ---------------")
    plot_results(results)

## 1. Define variables

In [None]:
batch_size = 128
epochs = 50
lr = 1e-3
seed = 369
experiments = 3
dropout_iter = 100
query = 10
acq_funcs = 0 # 0-all, 1-uniform, 2-max_entropy, 3-bald, 4-var_ratios

torch.manual_seed(seed)

## 2. Check device for training (cpu/ gpu)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Load Datasets

In [None]:
datasets = dict()
DataLoader = LoadData()
(
    datasets["X_init"],
    datasets["y_init"],
    datasets["X_val"],
    datasets["y_val"],
    datasets["X_pool"],
    datasets["y_pool"],
    datasets["X_test"],
    datasets["y_test"],
) = DataLoader.load_all()

## 4. Start training

In [None]:
results = train_active_learning(lr, 
                                batch_size, 
                                epochs, 
                                acq_funcs, 
                                experiments, 
                                dropout_iter, 
                                query, 
                                device, 
                                datasets)