# DISTIL Usage Example: CIFAR-10 with Redundancy

Occasionally, a labeled dataset may not have enough data to yield good model performance. The easiest fix is to add more data by selecting and labeling unlabeled instances from a large pool of unlabeled data. However, the relative ease of acquiring unlabeled data in certain instances makes it very easy to accrue massive amounts of redundant examples. Adding redundant examples to a labeled dataset does not yield substantive performance benefits for the added computational cost. So, how do we select examples from the unlabeled pool that are not redundant with respect to our labeled dataset? Here, we show how to use DISTIL's implementation of [SIMILAR](https://arxiv.org/abs/2107.00717) to avoid selecting redundant data.

# Preparation

## Installation and Imports

In [1]:
# Get DISTIL
!git clone https://github.com/decile-team/distil.git
!pip install -r distil/requirements/requirements.txt

import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch
import torch.optim as optim
import torch.nn.functional as F

from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, Subset, ConcatDataset, DataLoader
from torchvision import transforms
from torchvision.datasets import cifar

sys.path.append('distil/')
from distil.active_learning_strategies import SCG, GLISTER, BADGE, EntropySampling, RandomSampling  # All active learning strategies showcased in this example
from distil.utils.models.resnet import ResNet18                                                     # The model used in our image classification example
from distil.utils.train_helper import data_train                                                    # A utility training class provided by DISTIL
from distil.utils.utils import LabeledToUnlabeledDataset                                            # A utility wrapper class that removes labels from labeled PyTorch dataset objects

Cloning into 'distil'...
remote: Enumerating objects: 3324, done.[K
remote: Counting objects: 100% (1281/1281), done.[K
remote: Compressing objects: 100% (812/812), done.[K
remote: Total 3324 (delta 794), reused 841 (delta 461), pack-reused 2043[K
Receiving objects: 100% (3324/3324), 23.05 MiB | 24.05 MiB/s, done.
Resolving deltas: 100% (2067/2067), done.
Looking in indexes: https://test.pypi.org/simple/, https://pypi.org/simple/
Collecting sphinxcontrib-bibtex>=2.3.0
  Downloading sphinxcontrib_bibtex-2.4.1-py3-none-any.whl (38 kB)
Collecting multipledispatch==0.6.0
  Downloading multipledispatch-0.6.0-py3-none-any.whl (11 kB)
Collecting scikit-learn==0.23.0
  Downloading scikit_learn-0.23.0-cp37-cp37m-manylinux1_x86_64.whl (7.3 MB)
[K     |████████████████████████████████| 7.3 MB 7.1 MB/s 
Collecting submodlib
  Downloading https://test-files.pythonhosted.org/packages/55/62/88e02a0e170498f38f7b9ce22b3e0a6a3cf9c82a33d3553da693c5c52872/submodlib-1.1.5.tar.gz (83 kB)
[K     |█████

## Preparing a Redundant CIFAR-10

The CIFAR10 dataset contains 60,000 32x32 color images in 10 different classes.The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class. The training set contains 50,000 images, and the test set contains 10,000 images. Here, we do a simple setup of the CIFAR10 dataset that we will use in this example. More importantly, we define a split on CIFAR10's training set into an initial labeled seed set and an unlabeled set. We also impose an artificial source of redundancy to simulate a redundancy setting.

In [2]:
# Define the name of the dataset and the path that PyTorch should use when downloading the data
data_set_name = 'CIFAR10'
download_path = '.'

# Define the number of classes in CIFAR10
nclasses = 10

# Define transforms on the dataset splits of CIFAR10. Here, we use random crops and horizontal flips for training augmentations.
# Both the train and test sets are converted to PyTorch tensors and are normalized around the mean/std of CIFAR-10.
cifar_training_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# Get the dataset objects from PyTorch. Here, CIFAR10 is downloaded, and the transform is applied when points 
# are retrieved.
cifar10_full_train = cifar.CIFAR10(download_path, train=True, download=True, transform=cifar_training_transform)
cifar10_test = cifar.CIFAR10(download_path, train=False, download=True, transform=cifar_test_transform)

# Get the dimension of the images. Here, we simply take the very first image of CIFAR10 
# and query its dimension.
dim = np.shape(cifar10_full_train[0][0])

# We now define a train-unlabeled split for the sake of the experiment. Here, we randomly initialize a 
# seed set with 500 points.
initial_seed_size = 100

# We then create an unlabeled dataset with duplicated points. Here, we assume the setting where the initial 
# seed set was drawn from a redundant unlabeled set, so some of the points in the unlabeled dataset are those 
# already selected for the seed set. We also add some other examples that are redundant but were not selected 
# for the seed set. We duplicate everything 10 times, so there is quite a bit of redundancy.
other_duplicated_examples_in_unlabeled_dataset = 1000
duplication_factor = 10

# Select indices for the train dataset
index_bank = list(range(len(cifar10_full_train)))
train_idx = list(np.random.choice(index_bank, size=initial_seed_size, replace=False))
index_bank = list(set(index_bank) - set(train_idx))

# Select indices for the unique and duplicated portions of the unlabeled dataset
other_duplicated_unlabeled_idx = list(np.random.choice(index_bank, size=other_duplicated_examples_in_unlabeled_dataset, replace=False))

# Create the unlabeled_idx by repeatedly adding the duplicated_unlabeled_idx
unlabeled_idx = train_idx * duplication_factor + other_duplicated_unlabeled_idx * duplication_factor

# Create the train and unlabeled subsets based on the index lists above. While the unlabeled set constructed here technically has labels, they 
# are only used when querying for labels. Hence, they only exist here for the sake of experimental design.
cifar10_train = Subset(cifar10_full_train, train_idx)
cifar10_unlabeled = Subset(cifar10_full_train, unlabeled_idx)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


## Preparing the Model

Here, we use DISTIL's provided implementation of the [ResNet-18](https://arxiv.org/abs/1512.03385) architecture. We also create a model directory to store trained models in this example.

In [3]:
net = ResNet18()

# Initial Round

## Training

Now that we have prepared the training data and the model, we can begin training an initial model. We use DISTIL's provided [training loop](https://github.com/decile-team/distil/blob/main/distil/utils/train_helper.py) to do training.

In [4]:
# Define the training arguments to use.
args = {'n_epoch':      300,    # Stop training after 300 epochs.
        'lr':           0.01,   # Use a learning rate of 0.01
        'batch_size':   20,     # Update the parameters using training batches of size 20
        'max_accuracy': 0.99,   # Stop training once the training accuracy has exceeded 0.99
        'optimizer':    'sgd',  # Use the stochastic gradient descent optimizer
        'device':       "cuda" if torch.cuda.is_available() else "cpu"  # Use a GPU if one is available
        }

# Create the training loop using our training dataset, provided model, and training arguments.
# Train an initial model.
dt = data_train(cifar10_train, net, args)
trained_model = dt.train()

Training..
Epoch: 85 Training accuracy: 0.99


## Evaluation

How does our initial model do on the test set? Luckily, the training loop provided by DISTIL also provides a way to measure the accuracy of the model on a given dataset.

In [5]:
# Get the full test accuracy
full_test_accuracy = dt.get_acc_on_set(cifar10_test)
print(F"Full Test Accuracy: {full_test_accuracy}")

Full Test Accuracy: 0.2582


As we can see, the test performance could use some improvement. Let's add some new training examples to our training dataset.

# Selecting Non-Redundant Examples

## Preparing a Conditioning (Private) Set

In this example, we know that the unlabeled dataset can help in improving our model performance if we add the right examples to our training dataset; however, the unlabeled dataset is plagued with redundancy. Here, we use [SIMILAR](https://arxiv.org/abs/2107.00717) to select non-redundant examples from the unlabeled set. SIMILAR requires access to a conditioning (private) set of points that it uses to determine which unlabeled points are non-redundant. Hence, we must first prepare this conditioning set.

Where do we get the conditioning set? Quite simply, we can use our current training dataset as the conditioning set! This will encourage the selection of non-redundant points.

In [6]:
conditioning_set = cifar10_train

## Using DISTIL's SIMILAR Implementation: SCG

Now that we have the query set, we are ready to use DISTIL's implementation of SIMILAR. In particular, we use the submodular conditional gain [strategy](https://github.com/decile-team/distil/blob/main/distil/active_learning_strategies/scg.py) that is detailed in [SIMILAR](https://arxiv.org/abs/2107.00717). This will allow us to select a set of points to label within a specified budget $k$.

Specifically, the strategy attempts to maximize the [submodular conditional gain](https://arxiv.org/abs/2006.15412) between a subset $\mathcal{A}$ of size no greater than $k$ of the unlabeled dataset $\mathcal{U}$ and the conditioning set $\mathcal{P}$:

\begin{align}
\text{argmax}_{\mathcal{A} \subseteq \mathcal{U}, |\mathcal{A}|\leq k} H_F(\mathcal{A} | \mathcal{P})
\end{align}

where $F$ is a submodular set function.

In [7]:
# Define arguments for SMI
selection_strategy_args = {'device':        args['device'],       # Use the device used in training
                           'batch_size':    args['batch_size'],   # Use the batch size used in training
                           'scg_function':  'flcg',           # Use a log determinant function, which captures diversity information
                           'metric':        'cosine',             # Use cosine similarity when determining the likeness of two data points
                           'optimizer':     'LazyGreedy'          # When doing submodular maximization, use the lazy greedy optimizer
                          }

# Create the SMI selection strategy. Note: We remove the labels from the unlabeled portion of CIFAR-10 that we created earlier.
# In a practical application, one would not have these labels a priori.
selection_strategy = SCG(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), conditioning_set, trained_model, nclasses, selection_strategy_args)

# Disable the augmentations used in the training dataset. Since all augmentations come from the cifar10_full_train object, we set its transform to the 
# transform used by the test set.
cifar10_full_train.transform = cifar_test_transform

# Do the selection, which will return the indices of the selected points with respect to the unlabeled dataset.
budget = 1000
selected_idx = selection_strategy.select(budget)

# Re-enable augmentations
cifar10_full_train.transform = cifar_training_transform

## Labeling the Points

Now that we know which points should be labeled, we can present them to human labelers for annotation. Here, we can do so automatically since we already know their labels for the sake of the example.

In [8]:
# Form a labeled subset of the unlabeled dataset. Again, we already have the labels, 
# so we simply take a subset. Note, however, that the selection was done without the 
# use of the labels and that we would normally not have these labels. Hence, the 
# following statement would usually require human effort to complete.
scg_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

## Characterizing the Selection

Now that we have selected and labeled these new points, we can add them to the training dataset and retrain our model. Before that, how many points did we select that were not repeats?

In [9]:
# selected_idx contains indices of the unlabeled dataset. For measurement purposes, we want to get the full dataset indices 
# that correspond to the indices of the unlabeled dataset.
full_indices = list(np.array(unlabeled_idx)[selected_idx])

# Get the unique points in the list that also are not in train_idx
unique_points = len(set(full_indices) - set(train_idx))
print("Unique Points Selected:", unique_points)
print("Unique Fraction:", unique_points / budget)

Unique Points Selected: 998
Unique Fraction: 0.998


We were able to get a good number of unique points. For comparison sake, how many unique points would we get using some of DISTIL's other strategies?

**BADGE**

In [10]:
# Repeat the previous steps
selection_strategy = BADGE(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 1000
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

badge_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

full_indices = list(np.array(unlabeled_idx)[selected_idx])

# Get the unique points in the list that also are not in train_idx
unique_points = len(set(full_indices) - set(train_idx))
print("Unique Points Selected:", unique_points)
print("Unique Fraction:", unique_points / budget)

Unique Points Selected: 940
Unique Fraction: 0.94


**Random**

In [11]:
# Repeat the previous steps
selection_strategy = RandomSampling(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 1000
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

random_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

full_indices = list(np.array(unlabeled_idx)[selected_idx])

# Get the unique points in the list that also are not in train_idx
unique_points = len(set(full_indices) - set(train_idx))
print("Unique Points Selected:", unique_points)
print("Unique Fraction:", unique_points / budget)

Unique Points Selected: 600
Unique Fraction: 0.6


**Entropy**

In [12]:
# Repeat the previous steps
selection_strategy = EntropySampling(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 1000
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

entropy_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

full_indices = list(np.array(unlabeled_idx)[selected_idx])

# Get the unique points in the list that also are not in train_idx
unique_points = len(set(full_indices) - set(train_idx))
print("Unique Points Selected:", unique_points)
print("Unique Fraction:", unique_points / budget)

Unique Points Selected: 99
Unique Fraction: 0.099


Hence, we can see that SCG does better at selecting non-redundant examples.

# Improving Performance

## Re-Training

Let us re-train our model using the newly selected points.

In [13]:
# Create a new training dataset by concatenating what we have with the newly labeled points
new_training_dataset = ConcatDataset([cifar10_train, scg_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

Training..
Epoch: 115 Training accuracy: 0.995


## Evaluation

Now, let us see the accuracy improvement.

In [14]:
# Get the full test accuracy
full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

Full Test Accuracy Improvement: 0.2582 to 0.557


## Comparison

What would the accuracy improvement look like if we had used the other methods?

**BADGE**

In [15]:
# Repeat the process
new_training_dataset = ConcatDataset([cifar10_train, badge_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

Training..
Epoch: 109 Training accuracy: 0.992
Full Test Accuracy Improvement: 0.2582 to 0.5466


**Random**

In [16]:
# Repeat the process
new_training_dataset = ConcatDataset([cifar10_train, random_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

Training..
Epoch: 92 Training accuracy: 0.991
Full Test Accuracy Improvement: 0.2582 to 0.4695


**Entropy**

In [17]:
# Repeat the process
new_training_dataset = ConcatDataset([cifar10_train, entropy_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

Training..
Epoch: 37 Training accuracy: 0.995
Full Test Accuracy Improvement: 0.2582 to 0.2928
