# Integrating subset selection dataloaders with custom SSL training loop

In this tutorial, we will look at an example showing how to integrate RETRIEVEDataloader with custom SSL training loop

### Cloning CORDS repository

In [1]:
!git clone https://github.com/decile-team/cords.git
%cd cords/
%ls

Cloning into 'cords'...
remote: Enumerating objects: 3920, done.[K
remote: Counting objects: 100% (2542/2542), done.[K
remote: Compressing objects: 100% (1155/1155), done.[K
remote: Total 3920 (delta 1654), reused 2178 (delta 1349), pack-reused 1378[K
Receiving objects: 100% (3920/3920), 54.62 MiB | 12.55 MiB/s, done.
Resolving deltas: 100% (2391/2391), done.
/content/cords
[0m[01;34mbenchmarks[0m/  [01;34mdocs[0m/        README.md      [01;34mtests[0m/        train_ssl.py
[01;34mconfigs[0m/     [01;34mexamples[0m/    [01;34mrequirements[0m/  train_hpo.py
[01;34mcords[0m/       LICENSE.txt  setup.py       train_sl.py


### Install prerequisite libraries of CORDS

In [None]:
!pip install dotmap
!pip install apricot-select
!pip install ray[default]
!pip install ray[tune]
!pip install datasets

Collecting dotmap
  Downloading dotmap-1.3.26-py3-none-any.whl (11 kB)


###Import necessary libraries

In [None]:
import logging
import numpy, random, time, json, copy
import numpy as np
import os.path as osp
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from cords.utils.data.data_utils import WeightedSubset
from cords.utils.models import WideResNet, ShakeNet, CNN13, CNN
from cords.utils.data.datasets.SSL import utils as dataset_utils
from cords.selectionstrategies.helpers.ssl_lib.algs import utils as alg_utils
from cords.utils.models import utils as model_utils
from cords.utils.data.datasets.SSL import gen_dataset
from cords.selectionstrategies.helpers.ssl_lib.param_scheduler import scheduler
from cords.selectionstrategies.helpers.ssl_lib.misc.meter import Meter
from cords.utils.config_utils import load_config_data
import time
import os
import sys

###Get logger object for logging

In [None]:
def __get_logger(results_dir):
  os.makedirs(results_dir, exist_ok=True)
  # setup logger
  plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                      datefmt="%m/%d %H:%M:%S")
  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)
  s_handler = logging.StreamHandler(stream=sys.stdout)
  s_handler.setFormatter(plain_formatter)
  s_handler.setLevel(logging.INFO)
  logger.addHandler(s_handler)
  f_handler = logging.FileHandler(os.path.join(results_dir, "results.log"))
  f_handler.setFormatter(plain_formatter)
  f_handler.setLevel(logging.DEBUG)
  logger.addHandler(f_handler)
  logger.propagate = False
  return logger



### Defining the results directory and getting the results logger object

In [None]:
results_dir = 'results/'
logger = __get_logger(results_dir)


### Loading configuration file with predefined arguments:

We have a set of predefined configuration files added to CORDS for SSL under cords/configs/SSL/ which can be used directly by loading them as a dotmap object. 

An example of predefined configuration for CIFAR10 using VAT as SSL algorithm and RETRIEVE as subset selection strategy can be found below:

```Python3

# Learning setting
# Learning setting
config = dict(setting="SSL",
              dataset=dict(name="cifar10",
                           root="../data",
                           feature="dss",
                           type="pre-defined",
                           num_labels=4000,
                           val_ratio=0.1,
                           ood_ratio=0.5,
                           random_split=False,
                           whiten=False,
                           zca=True,
                           labeled_aug='WA',
                           unlabeled_aug='WA',
                           wa='t.t.f',
                           strong_aug=False),

              dataloader=dict(shuffle=True,
                              pin_memory=True,
                              num_workers=8,
                              l_batch_size=50,
                              ul_batch_size=50),

              model=dict(architecture='wrn',
                         type='pre-defined',
                         numclasses=10),

              ckpt=dict(is_load=False,
                        is_save=True,
                        checkpoint_model='model.ckpt',
                        checkpoint_optimizer='optimizer.ckpt',
                        start_iter=None,
                        checkpoint=10000),

              loss=dict(type='CrossEntropyLoss',
                        use_sigmoid=False),

              optimizer=dict(type="sgd",
                             momentum=0.9,
                             lr=0.03,
                             weight_decay=0,
                             nesterov=True,
                             tsa=False,
                             tsa_schedule='linear'),

              scheduler=dict(lr_decay="cos",
                             warmup_iter=0),

              ssl_args=dict(alg='vat',
                            coef=0.3,
                            ema_teacher=False,
                            ema_teacher_warmup=False,
                            ema_teacher_factor=0.999,
                            ema_apply_wd=False,
                            em=0,
                            threshold=None,
                            sharpen=None,
                            temp_softmax=None,
                            consis='ce',
                            eps=6,
                            xi=1e-6,
                            vat_iter=1
                            ),

              ssl_eval_args=dict(weight_average=False,
                                 wa_ema_factor=0.999,
                                 wa_apply_wd=False),

              dss_args=dict(type="RETRIEVE-Warm",
                            fraction=0.1,
                            select_every=20,
                            kappa=0.5,
                            linear_layer=False,
                            selection_type='Supervised',
                            greedy='Stochastic',
                            valid=True),

              train_args=dict(iteration=500000,
                              max_iter=-1,
                              device="cuda",
                              results_dir='results/',
                              disp=256,
                              seed=96)
              )

```

Please find a detailed documentation explaining the available configuration parameters in the following readthedocs [page]()

***Loading the predefined configuration file directly using the load_config_data function in CORDS***

In [None]:
from cords.utils.config_utils import load_config_data
cfg = load_config_data('/content/cords/configs/SSL/config_retrieve-warm_vat_cifar10.py')

### Loading the CIFAR10 dataset for SSL

Since CIFAR10 dataset is a predefined dataset in CORDS repository for SSL. You can use the gen_dataset function in cords/utils/data/datasets/SSL/builder.py for loading the CIFAR10 dataset.

**Input parameters of gen_dataset function:**

Parameters
-----------
    root: str
        root directory in which data is present or needs to be downloaded
    dataset: str
        dataset name,
        Existing dataset choices: ['cifar10', 'cifar100', 'svhn', 'stl10', 'cifarOOD', 'mnistOOD', 'cifarImbalance']
    validation_split: bool
        if True, return validation loader.
        We use 10% random split of training data as validation data
    cfg: argparse.Namespace or dict
        Dictionary containing necessary arguments for generating the dataset
    logger: logging.Logger
        Logger class for logging the information


In [None]:
lt_data, ult_data, test_data, num_classes, img_size = gen_dataset('data/', 'cifar10',
                                                                  False, cfg, logger)


###Defining Model

CORDS has a set of predefined models bulit in utils folder. You can import them directly by passing on the corresponding set of rquired arguments for the model.

In this notebook, we are going to use a WideResNet model that takes in the following arguments:

```
WideResNet Parameters
-----------
  num_classes: int
      number of classes
  filters: int
      number of filters
  scales: int
      number of scales
  repeat: int
      number of residual blocks per scale
  dropout: float
      dropout ratio (None indicates dropout is unused)

```

We have numclasses which is a part of model arguments in the config file and can be accessed by cfg.model.numclasses

***Note: Instead of as dictionary objects, we load config files as dotmap objects. Hence, we can use dot notation (e.g., cfg.model) or original dictionary notation (e.g., cfg['model']) to access the elements. However, we suggest the usage of dot notation for consistency purposes***


      

In [None]:
from cords.utils.models import WideResNet

scale = int(np.ceil(np.log2(img_size)))

#Defining the model and copies the model to the device mentioned in train_args.device argument in config file
model = WideResNet(cfg.model.numclasses, 32, scale, 4).to(cfg.train_args.device)

### Defining Teacher Model

Some SSL algorithms use a teacher model to estimate the consistency loss. We will be using the argument cfg.ssl_args.ema_teacher in the config file to denote as a boolean indicator for the usage of the teacher model. In our example, where we use the VAT algorithm, which does not use a teacher model. So, we can set the cfg.ssl_args.ema_teacher argument to be False.

In cases where we use teacher model, we may need to mention additional arguments like cfg.ssl_args.ema_teacher_warmup and cfg.ssl_args.ema_teacher_factor which are specifically required for calculating the teacher model properties using exponential moving average.

In [None]:
# build teacher model
scale = int(np.ceil(np.log2(img_size)))
if cfg.ssl_args.ema_teacher:
    teacher_model = WideResNet(cfg.model.numclasses, 32, scale, 4).to(cfg.train_args.device)
    teacher_model.load_state_dict(model.state_dict())
else:
    teacher_model = None

### Defining Evaluation Model

We can evaluate SSL algorithms on exponential moving average model or just on the model itself. We will be using the argument cfg.ssl_eval_args.weight_average in the config file to denote as a boolean indicator for the usage of the exponential weight average model for evaluation. In our example,
we will not be using weight avearge for evaluation. So, we can set the cfg.ssl_eval_args.weight_average argument to be False.

In cases where we use teacher model, we may need to mention additional arguments like cfg.ssl_args.ema_teacher_warmup and cfg.ssl_args.ema_teacher_factor which are specifically required for calculating the teacher model properties using exponential moving average.

In [None]:
# for evaluation
scale = int(np.ceil(np.log2(img_size)))
if cfg.ssl_eval_args.weight_average:
    average_model = WideResNet(cfg.model.numclasses, 32, scale, 4).to(cfg.train_args.device)
    average_model.load_state_dict(model.state_dict())
else:
    average_model = None


### Get SSL consistency loss functions 

gen_consistency function is implemented in the following file 'cords/selectionstrategies/helpers/ssl_lib/consistency/builder file' and it can be imported as follows:
```
from cords.selectionstrategies.helpers.ssl_lib.consistency.builder import gen_consistency
```
Existing Consistency loss functions are:
1.   Cross-Entropy Loss
2.   Squared Loss

** Note that we generate two versions of loss functions with mean reduction and without mean reduction. Loss function without mean reduction is used for data subset selection as most of the subset selection strategies need individual loss gradients. Hence, using a loss function without reduction helps calculate these individual loss gradients.**

We will be using ssl_args configuration arguments for generating the consistency function.



In [None]:
from cords.selectionstrategies.helpers.ssl_lib.consistency.builder import gen_consistency

consistency = gen_consistency(cfg.ssl_args.consis, cfg)
consistency_nored = gen_consistency(cfg.ssl_args.consis + '_red', cfg)

### Defining SSL algorithm

We integrated various consistency based SSL algorithms implemented in this awesome [repository](https://github.com/perrying/pytorch-consistency-regularization) with cords. These SSL algorithms can be imported by using gen_ssl_alg function implemented in cords.selectionstrategies.helpers.ssl_lib.algs.builder which can be imported as follows:

```
from cords.selectionstrategies.helpers.ssl_lib.algs.builder import gen_ssl_alg
```

In our example, we will be using VAT as SSL algorithm.

In [None]:
from cords.selectionstrategies.helpers.ssl_lib.algs.builder import gen_ssl_alg

ssl_alg = gen_ssl_alg(cfg.ssl_args.alg, cfg)

In [None]:
max_iteration = int(cfg.train_args.iteration * cfg.dss_args.fraction)

### Create unlabeled, labeled and test dataloaders

In [None]:
#Creating full unlabeled data loader with shuffle set to be False
ult_seq_loader = DataLoader(ult_data, batch_size=cfg.dataloader.ul_batch_size,
                                    shuffle=False, pin_memory=True)

#Creating labeled data loader with shuffle set to be False
lt_seq_loader = DataLoader(lt_data, batch_size=cfg.dataloader.l_batch_size,
                            shuffle=False, pin_memory=True)

#Creating test data loader with shuffle set to be False
test_loader = DataLoader(
    test_data,
    1,
    shuffle=False,
    drop_last=False,
    num_workers=cfg.dataloader.num_workers
)


### Instantiating RETRIEVE subset selection dataloader for unlabeled data

We instantiate subset dataloaders that can be used for training the models with adaptive subsets.

Each subset dataloader needs data selection strategy arguments in the form of a dotmap dictionary, logger and dataloader specific arguments like batch size, shuffle etc. We will be using dss_args in config file along with some additional arguments required for RETRIEVE.

Additional arguments required for RETRIEVEDataLoader on top of dss_args in the config file are:

* model
* teacher_model
* ssl_alg
* consistency_nored
* num_classes
* max_iteration
* learning rate
* device

We are instantiating RETRIEVE dataloader here with warm start. But any dataloader can be instantiated in the same way by passing the required arguments



In [None]:
from cords.utils.data.dataloader.SSL.adaptive import RETRIEVEDataLoader
from dotmap import DotMap

cfg.dss_args.model = model
cfg.dss_args.tea_model = teacher_model
cfg.dss_args.ssl_alg = ssl_alg
cfg.dss_args.loss = consistency_nored
cfg.dss_args.num_classes = num_classes
cfg.dss_args.num_iters = max_iteration
cfg.dss_args.eta = cfg.optimizer.lr
cfg.dss_args.device = cfg.train_args.device

ult_loader = RETRIEVEDataLoader(ult_seq_loader, lt_seq_loader, cfg.dss_args, logger=logger,
                                batch_size=cfg.dataloader.ul_batch_size,
                                pin_memory=cfg.dataloader.pin_memory,
                                num_workers=cfg.dataloader.num_workers)

### Get Optimizer

We store optimizer related arguments in the optimizer option of the configuration file. In our example, we will be using "sgd" optimizer with Nesterov momentum without any weight decay. The config.optimizer arguments in our example are as follows:

```
optimizer=dict(type="sgd",
                momentum=0.9,
                lr=0.03,
                weight_decay=0,
                nesterov=True,
                tsa=False,
                tsa_schedule='linear')
```

In [None]:
if cfg.optimizer.type == "sgd":
    optimizer = optim.SGD(
                model.parameters(), cfg.optimizer.lr, cfg.optimizer.momentum, 
                weight_decay=cfg.optimizer.weight_decay, nesterov=cfg.optimizer.nesterov)
elif cfg.optimizer.type == "adam":
    optimizer = optim.Adam(
        model.parameters(), cfg.optimizer.lr, (cfg.optimizer.momentum, 0.999), 
        weight_decay=cfg.optimizer.weight_decay)
else:
    raise NotImplementedError


### Get Scheduler

We store scheduler related arguments in the scheduler option of the configuration file. In our example, we will be using cosine-annealing scheduler. The config.scheduler arguments in our example are as follows:

```
scheduler=dict(lr_decay="cos",
              warmup_iter=0),

```

In [None]:
# set lr scheduler
if cfg.scheduler.lr_decay == "cos":
    if cfg.dss_args.type == 'Full':
        lr_scheduler = scheduler.CosineAnnealingLR(optimizer, max_iteration)
    else:
        lr_scheduler = scheduler.CosineAnnealingLR(optimizer,
                                                    cfg.train_args.iteration * cfg.dss_args.fraction)
elif cfg.scheduler.lr_decay == "step":
    # TODO: fixed milestones
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [400000, ], cfg.scheduler.lr_decay_rate)
else:
    raise NotImplementedError


### SSL Model Parameters Update function

In [None]:
"""
############################## Model Parameters Update ##############################
"""

def param_update(cfg,
                cur_iteration,
                model,
                teacher_model,
                optimizer,
                ssl_alg,
                consistency,
                labeled_data,
                ul_weak_data,
                ul_strong_data,
                labels,
                average_model,
                weights=None,
                ood=False
                ):
    start_time = time.time()
    # Concantenate labeled data, weakly augmented, and strongly augmented unlabeled data
    all_data = torch.cat([labeled_data, ul_weak_data, ul_strong_data], 0)
    forward_func = model.forward
    stu_logits = forward_func(all_data)
    labeled_preds = stu_logits[:labeled_data.shape[0]]

    # Separate weak unlabeled logits, and strong unlabeled logits
    stu_unlabeled_weak_logits, stu_unlabeled_strong_logits = torch.chunk(stu_logits[labels.shape[0]:], 2, dim=0)
    
    # Use training signal annealing (TSA)
    if cfg.optimizer.tsa:
        none_reduced_loss = F.cross_entropy(labeled_preds, labels, reduction="none")
        L_supervised = alg_utils.anneal_loss(
            labeled_preds, labels, none_reduced_loss, cur_iteration + 1,
            cfg.train_args.iteration, labeled_preds.shape[1], cfg.optimizer.tsa_schedule)
    else:
        L_supervised = F.cross_entropy(labeled_preds, labels)

    # IF SSL coefficient is greater than zero, calculate the consistency loss
    if cfg.ssl_args.coef > 0:
        # get target values
        if teacher_model is not None:  # get target values from teacher model
            t_forward_func = teacher_model.forward
            tea_logits = t_forward_func(all_data)
            tea_unlabeled_weak_logits, _ = torch.chunk(tea_logits[labels.shape[0]:], 2, dim=0)
        else:
            t_forward_func = forward_func
            tea_unlabeled_weak_logits = stu_unlabeled_weak_logits

        # calculate consistency loss
        model.update_batch_stats(False)
        y, targets, mask = ssl_alg(
            stu_preds=stu_unlabeled_strong_logits,
            tea_logits=tea_unlabeled_weak_logits.detach(),
            w_data=ul_strong_data,
            subset=False,
            stu_forward=forward_func,
            tea_forward=t_forward_func
        )
        model.update_batch_stats(True)

        # calculate weighted consistency loss
        if weights is None:
            L_consistency = consistency(y, targets, mask, weak_prediction=tea_unlabeled_weak_logits.softmax(1))
        else:
            L_consistency = consistency(y, targets, mask * weights,
                                        weak_prediction=tea_unlabeled_weak_logits.softmax(1))
    else:
        L_consistency = torch.zeros_like(L_supervised)
        mask = None

    # calculate total loss
    coef = scheduler.exp_warmup(cfg.ssl_args.coef, int(cfg.scheduler.warmup_iter), cur_iteration + 1)
    loss = L_supervised + coef * L_consistency
    if cfg.ssl_args.em > 0:
        loss -= cfg.ssl_args.em * \
                (stu_unlabeled_weak_logits.softmax(1) * F.log_softmax(stu_unlabeled_weak_logits, 1)).sum(1).mean()

    # update parameters
    cur_lr = optimizer.param_groups[0]["lr"]
    optimizer.zero_grad()
    loss.backward()
    if cfg.optimizer.weight_decay > 0:
        decay_coeff = cfg.optimizer.weight_decay * cur_lr
        model_utils.apply_weight_decay(model.modules(), decay_coeff)
    optimizer.step()

    # update teacher parameters by exponential moving average
    if cfg.ssl_args.ema_teacher:
        model_utils.ema_update(
            teacher_model, model, cfg.ssl_args.ema_teacher_factor,
            cfg.optimizer.weight_decay * cur_lr if cfg.ssl_args.ema_apply_wd else None,
            cur_iteration if cfg.ssl_args.ema_teacher_warmup else None)
    
    # update evaluation model's parameters by exponential moving average
    if cfg.ssl_eval_args.weight_average:
        model_utils.ema_update(
            average_model, model, cfg.ssl_eval_args.wa_ema_factor,
            cfg.optimizer.weight_decay * cur_lr if cfg.ssl_eval_args.wa_apply_wd else None)

    # calculate accuracy for labeled data
    acc = (labeled_preds.max(1)[1] == labels).float().mean()

    return {
        "acc": acc,
        "loss": loss.item(),
        "sup loss": L_supervised.item(),
        "ssl loss": L_consistency.item(),
        "mask": mask.float().mean().item() if mask is not None else 1,
        "coef": coef,
        "sec/iter": (time.time() - start_time)
    }


### SSL model evaluation function

Function that evaluates the raw SSL model and EMA evaluation model if any on test dataloader to calculate accuracy and loss metrics

In [None]:
def evaluation(raw_model, eval_model, loader, device):
    raw_model.eval()
    eval_model.eval()
    sum_raw_acc = sum_acc = sum_loss = 0
    with torch.no_grad():
        for (data, labels) in loader:
            data, labels = data.to(device), labels.to(device)
            preds = eval_model(data)
            raw_preds = raw_model(data)
            loss = F.cross_entropy(preds, labels)
            sum_loss += loss.item()
            acc = (preds.max(1)[1] == labels).float().mean()
            raw_acc = (raw_preds.max(1)[1] == labels).float().mean()
            sum_acc += acc.item()
            sum_raw_acc += raw_acc.item()
    mean_raw_acc = sum_raw_acc / len(loader)
    mean_acc = sum_acc / len(loader)
    mean_loss = sum_loss / len(loader)
    raw_model.train()
    eval_model.train()
    return mean_raw_acc, mean_acc, mean_loss


### SSL Training loop

In SSL training loop, we iterate over batches of labeled and unlabeled data subset selected. We can do this by iterating over labeled and RETRIEVEDataloader as follows:

```
for batch_idx, (l_data, ul_data) in enumerate(zip(lt_loader, ult_loader)):
  # ult_loader is an object of RETRIEVEDataloader class
```

In [None]:
model.train()
logger.info(model)

# init meter for metrics logging
metric_meter = Meter()
test_acc_list = []
raw_acc_list = []
logger.info("training")

iter_count = 1
subset_selection_time = 0
training_time = 0

# Start training until maximum number of iterations are reached
while iter_count <= max_iteration:
    lt_loader = DataLoader(
        lt_data,
        cfg.dataloader.l_batch_size,
        sampler=dataset_utils.InfiniteSampler(len(lt_data), len(list(
            ult_loader.batch_sampler)) * cfg.dataloader.l_batch_size),
        num_workers=cfg.dataloader.num_workers
    )

    logger.debug("Data loader iteration count is: {0:d}".format(len(list(ult_loader.batch_sampler))))
    # Enumerate on batches of labeled and unlabeled data. 
    # Note that the ult_loader enumerates only on subsets of unlabeled data selected by RETRIEVE
    for batch_idx, (l_data, ul_data) in enumerate(zip(lt_loader, ult_loader)):
        batch_start_time = time.time()
        if iter_count > max_iteration:
            break
        l_aug, labels = l_data
        ul_w_aug, ul_s_aug, _, weights = ul_data
        if cfg.dataset.feature in ['ood', 'classimb']:
            ood = True
        else:
            ood = False
        params = param_update(
                cfg, iter_count, model, teacher_model, optimizer, ssl_alg,
                consistency, l_aug.to(cfg.train_args.device), ul_w_aug.to(cfg.train_args.device),
                ul_s_aug.to(cfg.train_args.device), labels.to(cfg.train_args.device),
                average_model, weights=weights.to(cfg.train_args.device), ood=ood)
        training_time += (time.time() - batch_start_time)
        
        # moving average for reporting losses and accuracy
        metric_meter.add(params, ignores=["coef"])
        
        # display losses every cfg.disp iterations
        if ((iter_count + 1) % cfg.train_args.disp) == 0:
            state = metric_meter.state(
                header=f'[{iter_count + 1}/{max_iteration}]',
                footer=f'ssl coef {params["coef"]:.4g} | lr {optimizer.param_groups[0]["lr"]:.4g}'
            )
            logger.info(state)
        lr_scheduler.step()
        
        # Checkpoint model at regular intervals
        if ((iter_count + 1) % cfg.ckpt.checkpoint) == 0 or (iter_count + 1) == max_iteration:
            with torch.no_grad():
                if cfg.ssl_eval_args.weight_average:
                    eval_model = average_model
                else:
                    eval_model = model
                logger.info("test")
                mean_raw_acc, mean_test_acc, mean_test_loss = evaluation(model, eval_model, test_loader,
                                                                              cfg.train_args.device)
                logger.info("test loss %f | test acc. %f | raw acc. %f", mean_test_loss, mean_test_acc,
                            mean_raw_acc)
                test_acc_list.append(mean_test_acc)
                raw_acc_list.append(mean_raw_acc)
            torch.save(model.state_dict(), os.path.join(cfg.train_args.out_dir, "model_checkpoint.pth"))
            torch.save(optimizer.state_dict(),
                        os.path.join(cfg.train_args.out_dir, "optimizer_checkpoint.pth"))
        iter_count += 1


# Using default SSL training loop directly

We have incorporated the above training loop in train_ssl.py file of CORDS which can be used by directly importing the TrainClassifier class from train_ssl function as follows:

```
from train_ssl import TrainClassifier
```

Importing Semi-Supervised learning default training loop

In [None]:
from train_ssl import TrainClassifier

### Loading default RETRIEVE config file for CIFAR10 dataset

We can load other subset selection strategies like CRAIG, GradMatch, Random for CIFAR10 dataset by loading their respective config files.

Here we give an example of instantiating a SSL training loop using RETRIEVE config file

In [None]:
fraction = 0.1
retrieve_config_file = '/content/cords/configs/SSL/config_retrieve-warm_vat_cifar10.py'

from cords.utils.config_utils import load_config_data

cfg = load_config_data(retrieve_config_file)
retrieve_trn = TrainClassifier(cfg)

### Default config args can be modified in the following manner

We can modify the default arguments of the config file by just assigning them a new file

In [None]:
retrieve_trn.cfg.train_args.disp = 256
retrieve_trn.cfg.train_args.device = 'cuda'
retrieve_trn.cfg.dss_args.fraction = fraction

### Start the training process

In [None]:
retrieve_trn.train()