# Histopathology OOD Classification - DINOv2 + TTA Transform

In [1]:
import h5py
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torch.nn as nn
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import balanced_accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Chemins
TRAIN_IMAGES_PATH = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/train.h5'
VAL_IMAGES_PATH = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/val.h5'
TEST_IMAGES_PATH = '/kaggle/input/mva-dlmi-2025-histopathology-ood-classification/test.h5'

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x79a21a54dc70>

In [2]:
transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
        transforms.GaussianBlur(kernel_size=3)
    ], p=0.5),
    transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

transform_val = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [3]:
class BaselineDataset(Dataset):
    def __init__(self, dataset_path, transform, mode):
        self.dataset_path = dataset_path
        self.transform = transform
        self.mode = mode
        with h5py.File(self.dataset_path, 'r') as hdf:
            self.ids = list(hdf.keys())

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        with h5py.File(self.dataset_path, 'r') as hdf:
            img_id = self.ids[idx]
            img = np.array(hdf[img_id]['img'], dtype=np.float32)
            label = np.array(hdf[img_id]['label']) if self.mode == 'train' else None
        if img.shape[0] != 3:  # (H, W, C) → (C, H, W)
            img = torch.tensor(img).permute(2, 0, 1)
        else:
            img = torch.tensor(img) 
        img = self.transform(img)
        return img, label

In [None]:
BATCH_SIZE = 32
train_dataset = BaselineDataset(TRAIN_IMAGES_PATH, transform_train, 'train')
val_dataset = BaselineDataset(VAL_IMAGES_PATH, transform_val, 'train')

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)


In [5]:
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
feature_extractor.eval()

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:00<00:00, 224MB/s] 


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [6]:
def precompute(dataloader, model, device):
    xs, ys = [], []
    for x, y in tqdm(dataloader):
        with torch.no_grad():
            feats = model(x.to(device)).detach().cpu()
        xs.append(feats)
        ys.append(torch.tensor(y))
    return torch.cat(xs), torch.cat(ys)

x_train, y_train = precompute(train_dataloader, feature_extractor, device)
x_val, y_val = precompute(val_dataloader, feature_extractor, device)

  0%|          | 0/3125 [00:00<?, ?it/s]

  ys.append(torch.tensor(y))


  0%|          | 0/1091 [00:00<?, ?it/s]

In [9]:
class PrecomputedDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels.unsqueeze(1).float()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

train_ds = PrecomputedDataset(x_train, y_train)
val_ds = PrecomputedDataset(x_val, y_val)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
class SimpleLinearProbeSmall(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

model = SimpleLinearProbeSmall(x_train.shape[1]).to(device)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.BCELoss()

best_loss = float('inf')
best_epoch = 0
PATIENCE = 5
NUM_EPOCHS = 30

for epoch in range(NUM_EPOCHS):
    model.train()
    losses, preds, targets = [], [], []
    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        output = model(xb)
        loss = criterion(output, yb)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        preds.extend(output.detach().cpu().numpy())
        targets.extend(yb.cpu().numpy())
    acc = balanced_accuracy_score(targets, np.array(preds) > 0.5)
    print(f"[Train] Epoch {epoch+1} - Loss: {np.mean(losses):.4f} - BalAcc: {acc:.4f}")

    # Validation
    model.eval()
    losses, preds, targets = [], [], []
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            output = model(xb)
            loss = criterion(output, yb)
            losses.append(loss.item())
            preds.extend(output.cpu().numpy())
            targets.extend(yb.cpu().numpy())
    val_acc = balanced_accuracy_score(targets, np.array(preds) > 0.5)
    val_loss = np.mean(losses)
    print(f"[Val]   Epoch {epoch+1} - Loss: {val_loss:.4f} - BalAcc: {val_acc:.4f}")

    if val_loss < best_loss:
        best_loss = val_loss
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_model.pth')

    if epoch - best_epoch >= PATIENCE:
        print("Early stopping.")
        break

[Train] Epoch 1 - Loss: 0.1474 - BalAcc: 0.9429
[Val]   Epoch 1 - Loss: 0.2998 - BalAcc: 0.8875
[Train] Epoch 2 - Loss: 0.1199 - BalAcc: 0.9555
[Val]   Epoch 2 - Loss: 0.2411 - BalAcc: 0.8983
[Train] Epoch 3 - Loss: 0.1110 - BalAcc: 0.9583
[Val]   Epoch 3 - Loss: 0.2844 - BalAcc: 0.8980
[Train] Epoch 4 - Loss: 0.1065 - BalAcc: 0.9600
[Val]   Epoch 4 - Loss: 0.2960 - BalAcc: 0.8966
[Train] Epoch 5 - Loss: 0.1027 - BalAcc: 0.9617
[Val]   Epoch 5 - Loss: 0.2758 - BalAcc: 0.8959
[Train] Epoch 6 - Loss: 0.0999 - BalAcc: 0.9623
[Val]   Epoch 6 - Loss: 0.2586 - BalAcc: 0.8997
[Train] Epoch 7 - Loss: 0.0973 - BalAcc: 0.9631
[Val]   Epoch 7 - Loss: 0.2836 - BalAcc: 0.8929
Early stopping.


In [15]:
'''model.load_state_dict(torch.load('best_model.pth'))
model.eval()

submission = {'ID': [], 'Pred': []}
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    for img_id in tqdm(hdf.keys()):
        img = torch.tensor(np.array(hdf[img_id]['img'], dtype=np.float32))  # plus de permute
        img = transform_val(img).unsqueeze(0).to(device)
        with torch.no_grad():
            pred = model(feature_extractor(img)).item()
        submission['ID'].append(int(img_id))
        submission['Pred'].append(int(pred > 0.5))

submission_df = pd.DataFrame(submission).set_index('ID')
submission_df.to_csv("submission.csv")
print("Fichier submission.csv généré.") '''

  model.load_state_dict(torch.load('best_model.pth'))


  0%|          | 0/85054 [00:00<?, ?it/s]

✅ Fichier submission.csv généré.


## 🧪 Test-Time Augmentation (TTA)

In [12]:
tta_transforms = [
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ]),
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ]),
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ]),
    transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((98, 98), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
]

In [13]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

submission = {'ID': [], 'Pred': []}
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    for img_id in tqdm(hdf.keys()):
        img_raw = np.array(hdf[img_id]['img'], dtype=np.float32)
        if img_raw.shape[0] != 3:
            img_raw = torch.tensor(img_raw).permute(2, 0, 1)
        else:
            img_raw = torch.tensor(img_raw)

        tta_imgs = torch.stack([t(img_raw) for t in tta_transforms]).to(device)

        with torch.no_grad():
            features = feature_extractor(tta_imgs)
            output = model(features).squeeze(1)  # shape: (N_TTA,)
            pred = output.mean().item()

        submission['ID'].append(int(img_id))
        submission['Pred'].append(int(pred > 0.5))


submission_df = pd.DataFrame(submission).set_index('ID')
submission_df.to_csv("submission.csv")
print("Fichier submission.csv généré avec TTA.")

  model.load_state_dict(torch.load('best_model.pth'))


  0%|          | 0/85054 [00:00<?, ?it/s]

✅ Fichier submission.csv généré avec TTA.


In [None]:
sumbission_df.to_csv("sub_TTA.csv")