# Preface

The locations requiring configuration for your experiment are commented in capital text.

# Setup

**Installations**

In [None]:
!pip install apricot-select
!pip install sphinxcontrib-napoleon
!pip install sphinxcontrib-bibtex

!git clone https://github.com/decile-team/distil.git
!git clone https://github.com/circulosmeos/gdown.pl.git

!mv distil asdf
!mv asdf/distil .

**Experiment-Specific Imports**

In [None]:
from distil.utils.data_handler import DataHandler_CIFAR100, DataHandler_Points # IMPORT YOUR DATAHANDLER HERE
from distil.utils.models.resnet import ResNet18                                 # IMPORT YOUR MODEL HERE

**Imports, Training Class Definition, Experiment Procedure Definition**

Nothing needs to be modified in this code block unless it specifically pertains to a change of experimental procedure.

In [None]:
import pandas as pd 
import numpy as np
import copy
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
from torchvision import datasets
from PIL import Image
import torch
import torch.optim as optim
from torch.autograd import Variable
import sys
sys.path.append('../')
import matplotlib.pyplot as plt
import time
import math
import random
import os
import pickle

from numpy.linalg import cond
from numpy.linalg import inv
from numpy.linalg import norm
from scipy import sparse as sp
from scipy.linalg import lstsq
from scipy.linalg import solve
from scipy.optimize import nnls

from distil.active_learning_strategies.badge import BADGE
from distil.active_learning_strategies.glister import GLISTER
from distil.active_learning_strategies.margin_sampling import MarginSampling
from distil.active_learning_strategies.entropy_sampling import EntropySampling
from distil.active_learning_strategies.random_sampling import RandomSampling
from distil.active_learning_strategies.gradmatch_active import GradMatchActive
from distil.active_learning_strategies.craig_active import CRAIGActive
from distil.active_learning_strategies.fass import FASS
from distil.active_learning_strategies.adversarial_bim import AdversarialBIM
from distil.active_learning_strategies.adversarial_deepfool import AdversarialDeepFool
from distil.active_learning_strategies.core_set import CoreSet
from distil.active_learning_strategies.least_confidence import LeastConfidence
from distil.active_learning_strategies.margin_sampling import MarginSampling
from distil.active_learning_strategies.bayesian_active_learning_disagreement_dropout import BALDDropout
from distil.utils.dataset import get_dataset
from distil.utils.train_helper import data_train

from google.colab import drive
import warnings
warnings.filterwarnings("ignore")

class Checkpoint:

    def __init__(self, acc_list=None, indices=None, state_dict=None, experiment_name=None, path=None):

        # If a path is supplied, load a checkpoint from there.
        if path is not None:

            if experiment_name is not None:
                self.load_checkpoint(path, experiment_name)
            else:
                raise ValueError("Checkpoint contains None value for experiment_name")

            return

        if acc_list is None:
            raise ValueError("Checkpoint contains None value for acc_list")

        if indices is None:
            raise ValueError("Checkpoint contains None value for indices")

        if state_dict is None:
            raise ValueError("Checkpoint contains None value for state_dict")

        if experiment_name is None:
            raise ValueError("Checkpoint contains None value for experiment_name")

        self.acc_list = acc_list
        self.indices = indices
        self.state_dict = state_dict
        self.experiment_name = experiment_name

    def __eq__(self, other):

        # Check if the accuracy lists are equal
        acc_lists_equal = self.acc_list == other.acc_list

        # Check if the indices are equal
        indices_equal = self.indices == other.indices

        # Check if the experiment names are equal
        experiment_names_equal = self.experiment_name == other.experiment_name

        return acc_lists_equal and indices_equal and experiment_names_equal

    def save_checkpoint(self, path):

        # Get current time to use in file timestamp
        timestamp = time.time_ns()

        # Create the path supplied
        os.makedirs(path, exist_ok=True)

        # Name saved files using timestamp to add recency information
        save_path = os.path.join(path, F"c{timestamp}1")
        copy_save_path = os.path.join(path, F"c{timestamp}2")

        # Write this checkpoint to the first save location
        with open(save_path, 'wb') as save_file:
            pickle.dump(self, save_file)

        # Write this checkpoint to the second save location
        with open(copy_save_path, 'wb') as copy_save_file:
            pickle.dump(self, copy_save_file)

    def load_checkpoint(self, path, experiment_name):

        # Obtain a list of all files present at the path
        timestamp_save_no = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]

        # If there are no such files, set values to None and return
        if len(timestamp_save_no) == 0:
            self.acc_list = None
            self.indices = None
            self.state_dict = None
            return

        # Sort the list of strings to get the most recent
        timestamp_save_no.sort(reverse=True)

        # Read in two files at a time, checking if they are equal to one another. 
        # If they are equal, then it means that the save operation finished correctly.
        # If they are not, then it means that the save operation failed (could not be 
        # done atomically). Repeat this action until no possible pair can exist.
        while len(timestamp_save_no) > 1:

            # Pop a most recent checkpoint copy
            first_file = timestamp_save_no.pop(0)

            # Keep popping until two copies with equal timestamps are present
            while True:
                
                second_file = timestamp_save_no.pop(0)
                
                # Timestamps match if the removal of the "1" or "2" results in equal numbers
                if (second_file[:-1]) == (first_file[:-1]):
                    break
                else:
                    first_file = second_file

                    # If there are no more checkpoints to examine, set to None and return
                    if len(timestamp_save_no) == 0:
                        self.acc_list = None
                        self.indices = None
                        self.state_dict = None
                        return

            # Form the paths to the files
            load_path = os.path.join(path, first_file)
            copy_load_path = os.path.join(path, second_file)

            # Load the two checkpoints
            with open(load_path, 'rb') as load_file:
                checkpoint = pickle.load(load_file)

            with open(copy_load_path, 'rb') as copy_load_file:
                checkpoint_copy = pickle.load(copy_load_file)

            # Do not check this experiment if it is not the one we need to restore
            if checkpoint.experiment_name != experiment_name:
                continue

            # Check if they are equal
            if checkpoint == checkpoint_copy:

                # This checkpoint will suffice. Populate this checkpoint's fields 
                # with the selected checkpoint's fields.
                self.acc_list = checkpoint.acc_list
                self.indices = checkpoint.indices
                self.state_dict = checkpoint.state_dict
                return

        # Instantiate None values in acc_list, indices, and model
        self.acc_list = None
        self.indices = None
        self.state_dict = None

    def get_saved_values(self):

        return (self.acc_list, self.indices, self.state_dict)

def delete_checkpoints(checkpoint_directory, experiment_name):

    # Iteratively go through each checkpoint, deleting those whose experiment name matches.
    timestamp_save_no = [f for f in os.listdir(checkpoint_directory) if os.path.isfile(os.path.join(checkpoint_directory, f))]

    for file in timestamp_save_no:

        delete_file = False

        # Get file location
        file_path = os.path.join(checkpoint_directory, file)

        if not os.path.exists(file_path):
            continue

        # Unpickle the checkpoint and see if its experiment name matches
        with open(file_path, "rb") as load_file:

            checkpoint_copy = pickle.load(load_file)
            if checkpoint_copy.experiment_name == experiment_name:
                delete_file = True

        # Delete this file only if the experiment name matched
        if delete_file:
            os.remove(file_path)

#Logs
def write_logs(logs, save_directory, rd, run):
  file_path = save_directory + 'run_'+str(run)+'.txt'
  with open(file_path, 'a') as f:
    f.write('---------------------\n')
    f.write('Round '+str(rd)+'\n')
    f.write('---------------------\n')
    for key, val in logs.items():
      if key == 'Training':
        f.write(str(key)+ '\n')
        for epoch in val:
          f.write(str(epoch)+'\n')       
      else:
        f.write(str(key) + ' - '+ str(val) +'\n')

def train_one(X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, strategy, save_directory, run, checkpoint_directory, experiment_name):

    # Define acc initially
    acc = np.zeros(n_rounds+1)

    initial_unlabeled_size = X_unlabeled.shape[0]

    initial_round = 1

    # Define an index map
    index_map = np.array([x for x in range(initial_unlabeled_size)])

    # Attempt to load a checkpoint. If one exists, then the experiment crashed.
    training_checkpoint = Checkpoint(experiment_name=experiment_name, path=checkpoint_directory)
    rec_acc, rec_indices, rec_state_dict = training_checkpoint.get_saved_values()

    # Check if there are values to recover
    if rec_acc is not None:

        # Restore the accuracy list
        for i in range(len(rec_acc)):
            acc[i] = rec_acc[i]

        # Restore the indices list and shift those unlabeled points to the labeled set.
        index_map = np.delete(index_map, rec_indices)

        # Record initial size of X_tr
        intial_seed_size = X_tr.shape[0]

        X_tr = np.concatenate((X_tr, X_unlabeled[rec_indices]), axis=0)
        X_unlabeled = np.delete(X_unlabeled, rec_indices, axis = 0)

        y_tr = np.concatenate((y_tr, y_unlabeled[rec_indices]), axis = 0)
        y_unlabeled = np.delete(y_unlabeled, rec_indices, axis = 0)

        # Restore the model
        net.load_state_dict(rec_state_dict) 

        # Fix the initial round
        initial_round = (X_tr.shape[0] - initial_seed_size) // budget + 1

        # Ensure loaded model is moved to GPU
        if torch.cuda.is_available():
            net = net.cuda()     

        strategy.update_model(net)
        strategy.update_data(X_tr, y_tr, X_unlabeled) 

    else:

        if torch.cuda.is_available():
            net = net.cuda()

        acc[0] = dt.get_acc_on_set(X_test, y_test)
        print('Initial Testing accuracy:', round(acc[0]*100, 2), flush=True)

        logs = {}
        logs['Training Points'] = X_tr.shape[0]
        logs['Test Accuracy'] =  str(round(acc[0]*100, 2))
        write_logs(logs, save_directory, 0, run)
          
        #Updating the trained model in strategy class
        strategy.update_model(net)

    ##User Controlled Loop
    for rd in range(initial_round, n_rounds+1):
        print('-------------------------------------------------')
        print('Round', rd) 
        print('-------------------------------------------------')

        sel_time = time.time()
        idx = strategy.select(budget)            
        sel_time = time.time() - sel_time
        print("Selection Time:", sel_time)

        #Saving state of model, since labeling new points might take time
        # strategy.save_state()

        #Adding new points to training set
        X_tr = np.concatenate((X_tr, X_unlabeled[idx]), axis=0)
        X_unlabeled = np.delete(X_unlabeled, idx, axis = 0)

        #Human In Loop, Assuming user adds new labels here
        y_tr = np.concatenate((y_tr, y_unlabeled[idx]), axis = 0)
        y_unlabeled = np.delete(y_unlabeled, idx, axis = 0)

        # Update the index map
        index_map = np.delete(index_map, idx, axis = 0)

        print('Number of training points -',X_tr.shape[0])

        #Reload state and start training
        # strategy.load_state()
        strategy.update_data(X_tr, y_tr, X_unlabeled)
        dt.update_data(X_tr, y_tr)
        t1 = time.time()
        clf, train_logs = dt.train(None)
        t2 = time.time()
        acc[rd] = dt.get_acc_on_set(X_test, y_test)
        logs = {}
        logs['Training Points'] = X_tr.shape[0]
        logs['Test Accuracy'] =  str(round(acc[rd]*100, 2))
        logs['Selection Time'] = str(sel_time)
        logs['Trainining Time'] = str(t2 - t1) 
        logs['Training'] = train_logs
        write_logs(logs, save_directory, rd, run)
        strategy.update_model(clf)
        print('Testing accuracy:', round(acc[rd]*100, 2), flush=True)

        # Create a checkpoint
        used_indices = np.array([x for x in range(initial_unlabeled_size)])
        used_indices = np.delete(used_indices, index_map).tolist()

        round_checkpoint = Checkpoint(acc.tolist(), used_indices, clf.state_dict(), experiment_name=experiment_name)
        round_checkpoint.save_checkpoint(checkpoint_directory)

    print('Training Completed')
    return acc


# Define a function to perform experiments in bulk and return the mean accuracies
def BADGE_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = BADGE(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc


# Define a function to perform experiments in bulk and return the mean accuracies
def random_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = RandomSampling(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def entropy_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = EntropySampling(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def GLISTER_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)
        
        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'lr': args['lr'], 'device':args['device']}
        strategy = GLISTER(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args,valid=False, typeOf='rand', lam=0.1)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def FASS_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = FASS(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc


# Define a function to perform experiments in bulk and return the mean accuracies
def adversarial_bim_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = AdversarialBIM(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def adversarial_deepfool_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = AdversarialDeepFool(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def coreset_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = CoreSet(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def least_confidence_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = LeastConfidence(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# Define a function to perform experiments in bulk and return the mean accuracies
def margin_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = MarginSampling(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc


# Define a function to perform experiments in bulk and return the mean accuracies
def bald_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, net, n_rounds, budget, args, nclasses, save_directory, checkpoint_directory, experiment_name):

    test_acc_list = list()
    fig = plt.figure(figsize=(8,6), dpi=160)
    x_axis = [np.shape(X_tr)[0] + budget * x for x in range(0, n_rounds + 1)]

    for i in range(n_exp):
        # Copy data and model to ensure that experiments do not override base versions
        X_tr_copy = copy.deepcopy(X_tr)
        y_tr_copy = copy.deepcopy(y_tr)
        X_unlabeled_copy = copy.deepcopy(X_unlabeled)
        y_unlabeled_copy = copy.deepcopy(y_unlabeled)
        X_test_copy = copy.deepcopy(X_test)
        y_test_copy = copy.deepcopy(y_test)
        dt_copy = copy.deepcopy(dt)
        clf_copy = copy.deepcopy(net)

        #Initializing Strategy Class
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = BALDDropout(X_tr, y_tr, X_unlabeled, net, handler, nclasses, strategy_args)

        test_acc = train_one(X_tr_copy, y_tr_copy, X_test_copy, y_test_copy, X_unlabeled_copy, y_unlabeled_copy, dt_copy, clf_copy, n_rounds, budget, args, nclasses, strategy, save_directory, i, checkpoint_directory, experiment_name)
        test_acc_list.append(test_acc)
        plt.plot(x_axis, test_acc, label=str(i))
        print("EXPERIMENT", i, test_acc)

        # Experiment complete; delete all checkpoints related to this experiment
        delete_checkpoints(checkpoint_directory, experiment_name)

    mean_test_acc = np.zeros(n_rounds + 1)

    for test_acc in test_acc_list:
        mean_test_acc = mean_test_acc + test_acc

    mean_test_acc = mean_test_acc / n_exp
    plt.plot(x_axis, mean_test_acc, label="Mean")

    plt.xlabel("Labeled Set Size")
    plt.ylabel("Test Acc")
    plt.legend()
    plt.show()

    print("MEAN TEST ACC", mean_test_acc)

    return mean_test_acc

# CIFAR100

**Parameter Definitions**

Parameters related to the specific experiment are placed here. You should examine each and modify them as needed.

In [None]:
data_set_name = 'CIFAR100'
download_path = '../downloaded_data/'
handler = DataHandler_CIFAR100 # PUT DATAHANDLER HERE
net = ResNet18(100)# MODEL HERE

# MODIFY AS NECESSARY
logs_directory = '/content/gdrive/MyDrive/colab_storage/logs/'
checkpoint_directory = '/content/gdrive/MyDrive/colab_storage/check/'
initial_model = data_set_name
model_directory = "/content/gdrive/MyDrive/colab_storage/model/"

experiment_name = "CIFAR100 BASELINE"

initial_seed_size = 5000 # INIT SEED SIZE HERE
training_size_cap = 50000 # TRAIN SIZE CAP HERE

nclasses = 100 # NUM CLASSES HERE
budget = 5000 # BUDGET HERE

# CHANGE ARGS AS NECESSARY
args = {'n_epoch':300, 'lr':float(0.01), 'batch_size':20, 'max_accuracy':float(0.99), 'num_classes':nclasses, 'islogs':True, 'isreset':True, 'isverbose':True, 'device':'cuda'} 

# Train on approximately the full dataset given the budget contraints
n_rounds = (training_size_cap - initial_seed_size) // budget

# SET N EXP TO RUN (>1 for repeat)
n_exp = 1

**Initial Loading and Training**

You may choose to train a new initial model or to continue to load a specific model. If this notebook is being executed in Colab, you should consider whether or not you need the gdown line.

In [None]:
# Mount drive containing possible saved model and define file path.
colab_model_storage_mount = "/content/gdrive"
drive.mount(colab_model_storage_mount)

# Retrieve the model from Apurva's link and save it to the drive
os.makedirs(logs_directory, exist_ok = True)
os.makedirs(checkpoint_directory, exist_ok = True)
os.makedirs(model_directory, exist_ok = True)
model_directory = F"{model_directory}/{data_set_name}"
#!/content/gdown.pl/gdown.pl "clone link" "clone location" # MAY NOT NEED THIS LINE IF NOT CLONING MODEL FROM COLAB

X, y, X_test, y_test = get_dataset(data_set_name, download_path)
dim = np.shape(X)[1:]

X_tr = X[:initial_seed_size]
y_tr = y[:initial_seed_size].numpy()
X_unlabeled = X[initial_seed_size:]
y_unlabeled = y[initial_seed_size:].numpy()

X_test = X_test
y_test = y_test.numpy()

# COMMENT OUT ONE OR THE OTHER IF YOU WANT TO TRAIN A NEW INITIAL MODEL
load_model = False
#load_model = True

# Only train a new model if one does not exist.
if load_model:
    net.load_state_dict(torch.load(model_directory))
    dt = data_train(X_tr, y_tr, net, handler, args)
    clf = net
else:
    dt = data_train(X_tr, y_tr, net, handler, args)
    clf, _ = dt.train(None)
    torch.save(clf.state_dict(), model_directory)

print("Training for", n_rounds, "rounds with budget", budget, "on unlabeled set size", training_size_cap)

**Random Sampling**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/random_sampling/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_random = random_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_random")

**Entropy (Uncertainty) Sampling**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/entropy_sampling/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_entropy = entropy_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_entropy")

**GLISTER**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/glister/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_glister = GLISTER_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_glister")

**FASS**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/fass/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_fass = FASS_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_fass")

**BADGE**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/badge/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_badge = BADGE_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_badge")

**CoreSet**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/coreset/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_coreset = coreset_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_coreset")

**Least Confidence**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/least_confidence/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_least_confidence = least_confidence_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_least_conf")

**Margin**

In [None]:
strat_logs = logs_directory+F'{data_set_name}/margin_sampling/'
os.makedirs(strat_logs, exist_ok = True)
mean_test_acc_margin = margin_experiment_batch(n_exp, X_tr, y_tr, X_test, y_test, X_unlabeled, y_unlabeled, dt, clf, n_rounds, budget, args, nclasses, strat_logs, checkpoint_directory, F"{experiment_name}_margin")