#### Imports

In [1]:
import os
import hydra
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
import lightning.pytorch as L
import rootutils
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger, WandbLogger

rootutils.setup_root(os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True)

from src.utils import (
    instantiate_callbacks,
    instantiate_loggers,
    log_hyperparameters,
    set_precision,
)

from src.utils import rich_utils

#### Initialize config

In [None]:
with initialize(version_base='1.3', config_path='../configs'):
    cfg = compose(config_name='train.yaml', return_hydra_config=True, overrides=['experiment=train_seat_cls',
                                                                                 'logger=many_loggers',
                                                                                 'callbacks.model_summary=None',
                                                                                 'paths.log_dir=../logs',
                                                                                 'paths.output_dir=../logs'])
    HydraConfig.instance().set_config(cfg)
    
rich_utils.print_config_tree(cfg)

#### Initialize modules

In [None]:
if cfg.get('seed'):
    L.seed_everything(cfg.seed, workers=True)
if cfg.precision.get('float32_matmul'):
    set_precision(cfg.precision.float32_matmul)
loggers: list[Logger] = instantiate_loggers(cfg.get('logger'))
has_wandb = any(isinstance(logger, WandbLogger) for logger in loggers)
callbacks: list[Callback] = instantiate_callbacks(cfg.get('callbacks'), has_wandb=has_wandb)

#### Custom datamodule

In [4]:
from pathlib import Path
from typing import Any, Optional

from sklearn.model_selection import KFold

from torch.utils.data import Subset
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose

from src.data.components.preprocessing.preproc_pipeline_manager import PreprocessingPipeline
from src.data.components.utils import clear_directory


class ClassificationDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = 'data/',
        preprocessing_pipeline: PreprocessingPipeline = None,
        overwrite_data: bool = False,
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
        train_transforms: Compose = None,
        val_test_transforms: Compose = None,
        save_predict_images: bool = False,
        num_classes: int = 2,
        k_folds: int = 5,
        current_fold: int = 0,
    ) -> None:
        """Initialize a `DirDataModule`.

        Args:
            data_dir (str, optional): The data directory. Defaults to 'data/'.
            preprocessing_pipeline (PreprocessingPipeline, optional): Custom preprocessing pipeline. Defaults to None.
            batch_size (int, optional): Batch size. Defaults to 64.
            num_workers (int, optional): Number of workers. Defaults to 0.
            pin_memory (bool, optional): Whether to pin memory. Defaults to False.
            train_transforms (Compose, optional): Train split transformations. Defaults to None.
            val_test_transforms (Compose, optional): Validation and test split transformations. Defaults to None.
            save_predict_images (bool, optional): Save images in predict mode? Defaults to False.
            num_classes (int, optional): Number of classes in the dataset.
        """
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.dataset = None
        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        self.preprocessed_data: dict[Path] = {}

        self.k_folds = k_folds
        self.current_fold = current_fold
        self.kfold = None
        self.indices = None

    @property
    def num_classes(self) -> int:
        """Get the number of classes.

        Returns:
            int: The number of classes (2).
        """
        return self.hparams.num_classes

    def prepare_data(self) -> None:
        """Data preparation hook."""

        data_path = Path(self.hparams.data_dir)
        base_path = data_path.parent
        last_subdir = data_path.name
        output_path = base_path / f'{last_subdir}_processed'

        initial_data = {'initial_data': self.hparams.data_dir}
        if output_path.exists():
            if self.hparams.overwrite_data:
                clear_directory(output_path)
                output_path.rmdir()
                self.preprocessed_data = self.hparams.preprocessing_pipeline.run(
                    initial_data
                )
            else:
                self.preprocessed_data = (
                    self.hparams.preprocessing_pipeline.get_processed_data_path(
                        initial_data
                    )
                )
        else:
            self.preprocessed_data = self.hparams.preprocessing_pipeline.run(
                initial_data
            )

    def setup(self, stage: Optional[str] = None) -> None:
        """Datamodule setup step.

        Args:
            stage (Optional[str], optional): The stage to setup. Either `"fit"`,
            `"validate"`, `"test"`, or `"predict"`. Defaults to None.
        """
        if stage in {'fit', 'validate'}:
            # Create separate datasets
            train_dataset = ImageFolder(
                root=self.preprocessed_data['train'],
                transform=self.hparams.train_transforms,
            )
            val_dataset = ImageFolder(
                root=self.preprocessed_data['val'],
                transform=self.hparams.train_transforms,
            )

            # Merge datasets
            self.dataset = ConcatDataset([train_dataset, val_dataset])

            # Create KFold splitter
            self.kfold = KFold(n_splits=self.k_folds, shuffle=True, random_state=42)
            self.indices = list(range(len(self.dataset)))
            folds = list(self.kfold.split(self.indices))

            # Get train and validation indices for the current fold
            train_idx, val_idx = folds[self.current_fold]
            self.data_train = Subset(self.dataset, train_idx)
            self.data_val = Subset(self.dataset, val_idx)

        if stage == 'test':
            self.data_test = ImageFolder(
                root=self.preprocessed_data['test'],
                transform=self.hparams.val_test_transforms,
            )

    def train_dataloader(self) -> DataLoader[Any]:
        """Create and return the train dataloader.

        Returns:
            DataLoader[Any]: The train dataloader.
        """
        return self._default_dataloader(self.data_train, shuffle=True)

    def val_dataloader(self) -> DataLoader[Any]:
        """Create and return the validation dataloader.

        Returns:
            DataLoader[Any]: The validation dataloader.
        """
        return self._default_dataloader(self.data_val, shuffle=False)

    def test_dataloader(self) -> DataLoader[Any]:
        """Create and return the test dataloader.

        Returns:
            DataLoader[Any]: The test dataloader.
        """
        return self._default_dataloader(self.data_test, shuffle=False)

    def predict_dataloader(self) -> DataLoader[Any]:
        """Create and return the predict dataloader.

        Returns:
            DataLoader[Any]: The predict dataloader.
        """
        pass

    def teardown(self, stage: Optional[str] = None) -> None:
        """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
        `trainer.test()`, and `trainer.predict()`.

        Args:
            stage (Optional[str], optional): The stage being torn down. Either `"fit"`,
            `"validate"`, `"test"`, or `"predict"`. Defaults to None.
        """
        pass

    def state_dict(self) -> dict[Any, Any]:
        """Called when saving a checkpoint. Implement to generate and save the datamodule state.

        Returns:
            Dict[Any, Any]: A dictionary containing the datamodule state that you want to save.
        """
        return {}

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
        `state_dict()`.

        Args:
            state_dict (Dict[str, Any]): The datamodule state returned by `self.state_dict()`.
        """
        pass

    def _default_dataloader(
        self, dataset: Dataset, shuffle: bool = False
    ) -> DataLoader[Any]:
        """Create and return a dataloader.

        Args:
            dataset (Dataset): The dataset to use.
            shuffle (bool, optional): Flag for shuffling data. Defaults to False.

        Returns:
            DataLoader[Any]: Pytorch dataloader.
        """
        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=True,
            shuffle=shuffle,
        )

In [5]:
num_folds = 5
preprocessing_pipeline = hydra.utils.instantiate(cfg.data.preprocessing_pipeline)
train_transforms = hydra.utils.instantiate(cfg.data.train_transforms)
val_test_transforms = hydra.utils.instantiate(cfg.data.val_test_transforms)

#### Train k-fold cross-validation

In [None]:
for fold in range(num_folds):
    model: LightningModule = hydra.utils.instantiate(cfg.model)
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=loggers)
    datamodule = ClassificationDataModule(data_dir=cfg.data.data_dir,
                                          preprocessing_pipeline=preprocessing_pipeline,
                                          overwrite_data=cfg.data.overwrite_data,
                                          batch_size=cfg.data.batch_size,
                                          num_workers=cfg.data.num_workers,
                                          pin_memory=cfg.data.pin_memory,
                                          train_transforms=train_transforms,
                                          val_test_transforms=val_test_transforms,
                                          save_predict_images=cfg.data.save_predict_images,
                                          num_classes=cfg.data.num_classes,
                                          k_folds=num_folds,
                                          current_fold=fold)
    object_dict = {
        'cfg': cfg,
        'datamodule': datamodule,
        'model': model,
        'callbacks': callbacks,
        'logger': loggers,
        'trainer': trainer,
    }

    if loggers:
        log_hyperparameters(object_dict)

    if cfg.get('train'):
        trainer.fit(model, datamodule=datamodule)

    train_metrics = trainer.callback_metrics