<a href="https://colab.research.google.com/github/ayyucedemirbas/dinov2_nuclei_seg/blob/main/dinov2_pannuke.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from datasets import load_dataset
from transformers import Dinov2Config
import matplotlib.pyplot as plt
from accelerate import Accelerator

In [3]:
class PanNukeSegmentationDataset(Dataset):
    def __init__(self, split="train"):
        self.dataset = load_dataset("RationAI/PanNuke", split=split)
        self.patch_size = 14
        # Use nearest multiple of 14 for 256px (256 ÷ 14 = 18.285 -> 18×14=252)
        self.image_size = 252  # Now divisible by 14 (252 ÷ 14 = 18)
        self.num_classes = 6  # 5 nuclei types + background

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        image = item["image"].convert("RGB")
        image_transform = T.Compose([
            T.Resize((self.image_size, self.image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        image = image_transform(image)

        instances = item["instances"]
        categories = item["categories"]

        mask = torch.zeros((self.image_size, self.image_size), dtype=torch.long)

        for instance, category in zip(instances, categories):
            instance_mask = T.functional.pil_to_tensor(instance)
            instance_mask = T.functional.resize(
                instance_mask,
                (self.image_size, self.image_size),
                interpolation=T.InterpolationMode.NEAREST
            ).squeeze().bool()
            mask[instance_mask] = category + 1  # Shift categories to 1-5

        return image, mask

In [4]:
class DinoV2ForSegmentation(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

        self.seg_head = nn.Sequential(
            nn.Conv2d(768, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(size=252, mode='bilinear'),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, x):
        with torch.no_grad():
            features = self.dinov2.forward_features(x)

        patch_embeddings = features['x_norm_patchtokens']
        batch_size, num_patches, hidden_dim = patch_embeddings.shape
        h = w = int(num_patches**0.5)
        embeddings = patch_embeddings.permute(0, 2, 1).view(batch_size, hidden_dim, h, w)

        return self.seg_head(embeddings)

In [5]:
accelerator = Accelerator()
device = accelerator.device

In [6]:
batch_size = 16
num_epochs = 20
learning_rate = 1e-4

In [7]:
train_dataset = PanNukeSegmentationDataset(split="fold1")
val_dataset = PanNukeSegmentationDataset(split="fold2")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

fold1-00000-of-00001.parquet:   0%|          | 0.00/280M [00:00<?, ?B/s]

fold2-00000-of-00001.parquet:   0%|          | 0.00/264M [00:00<?, ?B/s]

fold3-00000-of-00001.parquet:   0%|          | 0.00/289M [00:00<?, ?B/s]

Generating fold1 split:   0%|          | 0/2656 [00:00<?, ? examples/s]

Generating fold2 split:   0%|          | 0/2523 [00:00<?, ? examples/s]

Generating fold3 split:   0%|          | 0/2722 [00:00<?, ? examples/s]

In [8]:
model = DinoV2ForSegmentation(num_classes=6)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth
100%|██████████| 330M/330M [00:02<00:00, 155MB/s]


In [9]:
model, optimizer, train_loader, val_loader = accelerator.prepare(
    model, optimizer, train_loader, val_loader
)

In [10]:
best_val_loss = float('inf')
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, masks)

        accelerator.backward(loss)
        optimizer.step()

        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    train_loss /= len(train_loader)
    val_loss /= len(val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_dinov2_pannuke.pth")

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

Epoch 1/20
Train Loss: 0.5006 | Val Loss: 0.3619
Epoch 2/20
Train Loss: 0.3481 | Val Loss: 0.3370
Epoch 3/20
Train Loss: 0.3239 | Val Loss: 0.3121
Epoch 4/20
Train Loss: 0.3099 | Val Loss: 0.3082
Epoch 5/20
Train Loss: 0.2940 | Val Loss: 0.3018
Epoch 6/20
Train Loss: 0.2885 | Val Loss: 0.2872
Epoch 7/20
Train Loss: 0.2817 | Val Loss: 0.2840
Epoch 8/20
Train Loss: 0.2753 | Val Loss: 0.2805
Epoch 9/20
Train Loss: 0.2695 | Val Loss: 0.2785
Epoch 10/20
Train Loss: 0.2614 | Val Loss: 0.2756
Epoch 11/20
Train Loss: 0.2582 | Val Loss: 0.2732
Epoch 12/20
Train Loss: 0.2532 | Val Loss: 0.2799
Epoch 13/20
Train Loss: 0.2483 | Val Loss: 0.2782
Epoch 14/20
Train Loss: 0.2452 | Val Loss: 0.2736
Epoch 15/20
Train Loss: 0.2411 | Val Loss: 0.2721
Epoch 16/20
Train Loss: 0.2354 | Val Loss: 0.2699
Epoch 17/20
Train Loss: 0.2316 | Val Loss: 0.2744
Epoch 18/20
Train Loss: 0.2320 | Val Loss: 0.2675
Epoch 19/20
Train Loss: 0.2302 | Val Loss: 0.2673
Epoch 20/20
Train Loss: 0.2265 | Val Loss: 0.2661


In [11]:
torch.save(model.state_dict(), "dinov2_pannuke_segmentation.pth")