In [1]:
import torch
import numpy as np

import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from skorch import NeuralNetClassifier

from modAL.models import ActiveLearner

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# build class for the skorch API
class Torch_Model(nn.Module):
    def __init__(self,):
        super(Torch_Model, self).__init__()
        self.convs = nn.Sequential(
                                nn.Conv2d(1,32,3),
                                nn.ReLU(),
                                nn.Conv2d(32,64,3),
                                nn.ReLU(),
                                nn.MaxPool2d(2),
                                nn.Dropout(0.25)
        )
        self.fcs = nn.Sequential(
                                nn.Linear(12*12*64,128),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(128,10),
        )

    def forward(self, x):
        out = x
        out = self.convs(out)
        out = out.view(-1,12*12*64)
        out = self.fcs(out)
        return out


In [4]:
dataroot = '../data/'
data = datasets.MNIST(root=dataroot, download=True, transform=ToTensor())

In [5]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

In [6]:
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=48000)
test_dataloader  = DataLoader(test_data , shuffle=False, batch_size=12000)
X_train, y_train = next(iter(train_dataloader))
X_test , y_test  = next(iter(test_dataloader))

In [7]:
X_train = X_train.reshape(48000, 1, 28, 28)
X_test = X_test.reshape(12000, 1, 28, 28)

In [8]:
X_train = X_train.detach().cpu().numpy()
y_train = y_train.detach().cpu().numpy()
X_test = X_test.detach().cpu().numpy()
y_test = y_test.detach().cpu().numpy()

In [9]:
initial_idx = np.array([],dtype=np.int)
for i in range(10):
    idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False)
    initial_idx = np.concatenate((initial_idx, idx))

X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  initial_idx = np.array([],dtype=np.int)


In [10]:
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

In [12]:
def uniform(learner, X, n_instances=1):
    query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
    return query_idx, X[query_idx]

def max_entropy(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
    pc = outputs.mean(axis=0)
    acquisition = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]

def bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
    pc = outputs.mean(axis=0)
    H   = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    E_H = - np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)  # [batch size]
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]    


In [13]:
def active_learning_procedure(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=100,
                              n_instances=10):
    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        learner.teach(X_pool[query_idx], y_pool[query_idx])
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
    return perf_hist

In [14]:
estimator = NeuralNetClassifier(Torch_Model,
                                max_epochs=50,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
entropy_perf_hist = active_learning_procedure(max_entropy,
                                              X_test,
                                              y_test,
                                              X_pool,
                                              y_pool,
                                              X_initial,
                                              y_initial,
                                              estimator,)

Accuracy after query 1: 0.6726
Accuracy after query 2: 0.7317
Accuracy after query 3: 0.7124
Accuracy after query 4: 0.7044
Accuracy after query 5: 0.7244
Accuracy after query 6: 0.7442
Accuracy after query 7: 0.7347
Accuracy after query 8: 0.7467
Accuracy after query 9: 0.7799
Accuracy after query 10: 0.8181
Accuracy after query 11: 0.8027
Accuracy after query 12: 0.7867
Accuracy after query 13: 0.7995
Accuracy after query 14: 0.8219
Accuracy after query 15: 0.8281
Accuracy after query 16: 0.8350
Accuracy after query 17: 0.8485
Accuracy after query 18: 0.8391
Accuracy after query 19: 0.8699
Accuracy after query 20: 0.8691
Accuracy after query 21: 0.8634
Accuracy after query 22: 0.8548
Accuracy after query 23: 0.8749
Accuracy after query 24: 0.8624
Accuracy after query 25: 0.8767
Accuracy after query 26: 0.8801
Accuracy after query 27: 0.8660
Accuracy after query 28: 0.8962
Accuracy after query 29: 0.9092
Accuracy after query 30: 0.9195
Accuracy after query 31: 0.9091
Accuracy after qu