# Out-of-Distribution (OOD) Data Analysis with PyTorch Lightning

In this notebook, we explore the evaluation of a model's performance on out-of-distribution (OOD) data. We utilize the PyTorch Lightning framework, which simplifies training and validation processes in PyTorch.

In [None]:
import argparse
import torch
from torch.optim import AdamW
from pytorch_lightning import Trainer
from src.datasets.ood_dataset import OodDataset
from src.datamodules.vienna_datamodule import ViennaDataModule
from src.datamodules.ebrain_datamodule import EbrainDataModule
from src.models.uncertainty_module import UncertaintyModule
from src.models.slides_module import ViT8
import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy("file_system")


## Loading the Model Checkpoint

Here we load the model checkpoint. This step initializes the model with the weights saved at a particular epoch.


In [None]:
checkpoint_path = "/n/data2/hms/dbmi/kyu/lab/raa006/pathology_uncertainty/logs/train/runs/2023-05-24_16-22-58/checkpoints/epoch_031.ckpt"
model = UncertaintyModule.load_from_checkpoint(checkpoint_path)

## Data Module Initialization

We prepare data modules for both in-domain and out-of-domain data. The `ViennaDataModule` is used for in-domain data, while `EbrainDataModule` handles out-of-domain data, specifically with extra OOD data.


In [None]:
# Initialize Vienna data module
vienna_datamodule = ViennaDataModule(batch_size=32, num_workers=6)
vienna_datamodule.setup()

# Initialize Ebrain data module with extra OOD data
ood_datamodule = EbrainDataModule(batch_size=1, num_workers=6, extra_ood=True)
ood_datamodule.setup()
model.ood_datamodule = ood_datamodule


## PyTorch Lightning Trainer Setup

We set up the PyTorch Lightning Trainer with GPU support and 16-bit precision for efficient training.


In [None]:
trainer = Trainer(gpus=1, precision=16)

## Model Evaluation

We perform the model evaluation for different seeds to assess the performance consistently. The model is evaluated on both in-domain and out-of-domain datasets.


In [None]:
scaler = torch.cuda.amp.GradScaler()
model.to("cuda:0")

for seed in [1]:
    model.seed = seed
    model.hparams.seed = seed

    trainer.validate(model, vienna_datamodule)
