# Minimal example of lightning-crossval

In [1]:
%load_ext watermark
%watermark -vp lightning_cv,torch,lightning,pydantic,sklearn

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.24.0

lightning_cv: 0.5.0
torch       : 2.2.2
lightning   : 2.2.4
pydantic    : 2.0.3
sklearn     : 1.4.2



In [2]:
# seed everything for reproducibility
from lightning import seed_everything
seed_everything(111)

Seed set to 111


111

`lightning-crossval` was concurrently developed with the [`Protein Set Transformer`](https://github.com/AnantharamanLab/protein_set_transformer/), so the most thorough example is in that repository. 

This package is very opinionated and integrateed with `PyTorch-Lightning` (more specifically the `Lightning Fabric` subset) and requires the following things: 

1. All models used are subclasses of `lightning.LightningModule`
2. All models take only a single argument to their `__init__` method called `config`, which is a subclass of `pydantic.BaseModel`. You should become familiar with Pydantic [here](https://docs.pydantic.dev/2.0/usage/models/).
    - This object should hold all the necessary arguments to setup the model. [Here](https://github.com/AnantharamanLab/protein_set_transformer/blob/main/src/pst/nn/config.py) is a real example from the PST model.
    - One of the fields of this config must be called `fabric`, which holds the reference to the `lightning.Fabric` object that is used to do all the magic.
        - We have provided a simple Pydantic model called `BaseModelConfig` that can be subclassed to provide this

-----

Let's start with a simple toy example. We are going to create a simple binary classifier that uses scaled dot product multihead self-attention.

## Model Config

First, we need to create the model config that has all the fields to setup the underlying components, such as the size of feed forward linear layers, number of layers, and number of attention heads.

NOTE: If you want to customize each field of the config, you can set the values to `pydantic.Field`, which takes many customization arguments, including default values, min/max values, etc.

In [3]:
from lightning_cv import BaseModelConfig

In [4]:
class ModelConfig(BaseModelConfig):
    num_layers: int
    hidden_size: int
    num_heads: int

## Model definition

Now let's define our simple classifier that MUST subclass the `lightning.LightningModule`. See the lightning [tutorial](https://lightning.ai/docs/pytorch/LTS/common/lightning_module.html) for all the ways to customize your model.

We provided a mixin class to handle some internal details related PyTorch Lightning and Lightning Fabric called `CrossValModuleMixin`.

In [5]:
import torch
from lightning import LightningModule
from lightning_cv import CrossValModuleMixin

In [6]:
class BinaryClassifier(CrossValModuleMixin, LightningModule):
    def __init__(self, config: ModelConfig):
        # need to init separately since CrossValModuleMixin needs the config
        LightningModule.__init__(self)
        CrossValModuleMixin.__init__(self, config=config)

        self.config = config

        ####### NN parts
        self.encoder = torch.nn.TransformerEncoder(
            torch.nn.TransformerEncoderLayer(
                d_model=config.hidden_size,
                nhead=config.num_heads,
                batch_first=True,
            ),
            num_layers=config.num_layers,
        )
        self.classifier = torch.nn.Linear(config.hidden_size, 1)
        ####### NN parts

        # binary cross entropy loss for binary classification
        self.objective_fn = torch.nn.BCEWithLogitsLoss()

    # this is required by pytorch lightning
    def configure_optimizers(self):
        # can define however you want
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward_step(self, x: torch.Tensor) -> torch.Tensor:
        out = self.encoder(x)
        return self.classifier(out).squeeze()

    def forward(self, batch: dict[str, torch.Tensor], stage: str):
        x, y = batch["x"], batch["y"]
        pred = self.forward_step(x)

        loss = self.objective_fn(pred, y.float().squeeze())
        
        if self.config.fabric is None:
            self.log(
                f"{stage}_loss", 
                loss,
                prog_bar=True,
                logger=True,
                on_step=True,
                on_epoch=True,
            )

        return loss

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
        return self(batch, stage="train")
    
    def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
        return self(batch, stage="val")

## Datamodule

Now to load data, you will need a custom subclass of `lightning.LightningDataModule` that implements your cross validation strategy. An additional requirement of your datamodule is that it **MUST** have the method `train_val_dataloaders` which returns an iterator with a tuple of 2  `torch.utils.data.DataLoader` objects. These dataloaders represent the training and validation dataloaders, respectively.

We have provided a base class `CrossValidationDataModule` that handles these details and takes as input a cross validation object, such as those provided by the `scikit-learn` library.

We are going to use a grouped k-folds cross validation scheme, but you can define your own custom cross validation splitters that have a `.split` method. Alternatively, you can use pre-existing splitters from `scikit-learn`, which is what we do here.

In [7]:
from sklearn.model_selection import GroupKFold
from lightning_cv import CrossValidationDataModule
from torch.utils.data import DataLoader, Dataset

### Dataset

First, we need a `torch.utils.data.Dataset`. We are going to use a simple setup where our total dataset is stored as in a single tensor of shape `[Num individual datapoints, embedding dim]`, along with a separate tensor for the binary labels.

In [8]:
class SimpleDataset(Dataset):
    def __init__(self, x: torch.Tensor, y: torch.Tensor, groups: torch.Tensor | None = None):
        self.x = x
        self.y = y
        self.groups = groups

        assert x.shape[0] == y.shape[0]
        assert y.ndim == 1

        if groups is not None:
            assert x.shape[0] == groups.shape[0]
            assert groups.ndim == 1

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {
            "x": self.x[idx],
            "y": self.y[idx],
            "groups": self.groups[idx] if self.groups is not None else None,
        }
    
    # this is needed since all CV splitters will return indices of each fold
    # rather than the actual data itself
    def collate_fn(self, batch: list[torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor]:
        if isinstance(batch, list):
            batch = torch.tensor(batch)

        return self[batch]

In [9]:
batch_size = 16
hidden_dim = 128
x = torch.randn(batch_size, hidden_dim)
y = torch.cat((
    torch.zeros(batch_size // 2),
    torch.ones(batch_size // 2),
)).long()
groups = torch.randint(0, 3, (batch_size,))

dataset = SimpleDataset(x, y, groups=groups)

In [10]:
y

tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])

In [11]:
groups

tensor([0, 2, 1, 2, 0, 0, 0, 1, 0, 0, 2, 0, 1, 2, 0, 1])

You can use the provided `CrossValidationDataModule` and be generally fine, but you can also subclass this and implement more custom functionality specific to your workflow.

In [12]:
data_module = CrossValidationDataModule(
    dataset=dataset,
    batch_size=4,
    cross_validator=GroupKFold,
    cross_validator_config={"n_splits": groups.max().item() + 1},
    group_attr="groups",
    collate_fn=dataset.collate_fn,
)

## CV Trainer
Now let's put it all together. You can perform cross validation with a `CrossValidationTrainer`. It accepts 2 arguments as input:

1. `model_type`: your model's type, NOT an instance. These need to be setup for each fold
2. `config`: a `CrossValidationTrainerConfig`, the arguments are mostly sent to `lightning.Fabric`

NOTE: The strategy of this package is to load separate models for each fold (meaning each model is in memory) to keep each fold synchronized at the epoch level.

In [13]:
from lightning_cv import CrossValidationTrainer, CrossValidationTrainerConfig

In [14]:
trainer_cfg = CrossValidationTrainerConfig(max_epochs=5)
cv_trainer = CrossValidationTrainer(
    model_type=BinaryClassifier,
    config=trainer_cfg,
)

Let's create a model config so that the model for each fold has the same setup.

In [15]:
model_cfg = ModelConfig(
    num_layers=2,
    hidden_size=hidden_dim,
    num_heads=4,
)

Then you just need to call the `.train_with_cross_validation` method with the model config and cross validation data module.

This will train a model for each fold and automatically checkpoint to `./checkpoints` with the pytorch checkpoints and the logged train/val performance metrics.

In [16]:
cv_trainer.train_with_cross_validation(data_module, model_config=model_cfg)

0it [00:00, ?it/s]

  | Name            | Type             | Params
-----------------------------------------------------
0 | _forward_module | BinaryClassifier | 1.2 M 
-----------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.745     Total estimated model params size (MB)


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

COMPLETE: Training 5 epochs completed.



The trainer still holds the trained models, which you can access through the `.fold_manager` dictionary.

For this toy demo, we can compute the accuracy of the problem, expecting a relatively high accuracy for this fake classification problem.

In [17]:
all_data = dataset[:]

per_fold_accuracy = torch.zeros(len(cv_trainer.fold_manager))
with torch.no_grad():
    for fold_idx, fold in cv_trainer.fold_manager.items():
        model: BinaryClassifier = fold.model # type: ignore
        model.eval()

        logits = model.forward_step(all_data["x"])
        pred = (torch.sigmoid(logits) >= 0.5).long()
        accuracy = (pred == all_data["y"]).float().mean().item()
        per_fold_accuracy[fold_idx] = accuracy

per_fold_accuracy, per_fold_accuracy.mean()

(tensor([0.8125, 0.8750, 0.8125]), tensor(0.8333))