# Preface

The locations requiring configuration for your experiment are commented in capital text.

# Setup

## Installations

In [None]:
!pip install sphinxcontrib-napoleon
!pip install sphinxcontrib-bibtex
!pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ submodlib

!git clone https://github.com/decile-team/distil.git
!git clone https://github.com/circulosmeos/gdown.pl.git
!git clone https://github.com/owruby/shake-shake_pytorch.git

!mv shake-shake_pytorch/models .

import sys
sys.path.append("/content/distil/")

**Experiment-Specific Imports**

In [None]:
from distil.utils.models.resnet import ResNet18                                 # IMPORT YOUR MODEL HERE

## Main Imports

In [None]:
import pandas as pd 
import numpy as np
import copy
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
from torchvision import datasets
from PIL import Image
import torch
import torch.optim as optim
from torch.autograd import Variable
import sys
sys.path.append('../')
import matplotlib.pyplot as plt
import time
import math
import random
import os
import pickle

from numpy.linalg import cond
from numpy.linalg import inv
from numpy.linalg import norm
from scipy import sparse as sp
from scipy.linalg import lstsq
from scipy.linalg import solve
from scipy.optimize import nnls

from distil.active_learning_strategies.badge import BADGE
from distil.active_learning_strategies.glister import GLISTER
from distil.active_learning_strategies.margin_sampling import MarginSampling
from distil.active_learning_strategies.entropy_sampling import EntropySampling
from distil.active_learning_strategies.random_sampling import RandomSampling
from distil.active_learning_strategies.gradmatch_active import GradMatchActive
from distil.active_learning_strategies.fass import FASS
from distil.active_learning_strategies.adversarial_bim import AdversarialBIM
from distil.active_learning_strategies.adversarial_deepfool import AdversarialDeepFool
from distil.active_learning_strategies.core_set import CoreSet
from distil.active_learning_strategies.least_confidence_sampling import LeastConfidenceSampling
from distil.active_learning_strategies.margin_sampling import MarginSampling
from distil.active_learning_strategies.bayesian_active_learning_disagreement_dropout import BALDDropout
from distil.utils.utils import LabeledToUnlabeledDataset

from torch.optim.swa_utils import AveragedModel, SWALR, update_bn

from models.shakeshake import ShakeShake
from models.shakeshake import Shortcut

from google.colab import drive
import warnings
warnings.filterwarnings("ignore")

## Checkpointing and Logs

In [None]:
class Checkpoint:

    def __init__(self, acc_list=None, indices=None, state_dict=None, experiment_name=None, path=None):

        # If a path is supplied, load a checkpoint from there.
        if path is not None:

            if experiment_name is not None:
                self.load_checkpoint(path, experiment_name)
            else:
                raise ValueError("Checkpoint contains None value for experiment_name")

            return

        if acc_list is None:
            raise ValueError("Checkpoint contains None value for acc_list")

        if indices is None:
            raise ValueError("Checkpoint contains None value for indices")

        if state_dict is None:
            raise ValueError("Checkpoint contains None value for state_dict")

        if experiment_name is None:
            raise ValueError("Checkpoint contains None value for experiment_name")

        self.acc_list = acc_list
        self.indices = indices
        self.state_dict = state_dict
        self.experiment_name = experiment_name

    def __eq__(self, other):

        # Check if the accuracy lists are equal
        acc_lists_equal = self.acc_list == other.acc_list

        # Check if the indices are equal
        indices_equal = self.indices == other.indices

        # Check if the experiment names are equal
        experiment_names_equal = self.experiment_name == other.experiment_name

        return acc_lists_equal and indices_equal and experiment_names_equal

    def save_checkpoint(self, path):

        # Get current time to use in file timestamp
        timestamp = time.time_ns()

        # Create the path supplied
        os.makedirs(path, exist_ok=True)

        # Name saved files using timestamp to add recency information
        save_path = os.path.join(path, F"c{timestamp}1")
        copy_save_path = os.path.join(path, F"c{timestamp}2")

        # Write this checkpoint to the first save location
        with open(save_path, 'wb') as save_file:
            pickle.dump(self, save_file)

        # Write this checkpoint to the second save location
        with open(copy_save_path, 'wb') as copy_save_file:
            pickle.dump(self, copy_save_file)

    def load_checkpoint(self, path, experiment_name):

        # Obtain a list of all files present at the path
        timestamp_save_no = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]

        # If there are no such files, set values to None and return
        if len(timestamp_save_no) == 0:
            self.acc_list = None
            self.indices = None
            self.state_dict = None
            return

        # Sort the list of strings to get the most recent
        timestamp_save_no.sort(reverse=True)

        # Read in two files at a time, checking if they are equal to one another. 
        # If they are equal, then it means that the save operation finished correctly.
        # If they are not, then it means that the save operation failed (could not be 
        # done atomically). Repeat this action until no possible pair can exist.
        while len(timestamp_save_no) > 1:

            # Pop a most recent checkpoint copy
            first_file = timestamp_save_no.pop(0)

            # Keep popping until two copies with equal timestamps are present
            while True:
                
                second_file = timestamp_save_no.pop(0)
                
                # Timestamps match if the removal of the "1" or "2" results in equal numbers
                if (second_file[:-1]) == (first_file[:-1]):
                    break
                else:
                    first_file = second_file

                    # If there are no more checkpoints to examine, set to None and return
                    if len(timestamp_save_no) == 0:
                        self.acc_list = None
                        self.indices = None
                        self.state_dict = None
                        return

            # Form the paths to the files
            load_path = os.path.join(path, first_file)
            copy_load_path = os.path.join(path, second_file)

            # Load the two checkpoints
            with open(load_path, 'rb') as load_file:
                checkpoint = pickle.load(load_file)

            with open(copy_load_path, 'rb') as copy_load_file:
                checkpoint_copy = pickle.load(copy_load_file)

            # Do not check this experiment if it is not the one we need to restore
            if checkpoint.experiment_name != experiment_name:
                continue

            # Check if they are equal
            if checkpoint == checkpoint_copy:

                # This checkpoint will suffice. Populate this checkpoint's fields 
                # with the selected checkpoint's fields.
                self.acc_list = checkpoint.acc_list
                self.indices = checkpoint.indices
                self.state_dict = checkpoint.state_dict
                return

        # Instantiate None values in acc_list, indices, and model
        self.acc_list = None
        self.indices = None
        self.state_dict = None

    def get_saved_values(self):

        return (self.acc_list, self.indices, self.state_dict)

def delete_checkpoints(checkpoint_directory, experiment_name):

    # Iteratively go through each checkpoint, deleting those whose experiment name matches.
    timestamp_save_no = [f for f in os.listdir(checkpoint_directory) if os.path.isfile(os.path.join(checkpoint_directory, f))]

    for file in timestamp_save_no:

        delete_file = False

        # Get file location
        file_path = os.path.join(checkpoint_directory, file)

        if not os.path.exists(file_path):
            continue

        # Unpickle the checkpoint and see if its experiment name matches
        with open(file_path, "rb") as load_file:

            checkpoint_copy = pickle.load(load_file)
            if checkpoint_copy.experiment_name == experiment_name:
                delete_file = True

        # Delete this file only if the experiment name matched
        if delete_file:
            os.remove(file_path)

#Logs
def write_logs(logs, save_directory, rd):
  file_path = save_directory + 'run_'+'.txt'
  with open(file_path, 'a') as f:
    f.write('---------------------\n')
    f.write('Round '+str(rd)+'\n')
    f.write('---------------------\n')
    for key, val in logs.items():
      if key == 'Training':
        f.write(str(key)+ '\n')
        for epoch in val:
          f.write(str(epoch)+'\n')       
      else:
        f.write(str(key) + ' - '+ str(val) +'\n')

## Modified Training Class (SWA)

In [None]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

class AddIndexDataset(Dataset):
    
    def __init__(self, wrapped_dataset):
        self.wrapped_dataset = wrapped_dataset
        
    def __getitem__(self, index):
        data, label = self.wrapped_dataset[index]
        return data, label, index
    
    def __len__(self):
        return len(self.wrapped_dataset)

#custom training
class data_train:

    """
    Provides a configurable training loop for AL.
    
    Parameters
    ----------
    training_dataset: torch.utils.data.Dataset
        The training dataset to use
    net: torch.nn.Module
        The model to train
    args: dict
        Additional arguments to control the training loop
        
        `batch_size` - The size of each training batch (int, optional)
        `islogs`- Whether to return training metadata (bool, optional)
        `optimizer`- The choice of optimizer. Must be one of 'sgd' or 'adam' (string, optional)
        `isverbose`- Whether to print more messages about the training (bool, optional)
        `isreset`- Whether to reset the model before training (bool, optional)
        `max_accuracy`- The training accuracy cutoff by which to stop training (float, optional)
        `min_diff_acc`- The minimum difference in accuracy to measure in the window of monitored accuracies. If all differences are less than the minimum, stop training (float, optional)
        `window_size`- The size of the window for monitoring accuracies. If all differences are less than 'min_diff_acc', then stop training (int, optional)
        `criterion`- The criterion to use for training (typing.Callable[], optional)
        `device`- The device to use for training (string, optional)
    """
    
    def __init__(self, training_dataset, net, args):

        self.training_dataset = AddIndexDataset(training_dataset)
        self.net = net
        self.args = args
        
        self.n_pool = len(training_dataset)
        
        if 'islogs' not in args:
            self.args['islogs'] = False

        if 'optimizer' not in args:
            self.args['optimizer'] = 'sgd'
        
        if 'isverbose' not in args:
            self.args['isverbose'] = False
        
        if 'isreset' not in args:
            self.args['isreset'] = True

        if 'max_accuracy' not in args:
            self.args['max_accuracy'] = 0.95

        if 'min_diff_acc' not in args: #Threshold to monitor for
            self.args['min_diff_acc'] = 0.001

        if 'window_size' not in args:  #Window for monitoring accuracies
            self.args['window_size'] = 10
            
        if 'criterion' not in args:
            self.args['criterion'] = nn.CrossEntropyLoss()
            
        if 'device' not in args:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = args['device']

    def update_index(self, idxs_lb):
        self.idxs_lb = idxs_lb

    def update_data(self, new_training_dataset):
        """
        Updates the training dataset with the provided new training dataset
        
        Parameters
        ----------
        new_training_dataset: torch.utils.data.Dataset
            The new training dataset
        """
        self.training_dataset = AddIndexDataset(new_training_dataset)

    def get_acc_on_set(self, test_dataset):
        
        """
        Calculates and returns the accuracy on the given dataset to test
        
        Parameters
        ----------
        test_dataset: torch.utils.data.Dataset
            The dataset to test
        Returns
        -------
        accFinal: float
            The fraction of data points whose predictions by the current model match their targets
        """	
        
        try:
            self.clf
        except:
            self.clf = self.net

        if test_dataset is None:
            raise ValueError("Test data not present")
        
        if 'batch_size' in self.args:
            batch_size = self.args['batch_size']
        else:
            batch_size = 1 
        
        loader_te = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=batch_size)
        self.clf.eval()
        accFinal = 0.

        with torch.no_grad():        
            self.clf = self.clf.to(device=self.device)
            for batch_id, (x,y) in enumerate(loader_te):     
                x, y = x.to(device=self.device), y.to(device=self.device)
                out = self.clf(x)
                accFinal += torch.sum(1.0*(torch.max(out,1)[1] == y)).item() #.data.item()

        return accFinal / len(test_dataset)

    def _train_weighted(self, epoch, loader_tr, optimizer, gradient_weights):
        self.clf.train()
        accFinal = 0.
        criterion = self.args['criterion']
        criterion.reduction = "none"

        for batch_id, (x, y, idxs) in enumerate(loader_tr):
            x, y = x.to(device=self.device), y.to(device=self.device)
            gradient_weights = gradient_weights.to(device=self.device)

            optimizer.zero_grad()
            out = self.clf(x)

            # Modify the loss function to apply weights before reducing to a mean
            loss = criterion(out, y.long())

            # Perform a dot product with the loss vector and the weight vector, then divide by batch size.
            weighted_loss = torch.dot(loss, gradient_weights[idxs])
            weighted_loss = torch.div(weighted_loss, len(idxs))

            accFinal += torch.sum(torch.eq(torch.max(out,1)[1],y)).item() #.data.item()

            # Backward now does so on the weighted loss, not the regular mean loss
            weighted_loss.backward() 

            # clamp gradients, just in case
            # for p in filter(lambda p: p.grad is not None, self.clf.parameters()): p.grad.data.clamp_(min=-.1, max=.1)

            optimizer.step()
        return accFinal / len(loader_tr.dataset), weighted_loss

    def _train(self, epoch, loader_tr, optimizer):
        self.clf.train()
        accFinal = 0.
        criterion = self.args['criterion']
        criterion.reduction = "mean"

        for batch_id, (x, y, idxs) in enumerate(loader_tr):
            x, y = x.to(device=self.device), y.to(device=self.device)

            optimizer.zero_grad()
            out = self.clf(x)
            loss = criterion(out, y.long())
            accFinal += torch.sum((torch.max(out,1)[1] == y).float()).item()
            loss.backward()

            # clamp gradients, just in case
            # for p in filter(lambda p: p.grad is not None, self.clf.parameters()): p.grad.data.clamp_(min=-.1, max=.1)

            optimizer.step()
        return accFinal / len(loader_tr.dataset), loss

    def check_saturation(self, acc_monitor):
        
        saturate = True

        for i in range(len(acc_monitor)):
            for j in range(i+1, len(acc_monitor)):
                if acc_monitor[j] - acc_monitor[i] >= self.args['min_diff_acc']:
                    saturate = False
                    break

        return saturate

    def train(self, gradient_weights=None):

        print('Training..')
        def weight_reset(m):
            if hasattr(m, 'reset_parameters'):
                m.reset_parameters()

        train_logs = []
        n_epoch = self.args['n_epoch']
        
        if self.args['isreset']:
            self.clf = self.net.apply(weight_reset).to(device=self.device)
        else:
            try:
                self.clf
            except:
                self.clf = self.net.apply(weight_reset).to(device=self.device)

        if self.args['optimizer'] == 'sgd':
            optimizer = optim.SGD(self.clf.parameters(), lr = self.args['lr'], momentum=0.9, weight_decay=5e-4)
            lr_sched = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)
        
        elif self.args['optimizer'] == 'adam':
            optimizer = optim.Adam(self.clf.parameters(), lr = self.args['lr'], weight_decay=0)

        # ADD stochastic weight averaging
        swa_sched = SWALR(optimizer, anneal_strategy="cos", swa_lr=self.args['lr'])
        swa_model = AveragedModel(self.clf).to(device=self.device)

        if 'batch_size' in self.args:
            batch_size = self.args['batch_size']
        else:
            batch_size = 1

        # Set shuffle to true to encourage stochastic behavior for SGD
        loader_tr = DataLoader(self.training_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
        epoch = 1
        accCurrent = 0
        is_saturated = False
        acc_monitor = []
        start_swa = False

        while (accCurrent < self.args['max_accuracy']) and (epoch < n_epoch) and (not is_saturated): 
            
            if gradient_weights is None:
                accCurrent, lossCurrent = self._train(epoch, loader_tr, optimizer)
            else:
                accCurrent, lossCurrent = self._train_weighted(epoch, loader_tr, optimizer, gradient_weights)
            
            acc_monitor.append(accCurrent)

            if accCurrent > 0.95:
                start_swa = True

            if start_swa:
                swa_model.update_parameters(self.clf)
                swa_sched.step()
            elif self.args['optimizer'] == 'sgd':
                lr_sched.step()

            epoch += 1
            if(self.args['isverbose']):
                if epoch % 50 == 0:
                    print(str(epoch) + ' training accuracy: ' + str(accCurrent), flush=True)

            #Stop training if not converging
            if len(acc_monitor) >= self.args['window_size']:

                is_saturated = self.check_saturation(acc_monitor)
                del acc_monitor[0]

            log_string = 'Epoch:' + str(epoch) + '- training accuracy:'+str(accCurrent)+'- training loss:'+str(lossCurrent)
            train_logs.append(log_string)
            if (epoch % 50 == 0) and (accCurrent < 0.2): # resetif not converging
                self.clf = self.net.apply(weight_reset).to(device=self.device)
                
                if self.args['optimizer'] == 'sgd':

                    optimizer = optim.SGD(self.clf.parameters(), lr = self.args['lr'], momentum=0.9, weight_decay=5e-4)
                    lr_sched = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)

                else:
                    optimizer = optim.Adam(self.clf.parameters(), lr = self.args['lr'], weight_decay=0)

        print('Epoch:', str(epoch), 'Training accuracy:', round(accCurrent, 3), flush=True)

        # Update batch normalization and set averaged model to clf
        update_bn(loader_tr, swa_model, device=self.device)
        self.clf = swa_model.module

        if self.args['islogs']:
            return self.clf, train_logs
        else:
            return self.clf

## Shake-Shake Model

In [None]:
class ShakeBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride=1):
        super(ShakeBlock, self).__init__()
        self.equal_io = in_ch == out_ch
        self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride)

        self.branch1 = self._make_branch(in_ch, out_ch, stride)
        self.branch2 = self._make_branch(in_ch, out_ch, stride)

    def forward(self, x):
        h1 = self.branch1(x)
        h2 = self.branch2(x)
        h = ShakeShake.apply(h1, h2, self.training)
        h0 = x if self.equal_io else self.shortcut(x)
        return h + h0

    def _make_branch(self, in_ch, out_ch, stride=1):
        return nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=False),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch))


class ShakeResNet(nn.Module):

    def __init__(self, depth, w_base, label):
        super(ShakeResNet, self).__init__()
        n_units = (depth - 2) / 6

        in_chs = [16, w_base, w_base * 2, w_base * 4]
        self.in_chs = in_chs

        self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
        self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1])
        self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2)
        self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2)
        self.fc_out = nn.Linear(in_chs[3], label)

        # Initialize paramters
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x, last=False):
        h = self.c_in(x)
        h = self.layer1(h)
        h = self.layer2(h)
        h = self.layer3(h)
        h = F.relu(h)
        h = F.avg_pool2d(h, 8)
        h = h.view(-1, self.in_chs[3])
        if last:
            output = self.fc_out(h)
            return output, h
        else:
            h = self.fc_out(h)
            return h

    def get_embedding_dim(self):
        return self.fc_out.in_features

    def _make_layer(self, n_units, in_ch, out_ch, stride=1):
        layers = []
        for i in range(int(n_units)):
            layers.append(ShakeBlock(in_ch, out_ch, stride=stride))
            in_ch, stride = out_ch, 1
        return nn.Sequential(*layers)

## AL Loop

In [None]:
def train_one(full_train_dataset, initial_train_indices, test_dataset, net, n_rounds, budget, args, nclasses, strategy, save_directory, checkpoint_directory, experiment_name):

    # Split the full training dataset into an initial training dataset and an unlabeled dataset
    train_dataset = Subset(full_train_dataset, initial_train_indices)
    initial_unlabeled_indices = list(set(range(len(full_train_dataset))) - set(initial_train_indices))
    unlabeled_dataset = Subset(full_train_dataset, initial_unlabeled_indices)

    # Set up the AL strategy
    if strategy == "random":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = RandomSampling(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "entropy":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = EntropySampling(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "margin":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = MarginSampling(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "least_confidence":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = LeastConfidenceSampling(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "badge":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = BADGE(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "coreset":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = CoreSet(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "fass":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = FASS(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "glister":
        strategy_args = {'batch_size' : args['batch_size'], 'lr': args['lr'], 'device':args['device']}
        strategy = GLISTER(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args, typeOf='rand', lam=0.1)
    elif strategy == "adversarial_bim":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = AdversarialBIM(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "adversarial_deepfool":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = AdversarialDeepFool(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)
    elif strategy == "bald":
        strategy_args = {'batch_size' : args['batch_size'], 'device':args['device']}
        strategy = BALDDropout(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset), net, nclasses, strategy_args)

    # Define acc initially
    acc = np.zeros(n_rounds+1)

    initial_unlabeled_size = len(unlabeled_dataset)

    initial_round = 1

    # Define an index map
    index_map = np.array([x for x in range(initial_unlabeled_size)])

    # Attempt to load a checkpoint. If one exists, then the experiment crashed.
    training_checkpoint = Checkpoint(experiment_name=experiment_name, path=checkpoint_directory)
    rec_acc, rec_indices, rec_state_dict = training_checkpoint.get_saved_values()

    # Check if there are values to recover
    if rec_acc is not None:

        # Restore the accuracy list
        for i in range(len(rec_acc)):
            acc[i] = rec_acc[i]

        # Restore the indices list and shift those unlabeled points to the labeled set.
        index_map = np.delete(index_map, rec_indices)

        # Record initial size of the training dataset
        intial_seed_size = len(train_dataset)

        restored_unlabeled_points = Subset(unlabeled_dataset, rec_indices)
        train_dataset = ConcatDataset([train_dataset, restored_unlabeled_points])

        remaining_unlabeled_indices = list(set(range(len(unlabeled_dataset))) - set(rec_indices))
        unlabeled_dataset = Subset(unlabeled_dataset, remaining_unlabeled_indices)

        # Restore the model
        net.load_state_dict(rec_state_dict) 

        # Fix the initial round
        initial_round = (len(train_dataset) - initial_seed_size) // budget + 1

        # Ensure loaded model is moved to GPU
        if torch.cuda.is_available():
            net = net.cuda()     

        strategy.update_model(net)
        strategy.update_data(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset)) 

        dt = data_train(train_dataset, net, args)

    else:

        if torch.cuda.is_available():
            net = net.cuda()

        dt = data_train(train_dataset, net, args)

        acc[0] = dt.get_acc_on_set(test_dataset)
        print('Initial Testing accuracy:', round(acc[0]*100, 2), flush=True)

        logs = {}
        logs['Training Points'] = len(train_dataset)
        logs['Test Accuracy'] =  str(round(acc[0]*100, 2))
        write_logs(logs, save_directory, 0)
          
        #Updating the trained model in strategy class
        strategy.update_model(net)

    # Record the training transform and test transform for disabling purposes
    train_transform = full_train_dataset.transform
    test_transform = test_dataset.transform

    ##User Controlled Loop
    for rd in range(initial_round, n_rounds+1):
        print('-------------------------------------------------')
        print('Round', rd) 
        print('-------------------------------------------------')

        sel_time = time.time()
        full_train_dataset.transform = test_transform # Disable any augmentation while selecting points
        idx = strategy.select(budget)            
        full_train_dataset.transform = train_transform # Re-enable any augmentation done during training
        sel_time = time.time() - sel_time
        print("Selection Time:", sel_time)

        selected_unlabeled_points = Subset(unlabeled_dataset, idx)
        train_dataset = ConcatDataset([train_dataset, selected_unlabeled_points])

        remaining_unlabeled_indices = list(set(range(len(unlabeled_dataset))) - set(idx))
        unlabeled_dataset = Subset(unlabeled_dataset, remaining_unlabeled_indices)

        # Update the index map
        index_map = np.delete(index_map, idx, axis = 0)

        print('Number of training points -', len(train_dataset))

        # Start training
        strategy.update_data(train_dataset, LabeledToUnlabeledDataset(unlabeled_dataset))
        dt.update_data(train_dataset)
        t1 = time.time()
        clf, train_logs = dt.train(None)
        t2 = time.time()
        acc[rd] = dt.get_acc_on_set(test_dataset)
        logs = {}
        logs['Training Points'] = len(train_dataset)
        logs['Test Accuracy'] =  str(round(acc[rd]*100, 2))
        logs['Selection Time'] = str(sel_time)
        logs['Trainining Time'] = str(t2 - t1) 
        logs['Training'] = train_logs
        write_logs(logs, save_directory, rd)
        strategy.update_model(clf)
        print('Testing accuracy:', round(acc[rd]*100, 2), flush=True)

        # Create a checkpoint
        used_indices = np.array([x for x in range(initial_unlabeled_size)])
        used_indices = np.delete(used_indices, index_map).tolist()

        round_checkpoint = Checkpoint(acc.tolist(), used_indices, clf.state_dict(), experiment_name=experiment_name)
        round_checkpoint.save_checkpoint(checkpoint_directory)

    print('Training Completed')
    return acc

# CIFAR10

## Parameter Definitions

Parameters related to the specific experiment are placed here. You should examine each and modify them as needed.

In [None]:
data_set_name = "CIFAR10" # DSET NAME HERE
dataset_root_path = '../downloaded_data/'
net = ShakeResNet(depth=14, w_base=32, label=10) # MODEL HERE

# MODIFY AS NECESSARY
logs_directory = '/content/gdrive/MyDrive/colab_storage/logs/'
checkpoint_directory = '/content/gdrive/MyDrive/colab_storage/check/'
model_directory = "/content/gdrive/MyDrive/colab_storage/model/"

experiment_name = "CIFAR10 SWA SS"

initial_seed_size = 1000 # INIT SEED SIZE HERE
training_size_cap = 25000 # TRAIN SIZE CAP HERE

budget = 3000 # BUDGET HERE

# CHANGE ARGS AS NECESSARY
args = {'n_epoch':300, 'lr':float(0.01), 'batch_size':20, 'max_accuracy':float(0.99), 'islogs':True, 'isreset':True, 'isverbose':True, 'device':'cuda'} 

# Train on approximately the full dataset given the budget contraints
n_rounds = (training_size_cap - initial_seed_size) // budget

## Initial Loading and Training

You may choose to train a new initial model or to continue to load a specific model. If this notebook is being executed in Colab, you should consider whether or not you need the gdown line.

In [None]:
# Mount drive containing possible saved model and define file path.
colab_model_storage_mount = "/content/gdrive"
drive.mount(colab_model_storage_mount)

# Retrieve the model from a download link and save it to the drive
os.makedirs(logs_directory, exist_ok = True)
os.makedirs(checkpoint_directory, exist_ok = True)
os.makedirs(model_directory, exist_ok = True)
model_directory = F"{model_directory}/{data_set_name}"
#!/content/gdown.pl/gdown.pl "INSERT SHARABLE LINK HERE" "INSERT DOWNLOAD LOCATION HERE (ideally, same as model_directory)" # MAY NOT NEED THIS LINE IF NOT CLONING MODEL FROM COLAB

# Load the dataset
if data_set_name == "CIFAR10":

    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    full_train_dataset = datasets.CIFAR10(dataset_root_path, download=True, train=True, transform=train_transform, target_transform=torch.tensor)
    test_dataset = datasets.CIFAR10(dataset_root_path, download=True, train=False, transform=test_transform, target_transform=torch.tensor)

    nclasses = 10 # NUM CLASSES HERE

elif data_set_name == "CIFAR100":

    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

    full_train_dataset = datasets.CIFAR100(dataset_root_path, download=True, train=True, transform=train_transform, target_transform=torch.tensor)
    test_dataset = datasets.CIFAR100(dataset_root_path, download=True, train=False, transform=test_transform, target_transform=torch.tensor)

    nclasses = 100 # NUM CLASSES HERE

elif data_set_name == "MNIST":

    image_dim=28
    train_transform = transforms.Compose([transforms.RandomCrop(image_dim, padding=4), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    test_transform = transforms.Compose([transforms.Resize((image_dim, image_dim)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    full_train_dataset = datasets.MNIST(dataset_root_path, download=True, train=True, transform=train_transform, target_transform=torch.tensor)
    test_dataset = datasets.MNIST(dataset_root_path, download=True, train=False, transform=test_transform, target_transform=torch.tensor)

    nclasses = 10 # NUM CLASSES HERE

elif data_set_name == "FashionMNIST":

    train_transform = transforms.Compose([transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Use mean/std of MNIST

    full_train_dataset = datasets.FashionMNIST(dataset_root_path, download=True, train=True, transform=train_transform, target_transform=torch.tensor)
    test_dataset = datasets.FashionMNIST(dataset_root_path, download=True, train=False, transform=test_transform, target_transform=torch.tensor)

    nclasses = 10 # NUM CLASSES HERE

elif data_set_name == "SVHN":

    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std

    full_train_dataset = datasets.SVHN(dataset_root_path, split='train', download=True, transform=train_transform, target_transform=torch.tensor)
    test_dataset = datasets.SVHN(dataset_root_path, split='test', download=True, transform=test_transform, target_transform=torch.tensor)

    nclasses = 10 # NUM CLASSES HERE

elif data_set_name == "ImageNet":

    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std

    # Note: Not automatically downloaded due to size restrictions. Notebook needs to be adapted to run on local device.
    full_train_dataset = datasets.ImageNet(dataset_root_path, download=False, split='train', transform=train_transform, target_transform=torch.tensor)
    test_dataset = datasets.ImageNet(dataset_root_path, download=False, split='val', transform=test_transform, target_transform=torch.tensor)

    nclasses = 1000 # NUM CLASSES HERE

args['nclasses'] = nclasses

dim = full_train_dataset[0][0].shape

# Seed the random number generator for reproducibility and create the initial seed set
np.random.seed(42)
initial_train_indices = np.random.choice(len(full_train_dataset), replace=False, size=initial_seed_size)

# COMMENT OUT ONE OR THE OTHER IF YOU WANT TO TRAIN A NEW INITIAL MODEL
load_model = False
#load_model = True

# Only train a new model if one does not exist.
if load_model:
    net.load_state_dict(torch.load(model_directory))
    initial_model = net
else:
    dt = data_train(Subset(full_train_dataset, initial_train_indices), net, args)
    initial_model, _ = dt.train(None)
    torch.save(initial_model.state_dict(), model_directory)

print("Training for", n_rounds, "rounds with budget", budget, "on unlabeled set size", training_size_cap)

## Random Sampling

In [None]:
strategy = "random"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## Entropy

In [None]:
strategy = "entropy"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## GLISTER

In [None]:
strategy = "glister"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## FASS

In [None]:
strategy = "fass"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## BADGE

In [None]:
strategy = "badge"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## CoreSet

In [None]:
strategy = "coreset"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## Least Confidence

In [None]:
strategy = "least_confidence"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")

## Margin

In [None]:
strategy = "margin"
strat_logs = logs_directory+F'{data_set_name}/{strategy}/'
os.makedirs(strat_logs, exist_ok = True)
train_one(full_train_dataset, initial_train_indices, test_dataset, copy.deepcopy(initial_model), n_rounds, budget, args, nclasses, strategy, strat_logs, checkpoint_directory, F"{experiment_name}_{strategy}")