# Imports

In [1]:
!pip install -q -U transformers datasets segments-ai evaluate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m506.3/506.3 kB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dep

In [2]:
!pip install -U wandb -q

In [3]:
!pip install -U datasets -q

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
hf_username = "PushkarA07"

# 1. Choose a dataset

In [6]:
hf_dataset_identifier = "PushkarA07/batch2-tiles_W5"

# Load and prepare the Hugging Face dataset for training

In [7]:
from datasets import load_dataset

ds = load_dataset(hf_dataset_identifier)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/419 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/11.2M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/180 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/45 [00:00<?, ? examples/s]

In [8]:
ds = ds.shuffle(seed=48)
ds = ds["train"].train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

In [9]:
import json
from huggingface_hub import hf_hub_download

filename = "id2label.json"
id2label = {0: 'normal', 1:'abnormality'}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
print("Id2label:", id2label)

Id2label: {0: 'normal', 1: 'abnormality'}


## Image processor & data augmentation

In [10]:
def remap_labels(labels):
  labels = labels.copy()
  mask1 = (labels >= 0) & (labels <= 227)
  labels[mask1] = 0
  mask2 = (labels >= 228) & (labels <= 255)
  labels[mask2] = 1
  return labels

In [11]:
import numpy as np
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
import PIL.Image as PILImage
from transformers import SegformerImageProcessor

processor = SegformerImageProcessor()

# --- Albumentations pipeline ---
train_augmentations = A.Compose([
    # Geometric transforms
    A.HorizontalFlip(p=0.5),   # Brain is mostly symmetric, so flipping helps
    A.VerticalFlip(p=0.2),     # less common but useful for robustness
    A.RandomRotate90(p=0.3),   # Intensity transforms
    A.Affine(
        translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
        scale=1.0,
        rotate=0,
        fit_output=False,
        border_mode=cv2.BORDER_WRAP,  # <-- wrap-around
        p=0.3
    ),  # This approximates rolling across edges
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=0.3),  # histogram equalization
    A.GaussianBlur(blur_limit=(3,5), p=0.2),
    A.GaussNoise(p=0.3),
    A.RandomGamma(gamma_limit=(80, 120), p=0.3), # gamma adjustments
    A.Sharpen(p=0.2), # sharpening
    ToTensorV2()
])

val_augmentations = A.Compose([
    ToTensorV2()
])

def train_transforms(example_batch):
    images = []
    labels = []
    for img, lbl in zip(example_batch['pixel_values'], example_batch['label']):
        img = np.array(PILImage.fromarray(np.uint8(img)).convert("RGB"))
        lbl = np.array(PILImage.fromarray(np.uint8(lbl)).convert("L"))

        augmented = train_augmentations(image=img, mask=lbl)
        aug_img, aug_lbl = augmented["image"], augmented["mask"]

        labels.append(remap_labels(aug_lbl.numpy()))
        images.append(aug_img.numpy().transpose(1,2,0))  # back to HWC for processor

    inputs = processor(
        images, labels,
        return_tensors="pt",
        do_resize=True,
        do_normalize=True
    )
    return inputs

def val_transforms(example_batch):
    images = []
    labels = []
    for img, lbl in zip(example_batch['pixel_values'], example_batch['label']):
        img = np.array(PILImage.fromarray(np.uint8(img)).convert("RGB"))
        lbl = np.array(PILImage.fromarray(np.uint8(lbl)).convert("L"))

        augmented = val_augmentations(image=img, mask=lbl)
        aug_img, aug_lbl = augmented["image"], augmented["mask"]

        labels.append(remap_labels(aug_lbl.numpy()))
        images.append(aug_img.numpy().transpose(1,2,0))  # back to HWC

    inputs = processor(
        images, labels,
        return_tensors="pt",
        do_resize=True,
        do_normalize=True
    )
    return inputs

# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

# 3. Fine-tune a SegFormer model

In [12]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)

## Load the model to fine-tune

In [20]:
from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "PushkarA07/segformer-b0-finetuned-net-15Oct"
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id
)

model.safetensors:   0%|          | 0.00/14.9M [00:00<?, ?B/s]

In [None]:
from transformers import TrainingArguments

epochs = 20000
lr = 0.000001
batch_size = 8
hub_model_id = "segformer-b0-finetuned-net-15Oct"

training_args = TrainingArguments(
    "segformer-b0-finetuned-net-outputs",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=10,
    eval_steps=10,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="end",
    report_to= "none"
    # "wandb"
)

In [22]:
import torch
from torch import nn
import evaluate
import multiprocessing

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=255,
            reduce_labels=processor.do_reduce_labels,
        )

    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    for i in range(1, len(per_category_accuracy)):
        metrics.update({f"accuracy_{id2label[i]}": per_category_accuracy[i]})
        metrics.update({f"iou_{id2label[i]}": per_category_iou[i]})

    return metrics

In [23]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [24]:
trainer.train()

Step,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Abnormality,Iou Abnormality
10,0.0055,0.005872,0.866137,0.901968,0.997757,0.804677,0.734532
20,0.0073,0.005875,0.866391,0.90299,0.997757,0.806738,0.735039
30,0.0023,0.005887,0.866802,0.904436,0.997759,0.809651,0.73586
40,0.0045,0.005874,0.866814,0.904389,0.997759,0.809555,0.735883
50,0.0026,0.005861,0.867144,0.905371,0.997761,0.811534,0.736541
60,0.0059,0.005881,0.866826,0.904362,0.99776,0.8095,0.735907
70,0.0037,0.005884,0.866675,0.90351,0.997761,0.807783,0.735604
80,0.0052,0.005872,0.866721,0.90383,0.99776,0.808428,0.735695
90,0.007,0.005886,0.867372,0.90653,0.99776,0.81387,0.736997
100,0.0064,0.005886,0.866915,0.904614,0.99776,0.810009,0.736084


TrainOutput(global_step=180, training_loss=0.006007259568044295, metrics={'train_runtime': 136.4123, 'train_samples_per_second': 10.556, 'train_steps_per_second': 1.32, 'total_flos': 2.524025595101184e+16, 'train_loss': 0.006007259568044295, 'epoch': 10.0})

In [25]:
kwargs = {
    "tags": ["vision", "image-segmentation"],
    "finetuned_from": pretrained_model_name,
    "dataset": hf_dataset_identifier,
}

processor.push_to_hub(hub_model_id)
trainer.push_to_hub(**kwargs)

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...outputs/model.safetensors:   4%|3         |  572kB / 14.9MB            

  ...outputs/training_args.bin:   4%|3         |   219B / 5.84kB            

CommitInfo(commit_url='https://huggingface.co/PushkarA07/segformer-b0-finetuned-net-15Oct/commit/9a1b9303217c6322fc96e39dfd09c2ab1a0fd444', commit_message='End of training', commit_description='', oid='9a1b9303217c6322fc96e39dfd09c2ab1a0fd444', pr_url=None, repo_url=RepoUrl('https://huggingface.co/PushkarA07/segformer-b0-finetuned-net-15Oct', endpoint='https://huggingface.co', repo_type='model', repo_id='PushkarA07/segformer-b0-finetuned-net-15Oct'), pr_revision=None, pr_num=None)

# 4. Inference (test)

In [26]:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor.from_pretrained("PushkarA07/segformer-b0-finetuned-net-15Oct")
model = SegformerForSemanticSegmentation.from_pretrained(f"PushkarA07/segformer-b0-finetuned-net-15Oct")

preprocessor_config.json:   0%|          | 0.00/372 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/14.9M [00:00<?, ?B/s]

In [88]:
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion, label as ndi_label, generate_binary_structure
from tqdm import tqdm

# --- Cluster labeling with dilation + erosion ---
def count_and_label(mask, dilation_iters: int = 1):
    m = mask.astype(bool)
    if dilation_iters > 0:
        struct = generate_binary_structure(2, 1)
        for _ in range(dilation_iters):
            m = binary_dilation(m, structure=struct)
    labels, num = ndi_label(m)
    return labels, num

def compute_metrics(pred_mask, gt_mask, dilation_iters=1):
    pred_labels, _ = count_and_label(pred_mask, dilation_iters)
    gt_labels, _ = count_and_label(gt_mask, dilation_iters)
    pred_clusters = np.unique(pred_labels)
    gt_clusters = np.unique(gt_labels)
    pred_clusters = pred_clusters[pred_clusters != 0]
    gt_clusters = gt_clusters[gt_clusters != 0]
    matched_pairs = 0
    used_pred = set()
    used_gt = set()

    for gt_id in gt_clusters:
        gt_cluster_mask = (gt_labels == gt_id)
        for pred_id in pred_clusters:
            if pred_id in used_pred:
                continue
            pred_cluster_mask = (pred_labels == pred_id)
            intersection = np.logical_and(gt_cluster_mask, pred_cluster_mask).sum()
            union = np.logical_or(gt_cluster_mask, pred_cluster_mask).sum()
            iou = intersection / (union + 1e-8)
            if iou > 0.1:
                matched_pairs += 1
                used_pred.add(pred_id)
                used_gt.add(gt_id)
                break

    total_clusters = len(pred_clusters) + len(gt_clusters) - matched_pairs
    cluster_iou = matched_pairs / (total_clusters + 1e-8)
    cluster_dice = 2 * matched_pairs / (len(pred_clusters) + len(gt_clusters) + 1e-8)

    return cluster_dice, cluster_iou

# --- Dataset loop ---
dice_scores = []
cluster_ious = []

for idx in tqdm(range(len(test_ds))): # seed=48
    # 1) Load & normalize image
    sample = test_ds[idx]
    img_t = sample['pixel_values']
    img_t = (img_t - img_t.min()) / (img_t.max() - img_t.min())
    img_np = (img_t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # 2) Model prediction
    pil = Image.fromarray(img_np)
    inputs = processor(images=pil, return_tensors="pt")
    with torch.no_grad():
        out = model(**inputs).logits
        up = F.interpolate(out, size=sample['labels'].shape,
                           mode="bilinear", align_corners=False)
        pred = up.argmax(dim=1)[0].cpu().numpy()

    # 3) Ground truth
    gt = sample['labels'].numpy()

    # 4) Compute metrics
    dice, cluster_iou = compute_metrics(pred, gt, dilation_iters=3)
    dice_scores.append(dice)
    cluster_ious.append(cluster_iou)

# final averages
print("\n=== Averages ===")
print(f"Mean Dice:        {np.mean(dice_scores):.4f}")
print(f"Mean Cluster-IoU: {np.mean(cluster_ious):.4f}") # Batch 2

100%|██████████| 36/36 [00:22<00:00,  1.60it/s]


=== Averages ===
Mean Dice:        0.8812
Mean Cluster-IoU: 0.8229





In [None]:
from PIL import Image
import numpy as np
import torch
from torch import nn
import evaluate

image = test_ds[1]['pixel_values']
image = (image - image.min()) / (image.max() - image.min())
gt_seg = test_ds[1]['labels']

image = torch.tensor(image)
image_pil = Image.fromarray((image.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
inputs = processor(images=image_pil, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
upsampled_logits = nn.functional.interpolate(
    logits,
    size=gt_seg.shape[-2:],
    mode="bilinear",
    align_corners=False
)

# Get predicted segmentation map
pred_seg = upsampled_logits.argmax(dim=1)[0].detach().cpu().numpy()

# Display the image, predicted label, and ground truth
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(image_pil)
plt.title("Image")
plt.subplot(1, 3, 2)
plt.imshow(pred_seg, cmap='binary', interpolation='nearest')
plt.title("Predicted Segmentation")
plt.subplot(1, 3, 3)
plt.imshow(gt_seg, cmap='binary', interpolation='nearest')
plt.title("Ground Truth Segmentation")
plt.show()

## prev

In [87]:
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion, label as ndi_label, generate_binary_structure
from tqdm import tqdm

# --- Cluster labeling with dilation + erosion ---
def count_and_label(mask, dilation_iters: int = 1):
    m = mask.astype(bool)
    if dilation_iters > 0:
        struct = generate_binary_structure(2, 1)
        for _ in range(dilation_iters):
            m = binary_dilation(m, structure=struct)
    labels, num = ndi_label(m)
    return labels, num

def compute_metrics(pred_mask, gt_mask, dilation_iters=1):
    pred_labels, _ = count_and_label(pred_mask, dilation_iters)
    gt_labels, _ = count_and_label(gt_mask, dilation_iters)
    pred_clusters = np.unique(pred_labels)
    gt_clusters = np.unique(gt_labels)
    pred_clusters = pred_clusters[pred_clusters != 0]
    gt_clusters = gt_clusters[gt_clusters != 0]
    matched_pairs = 0
    used_pred = set()
    used_gt = set()

    for gt_id in gt_clusters:
        gt_cluster_mask = (gt_labels == gt_id)
        for pred_id in pred_clusters:
            if pred_id in used_pred:
                continue
            pred_cluster_mask = (pred_labels == pred_id)
            intersection = np.logical_and(gt_cluster_mask, pred_cluster_mask).sum()
            union = np.logical_or(gt_cluster_mask, pred_cluster_mask).sum()
            iou = intersection / (union + 1e-8)
            if iou > 0.1:
                matched_pairs += 1
                used_pred.add(pred_id)
                used_gt.add(gt_id)
                break

    total_clusters = len(pred_clusters) + len(gt_clusters) - matched_pairs
    cluster_iou = matched_pairs / (total_clusters + 1e-8)
    cluster_dice = 2 * matched_pairs / (len(pred_clusters) + len(gt_clusters) + 1e-8)

    return cluster_dice, cluster_iou

# --- Dataset loop ---
dice_scores = []
cluster_ious = []

for idx in tqdm(range(len(test_ds))): # seed=48
    # 1) Load & normalize image
    sample = test_ds[idx]
    img_t = sample['pixel_values']
    img_t = (img_t - img_t.min()) / (img_t.max() - img_t.min())
    img_np = (img_t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # 2) Model prediction
    pil = Image.fromarray(img_np)
    inputs = processor(images=pil, return_tensors="pt")
    with torch.no_grad():
        out = model(**inputs).logits
        up = F.interpolate(out, size=sample['labels'].shape,
                           mode="bilinear", align_corners=False)
        pred = up.argmax(dim=1)[0].cpu().numpy()

    # 3) Ground truth
    gt = sample['labels'].numpy()

    # 4) Compute metrics
    dice, cluster_iou = compute_metrics(pred, gt, dilation_iters=3)
    dice_scores.append(dice)
    cluster_ious.append(cluster_iou)

# final averages
print("\n=== Averages ===")
print(f"Mean Dice:        {np.mean(dice_scores):.4f}")
print(f"Mean Cluster-IoU: {np.mean(cluster_ious):.4f}") # Batch 2

100%|██████████| 36/36 [00:22<00:00,  1.62it/s]


=== Averages ===
Mean Dice:        0.8812
Mean Cluster-IoU: 0.8229





In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion, label as ndi_label, generate_binary_structure
from tqdm import tqdm

# --- Cluster labeling with dilation + erosion ---
def count_and_label(mask, dilation_iters: int = 1):
    m = mask.astype(bool)
    if dilation_iters > 0:
        struct = generate_binary_structure(2, 1)
        for _ in range(dilation_iters):
            m = binary_dilation(m, structure=struct)
    labels, num = ndi_label(m)
    return labels, num

def compute_metrics(pred_mask, gt_mask, dilation_iters=1):
    pred_labels, _ = count_and_label(pred_mask, dilation_iters)
    gt_labels, _ = count_and_label(gt_mask, dilation_iters)
    pred_clusters = np.unique(pred_labels)
    gt_clusters = np.unique(gt_labels)
    pred_clusters = pred_clusters[pred_clusters != 0]
    gt_clusters = gt_clusters[gt_clusters != 0]
    matched_pairs = 0
    used_pred = set()
    used_gt = set()

    for gt_id in gt_clusters:
        gt_cluster_mask = (gt_labels == gt_id)
        for pred_id in pred_clusters:
            if pred_id in used_pred:
                continue
            pred_cluster_mask = (pred_labels == pred_id)
            intersection = np.logical_and(gt_cluster_mask, pred_cluster_mask).sum()
            union = np.logical_or(gt_cluster_mask, pred_cluster_mask).sum()
            iou = intersection / (union + 1e-8)
            if iou > 0.1:
                matched_pairs += 1
                used_pred.add(pred_id)
                used_gt.add(gt_id)
                break

    total_clusters = len(pred_clusters) + len(gt_clusters) - matched_pairs
    cluster_iou = matched_pairs / (total_clusters + 1e-8)
    cluster_dice = 2 * matched_pairs / (len(pred_clusters) + len(gt_clusters) + 1e-8)

    return cluster_dice, cluster_iou

# --- Dataset loop ---
dice_scores = []
cluster_ious = []

for idx in tqdm(range(len(test_ds))):
    # 1) Load & normalize image
    sample = test_ds[idx]
    img_t = sample['pixel_values']
    img_t = (img_t - img_t.min()) / (img_t.max() - img_t.min())
    img_np = (img_t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # 2) Model prediction
    pil = Image.fromarray(img_np)
    inputs = processor(images=pil, return_tensors="pt")
    with torch.no_grad():
        out = model(**inputs).logits
        up = F.interpolate(out, size=sample['labels'].shape,
                           mode="bilinear", align_corners=False)
        pred = up.argmax(dim=1)[0].cpu().numpy()

    # 3) Ground truth
    gt = sample['labels'].numpy()

    # 4) Compute metrics
    dice, cluster_iou = compute_metrics(pred, gt, dilation_iters=3)
    dice_scores.append(dice)
    cluster_ious.append(cluster_iou)

# final averages
print("\n=== Averages ===")
print(f"Mean Dice:        {np.mean(dice_scores):.4f}")
print(f"Mean Cluster-IoU: {np.mean(cluster_ious):.4f}") # Batch 2

100%|██████████| 36/36 [00:21<00:00,  1.69it/s]


=== Averages ===
Mean Dice:        0.8640
Mean Cluster-IoU: 0.8050





In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion, label as ndi_label, generate_binary_structure
from tqdm import tqdm

# --- Cluster labeling with dilation + erosion ---
def count_and_label(mask, dilation_iters: int = 1):
    m = mask.astype(bool)
    if dilation_iters > 0:
        struct = generate_binary_structure(2, 1)
        for _ in range(dilation_iters):
            m = binary_dilation(m, structure=struct)
    labels, num = ndi_label(m)
    return labels, num

def compute_metrics(pred_mask, gt_mask, dilation_iters=1):
    pred_labels, _ = count_and_label(pred_mask, dilation_iters)
    gt_labels, _ = count_and_label(gt_mask, dilation_iters)
    pred_clusters = np.unique(pred_labels)
    gt_clusters = np.unique(gt_labels)
    pred_clusters = pred_clusters[pred_clusters != 0]
    gt_clusters = gt_clusters[gt_clusters != 0]
    matched_pairs = 0
    used_pred = set()
    used_gt = set()

    for gt_id in gt_clusters:
        gt_cluster_mask = (gt_labels == gt_id)
        for pred_id in pred_clusters:
            if pred_id in used_pred:
                continue
            pred_cluster_mask = (pred_labels == pred_id)
            intersection = np.logical_and(gt_cluster_mask, pred_cluster_mask).sum()
            union = np.logical_or(gt_cluster_mask, pred_cluster_mask).sum()
            iou = intersection / (union + 1e-8)
            if iou > 0.1:
                matched_pairs += 1
                used_pred.add(pred_id)
                used_gt.add(gt_id)
                break

    total_clusters = len(pred_clusters) + len(gt_clusters) - matched_pairs
    cluster_iou = matched_pairs / (total_clusters + 1e-8)
    cluster_dice = 2 * matched_pairs / (len(pred_clusters) + len(gt_clusters) + 1e-8)

    return cluster_dice, cluster_iou

# --- Dataset loop ---
dice_scores = []
cluster_ious = []

for idx in tqdm(range(len(test_ds))):
    # 1) Load & normalize image
    sample = test_ds[idx]
    img_t = sample['pixel_values']
    img_t = (img_t - img_t.min()) / (img_t.max() - img_t.min())
    img_np = (img_t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # 2) Model prediction
    pil = Image.fromarray(img_np)
    inputs = processor(images=pil, return_tensors="pt")
    with torch.no_grad():
        out = model(**inputs).logits
        up = F.interpolate(out, size=sample['labels'].shape,
                           mode="bilinear", align_corners=False)
        pred = up.argmax(dim=1)[0].cpu().numpy()

    # 3) Ground truth
    gt = sample['labels'].numpy()

    # 4) Compute metrics
    dice, cluster_iou = compute_metrics(pred, gt, dilation_iters=3)
    dice_scores.append(dice)
    cluster_ious.append(cluster_iou)

# final averages
print("\n=== Averages ===")
print(f"Mean Dice:        {np.mean(dice_scores):.4f}")
print(f"Mean Cluster-IoU: {np.mean(cluster_ious):.4f}") # Batch 3

100%|██████████| 22/22 [00:13<00:00,  1.57it/s]


=== Averages ===
Mean Dice:        0.7992
Mean Cluster-IoU: 0.7288





In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import binary_dilation, binary_erosion, label as ndi_label, generate_binary_structure
from tqdm import tqdm

# --- Cluster labeling with dilation + erosion ---
def count_and_label(mask, dilation_iters: int = 1):
    m = mask.astype(bool)
    if dilation_iters > 0:
        struct = generate_binary_structure(2, 1)
        for _ in range(dilation_iters):
            m = binary_dilation(m, structure=struct)
    labels, num = ndi_label(m)
    return labels, num

def compute_metrics(pred_mask, gt_mask, dilation_iters=1):
    pred_labels, _ = count_and_label(pred_mask, dilation_iters)
    gt_labels, _ = count_and_label(gt_mask, dilation_iters)
    pred_clusters = np.unique(pred_labels)
    gt_clusters = np.unique(gt_labels)
    pred_clusters = pred_clusters[pred_clusters != 0]
    gt_clusters = gt_clusters[gt_clusters != 0]
    matched_pairs = 0
    used_pred = set()
    used_gt = set()

    for gt_id in gt_clusters:
        gt_cluster_mask = (gt_labels == gt_id)
        for pred_id in pred_clusters:
            if pred_id in used_pred:
                continue
            pred_cluster_mask = (pred_labels == pred_id)
            intersection = np.logical_and(gt_cluster_mask, pred_cluster_mask).sum()
            union = np.logical_or(gt_cluster_mask, pred_cluster_mask).sum()
            iou = intersection / (union + 1e-8)
            if iou > 0.1:
                matched_pairs += 1
                used_pred.add(pred_id)
                used_gt.add(gt_id)
                break

    total_clusters = len(pred_clusters) + len(gt_clusters) - matched_pairs
    cluster_iou = matched_pairs / (total_clusters + 1e-8)
    cluster_dice = 2 * matched_pairs / (len(pred_clusters) + len(gt_clusters) + 1e-8)

    return cluster_dice, cluster_iou

# --- Dataset loop ---
dice_scores = []
cluster_ious = []

for idx in tqdm(range(len(test_ds))):
    # 1) Load & normalize image
    sample = test_ds[idx]
    img_t = sample['pixel_values']
    img_t = (img_t - img_t.min()) / (img_t.max() - img_t.min())
    img_np = (img_t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # 2) Model prediction
    pil = Image.fromarray(img_np)
    inputs = processor(images=pil, return_tensors="pt")
    with torch.no_grad():
        out = model(**inputs).logits
        up = F.interpolate(out, size=sample['labels'].shape,
                           mode="bilinear", align_corners=False)
        pred = up.argmax(dim=1)[0].cpu().numpy()

    # 3) Ground truth
    gt = sample['labels'].numpy()

    # 4) Compute metrics
    dice, cluster_iou = compute_metrics(pred, gt, dilation_iters=3)
    dice_scores.append(dice)
    cluster_ious.append(cluster_iou)

# final averages
print("\n=== Averages ===")
print(f"Mean Dice:        {np.mean(dice_scores):.4f}")
print(f"Mean Cluster-IoU: {np.mean(cluster_ious):.4f}") # Batch 2

100%|██████████| 36/36 [00:22<00:00,  1.61it/s]


=== Averages ===
Mean Dice:        0.7889
Mean Cluster-IoU: 0.7058





## Visualization

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import binary_dilation, label as ndi_label, generate_binary_structure

# --- cluster-count helper ---
def count_and_label(mask: np.ndarray, dilation_iters: int = 1):
    m = mask.astype(bool)
    if dilation_iters > 0:
        struct = generate_binary_structure(2, 1)  # 3×3 connectivity
        for _ in range(dilation_iters):
            m = binary_dilation(m, structure=struct)
    labels, num = ndi_label(m)
    return labels, num

# --- Cluster-based IoU calculation ---
def cluster_iou(pred_lbls, pred_n, gt_lbls, gt_n):
    intersection_clusters = 0
    pred_clusters = set(np.unique(pred_lbls)) - {0}
    gt_clusters = set(np.unique(gt_lbls)) - {0}

    for pred_id in pred_clusters:
        pred_cluster_mask = (pred_lbls == pred_id)
        for gt_id in gt_clusters:
            gt_cluster_mask = (gt_lbls == gt_id)
            if np.logical_and(pred_cluster_mask, gt_cluster_mask).any():
                intersection_clusters += 1
                break  # one-to-one matching

    union_clusters = len(pred_clusters.union(gt_clusters))
    cluster_iou = intersection_clusters / (union_clusters + 1e-8)

    return cluster_iou

# --- visualization function ---
def visualize_clusters(idx: int, dilation_iters: int = 1):
    sample = test_ds[idx]

    # 1) load & normalize RGB tile
    img_t = sample['pixel_values']
    img_t = (img_t - img_t.min())/(img_t.max()-img_t.min())
    img_np = (img_t.permute(1,2,0).numpy()*255).astype(np.uint8)

    # 2) get prediction mask
    pil    = Image.fromarray(img_np)
    inputs = processor(images=pil, return_tensors="pt")
    with torch.no_grad():
        out  = model(**inputs).logits
        up   = F.interpolate(out, size=sample['labels'].shape,
                             mode="bilinear", align_corners=False)
        pred = up.argmax(dim=1)[0].cpu().numpy()  # ints {0,1}

    # 3) count & label clusters for pred and GT
    pred_lbls, pred_n = count_and_label(pred == 1, dilation_iters=dilation_iters)
    gt_mask           = sample['labels'].numpy() == 1
    gt_lbls,  gt_n    = count_and_label(gt_mask,     dilation_iters=dilation_iters)

    # 4) calculate cluster IoU
    cluster_iou_score = cluster_iou(pred_lbls, pred_n, gt_lbls, gt_n)

    # 5) plot side by side
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(15,5))

    ax1.imshow(img_np)
    ax1.axis("off")
    ax1.set_title(f"Sample #{idx}\nRGB Tile")

    ax2.imshow(pred, cmap="gray")
    ax2.contour(pred_lbls, levels=np.arange(0.5, pred_n+0.5),
                colors="red", linewidths=0.8)
    ax2.axis("off")
    ax2.set_title(f"Pred Clusters: {pred_n}")

    ax3.imshow(gt_mask, cmap="gray")
    ax3.contour(gt_lbls, levels=np.arange(0.5, gt_n+0.5),
                colors="blue", linewidths=0.8)
    ax3.axis("off")
    ax3.set_title(f"GT Clusters: {gt_n}")

    fig.suptitle(f"Cluster IoU: {cluster_iou_score:.3f}", fontsize=16)
    plt.tight_layout()
    plt.show()

# Example usage
for i in range(len(test_ds)):
    visualize_clusters(idx=i, dilation_iters=3)

### extra

In [None]:
# import wandb
# api = wandb.Api()
# team, project, run_id = "pushkar-ambastha", "huggingface", "2dtqfc7n"

In [None]:
# run = api.run(f"{team}/{project}/{run_id}")

# run.display(height=720)