# Tutorial notebook

This notebook runs the full training pipeline, including:
* training the irreducible loss model on the holdout set
* training the target model

The dataset is CIFAR-10, both target model and irreducible loss model have a Resnet-18 architecture.

Note: Before you can run this, you need to install the dependencies and activate the environment (see readme).

In [1]:
from typing import List, Optional
import os

import numpy as np

import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
import torch
import torch.nn as nn

from src.utils import utils
from src.models.OneModel import OneModel

log = utils.get_logger(__name__)

import pdb

cwd: /users/btech/rohanb21/data/rohanb21/RHO-Loss-main/src/datamodules


# Irreducible loss model training

## Create config

We manage config files with Hydra. For this tutorial, we just create the whole config here as a dictionary.

In [2]:
# where to download the datasets
data_dir = "./dataset/"

# where to upload the weights and biases logs
my_project = "CS772-CIFAR100"
my_entity = "rohanb21-indian-institute-of-technology-kanpur"

In [3]:
config = {
    "trainer": {
        "_target_": "pytorch_lightning.Trainer",
        "gpus": 1,
        "min_epochs": 1,
        "max_epochs": 120,
        "weights_summary": None,
        "progress_bar_refresh_rate": 20,
    },
    "model": {
        "_target_": "src.models.OneModel.OneModel",
        "model": {
            "_target_": "src.models.modules.resnet_cifar.ResNet18"
        },
    },
    "optimizer": {
        "_target_": "torch.optim.AdamW",
        "lr": 0.01
    },
    "datamodule": {
        "_target_": "src.datamodules.datamodules.CIFAR100DataModule",
        "data_dir": data_dir,
        "batch_size": 320,
        "num_workers": 4,
        "pin_memory": True,
        "shuffle": True,
        "trainset_data_aug": False,
        # This is the irreducible loss model training, so we train on the
        # holdout set (we call this set the "valset" in the global terminology for the dataset
        # splits). Thus, we need augmentation on the valset
        "valset_data_aug": True,
    },
    "callbacks": {
        # We want to save that irreducible loss model with the lowest validation
        # loss (we validate on the "trainset", in global terminology for the
        # dataset splits).
        "model_checkpoint": {
            "_target_": "pytorch_lightning.callbacks.ModelCheckpoint",
            "monitor": "val_loss_epoch",
            "mode": "min",
            "save_top_k": 1,
            "save_last": True,
            "verbose": False,
            "dirpath": os.path.join("tutorial_outputs", "irreducible_loss_model"),
            "filename": "epoch_{epoch:03d}",
            "auto_insert_metric_name": False,
        },
    },
    "logger": {
        # Log with wandb, you could choose a different logger
        "wandb": {
            "_target_": "pytorch_lightning.loggers.wandb.WandbLogger",
            "project": my_project,
            "save_dir": ".",
            "entity": my_entity,
            "job_type": "train",
        }
    },
    "seed": 12,
    "debug": False,
    "ignore_warnings": True,
    "test_after_training": True,
    "base_outdir": "logs",
}

In [4]:
# convert config to OmegaConf structured dict (default for Hydra), and pretty-print
config = OmegaConf.create(config)
utils.print_config(
    config,
    fields=(
        "trainer",
        "model",
        "datamodule",
        "callbacks",
        "logger",
        "seed",
        "optimizer",
        "scheduler",
    ),
    resolve=True,
)

## Training

We have split all datasets into a "trainset", a "valset", and a "testset".
* The "valset" is what we call the "holdout set" in the paper; it is used for training of the irreducible loss model.
* The "trainset" is what we call the training set in the paper. It is used for training the target model, and as a validation set for the irreducible loss model, to find the epoch with the lowest loss in irreducible loss model training.
* The "testset" is used to evaluate target model performance.

(For earlier experiments, we used the "valset" as the validation set for the target model. However, this was rarely used, as we did not tune any hyperparameters of the target model.)

In [5]:
# Set seed for random number generators in pytorch, numpy and python.random
if "seed" in config:
    seed_everything(config.seed, workers=True)

# Init lightning datamodule
print(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
datamodule.setup()

# Init lightning model
print(f"Instantiating model <{config.model._target_}>")
pl_model: LightningModule = hydra.utils.instantiate(
    config=config.model,
    optimizer_config=utils.mask_config(
        config.get("optimizer", None)
    ),  # When initialising the optimiser, you need to pass it the model parameters. As we haven't initialised the model yet, we cannot initialise the optimizer here. Thus, we need to pass-through the optimizer-config, to initialise it later. However, hydra.utils.instantiate will instatiate everything that looks like a config (if _recursive_==True, which is required here bc OneModel expects a model argument). Thus, we "mask" the optimizer config from hydra, by modifying the dict so that hydra no longer recognises it as a config.
    scheduler_config=utils.mask_config(
        config.get("scheduler", None)
    ),  # see line above
    datamodule=datamodule,
    _convert_="partial",
)

# Init lightning callbacks. Here, we only use one callback: saving the model
# with the lowest validation set loss. 
callbacks: List[Callback] = []
if "callbacks" in config:
    for _, cb_conf in config.callbacks.items():
        if "_target_" in cb_conf:
            print(f"Instantiating callback <{cb_conf._target_}>")
            callbacks.append(hydra.utils.instantiate(cb_conf))

# Init lightning loggers. Here, we use wandb.
logger: List[LightningLoggerBase] = []
if "logger" in config:
    for _, lg_conf in config.logger.items():
        if "_target_" in lg_conf:
            print(f"Instantiating logger <{lg_conf._target_}>")
            logger.append(hydra.utils.instantiate(lg_conf))

# Init lightning trainer
print(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
    config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

# Send config to all lightning loggers
print("Logging hyperparameters!")
trainer.logger.log_hyperparams(config)

# Train the model.
print("Starting training!")
trainer.fit(
    pl_model,
    train_dataloaders=datamodule.val_dataloader(), # see Markdown comment above
    val_dataloaders=datamodule.train_dataloader(), # see Markdown comment above
)

# Evaluate model on test set, using the best model achieved during training
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
    print("Starting testing!")
    trainer.test(test_dataloaders=datamodule.test_dataloader())

def evaluate_and_save_model_from_checkpoint_path(checkpoint_path, name):
    """Compute irreducible loss for the whole trainset with the best model"""

    # load best model
    model = OneModel.load_from_checkpoint(checkpoint_path)
    
    # compute irreducible losses
    model.eval()
    irreducible_loss_and_checks = utils.compute_losses_with_sanity_checks(
        dataloader=datamodule.train_dataloader(), model=model
    )

    # save irred losses in same directory as model checkpoint
    path = os.path.join(
        os.path.dirname(trainer.checkpoint_callback.best_model_path),
        name,
    )
    torch.save(irreducible_loss_and_checks, path)

    return path

saved_path = evaluate_and_save_model_from_checkpoint_path(
    trainer.checkpoint_callback.best_model_path, "irred_losses_and_checks.pt"
)

print(f"Using monitor: {trainer.checkpoint_callback.monitor}")

# Print path to best checkpoint
print(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
print(f"Best checkpoint irred_losses_path:\n{saved_path}")


# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
    config=config,
    model=pl_model,
    datamodule=datamodule,
    trainer=trainer,
    callbacks=callbacks,
    logger=logger,
)

Global seed set to 12


Instantiating datamodule <src.datamodules.datamodules.CIFAR100DataModule>
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Instantiating model <src.models.OneModel.OneModel>


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
Instantiating logger <pytorch_lightning.loggers.wandb.WandbLogger>
Instantiating trainer <pytorch_lightning.Trainer>
Logging hyperparameters!


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrohanb21-indian-institute-of-technology-kanpur[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.19.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training!


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
Global seed set to 12


Training: -1it [00:00, ?it/s]

  if isinstance(self.datamodule, CINIC10RelevanceDataModule):


Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Starting testing!


  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc_epoch': 0.45980000495910645, 'test_loss_epoch': 2.2732691764831543}
--------------------------------------------------------------------------------
Computing irreducible loss full training dataset.




Using monitor: val_loss_epoch
Best checkpoint path:
/users/btech/rohanb21/data/rohanb21/RHO-Loss-main/tutorial_outputs/irreducible_loss_model/epoch_020.ckpt
Best checkpoint irred_losses_path:
/users/btech/rohanb21/data/rohanb21/RHO-Loss-main/tutorial_outputs/irreducible_loss_model/irred_losses_and_checks.pt


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss_step,0.05316
train_acc_step,0.98438
epoch,119.0
trainer/global_step,9480.0
_runtime,2038.0
_timestamp,1746181836.0
_step,429.0
val_loss_epoch,3.63661
val_acc_epoch,0.5086
train_loss_epoch,0.06155


0,1
train_loss_step,█▇▇▅▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc_step,▁▁▂▃▃▄▅▅▆▇▇▇▇███████████████████████████
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss_epoch,█▆▅▃▁▂▁▁▂▃▃▃▄▄▄▅▅▆▅▅▆▅▅▆▆▆▆▆▆▅▅▆▆▆▆▆▆▆▆▅
val_acc_epoch,▁▂▃▅▆▆▇▇▇▇▇▇█████▇██████████▇███████████
train_loss_epoch,█▇▆▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


# Target model training

## Create config

In [3]:
# where to download the datasets
data_dir = "./dataset/"

# where to upload the weights and biases logs
my_project = "CS772-CIFAR100"
my_entity = "rohanb21-indian-institute-of-technology-kanpur"

# You could choose any selection method here
# ("src.curricula.selection_methods.xyz"). We have implemented:
# reducible_loss_selection, uniform_selection, ce_loss_selection,
# irreducible_loss_selection, gradnorm_ub_selection, and thers
selection_method = "reducible_loss_selection"

# Path to irreducible losses. Transferred from irreducible loss model training.
# You can replace this with the path if you want to run target model training
# without rerunning irreducible loss model training.
path_to_irreducible_losses = "./tutorial_outputs/irreducible_loss_model/irred_losses_and_checks.pt"

In [4]:
config = {
    "model": {
        "_target_": "src.models.MultiModels.MultiModels",
        "large_model": {
            "_target_": "src.models.modules.resnet_cifar.ResNet18"
        },
        "percent_train": 0.1,
    },
    "optimizer": {
        "_target_": "torch.optim.AdamW",
        "lr": 0.001
    },
    "trainer": {
        "_target_": "pytorch_lightning.Trainer",
        "gpus": 1,
        "min_epochs": 1,
        "max_epochs": 175,
        "weights_summary": None,
        "progress_bar_refresh_rate": 20,
    },
    "datamodule": {
        "_target_": "src.datamodules.datamodules.CIFAR100DataModule",
        "data_dir": data_dir,
        "batch_size": 320,
        "num_workers": 4,
        "pin_memory": True,
        "shuffle": True,
        "trainset_data_aug": True,
        "valset_data_aug": False,
    },
    "selection_method": {
        "_target_": "src.curricula.selection_methods." + selection_method
    },
    "callbacks": {
        "model_checkpoint": {
            "_target_": "pytorch_lightning.callbacks.ModelCheckpoint",
            "monitor": "val_acc_epoch",
            "mode": "max",
            "save_top_k": 1,
            "save_last": True,
            "verbose": False,
            "dirpath": os.path.join("tutorial_outputs", "target_model"),
            "filename": "epoch_{epoch:03d}",
            "auto_insert_metric_name": False,
        },
    },
    "logger": {
        "wandb": {
            "_target_": "pytorch_lightning.loggers.wandb.WandbLogger",
            "project": my_project,
            "save_dir": ".",
            "entity": my_entity,
            "job_type": "train",
        }
    },
    "irreducible_loss_generator": {
        "_target_": "torch.load",
        "f": path_to_irreducible_losses,
    },


    "debug": False,
    "ignore_warnings": True,
    "test_after_training": True,
    "seed": 12,
    "eval_set": "val", # set to test if you want to evaluate on the test set
}



In [5]:
# convert config to OmegaConf structured dict (default for Hydra), and pretty-print
config = OmegaConf.create(config)
utils.print_config(
    config,
    fields=(
        "trainer",
        "selection_method",
        "model",
        "irreducible_loss_generator",
        "datamodule",
        "callbacks",
        "logger",
        "seed",
        "optimizer",
    ),
    resolve=True,
)

## Training

In [6]:
# Set seed for random number generators in pytorch, numpy and python.random
if "seed" in config:
    seed_everything(config.seed, workers=True)

Global seed set to 12


In [7]:
# init irreducible loss generator (precomputed losses, or irreducible loss
# model)
irreducible_loss_generator = hydra.utils.instantiate(
    config.irreducible_loss_generator
)

In [8]:
# If precomputed losses are used, verify that the sorting
# of the precomputes losses matches the dataset
if type(irreducible_loss_generator) is dict:
    # instantiate a separate datamodule, so that the main datamodule is
    # instantiated with the same random seed whether or not the precomputed
    # losses are used
    datamodule_temp = hydra.utils.instantiate(config.datamodule)
    datamodule_temp.setup()
    utils.verify_correct_dataset_order(
        dataloader=datamodule_temp.train_dataloader(),
        sorted_target=irreducible_loss_generator["sorted_targets"],
        idx_of_control_images=irreducible_loss_generator["idx_of_control_images"],
        control_images=irreducible_loss_generator["control_images"],
        dont_compare_control_images=config.datamodule.get(
            "trainset_data_aug", False
        ),  # cannot compare images from irreducible loss model training run with those of the current run if there is trainset augmentation
    )

    del datamodule_temp

    irreducible_loss_generator = irreducible_loss_generator["irreducible_losses"]

    # Set seed again, so that the main datamodule is instantiated with the
    # same random seed whether or not the precomputed losses are used
    if "seed" in config:
        seed_everything(config.seed, workers=True)

# Init lightning datamodule
print(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
datamodule.setup()

# init selection method
print(f"Instantiating selection method <{config.selection_method._target_}>")
selection_method = hydra.utils.instantiate(config.selection_method)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Verifying that the dataset order is compatible with the order of the precomputed losses.


Global seed set to 12


Instantiating datamodule <src.datamodules.datamodules.CIFAR100DataModule>
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Instantiating selection method <src.curricula.selection_methods.reducible_loss_selection>


In [12]:

# Init lightning model
print(f"Instantiating models")
pl_model: LightningModule = hydra.utils.instantiate(
    config.model,
    selection_method=selection_method,
    irreducible_loss_generator=irreducible_loss_generator,
    datamodule=datamodule,
    optimizer_config=utils.mask_config(
        config.get("optimizer", None)
    ),  # When initialising the optimiser, you need to pass it the model parameters. As we haven't initialised the model yet, we cannot initialise the optimizer here. Thus, we need to pass-through the optimizer-config, to initialise it later. However, hydra.utils.instantiate will instatiate everything that looks like a config (if _recursive_==True, which is required here bc OneModel expects a model argument). Thus, we "mask" the optimizer config from hydra, by modifying the dict so that hydra no longer recognises it as a config.
    _convert_="partial",
)

# Init lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
    for _, cb_conf in config.callbacks.items():
        if "_target_" in cb_conf:
            print(f"Instantiating callback <{cb_conf._target_}>")
            callbacks.append(hydra.utils.instantiate(cb_conf))

# Init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
    for _, lg_conf in config.logger.items():
        if "_target_" in lg_conf:
            log.info(f"Instantiating logger <{lg_conf._target_}>")
            logger.append(hydra.utils.instantiate(lg_conf))

# Init lightning trainer
print(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
    config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

# Send config to all lightning loggers
print("Logging hyperparameters!")
trainer.logger.log_hyperparams(config)

# create eval set
if config.eval_set == "val":
    val_dataloader = datamodule.val_dataloader()
elif config.eval_set == "test":
    val_dataloader = datamodule.test_dataloader()
    print(
        "Using the test set as the validation dataloader. This is for final figures in the paper"
    )

# Train the model
print("Starting training!")
trainer.fit(
    pl_model,
    train_dataloaders=datamodule.train_dataloader(),
    val_dataloaders=val_dataloader, # we pass the eval set as the validation set to trainer.fit because we want to know the eval set accuracy after each epoch
)

# Make sure everything closed properly
print("Finalizing!")
utils.finish(
    config=config,
    model=pl_model,
    datamodule=datamodule,
    trainer=trainer,
    callbacks=callbacks,
    logger=logger,
)

# Print path to best checkpoint
print(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")

Instantiating models


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
Instantiating trainer <pytorch_lightning.Trainer>
Logging hyperparameters!


[34m[1mwandb[0m: wandb version 0.19.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training!


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 12


Training: -1it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Finalizing!


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss_step,0.1662
train_acc_step,0.9375
selected_irreducible_loss_2.5_step,9e-05
selected_irreducible_loss_25_step,0.00033
selected_irreducible_loss_50_step,0.00285
selected_irreducible_loss_75_step,0.00987
selected_irreducible_loss_97.5_step,1.64466
not_selected_irreducible_loss_2.5_step,0.01384
not_selected_irreducible_loss_25_step,0.38919
not_selected_irreducible_loss_50_step,1.71552


0,1
train_loss_step,██▇▇█▆▅▅▅▄▄▄▄▃▃▃▃▂▃▂▃▁▂▁▂▁▂▁▁▂▂▂▂▂▂▂▁▁▁▁
train_acc_step,▁▁▁▁▁▁▂▃▄▄▄▄▅▅▄▆▅▆▅▇▅█▇█▆▇▇██▇▇▇▇▇█▇█▇▇█
selected_irreducible_loss_2.5_step,▁▁▅▂▁█▃▅▂▄▅▃▃▃▁▂▂▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
selected_irreducible_loss_25_step,▂▁▄▃▄█▅▄▃▄▄█▄▅▄▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
selected_irreducible_loss_50_step,▂▂▅▄█▇▆▅▅▆▇▇▆▅▇▂▇▃▄▂▅▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
selected_irreducible_loss_75_step,▂▂▃▃█▅▅▅▆▅▆▅▇▅▇▅▅▆▆▂▅▁▃▁▃▁▂▁▁▂▄▂▂▁▂▁▁▁▁▁
selected_irreducible_loss_97.5_step,▁▂▂▃▅▃▄▅▅▄▃▄▅▅▄▅▅▅█▅▅▂▄▂▅▄▃▁▂▅▇▅▆▅▅▅▂▂▁▂
not_selected_irreducible_loss_2.5_step,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▂▂▁▃▁▃▄▁▄▃▁▄▄█▃▂▃▃
not_selected_irreducible_loss_25_step,▇▅▂▃▁▂▃▂▄▃▃▃▂▂▁▅▂▆▃▂▄▄▅▃▄▄▃▂▄▃▄▅▃▇▄█▄▅▆▄
not_selected_irreducible_loss_50_step,▇▅▅▃▁▃▄▁▄▄▄▃▆▄▃▇▆▅▅▂▅▃▆▂▄▃▄▁▂▃▃▄▃▆▄▅▅▄█▄


Best checkpoint path:
/users/btech/rohanb21/data/rohanb21/RHO-Loss-main/tutorial_outputs/target_model/epoch_172.ckpt


## Retreiving model after training

In [9]:
# Instantiate args from config (but not the model itself)
model_args = dict(
    selection_method=selection_method,
    irreducible_loss_generator=irreducible_loss_generator,
    datamodule=datamodule,
    optimizer_config=utils.mask_config(config.get("optimizer", None)),
)

ModelClass = hydra.utils.get_class(config.model._target_)

# Now load from checkpoint
pl_model = ModelClass.load_from_checkpoint(
    checkpoint_path="./tutorial_outputs/target_model/epoch_172.ckpt",
    **model_args
)

In [10]:
pl_model

MultiModels(
  (large_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   

In [11]:
# Init lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
    for _, cb_conf in config.callbacks.items():
        if "_target_" in cb_conf:
            print(f"Instantiating callback <{cb_conf._target_}>")
            callbacks.append(hydra.utils.instantiate(cb_conf))

# Init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
    for _, lg_conf in config.logger.items():
        if "_target_" in lg_conf:
            log.info(f"Instantiating logger <{lg_conf._target_}>")
            logger.append(hydra.utils.instantiate(lg_conf))

# Init lightning trainer
print(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
    config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
Instantiating trainer <pytorch_lightning.Trainer>


In [17]:
# Evaluate model on test set using the best model achieved during training
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
    print("Starting testing!")
    trainer.test(model=pl_model, dataloaders=datamodule.test_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting testing!


[34m[1mwandb[0m: wandb version 0.19.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Testing: 0it [00:00, ?it/s]



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.5778999924659729, 'test_loss': 1.618593692779541}
--------------------------------------------------------------------------------
Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7cf6f4dbf550>> (for post_run_cell), with arguments args (<ExecutionResult object at 7cf79ff8b220, execution_count=17 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7cf7a9f443d0, raw_cell="# Evaluate model on test set using the best model .." store_history=True silent=False shell_futures=True cell_id=None> result=None>,),kwargs {}:


TypeError: _pause_backend() takes 1 positional argument but 2 were given

In [13]:
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
    print("Starting testing!")
    trainer.test(model=pl_model, dataloaders=datamodule.test_dataloader())

    # After testing
    probs = pl_model.test_outputs["probs"]
    labels = pl_model.test_outputs["labels"]
    # You can now use probs and labels for further analysis

  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting testing!


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrohanb21-indian-institute-of-technology-kanpur[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.19.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Testing: 0it [00:00, ?it/s]



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.5778999924659729, 'test_loss': 1.618593692779541}
--------------------------------------------------------------------------------


AttributeError: 'MultiModels' object has no attribute 'test_outputs'

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7a0190073430>> (for post_run_cell), with arguments args (<ExecutionResult object at 7a019008b5e0, execution_count=13 error_before_exec=None error_in_exec='MultiModels' object has no attribute 'test_outputs' info=<ExecutionInfo object at 7a019008bfa0, raw_cell="if config.get("test_after_training") and not confi.." store_history=True silent=False shell_futures=True cell_id=None> result=None>,),kwargs {}:


TypeError: _pause_backend() takes 1 positional argument but 2 were given

In [None]:
import torch.nn.functional as F

# Use log probabilities for NLL
log_probs = torch.log(all_probs + 1e-12)  # for numerical stability

# Compute Negative Log-Likelihood
nll = F.nll_loss(log_probs, true_labels)
print(f"NLL (Single Model): {nll:.4f}")