In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from torchmetrics.classification import (
    BinaryJaccardIndex, BinaryF1Score, BinaryAccuracy,
    BinaryRecall, BinaryPrecision, BinaryAUROC
)
from tqdm import tqdm
import timm

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, 2)
        self.conv1 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.up(x)
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        return x

class SkipUpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, 2)
        self.conv1 = nn.Conv2d(out_ch + skip_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[2:] != skip.shape[2:]:
            skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        return x

class DepthwiseConvTransformer(nn.Module):
    def __init__(self, in_ch, out_ch, expansion=4):
        super().__init__()
        hidden = in_ch * expansion
        self.proj1 = nn.Conv2d(in_ch, hidden, 1)
        self.dw = nn.Conv2d(hidden, hidden, 3, padding=1, groups=hidden)
        self.bn = nn.BatchNorm2d(hidden)
        self.act = nn.GELU()
        self.proj2 = nn.Conv2d(hidden, out_ch, 1)

    def forward(self, x):
        x = self.proj1(x)
        x = self.dw(x)
        x = self.bn(x)
        x = self.act(x)
        return self.proj2(x)

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch),
        )

    def forward(self, x):
        return x + self.block(x)

# Cell 3: Encoders
class EfficientNetB4Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("efficientnet_b4", pretrained=True, features_only=True)
        self.out_ch = [24, 32, 56, 160, 448]

    def forward(self, x):
        return self.encoder(x)

class DenseNet121Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("densenet121", pretrained=True, features_only=True)

    def forward(self, x):
        return self.encoder(x)

# Cell 4: Model Factory
class DualEncoderModel(nn.Module):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

        # --- Encoder 1 + optional Transformer ---
        self.enc1 = EfficientNetB4Encoder()
        bottleneck_ch = self.enc1.out_ch[-1]              # 448 for B4
        self.use_dct = (variant not in ['a1'])  # include in full (a5)

        self.trans = (
            DepthwiseConvTransformer(bottleneck_ch, bottleneck_ch)
            if self.use_dct else
            nn.Identity()
        )

        # --- Decoder 1 (skip connections) ---
        rev_ch = self.enc1.out_ch[::-1]   # [448,160,56,32,24]
        skip_ch = rev_ch[1:]              # [160,56,32,24]
        dec_chs  = [256, 128, 64, 32]
        self.dec1 = nn.ModuleList([
            SkipUpBlock(bottleneck_ch, skip_ch[0], dec_chs[0]),
            SkipUpBlock(dec_chs[0], skip_ch[1], dec_chs[1]),
            SkipUpBlock(dec_chs[1], skip_ch[2], dec_chs[2]),
            SkipUpBlock(dec_chs[2], skip_ch[3], dec_chs[3]),
            UpBlock(dec_chs[3], 16)
        ])
        self.head1 = nn.Conv2d(16, 1, kernel_size=1)

        # --- Optional Encoder 2 branch ---
        self.use_encoder2 = (variant not in ['a3'])  # include in full (a5)
        if self.use_encoder2:
            self.enc2 = DenseNet121Encoder()

            # VSS vs ResBlock
            self.use_vss = (variant not in ['a2'])  # include in full (a5)
            if self.use_vss:
                self.vss = nn.Sequential(
                    nn.BatchNorm2d(1024),
                    nn.Conv2d(1024, 1024, 1),
                    nn.Conv2d(1024, 1024, 3, padding=1, groups=1024),
                    nn.GELU(), nn.Dropout2d(0.1),
                    nn.Conv2d(1024, 512, 1), nn.BatchNorm2d(512),
                )
                dec2_in_ch = 512
            else:
                self.vss = ResBlock(1024)
                dec2_in_ch = 1024

            # Decoder 2
            self.dec2 = nn.ModuleList([
                UpBlock(dec2_in_ch, 256),
                UpBlock(256, 128),
                UpBlock(128, 64),
                UpBlock(64, 32),
            ])
            self.head2 = nn.Conv2d(32, 1, 1)

        # --- Fusion or direct sum ---
        self.fuse = nn.Conv2d(2, 1, 3, padding=1) if variant not in ['a4'] else None

    def forward(self, x):
        # Stage 1
        f1 = self.enc1(x)
        x1 = self.trans(f1[-1])                  # now always bottleneck_ch
        for i, block in enumerate(self.dec1):
            x1 = block(x1, f1[-2 - i]) if i < 4 else block(x1)
        logit1 = self.head1(x1)

        if not self.use_encoder2:
            return logit1

        # Stage 2
        out1_prob = torch.sigmoid(logit1)
        f2 = self.enc2(x * out1_prob)
        x2 = self.vss(f2[-1])
        for block in self.dec2:
            x2 = block(x2)
        logit2 = F.interpolate(self.head2(x2), size=logit1.shape[2:], mode='bilinear')

        # Fuse or direct
        if self.fuse:
            return self.fuse(torch.cat([logit1, logit2], dim=1))
        return logit1 + logit2


In [3]:
# 3: Encoders
class EfficientNetB4Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("efficientnet_b4", pretrained=True, features_only=True)
        self.out_ch = [24, 32, 56, 160, 448]

    def forward(self, x):
        return self.encoder(x)

class DenseNet121Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("densenet121", pretrained=True, features_only=True)

    def forward(self, x):
        return self.encoder(x)

# Additional cells truncated for brevity...

In [4]:
# Cell 5: Dataset Loader
def get_dataset(name):
    root = "/kaggle/input"
    if name == "kvasir":
        image_dir = f"{root}/kvasirseg/Kvasir-SEG/Kvasir-SEG/images"
        mask_dir = f"{root}/kvasirseg/Kvasir-SEG/Kvasir-SEG/masks"
    elif name == "cvc":
        image_dir = f"{root}/cvcclinicdb/PNG/Original" 
        mask_dir = f"{root}/cvcclinicdb/PNG/Ground Truth" 
    elif name == "bkai":
        image_dir = f"{root}/bkai-igh-neopolyp/train/train"
        mask_dir = f"{root}/bkai-igh-neopolyp/train_gt/train_gt"
    else:
        raise ValueError("Unknown dataset")

    img_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir)])
    mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)])

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    class SegDataset(Dataset):
        def __init__(self, imgs, masks):
            self.imgs = imgs
            self.masks = masks
        def __len__(self): return len(self.imgs)
        def __getitem__(self, i):
            img = transform(Image.open(self.imgs[i]).convert("RGB"))
            mask = transform(Image.open(self.masks[i]).convert("L"))
            return img, (mask > 0).float()

    X_train, X_val, y_train, y_val = train_test_split(img_paths, mask_paths, test_size=0.2, random_state=42)
    return DataLoader(SegDataset(X_train, y_train), batch_size=8, shuffle=True), DataLoader(SegDataset(X_val, y_val), batch_size=8)


In [5]:

# Cell 6: Loss, Metrics, Trainer
class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, y_pred, y_true):
        y_pred = y_pred.contiguous()
        y_true = y_true.contiguous()
        intersection = (y_pred * y_true).sum(dim=(2, 3))
        union = y_pred.sum(dim=(2, 3)) + y_true.sum(dim=(2, 3))
        dice_score = (2. * intersection + 1) / (union + 1)
        return 1 - dice_score.mean()

def evaluate(model, loader, device):
    metrics = {
        "IoU": BinaryJaccardIndex(threshold=0.5).to(device),
        "Dice": BinaryF1Score(threshold=0.5).to(device),
        "Accuracy": BinaryAccuracy(threshold=0.5).to(device),
        "Recall": BinaryRecall(threshold=0.5).to(device),
        "Precision": BinaryPrecision(threshold=0.5).to(device),
        "AUROC": BinaryAUROC().to(device)
    }
    model.eval()
    with torch.no_grad():
        for metric in metrics.values(): metric.reset()
        for images, masks in loader:
            images, masks_f, masks_i = images.to(device), masks.to(device).float(), masks.to(device).int()
            logits = model(images)
            probs = torch.sigmoid(logits)
            for name, metric in metrics.items():
                metric.update(probs, masks_i)
    return {k: float(m.compute()) for k, m in metrics.items()}

def train_one(model, train_loader, val_loader, device):
    model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', patience=3, factor=0.5)
    bce, dice = nn.BCEWithLogitsLoss(), DiceLoss()
    best_loss = float('inf')
    for epoch in range(10):
        model.train()
        epoch_loss = 0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optim.zero_grad()
            logits = model(images)
            probs = torch.sigmoid(logits)
            loss = 0.5 * bce(logits, masks) + 0.5 * dice(probs, masks)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
        scheduler.step(epoch_loss / len(train_loader))
    return evaluate(model, val_loader, device)


In [6]:
variants = ['a5']
datasets = ['cvc', 'kvasir', 'bkai']
results = []

for variant in variants:
    for dataset in datasets:
        print(f"\n🚀 Running Variant: {variant.upper()} on Dataset: {dataset.upper()}")
        model = DualEncoderModel(variant)
        train_loader, val_loader = get_dataset(dataset)
        metrics = train_one(model, train_loader, val_loader, device)
        result_row = {"Variant": variant, "Dataset": dataset, **metrics}
        results.append(result_row)

# Export Results
results_df = pd.DataFrame(results)
results_df.to_csv("ablation_results.csv", index=False)
print("\n✅ Ablation Study Complete. Results saved to ablation_results.csv")
results_df



🚀 Running Variant: A5 on Dataset: CVC


model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/32.3M [00:00<?, ?B/s]


🚀 Running Variant: A5 on Dataset: KVASIR

🚀 Running Variant: A5 on Dataset: BKAI

✅ Ablation Study Complete. Results saved to ablation_results.csv


Unnamed: 0,Variant,Dataset,IoU,Dice,Accuracy,Recall,Precision,AUROC
0,a5,cvc,0.782259,0.877829,0.978087,0.893318,0.862868,0.991297
1,a5,kvasir,0.770625,0.870456,0.957073,0.849817,0.892121,0.982306
2,a5,bkai,0.800258,0.889048,0.985748,0.884227,0.893922,0.993936
