# DISTIL Usage Example: CIFAR-10 with Rare Classes

Not all datasets have an even spread of classes among the set of labels. Indeed, a dataset might have only a couple elements that have a particular class as a label. Such classes are considered *rare*, and extra work is required to achieve good model performance on these examples. The typical fix is to provide more data with the rare-class label; however, this issue is complicated by the fact that this data is usually part of a massive unlabeled pool of data. Here, we show how to use DISTIL's implementation of [SIMILAR](https://arxiv.org/abs/2107.00717) to mine these rare examples for labeling.

# Preparation

## Installation and Imports

In [1]:
# Get DISTIL
!git clone https://github.com/decile-team/distil.git
!pip install -r distil/requirements/requirements.txt

import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch
import torch.optim as optim
import torch.nn.functional as F

from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, Subset, ConcatDataset, DataLoader
from torchvision import transforms
from torchvision.datasets import cifar

sys.path.append('distil/')
from distil.active_learning_strategies import SMI, GLISTER, BADGE, EntropySampling, RandomSampling  # All active learning strategies showcased in this example
from distil.utils.models.resnet import ResNet18                                                     # The model used in our image classification example
from distil.utils.train_helper import data_train                                                    # A utility training class provided by DISTIL
from distil.utils.utils import LabeledToUnlabeledDataset                                            # A utility wrapper class that removes labels from labeled PyTorch dataset objects

Cloning into 'distil'...
remote: Enumerating objects: 3324, done.[K
remote: Counting objects: 100% (1281/1281), done.[K
remote: Compressing objects: 100% (812/812), done.[K
remote: Total 3324 (delta 794), reused 841 (delta 461), pack-reused 2043[K
Receiving objects: 100% (3324/3324), 23.05 MiB | 21.87 MiB/s, done.
Resolving deltas: 100% (2067/2067), done.
Looking in indexes: https://test.pypi.org/simple/, https://pypi.org/simple/
Collecting sphinxcontrib-bibtex>=2.3.0
  Downloading sphinxcontrib_bibtex-2.4.1-py3-none-any.whl (38 kB)
Collecting multipledispatch==0.6.0
  Downloading multipledispatch-0.6.0-py3-none-any.whl (11 kB)
Collecting scikit-learn==0.23.0
  Downloading scikit_learn-0.23.0-cp37-cp37m-manylinux1_x86_64.whl (7.3 MB)
[K     |████████████████████████████████| 7.3 MB 12.8 MB/s 
Collecting submodlib
  Downloading https://test-files.pythonhosted.org/packages/55/62/88e02a0e170498f38f7b9ce22b3e0a6a3cf9c82a33d3553da693c5c52872/submodlib-1.1.5.tar.gz (83 kB)
[K     |████

## Preparing CIFAR-10 with Rare Classes

The CIFAR10 dataset contains 60,000 32x32 color images in 10 different classes.The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class. The training set contains 50,000 images, and the test set contains 10,000 images. Here, we do a simple setup of the CIFAR10 dataset that we will use in this example. More importantly, we define a split on CIFAR10's training set into an initial labeled seed set and an unlabeled set. We also impose an artificial imbalance among the classes to simulate a rare-class scenario.

In [2]:
# Define the name of the dataset and the path that PyTorch should use when downloading the data
data_set_name = 'CIFAR10'
download_path = '.'

# Define the number of classes in CIFAR10
nclasses = 10

# Define transforms on the dataset splits of CIFAR10. Here, we use random crops and horizontal flips for training augmentations.
# Both the train and test sets are converted to PyTorch tensors and are normalized around the mean/std of CIFAR-10.
cifar_training_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# Get the dataset objects from PyTorch. Here, CIFAR10 is downloaded, and the transform is applied when points 
# are retrieved.
cifar10_full_train = cifar.CIFAR10(download_path, train=True, download=True, transform=cifar_training_transform)
cifar10_test = cifar.CIFAR10(download_path, train=False, download=True, transform=cifar_test_transform)

# Get the dimension of the images. Here, we simply take the very first image of CIFAR10 
# and query its dimension.
dim = np.shape(cifar10_full_train[0][0])

# We now define a train-unlabeled split for the sake of the experiment. Here, we specify the ratio of common classes to rare classes
# for both the initial training seed set and the unlabeled set. We use a ratio of 10:1.
num_examples_per_common_class_seed = 400
num_examples_per_rare_class_seed = 40

num_examples_per_common_class_unlabeled = 3000
num_examples_per_rare_class_unlabeled = 300

# We create the imbalance on classes 5,6,7,8,9.
rare_classes = [5,6,7,8,9]

# Create the imbalance by choosing which indices of the full training set to assign to the initial labeled seed set and the unlabeled set
train_idx = []
unlabeled_idx = []

for class_idx in range(nclasses):

    # Retrieve all the indices of the elements in CIFAR-10 whose label matches class_idx
    full_idx_class = torch.where(torch.Tensor(cifar10_full_train.targets) == class_idx)[0].cpu().numpy()

    # Determine how many points to add for this class depending on the rarity of the class
    if class_idx in rare_classes:
        class_num_training_examples_to_add = num_examples_per_rare_class_seed
        class_num_unlabeled_examples_to_add = num_examples_per_rare_class_unlabeled
    else:
        class_num_training_examples_to_add = num_examples_per_common_class_seed
        class_num_unlabeled_examples_to_add = num_examples_per_common_class_unlabeled

    # Choose randomly a subset of these indices. These will be added to the initial training seed set.
    class_train_idx = np.random.choice(full_idx_class, size=class_num_training_examples_to_add, replace=False)

    # Choose randomly a subset of the remaining indices. These will be added to the unlabeled set.
    remaining_class_idx = np.array(list(set(full_idx_class) - set(class_train_idx)))
    class_unlabeled_idx = np.random.choice(remaining_class_idx, size=class_num_unlabeled_examples_to_add, replace=False)

    # Add the chosen indices to the growing subset lists
    train_idx.extend(class_train_idx)
    unlabeled_idx.extend(class_unlabeled_idx)

# Create the train and unlabeled subsets based on the index lists above. While the unlabeled set constructed here technically has labels, they 
# are only used when querying for labels. Hence, they only exist here for the sake of experimental design.
cifar10_train = Subset(cifar10_full_train, train_idx)
cifar10_unlabeled = Subset(cifar10_full_train, unlabeled_idx)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


## Preparing the Model

Here, we use DISTIL's provided implementation of the [ResNet-18](https://arxiv.org/abs/1512.03385) architecture. We also create a model directory to store trained models in this example.

In [3]:
net = ResNet18()

# Initial Round

## Training

Now that we have prepared the training data and the model, we can begin training an initial model. We use DISTIL's provided [training loop](https://github.com/decile-team/distil/blob/main/distil/utils/train_helper.py) to do training.

In [4]:
# Define the training arguments to use.
args = {'n_epoch':      300,    # Stop training after 300 epochs.
        'lr':           0.01,   # Use a learning rate of 0.01
        'batch_size':   20,     # Update the parameters using training batches of size 20
        'max_accuracy': 0.99,   # Stop training once the training accuracy has exceeded 0.99
        'optimizer':    'sgd',  # Use the stochastic gradient descent optimizer
        'device':       "cuda" if torch.cuda.is_available() else "cpu"  # Use a GPU if one is available
        }

# Create the training loop using our training dataset, provided model, and training arguments.
# Train an initial model.
dt = data_train(cifar10_train, net, args)
trained_model = dt.train()

Training..
Epoch: 105 Training accuracy: 0.991


## Evaluation

How does our initial model do on the test set? Furthermore, how does our initial model do on the *rare classes* of the test set? Luckily, the training loop provided by DISTIL also provides a way to measure the accuracy of the model on a given dataset. We measure both accuracies here.

In [5]:
# Get the full test accuracy
full_test_accuracy = dt.get_acc_on_set(cifar10_test)
print(F"Full Test Accuracy: {full_test_accuracy}")

# Get the per-class test accuracies
rare_indices = []
for class_idx in range(nclasses):

    # Get the indices of the test set corresponding to this class
    test_rare_class_subset_idx = torch.where(torch.Tensor(cifar10_test.targets) == class_idx)[0].cpu().numpy()
    
    if class_idx in rare_classes:
        rare_indices.extend(test_rare_class_subset_idx)

    # Get the accuracy on this class subset
    cifar10_test_class_subset = Subset(cifar10_test, test_rare_class_subset_idx)
    rare_class_test_accuracy = dt.get_acc_on_set(cifar10_test_class_subset)
    print(F"Class {class_idx} Test Accuracy: {rare_class_test_accuracy}")

# Get accuracy on all rare points
cifar10_test_rare_subset = Subset(cifar10_test, rare_indices)
rare_test_accuracy = dt.get_acc_on_set(cifar10_test_rare_subset)
print(F"Rare Test Accuracy: {rare_test_accuracy}")

Full Test Accuracy: 0.5509
Class 0 Test Accuracy: 0.839
Class 1 Test Accuracy: 0.894
Class 2 Test Accuracy: 0.689
Class 3 Test Accuracy: 0.738
Class 4 Test Accuracy: 0.735
Class 5 Test Accuracy: 0.142
Class 6 Test Accuracy: 0.286
Class 7 Test Accuracy: 0.36
Class 8 Test Accuracy: 0.35
Class 9 Test Accuracy: 0.476
Rare Test Accuracy: 0.3228


As we can see, the test performance on the rare classes is awful. Can we rectify this issue by adding rare-class examples?

# Mining Rare Classes

## Preparing a Query Set

In this example, we know that there are rare classes in the unlabeled dataset that can help us improve the accuracy of the model. How do we select these points if we do not know *where* they are in the unlabeled dataset? Here, we use [SIMILAR](https://arxiv.org/abs/2107.00717) to select rare-class examples from the unlabeled set. SIMILAR requires access to a query set of points that it uses to choose similar unlabeled points. Hence, we must first prepare this query set.

Where do we get the query set? Luckily, we already have a couple points in our training dataset to choose.

In [6]:
# Go over the training dataset, getting the indices of all the rare-class examples
rare_training_example_indices = []
for index, (_, label) in enumerate(cifar10_train):
    if label in rare_classes:
        rare_training_example_indices.append(index)

# Create a query set that contains only the rare-class examples of the training dataset
rare_class_query_set = Subset(cifar10_train, rare_training_example_indices)

## Using DISTIL's SIMILAR Implementation: SMI

Now that we have the query set, we are ready to use DISTIL's implementation of SIMILAR. In particular, we use the submodular mutual information [strategy](https://github.com/decile-team/distil/blob/main/distil/active_learning_strategies/smi.py) that is detailed in [SIMILAR](https://arxiv.org/abs/2107.00717). This will allow us to select a set of points to label within a specified budget $k$.

Specifically, the strategy attempts to maximize the [submodular mutual information](https://arxiv.org/abs/2006.15412) between a subset $\mathcal{A}$ of size no greater than $k$ of the unlabeled dataset $\mathcal{U}$ and the query set $\mathcal{Q}$:

\begin{align}
\text{argmax}_{\mathcal{A} \subseteq \mathcal{U}, |\mathcal{A}|\leq k} I_F(\mathcal{A};\mathcal{Q})
\end{align}

where $F$ is a submodular set function.

In [7]:
# Define arguments for SMI
selection_strategy_args = {'device':        args['device'],       # Use the device used in training
                           'batch_size':    args['batch_size'],   # Use the batch size used in training
                           'smi_function':  'fl2mi',              # Use a facility location function, which captures representation information
                           'metric':        'cosine',             # Use cosine similarity when determining the likeness of two data points
                           'optimizer':     'LazyGreedy'          # When doing submodular maximization, use the lazy greedy optimizer
                          }

# Create the SMI selection strategy. Note: We remove the labels from the unlabeled portion of CIFAR-10 that we created earlier.
# In a practical application, one would not have these labels a priori.
selection_strategy = SMI(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), rare_class_query_set, trained_model, nclasses, selection_strategy_args)

# Disable the augmentations used in the training dataset. Since all augmentations come from the cifar10_full_train object, we set its transform to the 
# transform used by the test set.
cifar10_full_train.transform = cifar_test_transform

# Do the selection, which will return the indices of the selected points with respect to the unlabeled dataset.
budget = 750
selected_idx = selection_strategy.select(budget)

# Re-enable augmentations
cifar10_full_train.transform = cifar_training_transform

## Labeling the Points

Now that we know which points should be labeled, we can present them to human labelers for annotation. Here, we can do so automatically since we already know their labels for the sake of the example.

In [8]:
# Form a labeled subset of the unlabeled dataset. Again, we already have the labels, 
# so we simply take a subset. Note, however, that the selection was done without the 
# use of the labels and that we would normally not have these labels. Hence, the 
# following statement would usually require human effort to complete.
smi_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

## Characterizing the Selection

Now that we have selected and labeled these new points, we can add them to the training dataset and retrain our model. Before that, how many points did we select that actually were rare-class points?

In [9]:
# Go over the newly labeled dataset, tallying the number of points seen in each class.
smi_class_counts = [0 for x in range(nclasses)]
for _, label in smi_human_labeled_dataset:
    smi_class_counts[label] += 1

# Print each class count
for class_idx, class_count in enumerate(smi_class_counts):
    print(F"Class {class_idx} count: {class_count}")

# Print total rare count
total_rare_count = sum([smi_class_counts[i] for i in rare_classes])
print(F"Total Rare Count: {total_rare_count}")

Class 0 count: 47
Class 1 count: 40
Class 2 count: 43
Class 3 count: 118
Class 4 count: 81
Class 5 count: 45
Class 6 count: 65
Class 7 count: 108
Class 8 count: 74
Class 9 count: 129
Total Rare Count: 421


We were able to get a good number of rare-class points. For comparison sake, how many rare points would we get using some of DISTIL's other strategies?

**BADGE**

In [10]:
# Repeat the previous steps
selection_strategy = BADGE(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 750
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

badge_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

badge_class_counts = [0 for x in range(nclasses)]
for _, label in badge_human_labeled_dataset:
    badge_class_counts[label] += 1

for class_idx, class_count in enumerate(badge_class_counts):
    print(F"Class {class_idx} count: {class_count}")

total_rare_count = sum([badge_class_counts[i] for i in rare_classes])
print(F"Total Rare Count: {total_rare_count}")

Class 0 count: 114
Class 1 count: 92
Class 2 count: 139
Class 3 count: 158
Class 4 count: 125
Class 5 count: 16
Class 6 count: 29
Class 7 count: 28
Class 8 count: 19
Class 9 count: 30
Total Rare Count: 122


**Random**

In [11]:
# Repeat the previous steps
selection_strategy = RandomSampling(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 750
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

random_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

random_class_counts = [0 for x in range(nclasses)]
for _, label in random_human_labeled_dataset:
    random_class_counts[label] += 1

for class_idx, class_count in enumerate(random_class_counts):
    print(F"Class {class_idx} count: {class_count}")

total_rare_count = sum([random_class_counts[i] for i in rare_classes])
print(F"Total Rare Count: {total_rare_count}")

Class 0 count: 146
Class 1 count: 128
Class 2 count: 124
Class 3 count: 142
Class 4 count: 143
Class 5 count: 16
Class 6 count: 9
Class 7 count: 14
Class 8 count: 11
Class 9 count: 17
Total Rare Count: 67


**Entropy**

In [12]:
# Repeat the previous steps
selection_strategy = EntropySampling(cifar10_train, LabeledToUnlabeledDataset(cifar10_unlabeled), trained_model, nclasses, selection_strategy_args)

cifar10_full_train.transform = cifar_test_transform
budget = 750
selected_idx = selection_strategy.select(budget)
cifar10_full_train.transform = cifar_training_transform

entropy_human_labeled_dataset = Subset(cifar10_unlabeled, selected_idx)

entropy_class_counts = [0 for x in range(nclasses)]
for _, label in entropy_human_labeled_dataset:
    entropy_class_counts[label] += 1

for class_idx, class_count in enumerate(entropy_class_counts):
    print(F"Class {class_idx} count: {class_count}")

total_rare_count = sum([entropy_class_counts[i] for i in rare_classes])
print(F"Total Rare Count: {total_rare_count}")

Class 0 count: 96
Class 1 count: 54
Class 2 count: 126
Class 3 count: 168
Class 4 count: 172
Class 5 count: 22
Class 6 count: 34
Class 7 count: 34
Class 8 count: 22
Class 9 count: 22
Total Rare Count: 134


Hence, we can see that SMI does comparatively much better at selecting rare instances.

# Improving Performance

## Re-Training

Let us re-train our model using the newly selected points.

In [13]:
# Create a new training dataset by concatenating what we have with the newly labeled points
new_training_dataset = ConcatDataset([cifar10_train, smi_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

Training..
Epoch: 102 Training accuracy: 0.992


## Evaluation

Now, let us see the accuracy improvement.

In [14]:
# Get the full test accuracy
full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

# Get the per-class test accuracies
rare_indices = []
for class_idx in range(nclasses):

    # Get the indices of the test set corresponding to this class
    test_rare_class_subset_idx = torch.where(torch.Tensor(cifar10_test.targets) == class_idx)[0].cpu().numpy()
    
    if class_idx in rare_classes:
        rare_indices.extend(test_rare_class_subset_idx)

    # Get the accuracy on this class subset
    cifar10_test_class_subset = Subset(cifar10_test, test_rare_class_subset_idx)

    rare_class_test_accuracy_before = dt.get_acc_on_set(cifar10_test_class_subset)
    rare_class_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_class_subset)
    print(F"Class {class_idx} Test Accuracy: {rare_class_test_accuracy_before} to {rare_class_test_accuracy_after}")

# Get accuracy on all rare points
cifar10_test_rare_subset = Subset(cifar10_test, rare_indices)
rare_test_accuracy_before = dt.get_acc_on_set(cifar10_test_rare_subset)
rare_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_rare_subset)
print(F"Rare Test Accuracy: {rare_test_accuracy_before} to {rare_test_accuracy_after}")

Full Test Accuracy Improvement: 0.5509 to 0.6512
Class 0 Test Accuracy: 0.839 to 0.843
Class 1 Test Accuracy: 0.894 to 0.918
Class 2 Test Accuracy: 0.689 to 0.649
Class 3 Test Accuracy: 0.738 to 0.769
Class 4 Test Accuracy: 0.735 to 0.812
Class 5 Test Accuracy: 0.142 to 0.283
Class 6 Test Accuracy: 0.286 to 0.441
Class 7 Test Accuracy: 0.36 to 0.537
Class 8 Test Accuracy: 0.35 to 0.621
Class 9 Test Accuracy: 0.476 to 0.639
Rare Test Accuracy: 0.3228 to 0.5042


## Comparison

What would the accuracy improvement look like if we had used the other methods?

**BADGE**

In [15]:
# Repeat the process
new_training_dataset = ConcatDataset([cifar10_train, badge_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

rare_indices = []
for class_idx in range(nclasses):

    test_rare_class_subset_idx = torch.where(torch.Tensor(cifar10_test.targets) == class_idx)[0].cpu().numpy()
    
    if class_idx in rare_classes:
        rare_indices.extend(test_rare_class_subset_idx)

    cifar10_test_class_subset = Subset(cifar10_test, test_rare_class_subset_idx)

    rare_class_test_accuracy_before = dt.get_acc_on_set(cifar10_test_class_subset)
    rare_class_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_class_subset)
    print(F"Class {class_idx} Test Accuracy: {rare_class_test_accuracy_before} to {rare_class_test_accuracy_after}")

cifar10_test_rare_subset = Subset(cifar10_test, rare_indices)
rare_test_accuracy_before = dt.get_acc_on_set(cifar10_test_rare_subset)
rare_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_rare_subset)
print(F"Rare Test Accuracy: {rare_test_accuracy_before} to {rare_test_accuracy_after}")

Training..
Epoch: 122 Training accuracy: 0.994
Full Test Accuracy Improvement: 0.5509 to 0.623
Class 0 Test Accuracy: 0.839 to 0.883
Class 1 Test Accuracy: 0.894 to 0.941
Class 2 Test Accuracy: 0.689 to 0.725
Class 3 Test Accuracy: 0.738 to 0.748
Class 4 Test Accuracy: 0.735 to 0.797
Class 5 Test Accuracy: 0.142 to 0.276
Class 6 Test Accuracy: 0.286 to 0.312
Class 7 Test Accuracy: 0.36 to 0.492
Class 8 Test Accuracy: 0.35 to 0.441
Class 9 Test Accuracy: 0.476 to 0.615
Rare Test Accuracy: 0.3228 to 0.4272


**Random**

In [16]:
# Repeat the process
new_training_dataset = ConcatDataset([cifar10_train, random_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

rare_indices = []
for class_idx in range(nclasses):

    test_rare_class_subset_idx = torch.where(torch.Tensor(cifar10_test.targets) == class_idx)[0].cpu().numpy()
    
    if class_idx in rare_classes:
        rare_indices.extend(test_rare_class_subset_idx)

    cifar10_test_class_subset = Subset(cifar10_test, test_rare_class_subset_idx)

    rare_class_test_accuracy_before = dt.get_acc_on_set(cifar10_test_class_subset)
    rare_class_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_class_subset)
    print(F"Class {class_idx} Test Accuracy: {rare_class_test_accuracy_before} to {rare_class_test_accuracy_after}")

cifar10_test_rare_subset = Subset(cifar10_test, rare_indices)
rare_test_accuracy_before = dt.get_acc_on_set(cifar10_test_rare_subset)
rare_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_rare_subset)
print(F"Rare Test Accuracy: {rare_test_accuracy_before} to {rare_test_accuracy_after}")

Training..
Epoch: 118 Training accuracy: 0.991
Full Test Accuracy Improvement: 0.5509 to 0.5915
Class 0 Test Accuracy: 0.839 to 0.771
Class 1 Test Accuracy: 0.894 to 0.932
Class 2 Test Accuracy: 0.689 to 0.758
Class 3 Test Accuracy: 0.738 to 0.612
Class 4 Test Accuracy: 0.735 to 0.806
Class 5 Test Accuracy: 0.142 to 0.232
Class 6 Test Accuracy: 0.286 to 0.425
Class 7 Test Accuracy: 0.36 to 0.341
Class 8 Test Accuracy: 0.35 to 0.499
Class 9 Test Accuracy: 0.476 to 0.539
Rare Test Accuracy: 0.3228 to 0.4072


**Entropy**

In [17]:
# Repeat the process
new_training_dataset = ConcatDataset([cifar10_train, entropy_human_labeled_dataset])
new_dt = data_train(new_training_dataset, copy.deepcopy(net), args)
new_trained_model = new_dt.train()

full_test_accuracy_before = dt.get_acc_on_set(cifar10_test)
full_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test)

print(F"Full Test Accuracy Improvement: {full_test_accuracy_before} to {full_test_accuracy_after}")

rare_indices = []
for class_idx in range(nclasses):

    test_rare_class_subset_idx = torch.where(torch.Tensor(cifar10_test.targets) == class_idx)[0].cpu().numpy()
    
    if class_idx in rare_classes:
        rare_indices.extend(test_rare_class_subset_idx)

    cifar10_test_class_subset = Subset(cifar10_test, test_rare_class_subset_idx)

    rare_class_test_accuracy_before = dt.get_acc_on_set(cifar10_test_class_subset)
    rare_class_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_class_subset)
    print(F"Class {class_idx} Test Accuracy: {rare_class_test_accuracy_before} to {rare_class_test_accuracy_after}")

cifar10_test_rare_subset = Subset(cifar10_test, rare_indices)
rare_test_accuracy_before = dt.get_acc_on_set(cifar10_test_rare_subset)
rare_test_accuracy_after = new_dt.get_acc_on_set(cifar10_test_rare_subset)
print(F"Rare Test Accuracy: {rare_test_accuracy_before} to {rare_test_accuracy_after}")

Training..
Epoch: 116 Training accuracy: 0.993
Full Test Accuracy Improvement: 0.5509 to 0.6483
Class 0 Test Accuracy: 0.839 to 0.879
Class 1 Test Accuracy: 0.894 to 0.906
Class 2 Test Accuracy: 0.689 to 0.699
Class 3 Test Accuracy: 0.738 to 0.769
Class 4 Test Accuracy: 0.735 to 0.829
Class 5 Test Accuracy: 0.142 to 0.182
Class 6 Test Accuracy: 0.286 to 0.456
Class 7 Test Accuracy: 0.36 to 0.457
Class 8 Test Accuracy: 0.35 to 0.623
Class 9 Test Accuracy: 0.476 to 0.683
Rare Test Accuracy: 0.3228 to 0.4802
