In [3]:
%pip install hydra-core
%pip install pytorch-lightning

Collecting pytorch-lightning
  Using cached pytorch_lightning-2.3.3-py3-none-any.whl (812 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Using cached torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Using cached lightning_utilities-0.11.3.post0-py3-none-any.whl (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.0.0->pytorch-lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.0.0->pytorch-lightning)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.0.0->pytorch-lightning)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2.0.0->pytorch-lightning)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-ma

In [7]:
from typing import List, Optional
import os

import numpy as np

import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
from pytorch_lightning.loggers import Logger
import torch
import torch.nn as nn

# from src.utils import utils
# from src.models.OneModel import OneModel

# log = utils.get_logger(__name__)

import pdb

In [8]:
print("Hi")

Hi


In [9]:
# where to download the datasets
data_dir = "/path/to/dir/"

# where to upload the weights and biases logs
my_project = "tutorial_notebook"
my_entity = "xyz"

In [10]:
config = {
    "trainer": {
        "_target_": "pytorch_lightning.Trainer",
        "gpus": 1,
        "min_epochs": 1,
        "max_epochs": 100,
        "weights_summary": None,
        "progress_bar_refresh_rate": 20,
    },
    "model": {
        "_target_": "src.models.OneModel.OneModel",
        "model": {
            "_target_": "src.models.modules.resnet_cifar.ResNet18"
        },
    },
    "optimizer": {
        "_target_": "torch.optim.AdamW",
        "lr": 0.001
    },
    "datamodule": {
        "_target_": "src.datamodules.datamodules.CIFAR10DataModule",
        "data_dir": data_dir,
        "batch_size": 320,
        "num_workers": 4,
        "pin_memory": True,
        "shuffle": True,
        "trainset_data_aug": False,
        # This is the irreducible loss model training, so we train on the
        # holdout set (we call this set the "valset" in the global terminology for the dataset
        # splits). Thus, we need augmentation on the valset
        "valset_data_aug": True,
    },
    "callbacks": {
        # We want to save that irreducible loss model with the lowest validation
        # loss (we validate on the "trainset", in global terminology for the
        # dataset splits).
        "model_checkpoint": {
            "_target_": "pytorch_lightning.callbacks.ModelCheckpoint",
            "monitor": "val_loss_epoch",
            "mode": "min",
            "save_top_k": 1,
            "save_last": True,
            "verbose": False,
            "dirpath": os.path.join("tutorial_outputs", "irreducible_loss_model"),
            "filename": "epoch_{epoch:03d}",
            "auto_insert_metric_name": False,
        },
    },
    "logger": {
        # Log with wandb, you could choose a different logger
        "wandb": {
            "_target_": "pytorch_lightning.loggers.wandb.WandbLogger",
            "project": my_project,
            "save_dir": ".",
            "entity": my_entity,
            "job_type": "train",
        }
    },
    "seed": 12,
    "debug": False,
    "ignore_warnings": True,
    "test_after_training": True,
    "base_outdir": "logs",
}

In [12]:
# convert config to OmegaConf structured dict (default for Hydra), and pretty-print
config = OmegaConf.create(config)

In [13]:
# Set seed for random number generators in pytorch, numpy and python.random
if "seed" in config:
    seed_everything(config.seed, workers=True)

# Init lightning datamodule
print(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
datamodule.setup()

# Init lightning model
print(f"Instantiating model <{config.model._target_}>")
pl_model: LightningModule = hydra.utils.instantiate(
    config=config.model,
    optimizer_config=utils.mask_config(
        config.get("optimizer", None)
    ),  # When initialising the optimiser, you need to pass it the model parameters. As we haven't initialised the model yet, we cannot initialise the optimizer here. Thus, we need to pass-through the optimizer-config, to initialise it later. However, hydra.utils.instantiate will instatiate everything that looks like a config (if _recursive_==True, which is required here bc OneModel expects a model argument). Thus, we "mask" the optimizer config from hydra, by modifying the dict so that hydra no longer recognises it as a config.
    scheduler_config=utils.mask_config(
        config.get("scheduler", None)
    ),  # see line above
    datamodule=datamodule,
    _convert_="partial",
)

# Init lightning callbacks. Here, we only use one callback: saving the model
# with the lowest validation set loss.
callbacks: List[Callback] = []
if "callbacks" in config:
    for _, cb_conf in config.callbacks.items():
        if "_target_" in cb_conf:
            print(f"Instantiating callback <{cb_conf._target_}>")
            callbacks.append(hydra.utils.instantiate(cb_conf))

# Init lightning loggers. Here, we use wandb.
logger: List[LightningLoggerBase] = []
if "logger" in config:
    for _, lg_conf in config.logger.items():
        if "_target_" in lg_conf:
            print(f"Instantiating logger <{lg_conf._target_}>")
            logger.append(hydra.utils.instantiate(lg_conf))

# Init lightning trainer
print(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
    config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

# Send config to all lightning loggers
print("Logging hyperparameters!")
trainer.logger.log_hyperparams(config)

# Train the model.
print("Starting training!")
trainer.fit(
    pl_model,
    train_dataloaders=datamodule.val_dataloader(), # see Markdown comment above
    val_dataloaders=datamodule.train_dataloader(), # see Markdown comment above
)

# Evaluate model on test set, using the best model achieved during training
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
    print("Starting testing!")
    trainer.test(test_dataloaders=datamodule.test_dataloader())

def evaluate_and_save_model_from_checkpoint_path(checkpoint_path, name):
    """Compute irreducible loss for the whole trainset with the best model"""

    # load best model
    model = OneModel.load_from_checkpoint(checkpoint_path)

    # compute irreducible losses
    model.eval()
    irreducible_loss_and_checks = utils.compute_losses_with_sanity_checks(
        dataloader=datamodule.train_dataloader(), model=model
    )

    # save irred losses in same directory as model checkpoint
    path = os.path.join(
        os.path.dirname(trainer.checkpoint_callback.best_model_path),
        name,
    )
    torch.save(irreducible_loss_and_checks, path)

    return path

saved_path = evaluate_and_save_model_from_checkpoint_path(
    trainer.checkpoint_callback.best_model_path, "irred_losses_and_checks.pt"
)

print(f"Using monitor: {trainer.checkpoint_callback.monitor}")

# Print path to best checkpoint
print(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
print(f"Best checkpoint irred_losses_path:\n{saved_path}")


# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
    config=config,
    model=pl_model,
    datamodule=datamodule,
    trainer=trainer,
    callbacks=callbacks,
    logger=logger,
)

INFO:lightning_fabric.utilities.seed:Seed set to 12


Instantiating datamodule <src.datamodules.datamodules.CIFAR10DataModule>


InstantiationException: Error locating target 'src.datamodules.datamodules.CIFAR10DataModule', set env var HYDRA_FULL_ERROR=1 to see chained exception.
full_key: datamodule