# Subset selection for supervised learning

In this tutorial, we will look at an example showing how to integrate various subset selection based dataloaders with typical supervised learning training loop

### Cloning CORDS repository

In [None]:
!git clone https://github.com/decile-team/cords.git
%cd cords/
%ls

Cloning into 'cords'...
remote: Enumerating objects: 5113, done.[K
remote: Counting objects: 100% (428/428), done.[K
remote: Compressing objects: 100% (174/174), done.[K
remote: Total 5113 (delta 317), reused 310 (delta 246), pack-reused 4685[K
Receiving objects: 100% (5113/5113), 58.46 MiB | 17.22 MiB/s, done.
Resolving deltas: 100% (3169/3169), done.
/content/cords
[0m[01;34mbenchmarks[0m/   [01;34mexamples[0m/      [01;34mrequirements[0m/  train_ssl.py
CITATION.CFF  gradio_hpo.py  setup.py       transformers_train_sl.py
[01;34mconfigs[0m/      gradio_sl.py   [01;34mtests[0m/         [01;34mtutorial[0m/
[01;34mcords[0m/        LICENSE.txt    train_hpo.py
[01;34mdocs[0m/         README.md      train_sl.py


### Install prerequisite libraries of CORDS

In [None]:
!pip install dotmap
!pip install apricot-select
!pip install ray[default]
!pip install ray[tune]
!pip install datasets
!pip install transformers
!pip install sentence-transformers
!pip install scikit-learn
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dotmap
  Downloading dotmap-1.3.30-py3-none-any.whl (11 kB)
Installing collected packages: dotmap
Successfully installed dotmap-1.3.30
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting apricot-select
  Downloading apricot-select-0.6.1.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nose
  Downloading nose-1.3.7-py3-none-any.whl (154 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.7/154.7 KB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: apricot-select
  Building wheel for apricot-select (setup.py) ... [?25l[?25hdone
  Created wheel for apricot-select: filename=apricot_select-0.6.1-py3-none-any.whl size=48786 sha256=a2c5763cead9d961c8a5b1b319f9fc3f815b310cea31c3646888056b81afda5b
  Stored in directory: /root/.cache/pip/wheels/31/

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardX>=1.9
  Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.4/125.4 KB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.5.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.9.0-py3-none-any.whl (462 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 KB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.12.0-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m


# Install Submodlib

In [None]:
%cd ..
!git clone https://github.com/decile-team/submodlib.git
%cd submodlib
!pip install .
%cd ../cords

/content
Cloning into 'submodlib'...
remote: Enumerating objects: 2563, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 2563 (delta 0), reused 0 (delta 0), pack-reused 2559[K
Receiving objects: 100% (2563/2563), 30.56 MiB | 22.10 MiB/s, done.
Resolving deltas: 100% (1909/1909), done.
/content/submodlib
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing /content/submodlib
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting numpy==1.20.1
  Downloading numpy-1.20.1-cp38-cp38-manylinux2010_x86_64.whl (15.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.4/15.4 MB[0m [31m52.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sklearn
  Downloading sklearn-0.0.post1.tar.gz (3.6 kB)
  Preparing metadata (setup.py) 

### Import necessary libraries

In [None]:
import time
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from cords.utils.data.datasets.SL import gen_dataset
from torch.utils.data import Subset
from cords.utils.config_utils import load_config_data
import os.path as osp
from cords.utils.data.data_utils import WeightedSubset
from ray import tune

### Loading the CIFAR10 dataset

Since CIFAR10 dataset is a predefined dataset in CORDS repository. You can use the gen_dataset function for loading the CIFAR10 dataset.

**Input parameters of gen_dataset function:**

***datadir :*** Directory containing the data. If data is not downloaded, then data will be automatically downloaded into the mentioned directory path.

***dset_name :*** Dataset Name

***feature :*** If "classimb", we make the dataset inherently imbalanced.
          If "classimb", we make the dataset labels noisy.
          If None, we return the standard datasets.

***isnumpy :*** If True, return dataset in the numpy array format.
          If False, return dataset in torch dataset format.




In [None]:
trainset, validset, testset, num_cls = gen_dataset('data/', 'cifar10', None, isnumpy=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/
Files already downloaded and verified


### Create Train, Validation and Test dataloaders

In [None]:
trn_batch_size = 128
val_batch_size = 128
tst_batch_size = 1000

# Creating the Data Loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size,
                                          shuffle=False, pin_memory=True)

valloader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size,
                                        shuffle=False, pin_memory=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,
                                          shuffle=False, pin_memory=True)


### Defining Model

CORDS has a set of predefined models bulit in utils folder. You can import them directly.

In [None]:
from cords.utils.models import ResNet18
numclasses = 10
device = 'cuda' #Device Argument
model = ResNet18(10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model = model.to(device)

### Defining Loss Functions

In [None]:
criterion = nn.CrossEntropyLoss()
criterion_nored = nn.CrossEntropyLoss(reduction='none')

### Checkpoint Utility functions

In [None]:
def save_ckpt(state, ckpt_path):
    torch.save(state, ckpt_path)


def load_ckpt(ckpt_path, model, optimizer):
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    loss = checkpoint['loss']
    metrics = checkpoint['metrics']
    return start_epoch, model, optimizer, loss, metrics


### Cumulative time calculation

In [None]:
def generate_cumulative_timing(mod_timing):
    tmp = 0
    mod_cum_timing = np.zeros(len(mod_timing))
    for i in range(len(mod_timing)):
        tmp += mod_timing[i]
        mod_cum_timing[i] = tmp
    return mod_cum_timing


### Defining Optimizers and schedulers

In [None]:
optimizer = optim.SGD(model.parameters(), lr=5e-2,
                                  momentum=0.9,
                                  weight_decay=5e-4,
                                  nesterov=True)

#T_max is the maximum number of scheduler steps. Here we are using the number of epochs as the maximum number of scheduler steps.

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=200) 


### Get logger object for logging

In [None]:
def __get_logger(results_dir):
  os.makedirs(results_dir, exist_ok=True)
  # setup logger
  plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                      datefmt="%m/%d %H:%M:%S")
  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)
  s_handler = logging.StreamHandler(stream=sys.stdout)
  s_handler.setFormatter(plain_formatter)
  s_handler.setLevel(logging.INFO)
  logger.addHandler(s_handler)
  f_handler = logging.FileHandler(os.path.join(results_dir, "results.log"))
  f_handler.setFormatter(plain_formatter)
  f_handler.setLevel(logging.INFO)
  logger.addHandler(f_handler)
  logger.propagate = False
  return logger



### Instantiating logger file for logging the information

In [None]:
import logging
import os
import os.path as osp
import sys

#Results logging directory
results_dir = osp.abspath(osp.expanduser('results'))
logger = __get_logger(results_dir)

In [None]:
logger.info("hello")

[02/07 04:35:50] __main__ INFO: hello


### Instantiating GLISTER subset selection dataloaders
We instantiate subset dataloaders that can be used for training the models with adaptive subsets.

Each subset dataloader needs data selection strategy arguments in the form of a dotmap dictionary, logger and dataloader specific arguments like batch size, shuffle etc.

We are instantiating GLISTER dataloader here with no warm start. But any dataloader can be instantiated in the same way by passing the required arguments

**Data loader instantiation :**
```python
dataloader = <SubsetDataLoader>(trainloader, valloader, cfg.dss_args, logger, batch_size=cfg.dataloader.batch_size, shuffle=cfg.dataloader.shuffle, pin_memory=cfg.dataloader.pin_memory, collate_fn = cfg.dss_args.collate_fn)
```

Implemented SL strategies:


1.   GLISTER
2.   GradMatch
3.   CRAIG
4.   Random
5.   Submodular function based selection strategies    
    *   Facility Location
    *   GraphCut
    *   Sum Redundancy
    *   Saturated Coverage

In [None]:
from cords.utils.data.dataloader.SL.adaptive import GLISTERDataLoader, AdaptiveRandomDataLoader, \
    CRAIGDataLoader, GradMatchDataLoader, RandomDataLoader, MILODataLoader, StochasticGreedyDataLoader, \
    WeightedRandomDataLoader
from dotmap import DotMap

selection_strategy = 'CRAIG'
dss_args = dict(model=model,
                loss=criterion_nored,
                eta=0.01,
                num_classes=10,
                num_epochs=100,
                device='cuda',
                type="CRAIG",
                fraction=0.1,
                select_every=20,
                lam=0.5,
                selection_type='PerClass',
                v1=True,
                valid=False,
                kappa=0,
                eps=1e-100,
                linear_layer=True,
                optimizer='lazy',
                if_convex=False)
dss_args = DotMap(dss_args)

dataloader = CRAIGDataLoader(trainloader, valloader, dss_args, logger, 
                                  batch_size=20, 
                                  shuffle=True,
                                  pin_memory=False)



[02/07 04:35:52] __main__ INFO: CRAIG dataloader initialized. 


### Additional arguments for training, evaluation and checkpointing

In [None]:
#Training Arguments
num_epochs = 200

#Arguments for results logging
print_every = 10
print_args = ["val_loss", "val_acc", "tst_loss", "tst_acc", "trn_loss", "trn_acc", "time"]

#Argumets for checkpointing
save_every = 20
is_save = True

#Evaluation Metrics
trn_losses = list()
val_losses = list()
tst_losses = list()
subtrn_losses = list()
timing = [0]
trn_acc = list()
best_acc = list()
curr_best_acc = 0
val_acc = list()  
tst_acc = list()  
subtrn_acc = list()


# Evaluation Function

In [None]:
def evaluate_model(curr_best_acc):
    """
    ################################################# Evaluation Loop #################################################
    """
    trn_loss = 0
    trn_correct = 0
    trn_total = 0
    val_loss = 0
    val_correct = 0
    val_total = 0
    tst_correct = 0
    tst_total = 0
    tst_loss = 0
    model.eval()
    logger_dict = {}
    if ("trn_loss" in print_args) or ("trn_acc" in print_args):
        samples=0
    
        with torch.no_grad():
            for _, data in enumerate(trainloader):
                inputs, targets = data

                inputs, targets = inputs.to(device), \
                                  targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                trn_loss += (loss.item() * trainloader.batch_size)
                samples += targets.shape[0]
                if "trn_acc" in print_args:
                    _, predicted = outputs.max(1)
                    trn_total += targets.size(0)
                    trn_correct += predicted.eq(targets).sum().item()
            trn_loss = trn_loss/samples
            trn_losses.append(trn_loss)
            logger_dict['trn_loss'] = trn_loss
        if "trn_acc" in print_args:
            trn_acc.append(trn_correct / trn_total)
            logger_dict['trn_acc'] = trn_correct / trn_total

    if ("val_loss" in print_args) or ("val_acc" in print_args):
        samples =0
        with torch.no_grad():
            for _, data in enumerate(valloader):
                inputs, targets = data
                inputs, targets = inputs.to(device), \
                                  targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += (loss.item() * valloader.batch_size)
                samples += targets.shape[0]
                if "val_acc" in print_args:
                    
                    _, predicted = outputs.max(1)
                    val_total += targets.size(0)
                    val_correct += predicted.eq(targets).sum().item()
            val_loss = val_loss/samples
            val_losses.append(val_loss)
            logger_dict['val_loss'] = val_loss

        if "val_acc" in print_args:
            val_acc.append(val_correct / val_total)
            logger_dict['val_acc'] = val_correct / val_total

    if ("tst_loss" in print_args) or ("tst_acc" in print_args):
        samples =0
        with torch.no_grad():
            for _, data in enumerate(testloader):
                inputs, targets = data

                inputs, targets = inputs.to(device), \
                                  targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                tst_loss += (loss.item() * testloader.batch_size)
                samples += targets.shape[0]
                if "tst_acc" in print_args:
                    _, predicted = outputs.max(1)
                    tst_total += targets.size(0)
                    tst_correct += predicted.eq(targets).sum().item()
            tst_loss = tst_loss/samples
            tst_losses.append(tst_loss)
            logger_dict['tst_loss'] = tst_loss

        if (tst_correct/tst_total) > curr_best_acc:
            curr_best_acc = (tst_correct/tst_total)

        if "tst_acc" in print_args:
            tst_acc.append(tst_correct / tst_total)
            best_acc.append(curr_best_acc)
            logger_dict['tst_acc'] = tst_correct / tst_total
            logger_dict['best_acc'] = curr_best_acc

    if "subtrn_acc" in print_args:
        if epoch == 0:
            subtrn_acc.append(0)
            logger_dict['subtrn_acc'] = 0
        else:    
            subtrn_acc.append(subtrn_correct / subtrn_total)
            logger_dict['subtrn_acc'] = subtrn_correct / subtrn_total

    if "subtrn_losses" in print_args:
        if epoch == 0:
            subtrn_losses.append(0)
            logger_dict['subtrn_loss'] = 0
        else: 
            subtrn_losses.append(subtrn_loss)
            logger_dict['subtrn_loss'] = subtrn_loss

    print_str = "Epoch: " + str(epoch)
    logger_dict['Epoch'] = epoch
    logger_dict['Time'] = train_time

    """
    ################################################# Results Printing #################################################
    """

    for arg in print_args:
        if arg == "val_loss":
            print_str += " , " + "Validation Loss: " + str(val_losses[-1])

        if arg == "val_acc":
            print_str += " , " + "Validation Accuracy: " + str(val_acc[-1])

        if arg == "tst_loss":
            print_str += " , " + "Test Loss: " + str(tst_losses[-1])

        if arg == "tst_acc":
            print_str += " , " + "Test Accuracy: " + str(tst_acc[-1])
            print_str += " , " + "Best Accuracy: " + str(best_acc[-1])

        if arg == "trn_loss":
            print_str += " , " + "Training Loss: " + str(trn_losses[-1])

        if arg == "trn_acc":
            print_str += " , " + "Training Accuracy: " + str(trn_acc[-1])

        if arg == "subtrn_loss":
            print_str += " , " + "Subset Loss: " + str(subtrn_losses[-1])

        if arg == "subtrn_acc":
            print_str += " , " + "Subset Accuracy: " + str(subtrn_acc[-1])

        if arg == "time":
            print_str += " , " + "Timing: " + str(timing[-1])

    logger.info(print_str)


### Custom Training loop with evaluation

Subset dataloader returns data samples, labels and associated weights with each data sample. Hence, inorder to incorporate the weights in the dataloader into the training loop, we use a **loss function**  with **reduction='none'** to get per-sample loss values. Then we calculate the weighted average of batch losses using the following code snippet:

`loss = torch.dot(losses, weights/(weights.sum()))`

---
***NOTE***

### If you want to implement a custom training loop, please note that the subset dataloaders also returns additional weight parameter for each data sample.
---

In [None]:
"""
################################################# Training Loop #################################################
"""
train_time = 0
for epoch in range(0, num_epochs+1):
        
    # Evaluating the Model at Regular Intervals
    if (epoch % print_every == 0) or (epoch == num_epochs) or (epoch == 0):
        evaluate_model(curr_best_acc)
        
    subtrn_loss = 0
    subtrn_correct = 0
    subtrn_total = 0
    model.train()
    start_time = time.time()

    """
    ################################################# Mini-batch SGD #################################################
    """
    for _, (inputs, targets, weights) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device, non_blocking=True)
        weights = weights.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        losses = criterion_nored(outputs, targets)
        loss = torch.dot(losses, weights / (weights.sum()))
        loss.backward()
        subtrn_loss += loss.item()
        optimizer.step()
        _, predicted = outputs.max(1)
        subtrn_total += targets.size(0)
        subtrn_correct += predicted.eq(targets).sum().item()
    epoch_time = time.time() - start_time
    scheduler.step()
    timing.append(epoch_time)
    train_time += epoch_time
    

    """
    ################################################# Checkpoint Saving #################################################
    """

    if ((epoch + 1) % save_every == 0):

        metric_dict = {}

        for arg in print_args:
            if arg == "val_loss":
                metric_dict['val_loss'] = val_losses
            if arg == "val_acc":
                metric_dict['val_acc'] = val_acc
            if arg == "tst_loss":
                metric_dict['tst_loss'] = tst_losses
            if arg == "tst_acc":
                metric_dict['tst_acc'] = tst_acc
                metric_dict['best_acc'] = best_acc
            if arg == "trn_loss":
                metric_dict['trn_loss'] = trn_losses
            if arg == "trn_acc":
                metric_dict['trn_acc'] = trn_acc
            if arg == "subtrn_loss":
                metric_dict['subtrn_loss'] = subtrn_losses
            if arg == "subtrn_acc":
                metric_dict['subtrn_acc'] = subtrn_acc
            if arg == "time":
                metric_dict['time'] = timing

        ckpt_state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'metrics': metric_dict
        }

        # save checkpoint
        save_ckpt(ckpt_state, 'model.pth')
        logger.info("Model checkpoint saved at epoch: {0:d}".format(epoch + 1))

[02/07 04:36:23] __main__ INFO: Epoch: 0 , Validation Loss: 3.669420697021484 , Validation Accuracy: 0.1002 , Test Loss: 3.4174654483795166 , Test Accuracy: 0.1 , Best Accuracy: 0.1 , Training Loss: 3.5758021952311196 , Training Accuracy: 0.09997777777777778 , Timing: 0
[02/07 04:37:39] __main__ INFO: Epoch: 10 , Validation Loss: 1.5599272918701172 , Validation Accuracy: 0.4554 , Test Loss: 1.4468954086303711 , Test Accuracy: 0.4723 , Best Accuracy: 0.4723 , Training Loss: 1.5095771036783854 , Training Accuracy: 0.4458888888888889 , Timing: 5.218559503555298
[02/07 04:38:31] __main__ INFO: Model checkpoint saved at epoch: 20
[02/07 04:38:54] __main__ INFO: Epoch: 20 , Validation Loss: 1.2052583801269532 , Validation Accuracy: 0.5768 , Test Loss: 1.2199726343154906 , Test Accuracy: 0.5796 , Best Accuracy: 0.5796 , Training Loss: 1.1699570122612848 , Training Accuracy: 0.5855777777777778 , Timing: 5.142061948776245
[02/07 04:38:54] __main__ INFO: Epoch: 20, requires subset selection. 
[0

# Results Summary Logging

In [None]:
"""
################################################# Results Summary #################################################
"""
original_idxs = set([x for x in range(len(trainset))])
encountered_idxs = []
# if self.cfg.dss_args.type != 'Full':
for key in dataloader.selected_idxs.keys():
    encountered_idxs.extend(dataloader.selected_idxs[key])
encountered_idxs = set(encountered_idxs)
rem_idxs = original_idxs.difference(encountered_idxs)
encountered_percentage = len(encountered_idxs)/len(original_idxs)

logger.info("Selected Indices: ") 
logger.info(dataloader.selected_idxs)
logger.info("Percentages of data samples encountered during training: %.2f", encountered_percentage)
logger.info("Not Selected Indices: ")
logger.info(rem_idxs)                
logger.info("CRAIG Selection Run---------------------------------")
logger.info("Final SubsetTrn: {0:f}".format(subtrn_loss))
if "val_loss" in print_args:
    if "val_acc" in print_args:
        logger.info("Validation Loss: %.2f , Validation Accuracy: %.2f", val_losses[-1], val_acc[-1])
    else:
        logger.info("Validation Loss: %.2f", val_losses[-1])

if "tst_loss" in print_args:
    if "tst_acc" in print_args:
        logger.info("Test Loss: %.2f, Test Accuracy: %.2f, Best Accuracy: %.2f", tst_losses[-1], tst_acc[-1], best_acc[-1])
    else:
        logger.info("Test Data Loss: %f", tst_losses[-1])
logger.info('---------------------------------------------------------------------')
logger.info("CRAIG")
logger.info('---------------------------------------------------------------------')

"""
################################################# Final Results Logging #################################################
"""

if "val_acc" in print_args:
    val_str = "Validation Accuracy: "
    for val in val_acc:
        if val_str == "Validation Accuracy: ":
            val_str = val_str + str(val)
        else:
            val_str = val_str + " , " + str(val)
    logger.info(val_str)

if "tst_acc" in print_args:
    tst_str = "Test Accuracy: "
    for tst in tst_acc:
        if tst_str == "Test Accuracy: ":
            tst_str = tst_str + str(tst)
        else:
            tst_str = tst_str + " , " + str(tst)
    logger.info(tst_str)

    tst_str = "Best Accuracy: "
    for tst in best_acc:
        if tst_str == "Best Accuracy: ":
            tst_str = tst_str + str(tst)
        else:
            tst_str = tst_str + " , " + str(tst)
    logger.info(tst_str)

if "time" in print_args:
    time_str = "Time: "
    for t in timing:
        if time_str == "Time: ":
            time_str = time_str + str(t)
        else:
            time_str = time_str + " , " + str(t)
    logger.info(time_str)

omp_timing = np.array(timing)
omp_cum_timing = list(generate_cumulative_timing(omp_timing))
logger.info("Total time taken by %s = %.4f ", "CRAIG", omp_cum_timing[-1])

[02/07 05:13:53] __main__ INFO: Selected Indices: 
[02/07 05:13:53] __main__ INFO: {0: array([24930, 35950, 37158, ..., 35421,  3668, 39294]), 1: [36762, 16755, 16243, 13202, 1808, 35056, 9903, 11444, 3943, 24863, 6054, 41488, 27226, 30547, 9389, 32682, 29234, 6036, 20695, 18791, 37811, 892, 20619, 18154, 8788, 26259, 11842, 5157, 29461, 25312, 32581, 39990, 29135, 4486, 6362, 19546, 28309, 41709, 19501, 5953, 12792, 31602, 3887, 26392, 19646, 37994, 30875, 26317, 28418, 41602, 19879, 28491, 31735, 25841, 40596, 39887, 16324, 33253, 3647, 35485, 2232, 36685, 41905, 35552, 11773, 15304, 15277, 16252, 29002, 9765, 12300, 25213, 22075, 26035, 10603, 17559, 13299, 43276, 9526, 2471, 4183, 28433, 5895, 32529, 1942, 10147, 3652, 11988, 16766, 42648, 13005, 35721, 16103, 37228, 5852, 31320, 44682, 21247, 171, 14319, 29942, 24070, 26381, 14493, 18798, 5896, 14760, 9411, 40238, 31496, 30081, 31625, 12663, 34422, 14422, 20514, 36343, 30756, 34165, 12625, 6592, 41312, 40439, 22466, 21752, 22749, 

# GLISTER run using default SL training loop directly

We have incorporated the above training loop in train_sl.py file of CORDS which can be used by directly importing the TrainClassifier class from train_sl function as follows:

```
from train_sl import TrainClassifier
```

Importing Supervised learning default training loop

In [None]:
from train_sl import TrainClassifier

### Loading default GLISTER config file for CIFAR10 dataset

We can load other subset selection strategies like CRAIG, GradMatch, Random for CIFAR10 dataset by loading their respective config files.

Here we give an example of instantiating a SL training loop using GLISTER config file

In [None]:
fraction = 0.1
glister_config_file = '/content/cords/configs/SL/config_glister_cifar10.py'

from cords.utils.config_utils import load_config_data

cfg = load_config_data(glister_config_file)
glister_trn = TrainClassifier(cfg)

### Default config args can be modified in the following manner

We can modify the default arguments of the config file by just assigning them a new file

In [None]:
glister_trn.cfg.scheduler.T_max = 200

glister_trn.cfg.dss_args.fraction = fraction
glister_trn.cfg.dss_args.select_every = 20

glister_trn.cfg.train_args.device = 'cuda'
glister_trn.cfg.train_args.print_every = 10
glister_trn.cfg.train_args.num_epochs = 200

### Start the training process

In [None]:
glister_trn.train()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified
[02/07 05:14:24] train_sl  23/02/07 05:13:54 INFO: Epoch: 0 , Training Loss: 0.14156688164605036 , Training Accuracy: 0.0982 , Validation Loss: 0.14074729614257814 , Validation Accuracy: 0.101 , Test Loss: 0.1383069610595703 , Test Accuracy: 0.1011 , Best Accuracy: 0.1011 , Timing: 0
[02/07 05:15:12] train_sl  23/02/07 05:13:54 INFO: Epoch: 10 , Training Loss: 0.06779271850585937 , Training Accuracy: 0.5241777777777777 , Validation Loss: 0.0699295654296875 , Validation Accuracy: 0.5044 , Test Loss: 0.06966692962646484 , Test Accuracy: 0.5144 , Best Accuracy: 0.5144 , Timing: 25.90374517440796
[02/07 05:15:38] train_sl  23/02/07 05:13:54 INFO: Model checkpoint saved at epoch: 20
[02/07 05:16:01] train_sl  23/02/07 05:13:54 INFO: Epoch: 20 , Training Loss: 0.06793738844129775 , Training Accuracy: 0.5711111111111111 , Validation Loss: 0.07054898376464844 , Validation Accuracy: 0.5506 , Test Loss: 0.

([0.0982,
  0.5241777777777777,
  0.5711111111111111,
  0.6098444444444444,
  0.6232888888888889,
  0.7459777777777777,
  0.6512,
  0.7579333333333333,
  0.7676666666666667,
  0.7859111111111111,
  0.8220888888888889,
  0.8385555555555556,
  0.8183333333333334,
  0.8780888888888889,
  0.8739555555555556,
  0.8959111111111111,
  0.8984,
  0.949,
  0.9471555555555555,
  0.9693333333333334,
  0.9678],
 [0.101,
  0.5044,
  0.5506,
  0.6022,
  0.6172,
  0.7166,
  0.6168,
  0.7308,
  0.7348,
  0.758,
  0.785,
  0.7956,
  0.7772,
  0.8226,
  0.8184,
  0.8288,
  0.8362,
  0.8838,
  0.871,
  0.8918,
  0.8824],
 [0.1011,
  0.5144,
  0.5861,
  0.6058,
  0.6177,
  0.7355,
  0.6302,
  0.7289,
  0.7464,
  0.7602,
  0.7873,
  0.7966,
  0.7769,
  0.8322,
  0.8311,
  0.8332,
  0.8356,
  0.8833,
  0.8754,
  0.891,
  0.8884],
 [0.1011,
  0.5144,
  0.5861,
  0.6058,
  0.6177,
  0.7355,
  0.7355,
  0.7355,
  0.7464,
  0.7602,
  0.7873,
  0.7966,
  0.7966,
  0.8322,
  0.8322,
  0.8332,
  0.8356,
  0.8833,
 