# Subset selection for supervised learning

In this tutorial, we will look at an example showing how to integrate various subset selection based dataloaders with typical supervised learning training loop

# Install Submodlib

### Import necessary libraries

In [1]:
import sys
sys.path.append('../')
import time
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from cords.utils.data.datasets.SL import gen_dataset
from torch.utils.data import Subset
from cords.utils.config_utils import load_config_data
import os.path as osp
from cords.utils.data.data_utils import WeightedSubset
from ray import tune

  warn(f"Failed to load image Python extension: {e}")


### Loading the CIFAR10 dataset

Since CIFAR10 dataset is a predefined dataset in CORDS repository. You can use the gen_dataset function for loading the CIFAR10 dataset.

**Input parameters of gen_dataset function:**

***datadir :*** Directory containing the data. If data is not downloaded, then data will be automatically downloaded into the mentioned directory path.

***dset_name :*** Dataset Name

***feature :*** If "classimb", we make the dataset inherently imbalanced.
          If "classimb", we make the dataset labels noisy.
          If None, we return the standard datasets.

***isnumpy :*** If True, return dataset in the numpy array format.
          If False, return dataset in torch dataset format.




In [3]:
trainset, validset, testset, num_cls = gen_dataset('../data/', 'cifar10', None, isnumpy=False)


Files already downloaded and verified
Files already downloaded and verified


### Create Train, Validation and Test dataloaders

In [4]:
trn_batch_size = 128
val_batch_size = 128
tst_batch_size = 1000

# Creating the Data Loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size,
                                          shuffle=False, pin_memory=True)

valloader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size,
                                        shuffle=False, pin_memory=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,
                                          shuffle=False, pin_memory=True)


### Defining Model

CORDS has a set of predefined models bulit in utils folder. You can import them directly.

In [7]:
from cords.utils.models import ResNet18
numclasses = 10
device = 'cuda' #Device Argument
model = ResNet18(10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model = model.to(device)

In [8]:
print(model.get_embedding_dim())

512


### Defining Loss Functions

In [6]:
criterion = nn.CrossEntropyLoss()
criterion_nored = nn.CrossEntropyLoss(reduction='none')

### Checkpoint Utility functions

In [7]:
def save_ckpt(state, ckpt_path):
    torch.save(state, ckpt_path)


def load_ckpt(ckpt_path, model, optimizer):
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    loss = checkpoint['loss']
    metrics = checkpoint['metrics']
    return start_epoch, model, optimizer, loss, metrics


### Cumulative time calculation

In [8]:
def generate_cumulative_timing(mod_timing):
    tmp = 0
    mod_cum_timing = np.zeros(len(mod_timing))
    for i in range(len(mod_timing)):
        tmp += mod_timing[i]
        mod_cum_timing[i] = tmp
    return mod_cum_timing


### Defining Optimizers and schedulers

In [9]:
optimizer = optim.SGD(model.parameters(), lr=5e-2,
                                  momentum=0.9,
                                  weight_decay=5e-4,
                                  nesterov=True)

#T_max is the maximum number of scheduler steps. Here we are using the number of epochs as the maximum number of scheduler steps.

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=200) 


### Get logger object for logging

In [10]:
def __get_logger(results_dir):
  os.makedirs(results_dir, exist_ok=True)
  # setup logger
  plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                      datefmt="%m/%d %H:%M:%S")
  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)
  s_handler = logging.StreamHandler(stream=sys.stdout)
  s_handler.setFormatter(plain_formatter)
  s_handler.setLevel(logging.INFO)
  logger.addHandler(s_handler)
  f_handler = logging.FileHandler(os.path.join(results_dir, "results.log"))
  f_handler.setFormatter(plain_formatter)
  f_handler.setLevel(logging.INFO)
  logger.addHandler(f_handler)
  logger.propagate = False
  return logger



### Instantiating logger file for logging the information

In [None]:
import logging
import os
import os.path as osp
import sys

#Results logging directory
results_dir = osp.abspath(osp.expanduser('results'))
logger = __get_logger(results_dir)

In [None]:
logger.info("hello")

### Instantiating GLISTER subset selection dataloaders
We instantiate subset dataloaders that can be used for training the models with adaptive subsets.

Each subset dataloader needs data selection strategy arguments in the form of a dotmap dictionary, logger and dataloader specific arguments like batch size, shuffle etc.

We are instantiating GLISTER dataloader here with no warm start. But any dataloader can be instantiated in the same way by passing the required arguments

**Data loader instantiation :**
```python
dataloader = <SubsetDataLoader>(trainloader, valloader, cfg.dss_args, logger, batch_size=cfg.dataloader.batch_size, shuffle=cfg.dataloader.shuffle, pin_memory=cfg.dataloader.pin_memory, collate_fn = cfg.dss_args.collate_fn)
```

Implemented SL strategies:


1.   GLISTER
2.   GradMatch
3.   CRAIG
4.   Random
5.   Submodular function based selection strategies    
    *   Facility Location
    *   GraphCut
    *   Sum Redundancy
    *   Saturated Coverage

In [None]:
from cords.utils.data.dataloader.SL.adaptive import GLISTERDataLoader, AdaptiveRandomDataLoader, \
    CRAIGDataLoader, GradMatchDataLoader, RandomDataLoader, MILODataLoader, StochasticGreedyDataLoader, \
    WeightedRandomDataLoader
from dotmap import DotMap

selection_strategy = 'CRAIG'
dss_args = dict(model=model,
                loss=criterion_nored,
                eta=0.01,
                num_classes=10,
                num_epochs=100,
                device='cuda',
                type="CRAIG",
                fraction=0.1,
                select_every=20,
                lam=0.5,
                selection_type='PerClass',
                v1=True,
                valid=False,
                kappa=0,
                eps=1e-100,
                linear_layer=True,
                optimizer='lazy',
                if_convex=False)
dss_args = DotMap(dss_args)

dataloader = CRAIGDataLoader(trainloader, valloader, dss_args, logger, 
                                  batch_size=20, 
                                  shuffle=True,
                                  pin_memory=False)



### Additional arguments for training, evaluation and checkpointing

In [None]:
#Training Arguments
num_epochs = 200

#Arguments for results logging
print_every = 10
print_args = ["val_loss", "val_acc", "tst_loss", "tst_acc", "trn_loss", "trn_acc", "time"]

#Argumets for checkpointing
save_every = 20
is_save = True

#Evaluation Metrics
trn_losses = list()
val_losses = list()
tst_losses = list()
subtrn_losses = list()
timing = [0]
trn_acc = list()
best_acc = list()
curr_best_acc = 0
val_acc = list()  
tst_acc = list()  
subtrn_acc = list()


# Evaluation Function

In [None]:
def evaluate_model(curr_best_acc):
    """
    ################################################# Evaluation Loop #################################################
    """
    trn_loss = 0
    trn_correct = 0
    trn_total = 0
    val_loss = 0
    val_correct = 0
    val_total = 0
    tst_correct = 0
    tst_total = 0
    tst_loss = 0
    model.eval()
    logger_dict = {}
    if ("trn_loss" in print_args) or ("trn_acc" in print_args):
        samples=0
    
        with torch.no_grad():
            for _, data in enumerate(trainloader):
                inputs, targets = data

                inputs, targets = inputs.to(device), \
                                  targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                trn_loss += (loss.item() * trainloader.batch_size)
                samples += targets.shape[0]
                if "trn_acc" in print_args:
                    _, predicted = outputs.max(1)
                    trn_total += targets.size(0)
                    trn_correct += predicted.eq(targets).sum().item()
            trn_loss = trn_loss/samples
            trn_losses.append(trn_loss)
            logger_dict['trn_loss'] = trn_loss
        if "trn_acc" in print_args:
            trn_acc.append(trn_correct / trn_total)
            logger_dict['trn_acc'] = trn_correct / trn_total

    if ("val_loss" in print_args) or ("val_acc" in print_args):
        samples =0
        with torch.no_grad():
            for _, data in enumerate(valloader):
                inputs, targets = data
                inputs, targets = inputs.to(device), \
                                  targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += (loss.item() * valloader.batch_size)
                samples += targets.shape[0]
                if "val_acc" in print_args:
                    
                    _, predicted = outputs.max(1)
                    val_total += targets.size(0)
                    val_correct += predicted.eq(targets).sum().item()
            val_loss = val_loss/samples
            val_losses.append(val_loss)
            logger_dict['val_loss'] = val_loss

        if "val_acc" in print_args:
            val_acc.append(val_correct / val_total)
            logger_dict['val_acc'] = val_correct / val_total

    if ("tst_loss" in print_args) or ("tst_acc" in print_args):
        samples =0
        with torch.no_grad():
            for _, data in enumerate(testloader):
                inputs, targets = data

                inputs, targets = inputs.to(device), \
                                  targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                tst_loss += (loss.item() * testloader.batch_size)
                samples += targets.shape[0]
                if "tst_acc" in print_args:
                    _, predicted = outputs.max(1)
                    tst_total += targets.size(0)
                    tst_correct += predicted.eq(targets).sum().item()
            tst_loss = tst_loss/samples
            tst_losses.append(tst_loss)
            logger_dict['tst_loss'] = tst_loss

        if (tst_correct/tst_total) > curr_best_acc:
            curr_best_acc = (tst_correct/tst_total)

        if "tst_acc" in print_args:
            tst_acc.append(tst_correct / tst_total)
            best_acc.append(curr_best_acc)
            logger_dict['tst_acc'] = tst_correct / tst_total
            logger_dict['best_acc'] = curr_best_acc

    if "subtrn_acc" in print_args:
        if epoch == 0:
            subtrn_acc.append(0)
            logger_dict['subtrn_acc'] = 0
        else:    
            subtrn_acc.append(subtrn_correct / subtrn_total)
            logger_dict['subtrn_acc'] = subtrn_correct / subtrn_total

    if "subtrn_losses" in print_args:
        if epoch == 0:
            subtrn_losses.append(0)
            logger_dict['subtrn_loss'] = 0
        else: 
            subtrn_losses.append(subtrn_loss)
            logger_dict['subtrn_loss'] = subtrn_loss

    print_str = "Epoch: " + str(epoch)
    logger_dict['Epoch'] = epoch
    logger_dict['Time'] = train_time

    """
    ################################################# Results Printing #################################################
    """

    for arg in print_args:
        if arg == "val_loss":
            print_str += " , " + "Validation Loss: " + str(val_losses[-1])

        if arg == "val_acc":
            print_str += " , " + "Validation Accuracy: " + str(val_acc[-1])

        if arg == "tst_loss":
            print_str += " , " + "Test Loss: " + str(tst_losses[-1])

        if arg == "tst_acc":
            print_str += " , " + "Test Accuracy: " + str(tst_acc[-1])
            print_str += " , " + "Best Accuracy: " + str(best_acc[-1])

        if arg == "trn_loss":
            print_str += " , " + "Training Loss: " + str(trn_losses[-1])

        if arg == "trn_acc":
            print_str += " , " + "Training Accuracy: " + str(trn_acc[-1])

        if arg == "subtrn_loss":
            print_str += " , " + "Subset Loss: " + str(subtrn_losses[-1])

        if arg == "subtrn_acc":
            print_str += " , " + "Subset Accuracy: " + str(subtrn_acc[-1])

        if arg == "time":
            print_str += " , " + "Timing: " + str(timing[-1])

    logger.info(print_str)


### Custom Training loop with evaluation

Subset dataloader returns data samples, labels and associated weights with each data sample. Hence, inorder to incorporate the weights in the dataloader into the training loop, we use a **loss function**  with **reduction='none'** to get per-sample loss values. Then we calculate the weighted average of batch losses using the following code snippet:

`loss = torch.dot(losses, weights/(weights.sum()))`

---
***NOTE***

### If you want to implement a custom training loop, please note that the subset dataloaders also returns additional weight parameter for each data sample.
---

In [None]:
"""
################################################# Training Loop #################################################
"""
train_time = 0
for epoch in range(0, num_epochs+1):
        
    # Evaluating the Model at Regular Intervals
    if (epoch % print_every == 0) or (epoch == num_epochs) or (epoch == 0):
        evaluate_model(curr_best_acc)
        
    subtrn_loss = 0
    subtrn_correct = 0
    subtrn_total = 0
    model.train()
    start_time = time.time()

    """
    ################################################# Mini-batch SGD #################################################
    """
    for _, (inputs, targets, weights) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device, non_blocking=True)
        weights = weights.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        losses = criterion_nored(outputs, targets)
        loss = torch.dot(losses, weights / (weights.sum()))
        loss.backward()
        subtrn_loss += loss.item()
        optimizer.step()
        _, predicted = outputs.max(1)
        subtrn_total += targets.size(0)
        subtrn_correct += predicted.eq(targets).sum().item()
    epoch_time = time.time() - start_time
    scheduler.step()
    timing.append(epoch_time)
    train_time += epoch_time
    

    """
    ################################################# Checkpoint Saving #################################################
    """

    if ((epoch + 1) % save_every == 0):

        metric_dict = {}

        for arg in print_args:
            if arg == "val_loss":
                metric_dict['val_loss'] = val_losses
            if arg == "val_acc":
                metric_dict['val_acc'] = val_acc
            if arg == "tst_loss":
                metric_dict['tst_loss'] = tst_losses
            if arg == "tst_acc":
                metric_dict['tst_acc'] = tst_acc
                metric_dict['best_acc'] = best_acc
            if arg == "trn_loss":
                metric_dict['trn_loss'] = trn_losses
            if arg == "trn_acc":
                metric_dict['trn_acc'] = trn_acc
            if arg == "subtrn_loss":
                metric_dict['subtrn_loss'] = subtrn_losses
            if arg == "subtrn_acc":
                metric_dict['subtrn_acc'] = subtrn_acc
            if arg == "time":
                metric_dict['time'] = timing

        ckpt_state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'metrics': metric_dict
        }

        # save checkpoint
        save_ckpt(ckpt_state, 'model.pth')
        logger.info("Model checkpoint saved at epoch: {0:d}".format(epoch + 1))

# Results Summary Logging

In [None]:
"""
################################################# Results Summary #################################################
"""
original_idxs = set([x for x in range(len(trainset))])
encountered_idxs = []
# if self.cfg.dss_args.type != 'Full':
for key in dataloader.selected_idxs.keys():
    encountered_idxs.extend(dataloader.selected_idxs[key])
encountered_idxs = set(encountered_idxs)
rem_idxs = original_idxs.difference(encountered_idxs)
encountered_percentage = len(encountered_idxs)/len(original_idxs)

logger.info("Selected Indices: ") 
logger.info(dataloader.selected_idxs)
logger.info("Percentages of data samples encountered during training: %.2f", encountered_percentage)
logger.info("Not Selected Indices: ")
logger.info(rem_idxs)                
logger.info("CRAIG Selection Run---------------------------------")
logger.info("Final SubsetTrn: {0:f}".format(subtrn_loss))
if "val_loss" in print_args:
    if "val_acc" in print_args:
        logger.info("Validation Loss: %.2f , Validation Accuracy: %.2f", val_losses[-1], val_acc[-1])
    else:
        logger.info("Validation Loss: %.2f", val_losses[-1])

if "tst_loss" in print_args:
    if "tst_acc" in print_args:
        logger.info("Test Loss: %.2f, Test Accuracy: %.2f, Best Accuracy: %.2f", tst_losses[-1], tst_acc[-1], best_acc[-1])
    else:
        logger.info("Test Data Loss: %f", tst_losses[-1])
logger.info('---------------------------------------------------------------------')
logger.info("CRAIG")
logger.info('---------------------------------------------------------------------')

"""
################################################# Final Results Logging #################################################
"""

if "val_acc" in print_args:
    val_str = "Validation Accuracy: "
    for val in val_acc:
        if val_str == "Validation Accuracy: ":
            val_str = val_str + str(val)
        else:
            val_str = val_str + " , " + str(val)
    logger.info(val_str)

if "tst_acc" in print_args:
    tst_str = "Test Accuracy: "
    for tst in tst_acc:
        if tst_str == "Test Accuracy: ":
            tst_str = tst_str + str(tst)
        else:
            tst_str = tst_str + " , " + str(tst)
    logger.info(tst_str)

    tst_str = "Best Accuracy: "
    for tst in best_acc:
        if tst_str == "Best Accuracy: ":
            tst_str = tst_str + str(tst)
        else:
            tst_str = tst_str + " , " + str(tst)
    logger.info(tst_str)

if "time" in print_args:
    time_str = "Time: "
    for t in timing:
        if time_str == "Time: ":
            time_str = time_str + str(t)
        else:
            time_str = time_str + " , " + str(t)
    logger.info(time_str)

omp_timing = np.array(timing)
omp_cum_timing = list(generate_cumulative_timing(omp_timing))
logger.info("Total time taken by %s = %.4f ", "CRAIG", omp_cum_timing[-1])

# GLISTER run using default SL training loop directly

We have incorporated the above training loop in train_sl.py file of CORDS which can be used by directly importing the TrainClassifier class from train_sl function as follows:

```
from train_sl import TrainClassifier
```

Importing Supervised learning default training loop

In [2]:
from train_sl import TrainClassifier

### Loading default GLISTER config file for CIFAR10 dataset

We can load other subset selection strategies like CRAIG, GradMatch, Random for CIFAR10 dataset by loading their respective config files.

Here we give an example of instantiating a SL training loop using GLISTER config file

In [4]:
fraction = 0.1
glister_config_file = '../configs/SL/config_glister_cifar10.py'

from cords.utils.config_utils import load_config_data

cfg = load_config_data(glister_config_file)
glister_trn = TrainClassifier(cfg)

### Default config args can be modified in the following manner

We can modify the default arguments of the config file by just assigning them a new file

In [5]:
glister_trn.cfg.scheduler.T_max = 200

glister_trn.cfg.dss_args.fraction = fraction
glister_trn.cfg.dss_args.select_every = 20

glister_trn.cfg.train_args.device = 'cuda'
glister_trn.cfg.train_args.print_every = 10
glister_trn.cfg.train_args.num_epochs = 200

### Start the training process

In [6]:
glister_trn.train()

Files already downloaded and verified
Files already downloaded and verified
[05/01 15:23:36] train_sl  23/05/01 15:23:08 INFO: Epoch: 0 , Training Loss: 0.19857515055338543 , Training Accuracy: 0.10066666666666667 , Validation Loss: 0.2005577392578125 , Validation Accuracy: 0.0942 , Test Loss: 0.18879930114746094 , Test Accuracy: 0.0999 , Best Accuracy: 0.0999 , Timing: 0
[05/01 15:24:37] train_sl  23/05/01 15:23:08 INFO: Epoch: 10 , Training Loss: 0.11375550537109375 , Training Accuracy: 0.2823555555555556 , Validation Loss: 0.11480399169921875 , Validation Accuracy: 0.276 , Test Loss: 0.10950443420410157 , Test Accuracy: 0.3053 , Best Accuracy: 0.3053 , Timing: 42.41166567802429
[05/01 15:25:20] train_sl  23/05/01 15:23:08 INFO: Model checkpoint saved at epoch: 20
[05/01 15:25:39] train_sl  23/05/01 15:23:08 INFO: Epoch: 20 , Training Loss: 0.06841008877224393 , Training Accuracy: 0.5405333333333333 , Validation Loss: 0.07114786682128907 , Validation Accuracy: 0.5322 , Test Loss: 0.0

([0.10066666666666667,
  0.2823555555555556,
  0.5405333333333333,
  0.5054444444444445,
  0.5988,
  0.6266444444444444,
  0.6966222222222223,
  0.7020666666666666,
  0.7195777777777778,
  0.7641333333333333,
  0.7544666666666666,
  0.7444666666666667,
  0.8006888888888889,
  0.8414444444444444,
  0.8531333333333333,
  0.8742666666666666,
  0.8758,
  0.9251777777777778,
  0.9248888888888889,
  0.9526,
  0.954],
 [0.0942,
  0.276,
  0.5322,
  0.4942,
  0.5846,
  0.6062,
  0.6714,
  0.679,
  0.69,
  0.7296,
  0.7178,
  0.7084,
  0.7508,
  0.8038,
  0.811,
  0.8116,
  0.8118,
  0.8572,
  0.8646,
  0.8738,
  0.8762],
 [0.0999,
  0.3053,
  0.5468,
  0.4915,
  0.5836,
  0.6263,
  0.6928,
  0.6787,
  0.6919,
  0.7326,
  0.7222,
  0.7155,
  0.7649,
  0.7952,
  0.8073,
  0.825,
  0.8229,
  0.8583,
  0.8572,
  0.8795,
  0.879],
 [0.0999,
  0.3053,
  0.5468,
  0.5468,
  0.5836,
  0.6263,
  0.6928,
  0.6928,
  0.6928,
  0.7326,
  0.7326,
  0.7326,
  0.7649,
  0.7952,
  0.8073,
  0.825,
  0.825,
  