In [None]:
import os
from PIL import Image
from glob import glob

# Input paths
image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/train/images"
mask_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/train/masks"

# Output paths
out_image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/images"
out_mask_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/masks"

os.makedirs(out_image_dir, exist_ok=True)
os.makedirs(out_mask_dir, exist_ok=True)

# Supported extensions
exts = ["*.png", "*.jpg", "*.jpeg"]

image_files = []
for ext in exts:
    image_files.extend(glob(os.path.join(image_dir, ext)))

print(f"Found {len(image_files)} images.")

for img_path in image_files:
    fname = os.path.basename(img_path)
    name, ext = os.path.splitext(fname)

    # Load image and mask
    mask_path = os.path.join(mask_dir, fname)
    if not os.path.exists(mask_path):
        print(f"Mask not found for {fname}, skipping.")
        continue

    img = Image.open(img_path)
    mask = Image.open(mask_path)

    # Save original
    img.save(os.path.join(out_image_dir, fname))
    mask.save(os.path.join(out_mask_dir, fname))

    # Horizontal flip
    img_h = img.transpose(Image.FLIP_LEFT_RIGHT)
    mask_h = mask.transpose(Image.FLIP_LEFT_RIGHT)
    img_h.save(os.path.join(out_image_dir, f"{name}_hflip{ext}"))
    mask_h.save(os.path.join(out_mask_dir, f"{name}_hflip{ext}"))

    # Vertical flip
    img_v = img.transpose(Image.FLIP_TOP_BOTTOM)
    mask_v = mask.transpose(Image.FLIP_TOP_BOTTOM)
    img_v.save(os.path.join(out_image_dir, f"{name}_vflip{ext}"))
    mask_v.save(os.path.join(out_mask_dir, f"{name}_vflip{ext}"))

print("Augmentation completed. Files saved in:")
print(out_image_dir)


In [None]:
#Generate Features for SAM 
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import SamModel, SamProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/images"  
sam_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/sam_train_features"
os.makedirs(sam_output_dir, exist_ok=True)

# Load SAM model for prediction
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# === Load image paths ===
image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

# === Inference loop ===
for path in tqdm(image_paths, desc="Extracting SAM masks"):
    img = Image.open(path).convert("RGB")
    base_name = os.path.splitext(os.path.basename(path))[0]

    inputs = sam_processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = sam_model(**inputs)
        masks = outputs.pred_masks  # [B, num_masks, H, W], logits
        iou_scores = outputs.iou_scores  # [B, num_masks]

        # Pick best mask per image 
        best_mask_idx = iou_scores.argmax(dim=1)[0]  # scalar
        best_mask_logits = masks[0, best_mask_idx]    # [H, W], logits

        # Convert logits to soft mask 
        soft_mask = torch.sigmoid(best_mask_logits).cpu()

    # Save soft mask 
    torch.save(soft_mask, os.path.join(sam_output_dir, f"{base_name}.pt"))


In [None]:
!pip uninstall transformers -y
!pip install transformers==4.40.0

In [2]:
#Generate Features for DINOv2 
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/images"  
dino_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/dino_train_features"
os.makedirs(dino_output_dir, exist_ok=True)

# Load DINOv2 ViT model and processor 
model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

# Load image paths
image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

# Inference loop
for path in tqdm(image_paths, desc="Extracting DINOv2 features"):
    img = Image.open(path).convert("RGB")
    base_name = os.path.splitext(os.path.basename(path))[0]

    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state[0]  # [num_tokens, C]

        # Remove CLS token
        spatial_tokens = last_hidden_state[1:]  # [num_patches, C]
        C = spatial_tokens.shape[-1]
        num_patches = spatial_tokens.shape[0]
        H = W = int(num_patches ** 0.5)

        if H * W != num_patches:
            raise ValueError(f"Expected square feature map but got {num_patches} tokens.")

        features = spatial_tokens.reshape(H, W, C).permute(2, 0, 1).contiguous()  # [C, H, W]

    # Save feature map
    torch.save(features.cpu(), os.path.join(dino_output_dir, f"{base_name}.pt"))


  from .autonotebook import tqdm as notebook_tqdm
    Found GPU1 NVIDIA GeForce GT 710 which is of cuda capability 3.5.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 3.7.
    
Extracting DINOv2 features: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5800/5800 [06:50<00:00, 14.14it/s]


In [1]:
!python --version

Python 3.9.22


In [3]:
#Generate Features for SAM for test
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import SamModel, SamProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_kvasir/images"  # <-- update this
sam_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/sam_kvasir_test_features"
os.makedirs(sam_output_dir, exist_ok=True)

# Load SAM model for mask prediction 
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# === Load image paths ===
image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

# === Inference loop ===
for path in tqdm(image_paths, desc="Extracting SAM masks"):
    img = Image.open(path).convert("RGB")
    base_name = os.path.splitext(os.path.basename(path))[0]

    inputs = sam_processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = sam_model(**inputs)
        masks = outputs.pred_masks  # [B, num_masks, H, W], logits
        iou_scores = outputs.iou_scores  # [B, num_masks]

        # Pick best mask per image 
        best_mask_idx = iou_scores.argmax(dim=1)[0]  # scalar
        best_mask_logits = masks[0, best_mask_idx]    # [H, W], logits

        # Convert logits to soft mask
        soft_mask = torch.sigmoid(best_mask_logits).cpu()

    # Save soft mask tensor
    torch.save(soft_mask, os.path.join(sam_output_dir, f"{base_name}.pt"))


Extracting SAM masks: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [01:21<00:00,  1.23it/s]


In [4]:
#Generate Features for SAM 
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import SamModel, SamProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_cvc/images"  # <-- update this
sam_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_test_features"
os.makedirs(sam_output_dir, exist_ok=True)

# Load SAM model for mask prediction instead of just embeddings
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# === Load image paths ===
image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

# === Inference loop ===
for path in tqdm(image_paths, desc="Extracting SAM masks"):
    img = Image.open(path).convert("RGB")
    base_name = os.path.splitext(os.path.basename(path))[0]

    inputs = sam_processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = sam_model(**inputs)
        masks = outputs.pred_masks  # [B, num_masks, H, W], logits
        iou_scores = outputs.iou_scores  # [B, num_masks]

        # Pick best mask per image (assuming batch size = 1)
        best_mask_idx = iou_scores.argmax(dim=1)[0]  # scalar
        best_mask_logits = masks[0, best_mask_idx]    # [H, W], logits

        # Convert logits to soft mask (sigmoid probabilities)
        soft_mask = torch.sigmoid(best_mask_logits).cpu()

    # Save soft mask tensor
    torch.save(soft_mask, os.path.join(sam_output_dir, f"{base_name}.pt"))


Extracting SAM masks: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 62/62 [00:49<00:00,  1.24it/s]


In [5]:
#Generate Dinov2 Features for test
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_kvasir/images"  # Same image dir
dino_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/dino_kvasir_test_features"
os.makedirs(dino_output_dir, exist_ok=True)

# Load DINOv2 ViT model and processor
model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

# === Load image paths ===
image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

# === Inference loop ===
for path in tqdm(image_paths, desc="Extracting DINOv2 features"):
    img = Image.open(path).convert("RGB")
    base_name = os.path.splitext(os.path.basename(path))[0]

    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state[0]  # [num_tokens, C]
        
        # Remove CLS token
        spatial_tokens = last_hidden_state[1:]  # [num_patches, C]
        C = spatial_tokens.shape[-1]
        num_patches = spatial_tokens.shape[0]
        H = W = int(num_patches ** 0.5)
    
        if H * W != num_patches:
            raise ValueError(f"Expected square feature map but got {num_patches} tokens.")
    
        features = spatial_tokens.reshape(H, W, C).permute(2, 0, 1).contiguous()  # [C, H, W]

    # Save feature map
    torch.save(features.cpu(), os.path.join(dino_output_dir, f"{base_name}.pt"))


Extracting DINOv2 features: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [00:04<00:00, 23.44it/s]


In [6]:
#Generate Dinov2 Features for test
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_cvc/images"  # Same image dir
dino_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_test_features"
os.makedirs(dino_output_dir, exist_ok=True)

# Load DINOv2 ViT model and processor
model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

# === Load image paths ===
image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

# === Inference loop ===
for path in tqdm(image_paths, desc="Extracting DINOv2 features"):
    img = Image.open(path).convert("RGB")
    base_name = os.path.splitext(os.path.basename(path))[0]

    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state[0]  # [num_tokens, C]
        
        # Remove CLS token
        spatial_tokens = last_hidden_state[1:]  # [num_patches, C]
        C = spatial_tokens.shape[-1]
        num_patches = spatial_tokens.shape[0]
        H = W = int(num_patches ** 0.5)
    
        if H * W != num_patches:
            raise ValueError(f"Expected square feature map but got {num_patches} tokens.")
    
        features = spatial_tokens.reshape(H, W, C).permute(2, 0, 1).contiguous()  # [C, H, W]

    # Save feature map
    torch.save(features.cpu(), os.path.join(dino_output_dir, f"{base_name}.pt"))


Extracting DINOv2 features: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 62/62 [00:01<00:00, 32.45it/s]


## In similar manner generate test features for unseen dataset (cvc-300, etis, cvc-colondb) for SAM and DINOv2..............

In [None]:
#Generate OneFormer Features
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dir = "/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/images"  # update path
oneformer_output_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_train_features"
os.makedirs(oneformer_output_dir, exist_ok=True)

# Load OneFormer (COCO semantic segmentation variant)
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large").to(device)

task_type = "semantic"  # can also be "instance" or "panoptic"

image_paths = sorted([
    os.path.join(image_dir, fname)
    for fname in os.listdir(image_dir)
    if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
])

for path in tqdm(image_paths, desc="Extracting OneFormer dense features"):
    img = Image.open(path).convert("RGB")
    inputs = processor(images=img, task_inputs=[task_type], return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    # Dense spatial feature map 
    dense_features = outputs.pixel_decoder_hidden_states[-1]  # shape: [1, C, H, W]
    dense_features = dense_features.squeeze(0).cpu()  # remove batch dim

    base_name = os.path.splitext(os.path.basename(path))[0]
    torch.save(dense_features, os.path.join(oneformer_output_dir, f"{base_name}.pt"))


## Similarly generate all the test features for Oneformer.....................


In [1]:
#Unet++ 3F 
import torch
from torch import nn
import torch.nn.functional as F
import torch.fft

__all__ = ['VGGBlock', 'UNet', 'NestedUNet']


class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out


class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False,
                 f1_channels=9, f2_channels=768, f3_channels=256):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # --- Standard UNet++ blocks ---
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        # --- Latent space projection layers (L1 / L2) ---
        self.linear1 = nn.Sequential(
            nn.Conv2d(nb_filter[4], 256, 1),
            nn.ReLU()
        )
        self.linear2 = nn.Conv2d(256, 256, 1)

        # --- Fusion layer for external features ---
        in_channels_total = f1_channels + f2_channels + f3_channels
        self.linear = nn.Sequential(
            nn.Conv2d(in_channels_total, 256, kernel_size=1),
            nn.ReLU(inplace=True)
        )

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input, f1, f2, f3):
        # --- Encoder path ---
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))  # bottleneck

        # --- Latent projections ---
        L1 = self.linear1(x4_0)
        L2 = self.linear2(L1)

        # --- Feature alignment & fusion ---
        target_size = f1.shape[2:]
        if f2.shape[2:] != target_size:
            f2 = F.interpolate(f2, size=target_size, mode='bilinear', align_corners=False)
        if f3.shape[2:] != target_size:
            f3 = F.interpolate(f3, size=target_size, mode='bilinear', align_corners=False)

        f = torch.cat([f1, f2, f3], dim=1)
        f = self.linear(f)  # -> (B,256,H,W)

        # --- Frequency filtering ---
        f_freq = torch.fft.fft2(f, norm="ortho")
        f_freq_shifted = torch.fft.fftshift(f_freq)

        B, C, H, W = f_freq_shifted.shape
        low_mask = torch.zeros_like(f_freq_shifted)
        high_mask = torch.ones_like(f_freq_shifted)

        center_h = H // 2
        center_w = W // 2
        radius = min(H, W) // 6

        low_mask[:, :, center_h-radius:center_h+radius, center_w-radius:center_w+radius] = 1
        high_mask = 1 - low_mask

        f_low = torch.real(torch.fft.ifft2(torch.fft.ifftshift(f_freq_shifted * low_mask), norm="ortho"))
        f_high = torch.real(torch.fft.ifft2(torch.fft.ifftshift(f_freq_shifted * high_mask), norm="ortho"))

        if f_low.shape[2:] != L1.shape[2:]:
            f_low = F.interpolate(f_low, size=L1.shape[2:], mode="bilinear", align_corners=False)
        if f_high.shape[2:] != L2.shape[2:]:
            f_high = F.interpolate(f_high, size=L2.shape[2:], mode="bilinear", align_corners=False)

        distillation1 = F.mse_loss(f_low, L1)
        distillation2 = F.mse_loss(f_high, L2)

        # --- Decoder path ---
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4], distillation1, distillation2
        else:
            output = self.final(x0_4)
            return output, distillation1, distillation2


# --- Quick test ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = NestedUNet(num_classes=1, input_channels=3).to(device)
    x = torch.randn(1, 3, 352, 352).to(device)
    f1 = torch.randn(1, 9, 32, 32).to(device)       # SAM
    f2 = torch.randn(1, 768, 32, 32).to(device)     # DINOv2
    f3 = torch.randn(1, 256, 32, 32).to(device)     # OneFormer

    out, d1, d2 = model(x, f1, f2, f3)
    print("Output:", out.shape)
    print("Distillation1:", d1.item())
    print("Distillation2:", d2.item())


Output: torch.Size([1, 1, 352, 352])
Distillation1: 0.06858216971158981
Distillation2: 0.06587187945842743


In [2]:
import os
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset

class UNetPPWithSAMDINOneFormerDataset(Dataset):
    def __init__(self, image_dir, mask_dir, sam_feature_dir, dino_feature_dir, oneformer_feature_dir, transform=None, feature_size=(64, 64)):
        """
        feature_size: tuple (H, W) â†’ all SAM/DINO/OneFormer features will be resized to this shape
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.sam_feature_dir = sam_feature_dir
        self.dino_feature_dir = dino_feature_dir
        self.oneformer_feature_dir = oneformer_feature_dir
        self.transform = transform
        self.feature_size = feature_size
        self.image_names = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        base_name, _ = os.path.splitext(image_name)

        image_path = os.path.join(self.image_dir, image_name)

        # Find mask file
        mask_path = None
        for ext in [".jpg", ".png"]:
            candidate = os.path.join(self.mask_dir, base_name + ext)
            if os.path.exists(candidate):
                mask_path = candidate
                break
        if mask_path is None:
            raise FileNotFoundError(f"No mask found for {base_name}")

        # Feature paths
        f1_path = os.path.join(self.sam_feature_dir, base_name + ".pt")        # SAM
        f2_path = os.path.join(self.dino_feature_dir, base_name + ".pt")       # DINO
        f3_path = os.path.join(self.oneformer_feature_dir, base_name + ".pt")  # OneFormer

        # Load data
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        f1 = torch.load(f1_path)  # SAM features [N, C, H, W] or [C, H, W]
        f2 = torch.load(f2_path)  # DINO features
        f3 = torch.load(f3_path)  # OneFormer features

        # Ensure 4D shape for interpolation
        if f1.dim() == 3:
            f1 = f1.unsqueeze(0)  # [1, C, H, W]
        if f2.dim() == 3:
            f2 = f2.unsqueeze(0)
        if f3.dim() == 3:
            f3 = f3.unsqueeze(0)

        # Resize features to the same size
        f1 = F.interpolate(f1, size=self.feature_size, mode="bilinear", align_corners=False).squeeze(0)
        f2 = F.interpolate(f2, size=self.feature_size, mode="bilinear", align_corners=False).squeeze(0)
        f3 = F.interpolate(f3, size=self.feature_size, mode="bilinear", align_corners=False).squeeze(0)

        # Apply transforms to image and mask
        if self.transform:
            transformed = self.transform(image=np.array(image), mask=np.array(mask))
            image = transformed['image']
            mask = transformed['mask']

        return {
            "pixel_values": image,
            "ground_truth_mask": mask.float() / 255.0,
            "f1": f1,
            "f2": f2,
            "f3": f3,
            "image_name": image_name
        }


In [4]:
# import albumentations as A
# import os
# from albumentations.pytorch import ToTensorV2
# from torch.utils.data import DataLoader

# # Transform for both train/valid
# transform = A.Compose([
#     A.Resize(352, 352),
#     A.Normalize(),
#     ToTensorV2()
# ])

# train_dataset = UNetWithSAMAndDINODataset(
#     image_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/images",
#     mask_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/masks",
#     sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_train_features",
#     dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_train_features",
#     transform=transform
# )

# test_kvasir_dataset = UNetWithSAMAndDINODataset(
#     image_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_kvasir/images",
#     mask_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_kvasir/masks",
#     sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_kvasir_test_features",
#     dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_kvasir_test_features",
#     transform=transform
# )

# test_cvc_dataset = UNetWithSAMAndDINODataset(
#     image_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_cvc/images",
#     mask_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_cvc/masks",
#     sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_test_features",
#     dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_test_features",
#     transform=transform
# )

# test_cvc_300_dataset = UNetWithSAMAndDINODataset(
#     image_dir="/home/deepak1010/Shivanshu Code/CVC-300/images",
#     mask_dir="/home/deepak1010/Shivanshu Code/CVC-300/masks",
#     sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_300_test_features",
#     dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_300_test_features",
#     transform=transform
# )

# test_etis_dataset = UNetWithSAMAndDINODataset(
#     image_dir="/home/deepak1010/Shivanshu Code/ETIS/images",
#     mask_dir="/home/deepak1010/Shivanshu Code/ETIS/masks",
#     sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_etis_test_features",
#     dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_etis_test_features",
#     transform=transform
# )

# test_cvc_colondb_dataset = UNetWithSAMAndDINODataset(
#     image_dir="/home/deepak1010/Shivanshu Code/CVC-ColonDB/images",
#     mask_dir="/home/deepak1010/Shivanshu Code/CVC-ColonDB/masks",
#     sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_colondb_test_features",
#     dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_colondb_test_features",
#     transform=transform
# )


# train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
# test_kvasir_dataloader = DataLoader(test_kvasir_dataset, batch_size=4, shuffle=False)
# test_cvc_dataloader = DataLoader(test_cvc_dataset, batch_size=4, shuffle=False)
# test_cvc_300_dataloader = DataLoader(test_cvc_300_dataset, batch_size=4, shuffle=False)
# test_etis_dataloader = DataLoader(test_etis_dataset, batch_size=4, shuffle=False)
# test_cvc_colondb_dataloader = DataLoader(test_cvc_colondb_dataset, batch_size=4, shuffle=False)


# print(len(train_dataloader))
# print(len(test_kvasir_dataloader))
# print(len(test_cvc_dataloader))
# print(len(test_cvc_300_dataloader))
# print(len(test_etis_dataloader))
# print(len(test_cvc_colondb_dataloader))

1450
25
16
15
49
95


  check_for_updates()


In [3]:
import albumentations as A
import os
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader

# Transform for both train/valid
transform = A.Compose([
    A.Resize(352, 352),
    A.Normalize(),
    ToTensorV2()
])

train_dataset = UNetPPWithSAMDINOneFormerDataset(
    image_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/images",
    mask_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc_augmented/train/masks",
    sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_train_features",
    dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_train_features",
    oneformer_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_train_features",
    transform=transform
)

test_kvasir_dataset = UNetPPWithSAMDINOneFormerDataset(
    image_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_kvasir/images",
    mask_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_kvasir/masks",
    sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_kvasir_test_features",
    dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_kvasir_test_features",
    oneformer_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_kvasir_test_features",
    transform=transform
)

test_cvc_dataset = UNetPPWithSAMDINOneFormerDataset(
    image_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_cvc/images",
    mask_dir="/home/deepak1010/Shivanshu Code/mixed_ds_kvasir_cvc/test_cvc/masks",
    sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_test_features",
    dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_test_features",
    oneformer_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_cvc_test_features",
    transform=transform
)

test_cvc_300_dataset = UNetPPWithSAMDINOneFormerDataset(
    image_dir="/home/deepak1010/Shivanshu Code/CVC-300/images",
    mask_dir="/home/deepak1010/Shivanshu Code/CVC-300/masks",
    sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_300_test_features",
    dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_300_test_features",
    oneformer_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_cvc_300_test_features",
    transform=transform
)

test_etis_dataset = UNetPPWithSAMDINOneFormerDataset(
    image_dir="/home/deepak1010/Shivanshu Code/ETIS/images",
    mask_dir="/home/deepak1010/Shivanshu Code/ETIS/masks",
    sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_etis_test_features",
    dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_etis_test_features",
    oneformer_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_etis_test_features",
    transform=transform
)

test_cvc_colondb_dataset = UNetPPWithSAMDINOneFormerDataset(
    image_dir="/home/deepak1010/Shivanshu Code/CVC-ColonDB/images",
    mask_dir="/home/deepak1010/Shivanshu Code/CVC-ColonDB/masks",
    sam_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/sam_cvc_colondb_test_features",
    dino_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/dino_cvc_colondb_test_features",
    oneformer_feature_dir="/home/deepak1010/Shivanshu Code/features_sam_clip/oneformer_cvc_colondb_test_features",
    transform=transform
)

# DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True)
test_kvasir_dataloader = DataLoader(test_kvasir_dataset, batch_size=6, shuffle=False)
test_cvc_dataloader = DataLoader(test_cvc_dataset, batch_size=6, shuffle=False)
test_cvc_300_dataloader = DataLoader(test_cvc_300_dataset, batch_size=6, shuffle=False)
test_etis_dataloader = DataLoader(test_etis_dataset, batch_size=6, shuffle=False)
test_cvc_colondb_dataloader = DataLoader(test_cvc_colondb_dataset, batch_size=6, shuffle=False)

# Quick check
print(len(train_dataloader))
print(len(test_kvasir_dataloader))
print(len(test_cvc_dataloader))
print(len(test_cvc_300_dataloader))
print(len(test_etis_dataloader))
print(len(test_cvc_colondb_dataloader))


967
17
11
10
33
64


In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from statistics import mean
from PIL import Image

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(min=1e-7, max=1 - 1e-7)
        preds = preds.contiguous()
        targets = targets.contiguous()

        intersection = (preds * targets).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

# === Device, model, optimizer ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NestedUNet(num_classes=1, input_channels=3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
seg_loss = DiceLoss()
criteria = nn.BCEWithLogitsLoss()

# === Training settings ===
num_epochs = 75
best_train_loss = float("inf")
patience = 10
early_stop_counter = 0
min_delta = 1e-4

for epoch in range(num_epochs):
    model.train()
    epoch_losses, seg_losses, distil_losses = [], [], []

    if epoch == 51:
        print("Freezing encoder layers (conv0_0 to conv4_0)...")
        for name, param in model.named_parameters():
            if any(enc in name for enc in ['conv0_0', 'conv1_0', 'conv2_0', 'conv3_0', 'conv4_0']):
                param.requires_grad = False
                print(f"Frozen: {name}")

    tq = tqdm(train_dataloader, desc=f"[Train] Epoch {epoch}")
    for batch in tq:
        x = batch["pixel_values"].to(device)                # Input image
        mask = batch["ground_truth_mask"].unsqueeze(1).to(device)

        # SAM features
        f1 = batch["f1"].to(device)                         # Shape: [B, N, C, H, W]
        B, N, C, H, W = f1.shape
        f1 = f1.view(B, N * C, H, W)                         # â†’ [B, 512, H, W] if N=1, C=512

        # DINOv2 features
        f2 = batch["f2"].to(device)                         # Shape: [B, 512, H, W] 
        #Oneformer Features
        f3 = batch["f3"].to(device)
        # print(f1.shape)
        # print(f2.shape)
        # print(f3.shape)
        logits, dist1, dist2 = model(x, f1, f2, f3)

        # Resize mask to match logits
        mask_resized = nn.functional.interpolate(mask, size=logits.shape[2:], mode="bilinear", align_corners=False)

        segmentation_loss = seg_loss(logits, mask_resized) + criteria(logits, mask_resized)

        if epoch < 25:
            loss = segmentation_loss
        else:
            loss = 0.6 * segmentation_loss + 0.1 * dist1 + 0.1 * dist2

        if not torch.isnan(loss):
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())
            seg_losses.append(segmentation_loss.item())
            distil_losses.append((dist1.item() + dist2.item()) / 2)

        tq.set_postfix({
            "Loss": loss.item(),
            "Seg": segmentation_loss.item(),
            "Dist": (dist1.item() + dist2.item()) / 2
        })

    avg_train_loss = mean(epoch_losses)
    print(f"[Train] Epoch {epoch:03}: Loss = {avg_train_loss:.4f}, Seg = {mean(seg_losses):.4f}, Dist = {mean(distil_losses):.4f}")

    # === Early Stopping Check ===
    if best_train_loss - avg_train_loss > min_delta:
        best_train_loss = avg_train_loss
        early_stop_counter = 0
        save_path = "/home/deepak1010/Shivanshu Code/features_sam_clip/unetplusplus_sam_dino_oneformer.pth"
        torch.save(model.state_dict(), save_path)
        print(f"âœ… Saved new best model at epoch {epoch} with train_loss = {best_train_loss:.4f}")
    else:
        early_stop_counter += 1
        print(f"Early stopping counter: {early_stop_counter}/{patience}")
        if early_stop_counter >= patience:
            print(" Early stopping triggered. Stopping training.")
            break


[Train] Epoch 0: 100%|â–ˆ| 967/967 [24:56<00:00,  1.55s/it, Loss=0.834, Seg=0.834, Dist=0.181]


[Train] Epoch 000: Loss = 1.0143, Seg = 1.0143, Dist = 0.1807
âœ… Saved new best model at epoch 0 with train_loss = 1.0143


[Train] Epoch 1: 100%|â–ˆ| 967/967 [22:10<00:00,  1.38s/it, Loss=0.652, Seg=0.652, Dist=0.179]


[Train] Epoch 001: Loss = 0.6308, Seg = 0.6308, Dist = 0.1778
âœ… Saved new best model at epoch 1 with train_loss = 0.6308


[Train] Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆ| 967/967 [22:00<00:00,  1.37s/it, Loss=0.28, Seg=0.28, Dist=0.18]


[Train] Epoch 002: Loss = 0.4594, Seg = 0.4594, Dist = 0.1770
âœ… Saved new best model at epoch 2 with train_loss = 0.4594


[Train] Epoch 3: 100%|â–ˆâ–ˆ| 967/967 [21:54<00:00,  1.36s/it, Loss=0.192, Seg=0.192, Dist=0.18]


[Train] Epoch 003: Loss = 0.3826, Seg = 0.3826, Dist = 0.1768
âœ… Saved new best model at epoch 3 with train_loss = 0.3826


[Train] Epoch 4: 100%|â–ˆ| 967/967 [22:04<00:00,  1.37s/it, Loss=0.264, Seg=0.264, Dist=0.176]


[Train] Epoch 004: Loss = 0.3343, Seg = 0.3343, Dist = 0.1764
âœ… Saved new best model at epoch 4 with train_loss = 0.3343


[Train] Epoch 5: 100%|â–ˆ| 967/967 [22:03<00:00,  1.37s/it, Loss=0.185, Seg=0.185, Dist=0.184]


[Train] Epoch 005: Loss = 0.2966, Seg = 0.2966, Dist = 0.1759
âœ… Saved new best model at epoch 5 with train_loss = 0.2966


[Train] Epoch 6: 100%|â–ˆ| 967/967 [22:01<00:00,  1.37s/it, Loss=0.365, Seg=0.365, Dist=0.171]


[Train] Epoch 006: Loss = 0.2679, Seg = 0.2679, Dist = 0.1758
âœ… Saved new best model at epoch 6 with train_loss = 0.2679


[Train] Epoch 7: 100%|â–ˆ| 967/967 [22:09<00:00,  1.37s/it, Loss=0.471, Seg=0.471, Dist=0.172]


[Train] Epoch 007: Loss = 0.2424, Seg = 0.2424, Dist = 0.1756
âœ… Saved new best model at epoch 7 with train_loss = 0.2424


[Train] Epoch 8: 100%|â–ˆ| 967/967 [22:08<00:00,  1.37s/it, Loss=0.231, Seg=0.231, Dist=0.174]


[Train] Epoch 008: Loss = 0.2201, Seg = 0.2201, Dist = 0.1754
âœ… Saved new best model at epoch 8 with train_loss = 0.2201


[Train] Epoch 9: 100%|â–ˆ| 967/967 [22:24<00:00,  1.39s/it, Loss=0.187, Seg=0.187, Dist=0.177]


[Train] Epoch 009: Loss = 0.2017, Seg = 0.2017, Dist = 0.1754
âœ… Saved new best model at epoch 9 with train_loss = 0.2017


[Train] Epoch 10: 100%|â–ˆ| 967/967 [22:16<00:00,  1.38s/it, Loss=0.115, Seg=0.115, Dist=0.173


[Train] Epoch 010: Loss = 0.1884, Seg = 0.1884, Dist = 0.1752
âœ… Saved new best model at epoch 10 with train_loss = 0.1884


[Train] Epoch 11: 100%|â–ˆ| 967/967 [22:22<00:00,  1.39s/it, Loss=0.333, Seg=0.333, Dist=0.175


[Train] Epoch 011: Loss = 0.1724, Seg = 0.1724, Dist = 0.1752
âœ… Saved new best model at epoch 11 with train_loss = 0.1724


[Train] Epoch 12: 100%|â–ˆ| 967/967 [22:22<00:00,  1.39s/it, Loss=0.159, Seg=0.159, Dist=0.178


[Train] Epoch 012: Loss = 0.1540, Seg = 0.1540, Dist = 0.1751
âœ… Saved new best model at epoch 12 with train_loss = 0.1540


[Train] Epoch 13: 100%|â–ˆ| 967/967 [22:09<00:00,  1.38s/it, Loss=0.239, Seg=0.239, Dist=0.173


[Train] Epoch 013: Loss = 0.1467, Seg = 0.1467, Dist = 0.1754
âœ… Saved new best model at epoch 13 with train_loss = 0.1467


[Train] Epoch 14: 100%|â–ˆ| 967/967 [21:53<00:00,  1.36s/it, Loss=0.149, Seg=0.149, Dist=0.178


[Train] Epoch 014: Loss = 0.1380, Seg = 0.1380, Dist = 0.1754
âœ… Saved new best model at epoch 14 with train_loss = 0.1380


[Train] Epoch 15: 100%|â–ˆ| 967/967 [21:48<00:00,  1.35s/it, Loss=0.153, Seg=0.153, Dist=0.178


[Train] Epoch 015: Loss = 0.1231, Seg = 0.1231, Dist = 0.1752
âœ… Saved new best model at epoch 15 with train_loss = 0.1231


[Train] Epoch 16: 100%|â–ˆ| 967/967 [21:50<00:00,  1.36s/it, Loss=0.107, Seg=0.107, Dist=0.181


[Train] Epoch 016: Loss = 0.1160, Seg = 0.1160, Dist = 0.1750
âœ… Saved new best model at epoch 16 with train_loss = 0.1160


[Train] Epoch 17: 100%|â–ˆ| 967/967 [21:58<00:00,  1.36s/it, Loss=0.0831, Seg=0.0831, Dist=0.1


[Train] Epoch 017: Loss = 0.1099, Seg = 0.1099, Dist = 0.1752
âœ… Saved new best model at epoch 17 with train_loss = 0.1099


[Train] Epoch 18: 100%|â–ˆ| 967/967 [21:51<00:00,  1.36s/it, Loss=0.0865, Seg=0.0865, Dist=0.1


[Train] Epoch 018: Loss = 0.1021, Seg = 0.1021, Dist = 0.1753
âœ… Saved new best model at epoch 18 with train_loss = 0.1021


[Train] Epoch 19: 100%|â–ˆ| 967/967 [21:47<00:00,  1.35s/it, Loss=0.0735, Seg=0.0735, Dist=0.1


[Train] Epoch 019: Loss = 0.1062, Seg = 0.1062, Dist = 0.1749
Early stopping counter: 1/10


[Train] Epoch 20: 100%|â–ˆ| 967/967 [21:45<00:00,  1.35s/it, Loss=0.0607, Seg=0.0607, Dist=0.1


[Train] Epoch 020: Loss = 0.0925, Seg = 0.0925, Dist = 0.1751
âœ… Saved new best model at epoch 20 with train_loss = 0.0925


[Train] Epoch 21: 100%|â–ˆ| 967/967 [21:33<00:00,  1.34s/it, Loss=0.0932, Seg=0.0932, Dist=0.1


[Train] Epoch 021: Loss = 0.0894, Seg = 0.0894, Dist = 0.1752
âœ… Saved new best model at epoch 21 with train_loss = 0.0894


[Train] Epoch 22: 100%|â–ˆ| 967/967 [21:27<00:00,  1.33s/it, Loss=0.0588, Seg=0.0588, Dist=0.1


[Train] Epoch 022: Loss = 0.0832, Seg = 0.0832, Dist = 0.1751
âœ… Saved new best model at epoch 22 with train_loss = 0.0832


[Train] Epoch 23: 100%|â–ˆ| 967/967 [21:27<00:00,  1.33s/it, Loss=0.0678, Seg=0.0678, Dist=0.1


[Train] Epoch 023: Loss = 0.0805, Seg = 0.0805, Dist = 0.1751
âœ… Saved new best model at epoch 23 with train_loss = 0.0805


[Train] Epoch 24: 100%|â–ˆ| 967/967 [21:14<00:00,  1.32s/it, Loss=0.0457, Seg=0.0457, Dist=0.1


[Train] Epoch 024: Loss = 0.0768, Seg = 0.0768, Dist = 0.1750
âœ… Saved new best model at epoch 24 with train_loss = 0.0768


[Train] Epoch 25: 100%|â–ˆ| 967/967 [21:32<00:00,  1.34s/it, Loss=0.0344, Seg=0.0573, Dist=5.2


[Train] Epoch 025: Loss = 0.0384, Seg = 0.0632, Dist = 0.0024
âœ… Saved new best model at epoch 25 with train_loss = 0.0384


[Train] Epoch 26: 100%|â–ˆ| 967/967 [21:42<00:00,  1.35s/it, Loss=0.042, Seg=0.0699, Dist=1.98


[Train] Epoch 026: Loss = 0.0386, Seg = 0.0643, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 27: 100%|â–ˆ| 967/967 [22:34<00:00,  1.40s/it, Loss=0.0368, Seg=0.0614, Dist=9.4


[Train] Epoch 027: Loss = 0.0496, Seg = 0.0827, Dist = 0.0000
Early stopping counter: 2/10


[Train] Epoch 28: 100%|â–ˆ| 967/967 [22:57<00:00,  1.42s/it, Loss=0.0297, Seg=0.0495, Dist=5.3


[Train] Epoch 028: Loss = 0.0379, Seg = 0.0632, Dist = 0.0000
âœ… Saved new best model at epoch 28 with train_loss = 0.0379


[Train] Epoch 29: 100%|â–ˆ| 967/967 [23:01<00:00,  1.43s/it, Loss=0.0293, Seg=0.0488, Dist=1.9


[Train] Epoch 029: Loss = 0.0423, Seg = 0.0706, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 30: 100%|â–ˆ| 967/967 [23:02<00:00,  1.43s/it, Loss=0.0322, Seg=0.0537, Dist=1.1


[Train] Epoch 030: Loss = 0.0368, Seg = 0.0614, Dist = 0.0000
âœ… Saved new best model at epoch 30 with train_loss = 0.0368


[Train] Epoch 31: 100%|â–ˆ| 967/967 [23:08<00:00,  1.44s/it, Loss=0.0452, Seg=0.0754, Dist=1.1


[Train] Epoch 031: Loss = 0.0388, Seg = 0.0646, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 32: 100%|â–ˆ| 967/967 [23:07<00:00,  1.43s/it, Loss=0.0425, Seg=0.0709, Dist=1.6


[Train] Epoch 032: Loss = 0.0448, Seg = 0.0747, Dist = 0.0000
Early stopping counter: 2/10


[Train] Epoch 33: 100%|â–ˆ| 967/967 [23:00<00:00,  1.43s/it, Loss=0.0365, Seg=0.0608, Dist=1.2


[Train] Epoch 033: Loss = 0.0326, Seg = 0.0543, Dist = 0.0000
âœ… Saved new best model at epoch 33 with train_loss = 0.0326


[Train] Epoch 34: 100%|â–ˆ| 967/967 [23:00<00:00,  1.43s/it, Loss=0.0237, Seg=0.0396, Dist=1.7


[Train] Epoch 034: Loss = 0.0303, Seg = 0.0505, Dist = 0.0000
âœ… Saved new best model at epoch 34 with train_loss = 0.0303


[Train] Epoch 35: 100%|â–ˆ| 967/967 [22:59<00:00,  1.43s/it, Loss=0.0561, Seg=0.0935, Dist=1.3


[Train] Epoch 035: Loss = 0.0478, Seg = 0.0796, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 36: 100%|â–ˆ| 967/967 [22:59<00:00,  1.43s/it, Loss=0.0279, Seg=0.0465, Dist=4.0


[Train] Epoch 036: Loss = 0.0303, Seg = 0.0505, Dist = 0.0000
Early stopping counter: 2/10


[Train] Epoch 37: 100%|â–ˆ| 967/967 [23:01<00:00,  1.43s/it, Loss=0.0407, Seg=0.0678, Dist=1.3


[Train] Epoch 037: Loss = 0.0353, Seg = 0.0589, Dist = 0.0000
Early stopping counter: 3/10


[Train] Epoch 38: 100%|â–ˆ| 967/967 [23:00<00:00,  1.43s/it, Loss=0.0249, Seg=0.0414, Dist=6.2


[Train] Epoch 038: Loss = 0.0323, Seg = 0.0538, Dist = 0.0000
Early stopping counter: 4/10


[Train] Epoch 39: 100%|â–ˆ| 967/967 [23:00<00:00,  1.43s/it, Loss=0.0532, Seg=0.0887, Dist=5.9


[Train] Epoch 039: Loss = 0.0313, Seg = 0.0522, Dist = 0.0000
Early stopping counter: 5/10


[Train] Epoch 40: 100%|â–ˆ| 967/967 [23:01<00:00,  1.43s/it, Loss=0.0325, Seg=0.0541, Dist=1.8


[Train] Epoch 040: Loss = 0.0369, Seg = 0.0615, Dist = 0.0000
Early stopping counter: 6/10


[Train] Epoch 41: 100%|â–ˆ| 967/967 [22:57<00:00,  1.42s/it, Loss=0.0333, Seg=0.0556, Dist=1.7


[Train] Epoch 041: Loss = 0.0258, Seg = 0.0429, Dist = 0.0000
âœ… Saved new best model at epoch 41 with train_loss = 0.0258


[Train] Epoch 42: 100%|â–ˆ| 967/967 [22:56<00:00,  1.42s/it, Loss=0.0286, Seg=0.0477, Dist=4.4


[Train] Epoch 042: Loss = 0.0246, Seg = 0.0409, Dist = 0.0000
âœ… Saved new best model at epoch 42 with train_loss = 0.0246


[Train] Epoch 43: 100%|â–ˆ| 967/967 [22:54<00:00,  1.42s/it, Loss=0.0258, Seg=0.043, Dist=7.51


[Train] Epoch 043: Loss = 0.0398, Seg = 0.0663, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 44: 100%|â–ˆ| 967/967 [22:50<00:00,  1.42s/it, Loss=0.0198, Seg=0.0329, Dist=3.9


[Train] Epoch 044: Loss = 0.0282, Seg = 0.0469, Dist = 0.0000
Early stopping counter: 2/10


[Train] Epoch 45: 100%|â–ˆ| 967/967 [22:52<00:00,  1.42s/it, Loss=0.0335, Seg=0.0558, Dist=6.7


[Train] Epoch 045: Loss = 0.0270, Seg = 0.0449, Dist = 0.0000
Early stopping counter: 3/10


[Train] Epoch 46: 100%|â–ˆ| 967/967 [22:44<00:00,  1.41s/it, Loss=0.0263, Seg=0.0438, Dist=3.9


[Train] Epoch 046: Loss = 0.0312, Seg = 0.0521, Dist = 0.0000
Early stopping counter: 4/10


[Train] Epoch 47: 100%|â–ˆ| 967/967 [22:44<00:00,  1.41s/it, Loss=0.0455, Seg=0.0759, Dist=1.0


[Train] Epoch 047: Loss = 0.0270, Seg = 0.0450, Dist = 0.0000
Early stopping counter: 5/10


[Train] Epoch 48: 100%|â–ˆ| 967/967 [22:40<00:00,  1.41s/it, Loss=0.0217, Seg=0.0361, Dist=8.1


[Train] Epoch 048: Loss = 0.0300, Seg = 0.0500, Dist = 0.0000
Early stopping counter: 6/10


[Train] Epoch 49: 100%|â–ˆ| 967/967 [22:38<00:00,  1.41s/it, Loss=0.0179, Seg=0.0298, Dist=3.5


[Train] Epoch 049: Loss = 0.0222, Seg = 0.0370, Dist = 0.0000
âœ… Saved new best model at epoch 49 with train_loss = 0.0222


[Train] Epoch 50: 100%|â–ˆ| 967/967 [22:35<00:00,  1.40s/it, Loss=0.0334, Seg=0.0557, Dist=1.3


[Train] Epoch 050: Loss = 0.0215, Seg = 0.0359, Dist = 0.0000
âœ… Saved new best model at epoch 50 with train_loss = 0.0215
ðŸ”’ Freezing encoder layers (conv0_0 to conv4_0)...
âœ… Frozen: conv0_0.conv1.weight
âœ… Frozen: conv0_0.conv1.bias
âœ… Frozen: conv0_0.bn1.weight
âœ… Frozen: conv0_0.bn1.bias
âœ… Frozen: conv0_0.conv2.weight
âœ… Frozen: conv0_0.conv2.bias
âœ… Frozen: conv0_0.bn2.weight
âœ… Frozen: conv0_0.bn2.bias
âœ… Frozen: conv1_0.conv1.weight
âœ… Frozen: conv1_0.conv1.bias
âœ… Frozen: conv1_0.bn1.weight
âœ… Frozen: conv1_0.bn1.bias
âœ… Frozen: conv1_0.conv2.weight
âœ… Frozen: conv1_0.conv2.bias
âœ… Frozen: conv1_0.bn2.weight
âœ… Frozen: conv1_0.bn2.bias
âœ… Frozen: conv2_0.conv1.weight
âœ… Frozen: conv2_0.conv1.bias
âœ… Frozen: conv2_0.bn1.weight
âœ… Frozen: conv2_0.bn1.bias
âœ… Frozen: conv2_0.conv2.weight
âœ… Frozen: conv2_0.conv2.bias
âœ… Frozen: conv2_0.bn2.weight
âœ… Frozen: conv2_0.bn2.bias
âœ… Frozen: conv3_0.conv1.weight
âœ… Frozen: conv3_0.conv1.bias
âœ… Frozen: con

[Train] Epoch 51: 100%|â–ˆ| 967/967 [19:00<00:00,  1.18s/it, Loss=0.0187, Seg=0.0312, Dist=1.1


[Train] Epoch 051: Loss = 0.0206, Seg = 0.0344, Dist = 0.0000
âœ… Saved new best model at epoch 51 with train_loss = 0.0206


[Train] Epoch 52: 100%|â–ˆ| 967/967 [19:03<00:00,  1.18s/it, Loss=0.0217, Seg=0.0361, Dist=7e-


[Train] Epoch 052: Loss = 0.0197, Seg = 0.0328, Dist = 0.0000
âœ… Saved new best model at epoch 52 with train_loss = 0.0197


[Train] Epoch 53: 100%|â–ˆ| 967/967 [19:01<00:00,  1.18s/it, Loss=0.0182, Seg=0.0304, Dist=4.6


[Train] Epoch 053: Loss = 0.0194, Seg = 0.0323, Dist = 0.0000
âœ… Saved new best model at epoch 53 with train_loss = 0.0194


[Train] Epoch 54: 100%|â–ˆ| 967/967 [18:59<00:00,  1.18s/it, Loss=0.0172, Seg=0.0287, Dist=7.2


[Train] Epoch 054: Loss = 0.0189, Seg = 0.0314, Dist = 0.0000
âœ… Saved new best model at epoch 54 with train_loss = 0.0189


[Train] Epoch 55: 100%|â–ˆ| 967/967 [18:59<00:00,  1.18s/it, Loss=0.0157, Seg=0.0261, Dist=5.6


[Train] Epoch 055: Loss = 0.0185, Seg = 0.0309, Dist = 0.0000
âœ… Saved new best model at epoch 55 with train_loss = 0.0185


[Train] Epoch 56: 100%|â–ˆ| 967/967 [19:00<00:00,  1.18s/it, Loss=0.0183, Seg=0.0305, Dist=5.5


[Train] Epoch 056: Loss = 0.0182, Seg = 0.0303, Dist = 0.0000
âœ… Saved new best model at epoch 56 with train_loss = 0.0182


[Train] Epoch 57: 100%|â–ˆ| 967/967 [18:53<00:00,  1.17s/it, Loss=0.0144, Seg=0.024, Dist=4.12


[Train] Epoch 057: Loss = 0.0179, Seg = 0.0298, Dist = 0.0000
âœ… Saved new best model at epoch 57 with train_loss = 0.0179


[Train] Epoch 58: 100%|â–ˆ| 967/967 [19:02<00:00,  1.18s/it, Loss=0.0195, Seg=0.0326, Dist=1.2


[Train] Epoch 058: Loss = 0.0175, Seg = 0.0292, Dist = 0.0000
âœ… Saved new best model at epoch 58 with train_loss = 0.0175


[Train] Epoch 59: 100%|â–ˆ| 967/967 [18:58<00:00,  1.18s/it, Loss=0.0211, Seg=0.0351, Dist=9.9


[Train] Epoch 059: Loss = 0.0174, Seg = 0.0289, Dist = 0.0000
âœ… Saved new best model at epoch 59 with train_loss = 0.0174


[Train] Epoch 60: 100%|â–ˆ| 967/967 [18:57<00:00,  1.18s/it, Loss=0.013, Seg=0.0217, Dist=5.11


[Train] Epoch 060: Loss = 0.0170, Seg = 0.0283, Dist = 0.0000
âœ… Saved new best model at epoch 60 with train_loss = 0.0170


[Train] Epoch 61: 100%|â–ˆ| 967/967 [18:56<00:00,  1.18s/it, Loss=0.0149, Seg=0.0248, Dist=1.0


[Train] Epoch 061: Loss = 0.0167, Seg = 0.0278, Dist = 0.0000
âœ… Saved new best model at epoch 61 with train_loss = 0.0167


[Train] Epoch 62: 100%|â–ˆ| 967/967 [18:52<00:00,  1.17s/it, Loss=0.0162, Seg=0.027, Dist=3.65


[Train] Epoch 062: Loss = 0.0166, Seg = 0.0277, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 63: 100%|â–ˆ| 967/967 [18:38<00:00,  1.16s/it, Loss=0.0171, Seg=0.0284, Dist=9.7


[Train] Epoch 063: Loss = 0.0162, Seg = 0.0271, Dist = 0.0000
âœ… Saved new best model at epoch 63 with train_loss = 0.0162


[Train] Epoch 64: 100%|â–ˆ| 967/967 [17:36<00:00,  1.09s/it, Loss=0.0179, Seg=0.0298, Dist=6.3


[Train] Epoch 064: Loss = 0.0160, Seg = 0.0266, Dist = 0.0000
âœ… Saved new best model at epoch 64 with train_loss = 0.0160


[Train] Epoch 65: 100%|â–ˆ| 967/967 [17:11<00:00,  1.07s/it, Loss=0.0135, Seg=0.0226, Dist=2.2


[Train] Epoch 065: Loss = 0.0158, Seg = 0.0264, Dist = 0.0000
âœ… Saved new best model at epoch 65 with train_loss = 0.0158


[Train] Epoch 66: 100%|â–ˆ| 967/967 [16:58<00:00,  1.05s/it, Loss=0.0152, Seg=0.0254, Dist=9.3


[Train] Epoch 066: Loss = 0.0158, Seg = 0.0263, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 67: 100%|â–ˆ| 967/967 [16:49<00:00,  1.04s/it, Loss=0.017, Seg=0.0283, Dist=5.17


[Train] Epoch 067: Loss = 0.0154, Seg = 0.0257, Dist = 0.0000
âœ… Saved new best model at epoch 67 with train_loss = 0.0154


[Train] Epoch 68: 100%|â–ˆ| 967/967 [16:50<00:00,  1.05s/it, Loss=0.0165, Seg=0.0275, Dist=1.5


[Train] Epoch 068: Loss = 0.0152, Seg = 0.0253, Dist = 0.0000
âœ… Saved new best model at epoch 68 with train_loss = 0.0152


[Train] Epoch 69: 100%|â–ˆ| 967/967 [16:48<00:00,  1.04s/it, Loss=0.0158, Seg=0.0263, Dist=1.1


[Train] Epoch 069: Loss = 0.0150, Seg = 0.0251, Dist = 0.0000
âœ… Saved new best model at epoch 69 with train_loss = 0.0150


[Train] Epoch 70: 100%|â–ˆ| 967/967 [16:47<00:00,  1.04s/it, Loss=0.018, Seg=0.03, Dist=1.53e-


[Train] Epoch 070: Loss = 0.0148, Seg = 0.0247, Dist = 0.0000
âœ… Saved new best model at epoch 70 with train_loss = 0.0148


[Train] Epoch 71: 100%|â–ˆ| 967/967 [16:54<00:00,  1.05s/it, Loss=0.0125, Seg=0.0208, Dist=1.0


[Train] Epoch 071: Loss = 0.0146, Seg = 0.0243, Dist = 0.0000
âœ… Saved new best model at epoch 71 with train_loss = 0.0146


[Train] Epoch 72: 100%|â–ˆ| 967/967 [17:00<00:00,  1.06s/it, Loss=0.0137, Seg=0.0229, Dist=7.6


[Train] Epoch 072: Loss = 0.0149, Seg = 0.0249, Dist = 0.0000
Early stopping counter: 1/10


[Train] Epoch 73: 100%|â–ˆ| 967/967 [17:01<00:00,  1.06s/it, Loss=0.0127, Seg=0.0211, Dist=4.3


[Train] Epoch 073: Loss = 0.0142, Seg = 0.0237, Dist = 0.0000
âœ… Saved new best model at epoch 73 with train_loss = 0.0142


[Train] Epoch 74:   7%| | 70/967 [01:13<16:09,  1.08s/it, Loss=0.0144, Seg=0.0239, Dist=5.12

In [4]:
# === Testing for Kvasir with SAM + DINO + OneFormer Features ===
import torch
import torch.nn as nn
from tqdm import tqdm
import os
from statistics import mean
from torchvision.utils import save_image

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(min=1e-7, max=1 - 1e-7)
        preds = preds.contiguous()
        targets = targets.contiguous()

        intersection = (preds * targets).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def compute_metrics(preds, targets, smooth=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)

    tp = (preds * targets).sum().item()
    fp = (preds * (1 - targets)).sum().item()
    fn = ((1 - preds) * targets).sum().item()

    dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    iou = (tp + smooth) / (tp + fp + fn + smooth)
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)

    return dice, iou, precision, recall

# ---- DEVICE ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- LOAD MODEL ----
model = NestedUNet(num_classes=1, input_channels=3).to(device)
model.load_state_dict(torch.load(
    "/home/deepak1010/Shivanshu Code/features_sam_clip/unetplusplus_sam_dino_oneformer.pth",
    map_location=device
))
model.eval()

# ---- LOSSES ----
seg_loss = DiceLoss()
criteria = nn.BCEWithLogitsLoss()

# ---- TEST LOOP ----
test_losses, test_seg_losses, test_distil_losses = [], [], []
all_dice, all_iou, all_precision, all_recall = [], [], [], []

save_pred_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_kvasir_unetplusplus"
os.makedirs(save_pred_dir, exist_ok=True)

with torch.no_grad():
    tq = tqdm(test_kvasir_dataloader, desc="[Test]")
    for i, batch in enumerate(tq):
        # Input image & GT
        x = batch["pixel_values"].to(device)
        mask = batch["ground_truth_mask"].unsqueeze(1).to(device)

        # SAM features
        f1 = batch["f1"].to(device)                         # [B, N, C, H, W]
        B, N, C, H, W = f1.shape
        f1 = f1.view(B, N * C, H, W)                        # [B, N*C, H, W]

        # DINOv2 features
        f2 = batch["f2"].to(device)                         # [B, 512, H, W]

        # OneFormer features
        f3 = batch["f3"].to(device)                         # shape depends on extractor

        # Forward pass
        logits, dist1, dist2 = model(x, f1, f2, f3)

        # Resize mask
        mask_resized = nn.functional.interpolate(mask, size=logits.shape[2:], mode="bilinear", align_corners=False)

        # Loss calculation (same as training)
        seg_loss_value = seg_loss(logits, mask_resized) + criteria(logits, mask_resized)
        loss = 0.6 * seg_loss_value + 0.1 * dist1 + 0.1 * dist2

        test_losses.append(loss.item())
        test_seg_losses.append(seg_loss_value.item())
        test_distil_losses.append((dist1.item() + dist2.item()) / 2)

        # Predictions -> binary
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        # Metrics
        dice, iou, precision, recall = compute_metrics(preds.cpu(), mask_resized.cpu())
        all_dice.append(dice)
        all_iou.append(iou)
        all_precision.append(precision)
        all_recall.append(recall)

        # Save first 20 predictions
        if i < 20:
            save_image(preds, os.path.join(save_pred_dir, f"pred_{i}.png"))
            save_image(mask_resized.float(), os.path.join(save_pred_dir, f"gt_{i}.png"))

# ---- FINAL RESULTS ----
print(f"[Test Results] Loss = {mean(test_losses):.4f}, Seg = {mean(test_seg_losses):.4f}, Dist = {mean(test_distil_losses):.4f}")
print(f"[Metrics] mDice = {mean(all_dice):.4f}, mIoU = {mean(all_iou):.4f}, Precision = {mean(all_precision):.4f}, Recall = {mean(all_recall):.4f}")
print(f"âœ… Predictions saved in {save_pred_dir}")


[Test]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 17/17 [00:20<00:00,  1.22s/it]

[Test Results] Loss = 0.3000, Seg = 0.5000, Dist = 0.0000
[Metrics] mDice = 0.8586, mIoU = 0.7584, Precision = 0.8831, Recall = 0.8467
âœ… Predictions saved in /home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_kvasir_unetplusplus





In [5]:
# === Testing for Kvasir with SAM + DINO + OneFormer Features ===
import torch
import torch.nn as nn
from tqdm import tqdm
import os
from statistics import mean
from torchvision.utils import save_image

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(min=1e-7, max=1 - 1e-7)
        preds = preds.contiguous()
        targets = targets.contiguous()

        intersection = (preds * targets).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def compute_metrics(preds, targets, smooth=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)

    tp = (preds * targets).sum().item()
    fp = (preds * (1 - targets)).sum().item()
    fn = ((1 - preds) * targets).sum().item()

    dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    iou = (tp + smooth) / (tp + fp + fn + smooth)
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)

    return dice, iou, precision, recall

# ---- DEVICE ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- LOAD MODEL ----
model = NestedUNet(num_classes=1, input_channels=3).to(device)
model.load_state_dict(torch.load(
    "/home/deepak1010/Shivanshu Code/features_sam_clip/unetplusplus_sam_dino_oneformer.pth",
    map_location=device
))
model.eval()

# ---- LOSSES ----
seg_loss = DiceLoss()
criteria = nn.BCEWithLogitsLoss()

# ---- TEST LOOP ----
test_losses, test_seg_losses, test_distil_losses = [], [], []
all_dice, all_iou, all_precision, all_recall = [], [], [], []

save_pred_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_cvc_unetplusplus"
os.makedirs(save_pred_dir, exist_ok=True)

with torch.no_grad():
    tq = tqdm(test_cvc_dataloader, desc="[Test]")
    for i, batch in enumerate(tq):
        # Input image & GT
        x = batch["pixel_values"].to(device)
        mask = batch["ground_truth_mask"].unsqueeze(1).to(device)

        # SAM features
        f1 = batch["f1"].to(device)                         # [B, N, C, H, W]
        B, N, C, H, W = f1.shape
        f1 = f1.view(B, N * C, H, W)                        # [B, N*C, H, W]

        # DINOv2 features
        f2 = batch["f2"].to(device)                         # [B, 512, H, W]

        # OneFormer features
        f3 = batch["f3"].to(device)                         # shape depends on extractor

        # Forward pass
        logits, dist1, dist2 = model(x, f1, f2, f3)

        # Resize mask
        mask_resized = nn.functional.interpolate(mask, size=logits.shape[2:], mode="bilinear", align_corners=False)

        # Loss calculation (same as training)
        seg_loss_value = seg_loss(logits, mask_resized) + criteria(logits, mask_resized)
        loss = 0.6 * seg_loss_value + 0.1 * dist1 + 0.1 * dist2

        test_losses.append(loss.item())
        test_seg_losses.append(seg_loss_value.item())
        test_distil_losses.append((dist1.item() + dist2.item()) / 2)

        # Predictions -> binary
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        # Metrics
        dice, iou, precision, recall = compute_metrics(preds.cpu(), mask_resized.cpu())
        all_dice.append(dice)
        all_iou.append(iou)
        all_precision.append(precision)
        all_recall.append(recall)

        # Save first 20 predictions
        if i < 20:
            save_image(preds, os.path.join(save_pred_dir, f"pred_{i}.png"))
            save_image(mask_resized.float(), os.path.join(save_pred_dir, f"gt_{i}.png"))

# ---- FINAL RESULTS ----
print(f"[Test Results] Loss = {mean(test_losses):.4f}, Seg = {mean(test_seg_losses):.4f}, Dist = {mean(test_distil_losses):.4f}")
print(f"[Metrics] mDice = {mean(all_dice):.4f}, mIoU = {mean(all_iou):.4f}, Precision = {mean(all_precision):.4f}, Recall = {mean(all_recall):.4f}")
print(f"âœ… Predictions saved in {save_pred_dir}")


[Test]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 11/11 [00:12<00:00,  1.09s/it]

[Test Results] Loss = 0.0684, Seg = 0.1139, Dist = 0.0000
[Metrics] mDice = 0.9471, mIoU = 0.9000, Precision = 0.9563, Recall = 0.9387
âœ… Predictions saved in /home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_cvc_unetplusplus





In [6]:
# === Testing for Kvasir with SAM + DINO + OneFormer Features ===
import torch
import torch.nn as nn
from tqdm import tqdm
import os
from statistics import mean
from torchvision.utils import save_image

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(min=1e-7, max=1 - 1e-7)
        preds = preds.contiguous()
        targets = targets.contiguous()

        intersection = (preds * targets).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def compute_metrics(preds, targets, smooth=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)

    tp = (preds * targets).sum().item()
    fp = (preds * (1 - targets)).sum().item()
    fn = ((1 - preds) * targets).sum().item()

    dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    iou = (tp + smooth) / (tp + fp + fn + smooth)
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)

    return dice, iou, precision, recall

# ---- DEVICE ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- LOAD MODEL ----
model = NestedUNet(num_classes=1, input_channels=3).to(device)
model.load_state_dict(torch.load(
    "/home/deepak1010/Shivanshu Code/features_sam_clip/unetplusplus_sam_dino_oneformer.pth",
    map_location=device
))
model.eval()

# ---- LOSSES ----
seg_loss = DiceLoss()
criteria = nn.BCEWithLogitsLoss()

# ---- TEST LOOP ----
test_losses, test_seg_losses, test_distil_losses = [], [], []
all_dice, all_iou, all_precision, all_recall = [], [], [], []

save_pred_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_cvc_300_unetplusplus"
os.makedirs(save_pred_dir, exist_ok=True)

with torch.no_grad():
    tq = tqdm(test_cvc_300_dataloader, desc="[Test]")
    for i, batch in enumerate(tq):
        # Input image & GT
        x = batch["pixel_values"].to(device)
        mask = batch["ground_truth_mask"].unsqueeze(1).to(device)

        # SAM features
        f1 = batch["f1"].to(device)                         # [B, N, C, H, W]
        B, N, C, H, W = f1.shape
        f1 = f1.view(B, N * C, H, W)                        # [B, N*C, H, W]

        # DINOv2 features
        f2 = batch["f2"].to(device)                         # [B, 512, H, W]

        # OneFormer features
        f3 = batch["f3"].to(device)                         # shape depends on extractor

        # Forward pass
        logits, dist1, dist2 = model(x, f1, f2, f3)

        # Resize mask
        mask_resized = nn.functional.interpolate(mask, size=logits.shape[2:], mode="bilinear", align_corners=False)

        # Loss calculation (same as training)
        seg_loss_value = seg_loss(logits, mask_resized) + criteria(logits, mask_resized)
        loss = 0.6 * seg_loss_value + 0.1 * dist1 + 0.1 * dist2

        test_losses.append(loss.item())
        test_seg_losses.append(seg_loss_value.item())
        test_distil_losses.append((dist1.item() + dist2.item()) / 2)

        # Predictions -> binary
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        # Metrics
        dice, iou, precision, recall = compute_metrics(preds.cpu(), mask_resized.cpu())
        all_dice.append(dice)
        all_iou.append(iou)
        all_precision.append(precision)
        all_recall.append(recall)

        # Save first 20 predictions
        if i < 20:
            save_image(preds, os.path.join(save_pred_dir, f"pred_{i}.png"))
            save_image(mask_resized.float(), os.path.join(save_pred_dir, f"gt_{i}.png"))

# ---- FINAL RESULTS ----
print(f"[Test Results] Loss = {mean(test_losses):.4f}, Seg = {mean(test_seg_losses):.4f}, Dist = {mean(test_distil_losses):.4f}")
print(f"[Metrics] mDice = {mean(all_dice):.4f}, mIoU = {mean(all_iou):.4f}, Precision = {mean(all_precision):.4f}, Recall = {mean(all_recall):.4f}")
print(f"âœ… Predictions saved in {save_pred_dir}")


[Test]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:10<00:00,  1.09s/it]

[Test Results] Loss = 0.1512, Seg = 0.2520, Dist = 0.0000
[Metrics] mDice = 0.8416, mIoU = 0.7447, Precision = 0.8565, Recall = 0.8399
âœ… Predictions saved in /home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_cvc_300_unetplusplus





In [7]:
# === Testing for Kvasir with SAM + DINO + OneFormer Features ===
import torch
import torch.nn as nn
from tqdm import tqdm
import os
from statistics import mean
from torchvision.utils import save_image

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(min=1e-7, max=1 - 1e-7)
        preds = preds.contiguous()
        targets = targets.contiguous()

        intersection = (preds * targets).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def compute_metrics(preds, targets, smooth=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)

    tp = (preds * targets).sum().item()
    fp = (preds * (1 - targets)).sum().item()
    fn = ((1 - preds) * targets).sum().item()

    dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    iou = (tp + smooth) / (tp + fp + fn + smooth)
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)

    return dice, iou, precision, recall

# ---- DEVICE ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- LOAD MODEL ----
model = NestedUNet(num_classes=1, input_channels=3).to(device)
model.load_state_dict(torch.load(
    "/home/deepak1010/Shivanshu Code/features_sam_clip/unetplusplus_sam_dino_oneformer.pth",
    map_location=device
))
model.eval()

# ---- LOSSES ----
seg_loss = DiceLoss()
criteria = nn.BCEWithLogitsLoss()

# ---- TEST LOOP ----
test_losses, test_seg_losses, test_distil_losses = [], [], []
all_dice, all_iou, all_precision, all_recall = [], [], [], []

save_pred_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_cvc_colondb_unetplusplus"
os.makedirs(save_pred_dir, exist_ok=True)

with torch.no_grad():
    tq = tqdm(test_cvc_colondb_dataloader, desc="[Test]")
    for i, batch in enumerate(tq):
        # Input image & GT
        x = batch["pixel_values"].to(device)
        mask = batch["ground_truth_mask"].unsqueeze(1).to(device)

        # SAM features
        f1 = batch["f1"].to(device)                         # [B, N, C, H, W]
        B, N, C, H, W = f1.shape
        f1 = f1.view(B, N * C, H, W)                        # [B, N*C, H, W]

        # DINOv2 features
        f2 = batch["f2"].to(device)                         # [B, 512, H, W]

        # OneFormer features
        f3 = batch["f3"].to(device)                         # shape depends on extractor

        # Forward pass
        logits, dist1, dist2 = model(x, f1, f2, f3)

        # Resize mask
        mask_resized = nn.functional.interpolate(mask, size=logits.shape[2:], mode="bilinear", align_corners=False)

        # Loss calculation (same as training)
        seg_loss_value = seg_loss(logits, mask_resized) + criteria(logits, mask_resized)
        loss = 0.6 * seg_loss_value + 0.1 * dist1 + 0.1 * dist2

        test_losses.append(loss.item())
        test_seg_losses.append(seg_loss_value.item())
        test_distil_losses.append((dist1.item() + dist2.item()) / 2)

        # Predictions -> binary
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        # Metrics
        dice, iou, precision, recall = compute_metrics(preds.cpu(), mask_resized.cpu())
        all_dice.append(dice)
        all_iou.append(iou)
        all_precision.append(precision)
        all_recall.append(recall)

        # Save first 20 predictions
        if i < 20:
            save_image(preds, os.path.join(save_pred_dir, f"pred_{i}.png"))
            save_image(mask_resized.float(), os.path.join(save_pred_dir, f"gt_{i}.png"))

# ---- FINAL RESULTS ----
print(f"[Test Results] Loss = {mean(test_losses):.4f}, Seg = {mean(test_seg_losses):.4f}, Dist = {mean(test_distil_losses):.4f}")
print(f"[Metrics] mDice = {mean(all_dice):.4f}, mIoU = {mean(all_iou):.4f}, Precision = {mean(all_precision):.4f}, Recall = {mean(all_recall):.4f}")
print(f"âœ… Predictions saved in {save_pred_dir}")


[Test]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 64/64 [01:12<00:00,  1.14s/it]

[Test Results] Loss = 0.5373, Seg = 0.8955, Dist = 0.0000
[Metrics] mDice = 0.7115, mIoU = 0.5987, Precision = 0.8333, Recall = 0.6787
âœ… Predictions saved in /home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_cvc_colondb_unetplusplus





In [8]:
# === Testing for Kvasir with SAM + DINO + OneFormer Features ===
import torch
import torch.nn as nn
from tqdm import tqdm
import os
from statistics import mean
from torchvision.utils import save_image

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.clamp(min=1e-7, max=1 - 1e-7)
        preds = preds.contiguous()
        targets = targets.contiguous()

        intersection = (preds * targets).sum(dim=(2, 3))
        union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def compute_metrics(preds, targets, smooth=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)

    tp = (preds * targets).sum().item()
    fp = (preds * (1 - targets)).sum().item()
    fn = ((1 - preds) * targets).sum().item()

    dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    iou = (tp + smooth) / (tp + fp + fn + smooth)
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)

    return dice, iou, precision, recall

# ---- DEVICE ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- LOAD MODEL ----
model = NestedUNet(num_classes=1, input_channels=3).to(device)
model.load_state_dict(torch.load(
    "/home/deepak1010/Shivanshu Code/features_sam_clip/unetplusplus_sam_dino_oneformer.pth",
    map_location=device
))
model.eval()

# ---- LOSSES ----
seg_loss = DiceLoss()
criteria = nn.BCEWithLogitsLoss()

# ---- TEST LOOP ----
test_losses, test_seg_losses, test_distil_losses = [], [], []
all_dice, all_iou, all_precision, all_recall = [], [], [], []

save_pred_dir = "/home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_etis_unetplusplus"
os.makedirs(save_pred_dir, exist_ok=True)

with torch.no_grad():
    tq = tqdm(test_etis_dataloader, desc="[Test]")
    for i, batch in enumerate(tq):
        # Input image & GT
        x = batch["pixel_values"].to(device)
        mask = batch["ground_truth_mask"].unsqueeze(1).to(device)

        # SAM features
        f1 = batch["f1"].to(device)                         # [B, N, C, H, W]
        B, N, C, H, W = f1.shape
        f1 = f1.view(B, N * C, H, W)                        # [B, N*C, H, W]

        # DINOv2 features
        f2 = batch["f2"].to(device)                         # [B, 512, H, W]

        # OneFormer features
        f3 = batch["f3"].to(device)                         # shape depends on extractor

        # Forward pass
        logits, dist1, dist2 = model(x, f1, f2, f3)

        # Resize mask
        mask_resized = nn.functional.interpolate(mask, size=logits.shape[2:], mode="bilinear", align_corners=False)

        # Loss calculation (same as training)
        seg_loss_value = seg_loss(logits, mask_resized) + criteria(logits, mask_resized)
        loss = 0.6 * seg_loss_value + 0.1 * dist1 + 0.1 * dist2

        test_losses.append(loss.item())
        test_seg_losses.append(seg_loss_value.item())
        test_distil_losses.append((dist1.item() + dist2.item()) / 2)

        # Predictions -> binary
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        # Metrics
        dice, iou, precision, recall = compute_metrics(preds.cpu(), mask_resized.cpu())
        all_dice.append(dice)
        all_iou.append(iou)
        all_precision.append(precision)
        all_recall.append(recall)

        # Save first 20 predictions
        if i < 20:
            save_image(preds, os.path.join(save_pred_dir, f"pred_{i}.png"))
            save_image(mask_resized.float(), os.path.join(save_pred_dir, f"gt_{i}.png"))

# ---- FINAL RESULTS ----
print(f"[Test Results] Loss = {mean(test_losses):.4f}, Seg = {mean(test_seg_losses):.4f}, Dist = {mean(test_distil_losses):.4f}")
print(f"[Metrics] mDice = {mean(all_dice):.4f}, mIoU = {mean(all_iou):.4f}, Precision = {mean(all_precision):.4f}, Recall = {mean(all_recall):.4f}")
print(f"âœ… Predictions saved in {save_pred_dir}")


[Test]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 33/33 [00:52<00:00,  1.60s/it]

[Test Results] Loss = 0.4775, Seg = 0.7959, Dist = 0.0000
[Metrics] mDice = 0.5188, mIoU = 0.3992, Precision = 0.6649, Recall = 0.5040
âœ… Predictions saved in /home/deepak1010/Shivanshu Code/features_sam_clip/test_predictions_etis_unetplusplus



