In [1]:
import os
import glob
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
import random

In [21]:
# -----------------------
# Configuration
# -----------------------
SEG_DIR      = 'Kvasir-SEG'          # Kvasir-SEG root (images/, masks/)
CLS_DIR      = 'Kvasir-dataset-v2'           # Kvasir-v2 classification (subdirs per class)
CVC_DIR      = 'CVC-ClinicDB'         # external evaluation dataset for segmentation
IMG_SIZE     = 256
BATCH_SIZE   = 8
LR           = 1e-4
EPOCHS       = 25
PATIENCE     = 3
SEED         = 42
ENCODER      = 'resnet34'
ENC_WEIGHTS  = 'imagenet'
NUM_CLASSES  = len(next(os.walk(CLS_DIR))[1])  # auto-count subfolders
DEVICE       = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

print(f"Configuration: IMG_SIZE={IMG_SIZE}, BATCH_SIZE={BATCH_SIZE}, EPOCHS={EPOCHS}, Patience: {PATIENCE}")
print(f"Number of classes: {NUM_CLASSES}, Device: {DEVICE}")

Configuration: IMG_SIZE=256, BATCH_SIZE=8, EPOCHS=25, Patience: 3
Number of classes: 8, Device: cuda


In [22]:
class SegDataset(Dataset):
    def __init__(self, imgs, masks, augment=None):
        self.imgs, self.masks, self.aug = imgs, masks, augment
    def __len__(self): return len(self.imgs)
    def __getitem__(self, i):
        image = np.array(Image.open(self.imgs[i]).convert('RGB').resize((IMG_SIZE,IMG_SIZE)))
        mask  = np.array(Image.open(self.masks[i]).convert('L').resize((IMG_SIZE,IMG_SIZE)))
        mask = (mask > 127).astype('float32')
        if self.aug:
            a = self.aug(image=image, mask=mask)
            image, mask = a['image'], a['mask']
        else:
            image = ToTensorV2()(image=image)['image']
            mask  = torch.from_numpy(mask).unsqueeze(0)
        return image, mask

class ClassDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.samples = []
        for cls in sorted(os.listdir(root)):
            p = os.path.join(root, cls)
            if os.path.isdir(p):
                for f in glob.glob(os.path.join(p, '*.jpg')):
                    self.samples.append((f, cls))
        self.cls2idx = {c:i for i,c in enumerate(sorted({c for _,c in self.samples}))}
        
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, i):
        fp, cls = self.samples[i]
        img = Image.open(fp).convert('RGB').resize((IMG_SIZE,IMG_SIZE))
        if self.transform: img = self.transform(img)
        return img, self.cls2idx[cls]

In [23]:
# -----------------------
# Augmentations
# -----------------------
train_aug = A.Compose([
    A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(0.1,0.1,15,p=0.5), A.ColorJitter(0.2,0.2,p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
    A.GaussianBlur(blur_limit=(3,7), p=0.3),
    A.Normalize(), ToTensorV2()
])
val_aug = A.Compose([A.Normalize(), ToTensorV2()])

In [24]:
# Segmentation paths
seg_imgs = sorted(glob.glob(os.path.join(SEG_DIR, 'images', '*.jpg')))
seg_msks = sorted(glob.glob(os.path.join(SEG_DIR, 'masks', '*.jpg')))
train_si, val_si, train_sm, val_sm = train_test_split(seg_imgs, seg_msks, test_size=0.2, random_state=42)
seg_train = SegDataset(train_si, train_sm, augment=train_aug)
seg_val   = SegDataset(val_si, val_sm, augment=val_aug)
seg_loader    = DataLoader(seg_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
seg_val_loader= DataLoader(seg_val,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Classification loader
cls_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
class_ds    = ClassDataset(CLS_DIR, transform=cls_transform)
cls_train, cls_val = train_test_split(class_ds, test_size=0.2, random_state=42)
cls_loader     = DataLoader(cls_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
cls_val_loader = DataLoader(cls_val,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f"Found {len(seg_imgs)} segmentation images and {len(seg_msks)} masks.")

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 786432 bytes.

In [None]:
# -----------------------
# Multi-Task Model
# -----------------------a
class MultiTaskModel(nn.Module):
    def __init__(self, encoder, weights, n_classes):
        super().__init__()
        self.unet = smp.Unet(encoder_name=encoder, encoder_weights=weights, in_channels=3, classes=1)
        
        ch = self.unet.encoder.out_channels[-1]
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(ch, n_classes)
        )

    def forward(self, x):
        features = self.unet.encoder(x)
        dec = self.unet.decoder(features)
        seg = self.unet.segmentation_head(dec)
        cls = self.classifier(features[-1])
        return seg, cls

model = MultiTaskModel(ENCODER, ENC_WEIGHTS, NUM_CLASSES).to(DEVICE)
print(f"Initialized MultiTaskModel with encoder={ENCODER} and {NUM_CLASSES} classes.")

In [25]:
# -----------------------
# Losses & Optimizer
# -----------------------
dice = smp.losses.DiceLoss(mode='binary')
bce  = nn.BCEWithLogitsLoss()
cross= nn.CrossEntropyLoss()
opt  = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=5, verbose=True)



In [15]:
best_loss = float('inf')
epochs_no_improve = 0

for epoch in range(EPOCHS):
    model.train()
    total_seg_loss = 0.0
    total_cls_loss = 0.0

    # Segmentation training
    for imgs, msks in seg_loader:
        imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
        seg_out, _ = model(imgs)
        loss_seg = bce(seg_out, msks.unsqueeze(1)) + dice(seg_out, msks.unsqueeze(1))

        opt.zero_grad()
        loss_seg.backward()
        opt.step()

        total_seg_loss += loss_seg.item()

    # Classification training
    for imgs, labels in cls_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        _, logits = model(imgs)
        loss_cls = cross(logits, labels)

        opt.zero_grad()
        loss_cls.backward()
        opt.step()

        total_cls_loss += loss_cls.item()

    avg_seg = total_seg_loss / len(seg_loader)
    avg_cls = total_cls_loss / len(cls_loader)

    # Scheduler step on training segmentation loss
    scheduler.step(avg_seg)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Seg: {avg_seg:.4f} | Train Cls: {avg_cls:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

    # Early stopping based on training seg loss
    if avg_seg < best_loss:
        best_loss = avg_seg
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'colonoscopy_unet_model.pth')
        print(f"New best train seg loss: {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

Epoch 1/25 | Train Seg: 1.5727 | Train Cls: 2.1639 | Val Seg: 1.5432 | IoU: 0.1643 | Dice: 0.2667 | Val Cls Acc: 0.1512 | Prec: 0.2320 | Rec: 0.1501 | F1: 0.1034 | LR: 1.00e-08
New best val seg loss: 1.5432
Epoch 2/25 | Train Seg: 1.5717 | Train Cls: 2.1461 | Val Seg: 1.5431 | IoU: 0.1645 | Dice: 0.2670 | Val Cls Acc: 0.1550 | Prec: 0.1360 | Rec: 0.1537 | F1: 0.1030 | LR: 1.00e-08
New best val seg loss: 1.5431
Epoch 3/25 | Train Seg: 1.5695 | Train Cls: 2.1353 | Val Seg: 1.5461 | IoU: 0.1650 | Dice: 0.2675 | Val Cls Acc: 0.1725 | Prec: 0.1928 | Rec: 0.1725 | F1: 0.1253 | LR: 1.00e-08
Epoch 4/25 | Train Seg: 1.5785 | Train Cls: 2.1188 | Val Seg: 1.5476 | IoU: 0.1671 | Dice: 0.2704 | Val Cls Acc: 0.1781 | Prec: 0.2330 | Rec: 0.1785 | F1: 0.1262 | LR: 1.00e-08
Epoch 5/25 | Train Seg: 1.5735 | Train Cls: 2.1098 | Val Seg: 1.5453 | IoU: 0.1652 | Dice: 0.2679 | Val Cls Acc: 0.1812 | Prec: 0.2046 | Rec: 0.1814 | F1: 0.1350 | LR: 1.00e-08
Early stopping at epoch 5


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [ ]:
model.eval()
val_seg_loss = 0.0
val_iou = 0.0
val_dice = 0.0
val_cls_correct = 0
val_cls_total = 0
all_labels = []
all_preds = []

with torch.no_grad():
    # Segmentation validation
    for imgs, msks in seg_val_loader:
        imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
        seg_out, _ = model(imgs)
        msk = msks.unsqueeze(1)
        val_seg_loss += (bce(seg_out, msk) + dice(seg_out, msk)).item()

        prob = torch.sigmoid(seg_out)
        pred_mask = (prob > 0.5).float()
        intersection = (pred_mask * msk).sum(dim=[1,2,3])
        union = ((pred_mask + msk) >= 1).sum(dim=[1,2,3])
        val_iou += (intersection / union).mean().item()
        val_dice += (2 * intersection / (pred_mask.sum([1,2,3]) + msk.sum([1,2,3]))).mean().item()

    # Classification validation
    for imgs, labels in cls_val_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        _, logits = model(imgs)
        preds = torch.argmax(logits, dim=1)
        val_cls_correct += (preds == labels).sum().item()
        val_cls_total += labels.size(0)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

# Compute final metrics
avg_val_seg = val_seg_loss / len(seg_val_loader)
avg_iou = val_iou / len(seg_val_loader)
avg_dice = val_dice / len(seg_val_loader)
val_acc = val_cls_correct / val_cls_total
val_precision = precision_score(all_labels, all_preds, average='macro')
val_recall = recall_score(all_labels, all_preds, average='macro')
val_f1 = f1_score(all_labels, all_preds, average='macro')

# Print final validation results
print("\n=== Final Validation Metrics ===")
print(f"Seg Loss: {avg_val_seg:.4f} | IoU: {avg_iou:.4f} | Dice: {avg_dice:.4f}")
print(f"Cls Acc: {val_acc:.4f} | Prec: {val_precision:.4f} | Rec: {val_recall:.4f} | F1: {val_f1:.4f}")


In [12]:
cvc_imgs = sorted(glob.glob(os.path.join(CVC_DIR,'Original','*.png')))

cvc_msks = sorted(glob.glob(os.path.join(CVC_DIR,'Ground Truth','*.png')))
if not cvc_imgs:
    print("No CVC-ClinicDB images found in", CVC_DIR)
else:
    cvc_loader = DataLoader(SegDataset(cvc_imgs, cvc_msks, augment=val_aug), batch_size=BATCH_SIZE, shuffle=False)
    model.eval()
    cvc_iou = cvc_dice = 0.0
    with torch.no_grad():
        for imgs, msks in cvc_loader:
            imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
            so, _ = model(imgs); m = msks.unsqueeze(1)
            prob = torch.sigmoid(so) > 0.5
            inter = (prob * m).sum(dim=[1,2,3])
            uni = ((prob + m) >= 1).sum(dim=[1,2,3])
            cvc_iou += (inter / uni).mean().item()
            cvc_dice += (2 * inter / (prob.sum([1,2,3]) + m.sum([1,2,3]))).mean().item()
    n = len(cvc_loader)
    print(f"CVC-ClinicDB results. on {n} batches: IoU = {cvc_iou/n:.4f}, Dice = {cvc_dice/n:.4f}")

CVC-ClinicDB results on 77 batches: IoU = 0.3498, Dice = 0.4573
