# 🧠 Histopathology OOD Classification - DINOv2 + Scenario 2 Inspired Augmentations

In [1]:
import h5py
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torch.nn as nn
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import balanced_accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Chemins
TRAIN_IMAGES_PATH = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/train.h5'
VAL_IMAGES_PATH = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/val.h5'
TEST_IMAGES_PATH = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/test.h5'

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7ae832b49c30>

In [2]:
transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
        transforms.GaussianBlur(kernel_size=3)
    ], p=0.5),
    transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

transform_val = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [3]:
class BaselineDataset(Dataset):
    def __init__(self, dataset_path, transform, mode):
        self.dataset_path = dataset_path
        self.transform = transform
        self.mode = mode
        with h5py.File(self.dataset_path, 'r') as hdf:
            self.ids = list(hdf.keys())

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        with h5py.File(self.dataset_path, 'r') as hdf:
            img_id = self.ids[idx]
            img = np.array(hdf[img_id]['img'], dtype=np.float32)
            label = np.array(hdf[img_id]['label']) if self.mode == 'train' else None
        if img.shape[0] != 3:  # (H, W, C) → (C, H, W)
            img = torch.tensor(img).permute(2, 0, 1)
        else:
            img = torch.tensor(img) 
        img = self.transform(img)
        return img, label

In [4]:
class DoubleInputDataset(Dataset):
    def __init__(self, dataset_path, transform, mode):
        self.dataset_path = dataset_path
        self.transform = transform
        self.mode = mode
        with h5py.File(self.dataset_path, 'r') as hdf:
            self.ids = list(hdf.keys())

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        with h5py.File(self.dataset_path, 'r') as hdf:
            img_id = self.ids[idx]
            img = np.array(hdf[img_id]['img'], dtype=np.float32)
            label = np.array(hdf[img_id]['label']) if self.mode == 'train' else None
        
        if img.shape[0] != 3:
            img = torch.tensor(img).permute(2, 0, 1)
        else:
            img = torch.tensor(img)

        # Random crop for local patch
        local_patch = transforms.RandomCrop((64, 64))(img)
        # Resize entire image for context
        context_patch = transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC)(img)

        local_patch = self.transform(local_patch)
        context_patch = self.transform(context_patch)

        return (local_patch, context_patch), label

In [5]:
BATCH_SIZE = 32
# train_dataset = BaselineDataset(TRAIN_IMAGES_PATH, transform_train, 'train')
# val_dataset = BaselineDataset(VAL_IMAGES_PATH, transform_val, 'train')

train_dataset = DoubleInputDataset(TRAIN_IMAGES_PATH, transform_train, mode='train')
val_dataset = DoubleInputDataset(VAL_IMAGES_PATH, transform_val, 'train')

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

In [6]:
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
feature_extractor.eval()

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:01<00:00, 55.6MB/s]


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [7]:
def precompute(dataloader, model, device):
    xs, ys = [], []
    for x, y in tqdm(dataloader):
        with torch.no_grad():
            feats = model(x.to(device)).detach().cpu()
        xs.append(feats)
        ys.append(torch.tensor(y))
    return torch.cat(xs), torch.cat(ys)

def precompute_double(dataloader, model, device):
    xs, ys = [], []
    for (x_local, x_context), y in tqdm(dataloader):
        with torch.no_grad():
            feats_local = model(x_local.to(device)).detach().cpu()
            feats_context = model(x_context.to(device)).detach().cpu()
            feats = torch.cat([feats_local, feats_context], dim=1)
        xs.append(feats)
        ys.append(torch.tensor(y))
    return torch.cat(xs), torch.cat(ys)

x_train, y_train = precompute_double(train_dataloader, feature_extractor, device)
x_val, y_val = precompute_double(val_dataloader, feature_extractor, device)

  0%|          | 0/3125 [00:00<?, ?it/s]

  ys.append(torch.tensor(y))


  0%|          | 0/1091 [00:00<?, ?it/s]

In [8]:
class PrecomputedDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels.unsqueeze(1).float()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

train_ds = PrecomputedDataset(x_train, y_train)
val_ds = PrecomputedDataset(x_val, y_val)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
class SimpleLinearProbeSmall(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

model = SimpleLinearProbeSmall(x_train.shape[1]).to(device)

In [10]:
import torch.optim as optim

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.BCELoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Reduce LR every 10 epochs by a factor of 0.1

best_loss = float('inf')
best_acc = float('-inf')
best_epoch = 0
PATIENCE = 10
NUM_EPOCHS = 50

for epoch in range(NUM_EPOCHS):
    model.train()
    losses, preds, targets = [], [], []
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        output = model(xb)
        loss = criterion(output, yb)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        preds.extend(output.detach().cpu().numpy())
        targets.extend(yb.cpu().numpy())
    acc = balanced_accuracy_score(targets, np.array(preds) > 0.5)
    print(f"[Train] Epoch {epoch+1} - Loss: {np.mean(losses):.4f} - BalAcc: {acc:.4f}")

    # Validation
    model.eval()
    losses, preds, targets = [], [], []
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            output = model(xb)
            loss = criterion(output, yb)
            losses.append(loss.item())
            preds.extend(output.cpu().numpy())
            targets.extend(yb.cpu().numpy())
    val_acc = balanced_accuracy_score(targets, np.array(preds) > 0.5)
    val_loss = np.mean(losses)
    print(f"[Val]   Epoch {epoch+1} - Loss: {val_loss:.4f} - BalAcc: {val_acc:.4f}")

    # if val_loss < best_loss:
    #     best_loss = val_loss
    #     best_epoch = epoch
    #     torch.save(model.state_dict(), 'best_model.pth')

    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_model.pth')

    if epoch - best_epoch >= PATIENCE:
        print("Early stopping.")
        break
    scheduler.step()

[Train] Epoch 1 - Loss: 0.1440 - BalAcc: 0.9455
[Val]   Epoch 1 - Loss: 0.3075 - BalAcc: 0.8714
[Train] Epoch 2 - Loss: 0.1151 - BalAcc: 0.9572
[Val]   Epoch 2 - Loss: 0.2434 - BalAcc: 0.9027
[Train] Epoch 3 - Loss: 0.1057 - BalAcc: 0.9604
[Val]   Epoch 3 - Loss: 0.2828 - BalAcc: 0.8961
[Train] Epoch 4 - Loss: 0.0999 - BalAcc: 0.9628
[Val]   Epoch 4 - Loss: 0.2660 - BalAcc: 0.9042
[Train] Epoch 5 - Loss: 0.0963 - BalAcc: 0.9641
[Val]   Epoch 5 - Loss: 0.3260 - BalAcc: 0.8840
[Train] Epoch 6 - Loss: 0.0945 - BalAcc: 0.9646
[Val]   Epoch 6 - Loss: 0.2646 - BalAcc: 0.8962
[Train] Epoch 7 - Loss: 0.0913 - BalAcc: 0.9656
[Val]   Epoch 7 - Loss: 0.2579 - BalAcc: 0.8953
[Train] Epoch 8 - Loss: 0.0899 - BalAcc: 0.9662
[Val]   Epoch 8 - Loss: 0.2389 - BalAcc: 0.9052
[Train] Epoch 9 - Loss: 0.0884 - BalAcc: 0.9669
[Val]   Epoch 9 - Loss: 0.2698 - BalAcc: 0.8991
[Train] Epoch 10 - Loss: 0.0875 - BalAcc: 0.9676
[Val]   Epoch 10 - Loss: 0.2759 - BalAcc: 0.9022
[Train] Epoch 11 - Loss: 0.0729 - BalA

In [11]:
'''model.load_state_dict(torch.load('best_model.pth'))
model.eval()

submission = {'ID': [], 'Pred': []}
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    for img_id in tqdm(hdf.keys()):
        img = torch.tensor(np.array(hdf[img_id]['img'], dtype=np.float32))  # plus de permute
        img = transform_val(img).unsqueeze(0).to(device)
        with torch.no_grad():
            pred = model(feature_extractor(img)).item()
        submission['ID'].append(int(img_id))
        submission['Pred'].append(int(pred > 0.5))

submission_df = pd.DataFrame(submission).set_index('ID')
submission_df.to_csv("submission.csv")
print("✅ Fichier submission.csv généré.") '''

'model.load_state_dict(torch.load(\'best_model.pth\'))\nmodel.eval()\n\nsubmission = {\'ID\': [], \'Pred\': []}\nwith h5py.File(TEST_IMAGES_PATH, \'r\') as hdf:\n    for img_id in tqdm(hdf.keys()):\n        img = torch.tensor(np.array(hdf[img_id][\'img\'], dtype=np.float32))  # plus de permute\n        img = transform_val(img).unsqueeze(0).to(device)\n        with torch.no_grad():\n            pred = model(feature_extractor(img)).item()\n        submission[\'ID\'].append(int(img_id))\n        submission[\'Pred\'].append(int(pred > 0.5))\n\nsubmission_df = pd.DataFrame(submission).set_index(\'ID\')\nsubmission_df.to_csv("submission.csv")\nprint("✅ Fichier submission.csv généré.") '

## 🧪 Test-Time Augmentation (TTA)

In [12]:
tta_transforms = [
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ]),
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ]),
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ]),
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
]

In [13]:
def precompute_double_tta(dataloader, model, device, tta_transforms):
    xs, ys = [], []
    for (x_local, x_context), y in tqdm(dataloader):
        local_feats, context_feats = [], []
        for tta_transform in tta_transforms:
            x_local_tta = tta_transform(x_local)
            x_context_tta = tta_transform(x_context)
            with torch.no_grad():
                f_local = model(x_local_tta.to(device)).detach().cpu()
                f_context = model(x_context_tta.to(device)).detach().cpu()
            local_feats.append(f_local)
            context_feats.append(f_context)

        # Moyenne sur les TTA
        f_local = torch.stack(local_feats).mean(dim=0)
        f_context = torch.stack(context_feats).mean(dim=0)
        f = torch.cat([f_local, f_context], dim=1)
        xs.append(f)
        ys.append(torch.tensor(y))

    return torch.cat(xs), torch.cat(ys)


In [13]:
# model.load_state_dict(torch.load('best_model.pth'))
# model.eval()

# submission = {'ID': [], 'Pred': []}
# with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
#     for img_id in tqdm(hdf.keys()):
#         img_raw = np.array(hdf[img_id]['img'], dtype=np.float32)
#         if img_raw.shape[0] != 3:
#             img_raw = torch.tensor(img_raw).permute(2, 0, 1)
#         else:
#             img_raw = torch.tensor(img_raw)

#         tta_imgs = torch.stack([t(img_raw) for t in tta_transforms]).to(device)

#         with torch.no_grad():
#             features = feature_extractor(tta_imgs)
#             output = model(features).squeeze(1)  # shape: (N_TTA,)
#             pred = output.mean().item()

#         submission['ID'].append(int(img_id))
#         submission['Pred'].append(int(pred > 0.5))


# submission_df = pd.DataFrame(submission).set_index('ID')
# submission_df.to_csv("submission_doubleinput.csv")
# print("✅ Fichier submission.csv généré avec TTA.")

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

submission = {'ID': [], 'Pred': []}
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    for img_id in tqdm(hdf.keys()):
        img_raw = np.array(hdf[img_id]['img'], dtype=np.float32)
        if img_raw.shape[0] != 3:
            img_raw = torch.tensor(img_raw).permute(2, 0, 1)
        else:
            img_raw = torch.tensor(img_raw)

        # Créer local + context
        local_patch = transforms.CenterCrop((64, 64))(img_raw)
        context_patch = transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC)(img_raw)

        # Appliquer TTA à chacun (en parallèle)
        local_ttas = torch.stack([t(local_patch) for t in tta_transforms]).to(device)
        context_ttas = torch.stack([t(context_patch) for t in tta_transforms]).to(device)

        with torch.no_grad():
            feat_local = feature_extractor(local_ttas)
            feat_context = feature_extractor(context_ttas)

        # Moyenne des features
        feat_local_mean = feat_local.mean(dim=0)
        feat_context_mean = feat_context.mean(dim=0)

        features = torch.cat([feat_local_mean, feat_context_mean]).unsqueeze(0).to(device)

        with torch.no_grad():
            pred = model(features).item()

        submission['ID'].append(int(img_id))
        submission['Pred'].append(int(pred > 0.5))  # seuil à 0.5

  model.load_state_dict(torch.load('best_model.pth'))


  0%|          | 0/85054 [00:00<?, ?it/s]

In [None]:
sumbission_df.to_csv("sub_TTA.csv")