# Kaggle histopathology introduction
This notebook is an introduction to the data challenge of out of distribution classification of histopathology patches. It also serves as a baseline for the code and the model.

If you have any questions, feel free to contact me at [leo.fillioux@centralesupelec.fr](mailto:leo.fillioux@centralesupelec.fr).

In [None]:
import h5py
import torch
import random
import numpy as np
import pandas as pd
import torchmetrics
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
import random
import os

In [None]:
# BASE_PATH = "/kaggle/input/mva-dlmi-2025-histopathology-ood-classification"
TRAIN_IMAGES_PATH = 'data/train.h5'
VAL_IMAGES_PATH = 'data/val.h5'
TEST_IMAGES_PATH = 'data/test.h5'
SEED = 0

In [None]:
torch.random.manual_seed(SEED)
random.seed(SEED)

## 1. Introduction to the data
The dataset consists of patches of whole slide images which should be classified into either containing tumor or not. The training images come from 3 different centers (i.e. hospitals), while the validation set comes from another center and the test set from yet another center. The visual aspect of the patches are quite different due to the slightly different staining procedures, conditions, and equipment from each hospital. The objective of the task is to build a classifier that is impacted by this distribution shift as little as possible.

The data is stored in `.h5` files, which can be seen as a folder hierarchy, which are can be seen as the following.
```
├── idx           # index of the image
│   └── img       # image in a tensor format
│   └── label     # binary label of the image
│   └── metadata  # some metadata on the images
```
The metadata is included for completeness but is not necessarily useful. The first element in the metadata corresponds to the center.

The following is a visualization of how different the images look from the different centers.

In [None]:
train_images = {0: {0: None, 1: None},
                3: {0: None, 1: None},
                4: {0: None, 1: None}}
val_images = {1: {0: None, 1: None}}

In [None]:
for img_data, data_path in zip([train_images, val_images], [TRAIN_IMAGES_PATH, VAL_IMAGES_PATH]):
    with h5py.File(data_path, 'r') as hdf:
        for img_idx in list(hdf.keys()):
            label = int(np.array(hdf.get(img_idx).get('label')))
            center = int(np.array(hdf.get(img_idx).get('metadata'))[0])
            if img_data[center][label] is None:
                img_data[center][label] = np.array(hdf.get(img_idx).get('img'))
            if all(all(value is not None for value in inner_dict.values()) for inner_dict in img_data.values()):
                break
all_data = {**train_images, **val_images}

## 2. Building a baseline model
The baseline model consists of extracting DINOv2 embeddings and linear probing.

### 2.1. Baseline dataset
We start by creating the model to read and process the data. For this simple model we also use another dataset with the preprocessed embeddings to avoid recomputing the same embeddings each time.

In [None]:
import random
import torchvision.transforms as transforms
import torch
import h5py
import numpy as np
from torch.utils.data import Dataset

class BaselineDataset(Dataset):
    def __init__(self, dataset_path, preprocessing=None, mode="train", augmentations=None, num_augmentations=1):
        super(BaselineDataset, self).__init__()
        self.dataset_path = dataset_path
        self.preprocessing = preprocessing
        self.mode = mode
        self.augmentations = augmentations
        self.num_augmentations = num_augmentations if mode == "train" else 1  # Only augment during training

        with h5py.File(self.dataset_path, 'r') as hdf:
            self.image_ids = list(hdf.keys())  # Original image IDs

        # Expand dataset size
        self.expanded_indices = [
            (idx, aug_idx) for idx in range(len(self.image_ids)) for aug_idx in range(self.num_augmentations)
        ]

    def __len__(self):
        return len(self.expanded_indices)

    def __getitem__(self, idx):
        img_idx, aug_idx = self.expanded_indices[idx]
        img_id = self.image_ids[img_idx]

        with h5py.File(self.dataset_path, 'r') as hdf:
            img = torch.tensor(np.array(hdf.get(img_id).get('img')))
            label = np.array(hdf.get(img_id).get('label')) if self.mode == 'train' else -1

        # Apply standard preprocessing
        if self.preprocessing is not None:
            img = self.preprocessing(img)

        # Apply augmentation
        if self.augmentations is not None and self.mode == "train":
            img = self.augmentations(img)

        return img.float(), torch.tensor(label, dtype=torch.float32)


In [None]:
def precompute(dataloader, model, device):
    xs, ys = [], []
    for x, y in tqdm(dataloader, leave=False):
        with torch.no_grad():
            xs.append(model(x.to(device)).detach().cpu().numpy())
        ys.append(y.numpy())
    xs = np.vstack(xs)
    ys = np.hstack(ys)
    return torch.tensor(xs), torch.tensor(ys)

In [None]:
class PrecomputedDataset(Dataset):
    def __init__(self, features, labels):
        super(PrecomputedDataset, self).__init__()
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx].float()

    def save(self, file_path):
        """Save the dataset to a file."""
        torch.save({'features': self.features, 'labels': self.labels}, file_path)

    @classmethod
    def load(cls, file_path):
        """Load the dataset from a file and return a new instance."""
        data = torch.load(file_path)
        return cls(data['features'], data['labels'])

### 2.1.1 Checking the dataset distribution

In [None]:
class AddGaussianNoise(torch.nn.Module):
    def __init__(self, mean=0.0, std=0.1):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, tensor):
        noise = torch.randn_like(tensor) * self.std + self.mean
        return torch.clamp(tensor + noise, 0.0, 1.0)

    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

import torch.nn.functional as Fnn
import torch

class Sharpen(torch.nn.Module):
    def forward(self, x):
        x = x.float() / 255.0 if x.max() > 1 else x.float()  # normalize if needed
        kernel = torch.tensor([[[[0, -1,  0],
                                 [-1,  5, -1],
                                 [0, -1,  0]]]], dtype=torch.float32, device=x.device)
        kernel = kernel.repeat(x.shape[0], 1, 1, 1)
        x = x.unsqueeze(0) if x.dim() == 3 else x
        x = Fnn.conv2d(x, kernel, padding=1, groups=x.shape[1]).squeeze(0)
        return torch.clamp(x, 0, 1)

In [None]:
train_augmentations = transforms.Compose([
    # transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.2),
    # transforms.RandomAffine(3, shear=0.05),
    transforms.RandomAdjustSharpness(2),
    transforms.RandomRotation(5),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=[0.2,1.05], saturation=[0.8,2], contrast=[0.7,1.6], hue=0.1)], p=0.5),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomApply([AddGaussianNoise(mean=0.0, std=0.1)], p=0.5)
])

train_preprocessing = transforms.Compose([
    transforms.Resize((224,224)),
    Sharpen()
])

train_set = BaselineDataset(TRAIN_IMAGES_PATH, preprocessing=train_preprocessing,
                                mode='train', augmentations=train_augmentations, num_augmentations=1)

# train_set,_ = random_split(train_set, [25000, len(train_set) - 25000])

val_set = BaselineDataset(VAL_IMAGES_PATH, preprocessing=train_preprocessing, mode='train')
# val_set,_ = random_split(val_set, [10000, len(val_set) - 10000])

# print(f"Original training images: {len(train_set.image_ids)}")
print(f"Total augmented training images: {len(train_set)}")  # Should be 3× original size if num_augmentations=3


In [None]:
fig,axs = plt.subplots(6,6, figsize=(15,15))
for k,i in enumerate(np.random.choice(np.arange(len(train_set)), size = len(axs.flatten()), replace = False)):
    img, label = train_set.__getitem__(i)
    img = (img - img.min())/(img.max() - img.min())
    # print(np.moveaxis(img.numpy(), 0, -1))
    axs.flatten()[k].imshow(np.moveaxis(img.numpy(), 0, -1).astype(np.float32))
    axs.flatten()[k].set_title(str(label.item()))
    axs.flatten()[k].axis("off")
plt.show()

In [None]:
fig,axs = plt.subplots(6,6, figsize=(15,15))
for k,i in enumerate(np.random.choice(np.arange(len(val_set)), size = len(axs.flatten()), replace = False)):
    img, label = val_set.__getitem__(i)
    img = (img - img.min())/(img.max() - img.min())
    # print(np.moveaxis(img.numpy(), 0, -1))
    axs.flatten()[k].imshow(np.moveaxis(img.numpy(), 0, -1).astype(np.float32))
    axs.flatten()[k].set_title(str(label.item()))
    axs.flatten()[k].axis("off")
plt.show()

### 2.2. Building the models and precomputing the features

In [None]:
BATCH_SIZE = 64

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Working on {device}.')

In [None]:
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').to(device)
# feature_extractor.eval()
print(feature_extractor.num_features)

In [None]:
import torch.nn as nn
class TransferModel(nn.Module):
    def __init__(self, feature_extractor = feature_extractor, finetuning = True):
        super().__init__()
        self.feature_extractor = feature_extractor
        for name, param in self.feature_extractor.named_parameters():
            if "blocks.0" in name or "blocks.1" in name:  # freeze first few layers
                param.requires_grad = False
                
        self.linear_probing = nn.Sequential(
                            nn.Linear(feature_extractor.num_features, 512),
                            nn.ReLU(),
                            nn.BatchNorm1d(512),
                            nn.Dropout(0.5),  # Increased dropout
                            nn.Linear(512, 256),
                            nn.ReLU(),
                            nn.BatchNorm1d(256),
                            nn.Dropout(0.3),
                            nn.Linear(256, 1),
                            nn.Sigmoid()
                        ).to(device)

    def forward(self, x):
        return self.linear_probing(self.feature_extractor(x))

In [None]:
# train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
# val_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

# train_dataset = PrecomputedDataset(*precompute(train_dataloader, feature_extractor, device))
# val_dataset = PrecomputedDataset(*precompute(val_dataloader, feature_extractor, device))
# train_dataset.save("augmented_train_dataset.pth")
# val_dataset.save("val_dataset.pth")

# train_set = PrecomputedDataset.load("augmented_train_dataset.pth")
# val_set = PrecomputedDataset.load("val_dataset.pth")

In [None]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle = True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle = False, num_workers=4, pin_memory=True)

## 3. Training the model

In [None]:
OPTIMIZER = 'Adam'
OPTIMIZER_PARAMS = {'lr': 0.0000005} #, "weight_decay":0.002}
LOSS = 'BCELoss'
METRIC = 'Accuracy'
NUM_EPOCHS = 500
PATIENCE = 50


In [None]:
# import torch.nn as nn
# linear_probing = nn.Sequential(
#     nn.Linear(feature_extractor.num_features, 512),
#     nn.ReLU(),
#     nn.BatchNorm1d(512),
#     nn.Dropout(0.5),  # Increased dropout
#     nn.Linear(512, 256),
#     nn.ReLU(),
#     nn.BatchNorm1d(256),
#     nn.Dropout(0.3),
#     nn.Linear(256, 1),
#     nn.Sigmoid()
# ).to(device)

# filters = [32,64,128]
# unet_model = UNet(in_channels=3,filters = filters, depth = len(filters) -1).to(device)


In [None]:
model = TransferModel(feature_extractor = feature_extractor)

In [None]:
optimizer = getattr(torch.optim, OPTIMIZER)(model.parameters(), **OPTIMIZER_PARAMS)
# optimizer = getattr(torch.optim, OPTIMIZER)(unet_model.parameters(), **OPTIMIZER_PARAMS)
criterion = getattr(torch.nn, LOSS)()
metric = getattr(torchmetrics, METRIC)('binary')
min_loss, best_epoch = float('inf'), 0

In [None]:
def plot_training_history(history, save_file=None):
    epochs = range(1, len(history['train_loss']) + 1)
    full_val_epochs = range(1, len(history['val_loss']) + 1)

    plt.figure(figsize=(12, 5))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], label='Train Loss', linestyle='-')
    plt.plot(full_val_epochs, history['val_loss'], label='Validation Loss', linestyle='-', color="C1")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Metric plot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['train_metric'], label='Train Metric', linestyle='-')
    plt.plot(full_val_epochs, history['val_metric'], label='Validation Metric', linestyle='-', color="C1")
    plt.xlabel('Epochs')
    plt.ylabel('Metric')
    plt.title('Training and Validation Metric')
    plt.legend()
    if save_file:
        plt.savefig(save_file)
        plt.close()
    else:
        plt.show()

In [None]:
import torch
import random
import torchvision.transforms as T

# Assuming history dict to store loss and metric values
history = {'train_loss': [], 'train_metric': [], 'val_loss': [], 'val_metric': []}

accumulation_steps = 5
epoch_bar = tqdm(range(NUM_EPOCHS), leave=True, desc='Training Progress')

for epoch in epoch_bar:
    model.train()
    train_metrics, train_losses = [], []

    optimizer.zero_grad()

    for i, (train_x, train_y) in enumerate(train_loader):
        print(f"{i / len(train_loader) * 100:.2f}%", end="\r")

        # Forward pass
        train_x, train_y = train_x.to(device), train_y.to(device)
        augmented_images = [train_x]  # Add more augmentations if needed
        original_pred = model(augmented_images[0]).squeeze(1)
        task_loss = criterion(original_pred.float(), train_y.float())
        total_loss = task_loss / accumulation_steps  # normalize loss

        # Backward pass (accumulating gradients)
        total_loss.backward()

        # Step every N mini-batches
        if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
            optimizer.step()
            optimizer.zero_grad()

        train_losses.extend([task_loss.item()] * len(train_y))  # Note: use `task_loss`, not scaled one

        train_metric = metric(original_pred.cpu(), train_y.int().cpu())
        train_metrics.extend([train_metric.item()] * len(train_y))

    train_loss_mean = np.mean(train_losses)
    train_metric_mean = np.mean(train_metrics)
    history['train_loss'].append(train_loss_mean)
    history['train_metric'].append(train_metric_mean)

    model.eval()
    val_metrics, val_losses = [], []

    for val_x, val_y in val_loader:
        with torch.no_grad():
            # Apply random augmentations to the validation image
            augmented_images = [val_x.to(device)]  # Start with the original image

            # Generate augmented images by applying random transformations
            # for _ in range(10):
            #     augmented_images.append(augmentation_transforms(val_x.to(device)))

            # Make predictions on the original image and the augmented images
            original_pred = model(augmented_images[0]).squeeze(1)
            augmented_preds = [model(img).squeeze(1) for img in augmented_images[1:]]

            # Combine predictions (e.g., averaging them)
            all_preds = [original_pred] + augmented_preds
            ensemble_pred = torch.mean(torch.stack(all_preds), dim=0)  # Averaging the predictions

            # Loss (you could use cross-entropy, MSE, etc.)
            loss = criterion(ensemble_pred.float(), val_y.to(device).float())
            val_losses.extend([loss.item()] * len(val_y))

            # Calculate the metric (e.g., accuracy, F1-score)
            val_metric = metric(ensemble_pred.cpu(), val_y.int().cpu())
            val_metrics.extend([val_metric.item()] * len(val_y))

    val_loss_mean = np.mean(val_losses)
    val_metric_mean = np.mean(val_metrics)

    history['val_loss'].append(val_loss_mean)
    history['val_metric'].append(val_metric_mean)

    epoch_bar.set_postfix({
        'Train Loss': f'{train_loss_mean:.3g}',
        f'Train {METRIC}': f'{train_metric_mean:.3g}',
        'Val Loss': f'{val_loss_mean:.3g}',
        f'Val {METRIC}': f'{val_metric_mean:.3g}',
    })

    if val_loss_mean < min_loss:
        min_loss = val_loss_mean
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_model_bis.pth')

    if epoch - best_epoch == PATIENCE:
        break

    plot_training_history(history, "output.png")


In [None]:
plot_training_history(history)

## 4. Making the final prediction

To create a solutions file, you need to generate a CSV with 2 columns.
- **ID**: containing the ID of the image
- **Pred**: with the predicted class (**threshold the prediction to get either 0 or 1**)

In [None]:
model.load_state_dict(torch.load("best_model_bis.pth", weights_only=True))
model.eval()
model.to(device)
prediction_dict = {}

In [None]:
# test_preprocessing = transforms.Compose([transforms.Resize((98,98)),
#                                          transforms.Normalize(mean =[0.673, 0.483, 0.739],
#                                                               std = [0.194, 0.222, 0.123])])
# test_dataset = BaselineDataset(TEST_IMAGES_PATH, preprocessing = train_preprocessing, mode = "eval")
# test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=256)
# get_channel_histograms(test_dataloader)



In [None]:
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    test_ids = list(hdf.keys())

In [None]:
import pandas as pd
import h5py
import torch
import torchvision.transforms as T
from tqdm import tqdm

# Define a set of possible transformations for augmentation during test inference
augmentation_transforms = T.Compose([
    T.RandomChoice([
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        T.Grayscale(num_output_channels=3),  # Convert to grayscale with 3 channels (RGB)
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
    ])
])

# Number of augmented images to use for ensemble during testing
num_augmented_images = 0

solutions_data = {'ID': [], 'Pred': []}

with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    for test_id in tqdm(test_ids):
        # Preprocess the original image
        img = train_preprocessing(torch.tensor(np.array(hdf.get(test_id).get('img')))).unsqueeze(0).float()

        # Apply random augmentations to the original image
        augmented_images = [img.to(device)]  # Start with the original image

        # Generate augmented images by applying random transformations
        # for _ in range(num_augmented_images):
        #     augmented_images.append(augmentation_transforms(img.to(device)))

        # Make predictions on the original image and the augmented images
        original_pred = model(augmented_images[0]).detach().cpu()
        augmented_preds = [model(img).detach().cpu() for img in augmented_images[1:]]

        # Combine predictions (e.g., averaging them)
        all_preds = [original_pred] + augmented_preds
        ensemble_pred = torch.mean(torch.stack(all_preds), dim=0)  # Averaging the predictions

        # Get the final prediction by thresholding the ensemble prediction
        final_pred = int(ensemble_pred.item() > 0.5)

        # Store the ID and the final prediction
        solutions_data['ID'].append(int(test_id))
        solutions_data['Pred'].append(final_pred)

# Convert to a DataFrame and save to CSV
solutions_data = pd.DataFrame(solutions_data).set_index('ID')
solutions_data.to_csv('model_test.csv')
