# DISTIL Usage Example: CIFAR-10 with Out-of-Distribution Points

Occasionally, a labeled dataset may not have enough data to yield good model performance. The easiest fix is to add more data by selecting and labeling unlabeled instances from a large pool of unlabeled data. However, the relative ease of acquiring unlabeled data in certain instances makes it very easy to accrue some out-of-distribution (OOD) examples. Inadvertently selecting OOD points from the unlabeled data constitutes a waste of effort since these points are not added to the labeled dataset. So, how do we select examples from the unlabeled pool that are not OOD? Here, we show how to use DISTIL's implementation of [SIMILAR](https://arxiv.org/abs/2107.00717) to avoid selecting OOD data.

# Preparation

## Installation and Imports

In [1]:
# Get DISTIL
!git clone https://github.com/decile-team/distil.git
!pip install -r distil/requirements/requirements.txt

import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch
import torch.optim as optim
import torch.nn.functional as F

from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, Subset, ConcatDataset, DataLoader
from torchvision import transforms
from torchvision.datasets import cifar

sys.path.append('distil/')
from distil.active_learning_strategies import SCMI, GLISTER, BADGE, EntropySampling, RandomSampling   # All active learning strategies showcased in this example
from distil.utils.models.resnet import ResNet18                                                       # The model used in our image classification example
from distil.utils.train_helper import data_train                                                      # A utility training class provided by DISTIL
from distil.utils.utils import LabeledToUnlabeledDataset                                              # A utility wrapper class that removes labels from labeled PyTorch dataset objects

fatal: destination path 'distil' already exists and is not an empty directory.
Looking in indexes: https://test.pypi.org/simple/, https://pypi.org/simple/


## Preparing CIFAR-10 with OOD Data

The CIFAR10 dataset contains 60,000 32x32 color images in 10 different classes.The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class. The training set contains 50,000 images, and the test set contains 10,000 images. Here, we do a simple setup of the CIFAR10 dataset that we will use in this example. More importantly, we define a split on CIFAR10's training set into an initial labeled seed set and an unlabeled set. For the purposes of this example, we will treat the non-animal classes (0, 1, 8, and 9) as OOD classes.

In [2]:
# Define the name of the dataset and the path that PyTorch should use when downloading the data
data_set_name = 'CIFAR10'
download_path = '.'

# Define the number of classes in our modified CIFAR10, which is 6. We also define our ID classes
nclasses = 6
id_classes = [2,3,4,5,6,7]

# Define transforms on the dataset splits of CIFAR10. Here, we use random crops and horizontal flips for training augmentations.
# Both the train and test sets are converted to PyTorch tensors and are normalized around the mean/std of CIFAR-10.
cifar_training_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# To adapt to our setting, we must also fix the nature of the labels. Here, we remap ID labels to [0,1,2,...]
# and OOD labels to len(id_classes).
label_map = {id_classes[i]: i for i in range(len(id_classes))}
cifar_label_transform = lambda x: label_map[x] if x in label_map else len(id_classes)

# Get the dataset objects from PyTorch. Here, CIFAR10 is downloaded, and the transform is applied when points 
# are retrieved.
cifar10_full_train = cifar.CIFAR10(download_path, train=True, download=True, transform=cifar_training_transform, target_transform=cifar_label_transform)
cifar10_test = cifar.CIFAR10(download_path, train=False, download=True, transform=cifar_test_transform, target_transform=cifar_label_transform)

# Get the dimension of the images. Here, we simply take the very first image of CIFAR10 
# and query its dimension.
dim = np.shape(cifar10_full_train[0][0])

# We now define a train-unlabeled split for the sake of the experiment. Here, the initial train dataset
# contains only in-distribution points
num_train_examples_id_per_class = 200

# The unlabeled dataset will be composed of both in-distribution points and OOD points.
num_unlabeled_examples_id_per_class = 250
num_unlabeled_examples_ood_per_class = 750

# We assume that we have access to some OOD points initially. If we do not, then we can accrue them from the 
# unlabeled dataset iteratively after discovering them.
num_held_out_ood_per_class = 10

# Create the dataset splits using the above configuration
train_idx = []
unlabeled_idx = []
test_idx = []
held_out_ood_idx = []

for class_idx in range(10):

    # Retrieve all the indices of the elements in CIFAR-10 whose label matches class_idx
    full_idx_class = torch.where(torch.Tensor(cifar10_full_train.targets) == class_idx)[0].cpu().numpy()

    # Determine how many points to add for this class depending on the OOD flag of the class
    if class_idx not in id_classes:
        
        # Choose indices for the unlabeled OOD points and the held-out OOD points for this class
        class_unlabeled_idx = np.random.choice(full_idx_class, size=num_unlabeled_examples_ood_per_class, replace=False)
        remaining_class_idx = np.array(list(set(full_idx_class) - set(class_unlabeled_idx)))
        held_out_class_idx = np.random.choice(remaining_class_idx, size=num_held_out_ood_per_class, replace=False)

        unlabeled_idx.extend(class_unlabeled_idx)
        held_out_ood_idx.extend(held_out_class_idx)

    else:

        # Choose indices for the labeled ID points and the unlabeled ID points for this class
        class_train_idx = np.random.choice(full_idx_class, size=num_train_examples_id_per_class, replace=False)
        remaining_class_idx = np.array(list(set(full_idx_class) - set(class_unlabeled_idx)))
        class_unlabeled_idx = np.random.choice(remaining_class_idx, size=num_unlabeled_examples_id_per_class, replace=False)

        train_idx.extend(class_train_idx)
        unlabeled_idx.extend(class_unlabeled_idx)

    # Lastly, we need to update the test dataset as well since we only want to evaluate on ID points.
    if class_idx in id_classes:
        test_idx_class = torch.where(torch.Tensor(cifar10_test.targets) == class_idx)[0].cpu().numpy()
        test_idx.extend(test_idx_class)

# Create the train and unlabeled subsets based on the index lists above. While the unlabeled set constructed here technically has labels, they 
# are only used when querying for labels. Hence, they only exist here for the sake of experimental design.
cifar10_train = Subset(cifar10_full_train, train_idx)
cifar10_unlabeled = Subset(cifar10_full_train, unlabeled_idx)
cifar10_ood = Subset(cifar10_full_train, held_out_ood_idx)
cifar10_test = Subset(cifar10_test, test_idx)

Files already downloaded and verified
Files already downloaded and verified


## Preparing the Model

Here, we use DISTIL's provided implementation of the [ResNet-18](https://arxiv.org/abs/1512.03385) architecture. We also create a model directory to store trained models in this example.

In [3]:
# Use nclasses + 1. We care about the first 6 classes; the last one is not trained since we add no OOD training examples.
# We use nclasses + 1 to avoid issues in the embedding computation used by SIMILAR.
net = ResNet18(num_classes=nclasses + 1)  

# Initial Round

## Training

Now that we have prepared the training data and the model, we can begin training an initial model. We use DISTIL's provided [training loop](https://github.com/decile-team/distil/blob/main/distil/utils/train_helper.py) to do training.

In [4]:
# Define the training arguments to use.
args = {'n_epoch':      300,    # Stop training after 300 epochs.
        'lr':           0.01,   # Use a learning rate of 0.01
        'batch_size':   20,     # Update the parameters using training batches of size 20
        'max_accuracy': 0.99,   # Stop training once the training accuracy has exceeded 0.99
        'optimizer':    'sgd',  # Use the stochastic gradient descent optimizer
        'device':       "cuda" if torch.cuda.is_available() else "cpu"  # Use a GPU if one is available
        }

# Create the training loop using our training dataset, provided model, and training arguments.
# Train an initial model.
dt = data_train(cifar10_train, net, args)
trained_model = dt.train()

Training..
Epoch: 120 Training accuracy: 0.991


## Evaluation

How does our initial model do on the test set?

In [5]:
# Get the full test accuracy
full_test_accuracy = dt.get_acc_on_set(cifar10_test)
print(F"Full Test Accuracy: {full_test_accuracy}")

Full Test Accuracy: 0.6238333333333334


The test performance could use improving. Can we add in-distribution points to help the learned model generalize better?

# Selecting In-Distribution Points

## Preparing a Query Set and a Conditioning Set

In this example, we know that there are OOD examples in the unlabeled dataset that can help us improve the accuracy of the model. How do we avoid selecting these points if we do not know *where* they are in the unlabeled dataset? Here, we use [SIMILAR](https://arxiv.org/abs/2107.00717) to select in-distribution examples from the unlabeled set. SIMILAR requires access to a query set of points that it uses to choose unlabeled points that are similar to it. It also requires access to a conditioning (private) set of points that it uses to choose unlabeled points that are dissimilar to it. In this fashion, we can formulate a query set of in-distribution points and a conditioning set of OOD points to effectively choose in-distribution points.

Where do we get the query set? Luckily, we already have a couple points in our training dataset to choose. We will use the held-out OOD points for our conditioning set.

In [6]:
# Create a query set that contains the in-distribution examples of the training dataset
query_set = cifar10_train

# Use the held-out points for the private set
private_set = cifar10_ood

## Using DISTIL's SIMILAR Implementation: SCMI

Now that we have the query and private sets, we are ready to use DISTIL's implementation of SIMILAR. In particular, we use the submodular conditional mutual information [strategy](https://github.com/decile-team/distil/blob/main/distil/active_learning_strategies/smi.py) that is detailed in [SIMILAR](https://arxiv.org/abs/2107.00717). This will allow us to select a set of points to label within a specified budget $k$.

Specifically, the strategy attempts to maximize the [submodular conditional mutual information](https://arxiv.org/abs/2006.15412) between a subset $\mathcal{A}$ of size no greater than $k$ of the unlabeled dataset $\mathcal{U}$ and the query set $\mathcal{Q}$ given the conditioning set $\mathcal{P}$:

\begin{align}
\text{argmax}_{\mathcal{A} \subseteq \mathcal{U}, |\mathcal{A}|\leq k} I_F(\mathcal{A};\mathcal{Q}|\mathcal{P})
\end{align}

where $F$ is a submodular set function.

In [7]:
# Define arguments for SCMI
selection_strategy_args = {'device':        args['device'],       # Use the device used in training
                           'batch_size':    args['batch_size'],   # Use the batch size used in training
                           'scmi_function':  'flcmi',              # Use a facility location function, which captures representation information
                           'metric':        'cosine',             # Use cosine similarity when determining the likeness of two data points
                           'optimizer':     'LazyGreedy'          # When doing submodular maximization, use the lazy greedy optimizer
                          }

# Create the SCMI selection strategy. Note: We remove the labels from the unlabeled portion of CIFAR-10 that we created earlier.
# In a practical application, one would not have these labels a priori. Note: Since the OOD private set has class len(id_classes), we need to add 1 to nclasses.
selection_strategy = SCMI(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), query_set, private_set, trained_model, nclasses + 1, selection_strategy_args)

# Disable the augmentations used in the training dataset. Since all augmentations come from the cifar10_full_train object, we set its transform to the 
# transform used by the test set.
cifar10_full_train.transform = cifar_test_transform

# Do the selection, which will return the indices of the selected points with respect to the unlabeled dataset.
budget = 400
selected_idx = selection_strategy.select(budget)

# Re-enable augmentations
cifar10_full_train.transform = cifar_training_transform

## Labeling the Points

Now that we know which points should be labeled, we can present them to human labelers for annotation. Here, we can do so automatically since we already know their labels for the sake of the example.

In [8]:
# Form a labeled subset of the unlabeled dataset. Again, we already have the labels, 
# so we simply take a subset. Note, however, that the selection was done without the 
# use of the labels and that we would normally not have these labels. Hence, the 
# following statement would usually require human effort to complete.
scmi_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

## Characterizing the Selection

Now that we have selected and labeled these new points, we can add them to the training dataset and retrain our model. Before that, how many points did we select that actually were in-distribution points?

In [9]:
# Go over the newly labeled dataset, tallying the number of points seen in each class.
in_distribution_points = 0
for _, label in scmi_human_labeled_dataset:
    if label != len(id_classes):
        in_distribution_points += 1

# Print total ID count
print(F"Total In-Distribution Points: {in_distribution_points}")
print(F"Fraction of Budget: {in_distribution_points / budget}")

Total In-Distribution Points: 208
Fraction of Budget: 0.52


We were able to get a good number of in-distribution points. For comparison sake, how many in-distribution points would we get using some of DISTIL's other strategies?

**BADGE**

In [10]:
# Repeat the previous steps
selection_strategy = BADGE(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses + 1, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 400
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

badge_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

in_distribution_points = 0
for _, label in badge_human_labeled_dataset:
    if label != len(id_classes):
        in_distribution_points += 1

# Print total ID count
print(F"Total In-Distribution Points: {in_distribution_points}")
print(F"Fraction of Budget: {in_distribution_points / budget}")

Total In-Distribution Points: 103
Fraction of Budget: 0.2575


**Random**

In [11]:
# Repeat the previous steps
selection_strategy = RandomSampling(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses + 1, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 400
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

random_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

in_distribution_points = 0
for _, label in random_human_labeled_dataset:
    if label != len(id_classes):
        in_distribution_points += 1

# Print total ID count
print(F"Total In-Distribution Points: {in_distribution_points}")
print(F"Fraction of Budget: {in_distribution_points / budget}")

Total In-Distribution Points: 145
Fraction of Budget: 0.3625


**Entropy**

In [12]:
# Repeat the previous steps
selection_strategy = EntropySampling(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses + 1, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 400
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

entropy_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

in_distribution_points = 0
for _, label in entropy_human_labeled_dataset:
    if label != len(id_classes):
        in_distribution_points += 1

# Print total ID count
print(F"Total In-Distribution Points: {in_distribution_points}")
print(F"Fraction of Budget: {in_distribution_points / budget}")

Total In-Distribution Points: 80
Fraction of Budget: 0.2


Hence, we can see that SCMI does comparatively better at selecting in-distribution instances.

# Improving Performance

## Re-Training

Let us re-train our model using the newly selected points. Note, however, that we only want to add the in-distribution points.

In [13]:
# Create a new training dataset by concatenating what we have with the newly labeled in-distribution points.
in_distribution_selected_idx = []
for index, (_, label) in enumerate(scmi_human_labeled_dataset):
    if label != len(id_classes):
        in_distribution_selected_idx.append(index)
in_distribution_scmi_human_labeled_dataset = Subset(scmi_human_labeled_dataset, in_distribution_selected_idx)

new_training_dataset = ConcatDataset([cifar10_train, in_distribution_scmi_human_labeled_dataset])
print("New Training Dataset Length:", len(new_training_dataset))
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

New Training Dataset Length: 1408
Training..
Epoch: 123 Training accuracy: 0.991


## Evaluation

Now, let us see the accuracy improvement.

In [14]:
# Get the full test accuracy
full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

Full Test Accuracy Improvement: 0.6238333333333334 to 0.6513333333333333


## Comparison

What would the accuracy improvement look like if we had used the other methods?

**BADGE**

In [15]:
# Repeat the process
in_distribution_selected_idx = []
for index, (_, label) in enumerate(badge_human_labeled_dataset):
    if label != len(id_classes):
        in_distribution_selected_idx.append(index)
in_distribution_badge_human_labeled_dataset = Subset(badge_human_labeled_dataset, in_distribution_selected_idx)

new_training_dataset = ConcatDataset([cifar10_train, in_distribution_badge_human_labeled_dataset])
print("New Training Dataset Length:", len(new_training_dataset))
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

New Training Dataset Length: 1303
Training..
Epoch: 145 Training accuracy: 0.992
Full Test Accuracy Improvement: 0.6238333333333334 to 0.6311666666666667


**Random**

In [16]:
# Repeat the process
in_distribution_selected_idx = []
for index, (_, label) in enumerate(random_human_labeled_dataset):
    if label != len(id_classes):
        in_distribution_selected_idx.append(index)
in_distribution_random_human_labeled_dataset = Subset(random_human_labeled_dataset, in_distribution_selected_idx)

new_training_dataset = ConcatDataset([cifar10_train, in_distribution_random_human_labeled_dataset])
print("New Training Dataset Length:", len(new_training_dataset))
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

New Training Dataset Length: 1345
Training..
Epoch: 130 Training accuracy: 0.99
Full Test Accuracy Improvement: 0.6238333333333334 to 0.6201666666666666


**Entropy**

In [17]:
# Repeat the process
in_distribution_selected_idx = []
for index, (_, label) in enumerate(entropy_human_labeled_dataset):
    if label != len(id_classes):
        in_distribution_selected_idx.append(index)
in_distribution_entropy_human_labeled_dataset = Subset(entropy_human_labeled_dataset, in_distribution_selected_idx)

new_training_dataset = ConcatDataset([cifar10_train, in_distribution_entropy_human_labeled_dataset])
print("New Training Dataset Length:", len(new_training_dataset))
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

New Training Dataset Length: 1280
Training..
Epoch: 131 Training accuracy: 0.993
Full Test Accuracy Improvement: 0.6238333333333334 to 0.6231666666666666
