In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import numpy as np
import os
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")

In [None]:
class Config:
    """Configuration parameters for training"""
    # Dataset paths (adjust for your environment)
    BASE_DIR = "/kaggle/input/cvcclinicdb/PNG"  # Change this to your path
    IMAGE_DIR = os.path.join(BASE_DIR, "Original")
    MASK_DIR = os.path.join(BASE_DIR, "Ground Truth")
    
    # Training hyperparameters
    IMG_SIZE = 512
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 50
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Loss function weights
    DICE_BCE_ALPHA = 0.5  # Weight for Dice vs BCE loss
    
    # Random seed for reproducibility
    SEED = 42

config = Config()
torch.manual_seed(config.SEED)
print(f"Using device: {config.DEVICE}")

In [None]:
class CVCClinicDBDataset(Dataset):
    """
    Custom Dataset for CVC-ClinicDB polyp segmentation.
    
    Args:
        image_paths (list): List of paths to input images
        mask_paths (list): List of paths to segmentation masks
        transform: Transformations to apply to images
        mask_transform: Transformations to apply to masks
    """
    
    def __init__(self, image_paths, mask_paths, transform=None, mask_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.mask_transform = mask_transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        # Binarize mask: convert from [0, 255] to [0, 1]
        mask = (mask > 0).float()
        
        return image, mask

In [None]:
class CVCClinicDBDataset(Dataset):
    """
    Custom Dataset for CVC-ClinicDB polyp segmentation.
    
    Args:
        image_paths (list): List of paths to input images
        mask_paths (list): List of paths to segmentation masks
        transform: Transformations to apply to images
        mask_transform: Transformations to apply to masks
    """
    
    def __init__(self, image_paths, mask_paths, transform=None, mask_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.mask_transform = mask_transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        # Binarize mask: convert from [0, 255] to [0, 1]
        mask = (mask > 0).float()
        
        return image, mask

In [None]:
def calculate_iou(pred, target, threshold=0.5):
    """
    Calculate Intersection over Union (IoU/Jaccard Index).
    
    Args:
        pred: Model predictions (logits)
        target: Ground truth masks
        threshold: Threshold for binary prediction
    
    Returns:
        IoU score
    """
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + 1e-5) / (union + 1e-5)
    return iou.item()


def calculate_dice(pred, target, threshold=0.5):
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    dice = (2. * intersection + 1e-5) / (pred.sum() + target.sum() + 1e-5)
    return dice.item()


In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """
    Train the model for one epoch.
    
    Returns:
        Tuple of (average_loss, average_iou, average_dice)
    """
    model.train()
    running_loss = 0.0
    running_iou = 0.0
    running_dice = 0.0
    
    progress_bar = tqdm(dataloader, desc="Training")
    for images, masks in progress_bar:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        batch_iou = calculate_iou(outputs, masks)
        batch_dice = calculate_dice(outputs, masks)
        
        running_loss += loss.item()
        running_iou += batch_iou
        running_dice += batch_dice
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dice': f'{batch_dice:.4f}'
        })
    
    # Calculate epoch averages
    epoch_loss = running_loss / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    
    return epoch_loss, epoch_iou, epoch_dice

In [None]:
def validate(model, dataloader, criterion, device):
    """
    Validate the model.
    
    Returns:
        Tuple of (average_loss, average_iou, average_dice)
    """
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    running_dice = 0.0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validation")
        for images, masks in progress_bar:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            
            running_loss += loss.item()
            running_iou += calculate_iou(outputs, masks)
            running_dice += calculate_dice(outputs, masks)
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    
    return epoch_loss, epoch_iou, epoch_dice

In [None]:
def visualize_predictions(model, dataset, device, num_samples=5, save_path='predictions.png'):
    """Generate and save visualization of model predictions."""
    model.eval()
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    # Handle single sample case
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        image, mask = dataset[i]
        
        # Generate prediction
        with torch.no_grad():
            image_input = image.unsqueeze(0).to(device)
            output = model(image_input)['out']
            pred = torch.sigmoid(output).cpu().squeeze().numpy()
        
        # Denormalize image for visualization
        image_np = image.cpu().numpy().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image_np = std * image_np + mean
        image_np = np.clip(image_np, 0, 1)
        
        # Plot original image, ground truth, and prediction
        axes[i, 0].imshow(image_np)
        axes[i, 0].set_title('Original Image', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask.squeeze(), cmap='gray')
        axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred, cmap='gray')
        axes[i, 2].set_title('Prediction', fontsize=12, fontweight='bold')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Predictions saved to {save_path}")


def plot_training_curves(train_losses, val_losses, train_dices, val_dices, save_path='training_curves.png'):
    """Plot and save training curves."""
    epochs = range(1, len(train_losses) + 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss curve
    axes[0].plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, marker='o', markersize=4)
    axes[0].plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2, marker='s', markersize=4)
    axes[0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Loss', fontsize=12, fontweight='bold')
    axes[0].set_title('Loss Curve', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Dice score curve
    axes[1].plot(epochs, train_dices, 'b-', label='Train Dice', linewidth=2, marker='o', markersize=4)
    axes[1].plot(epochs, val_dices, 'r-', label='Val Dice', linewidth=2, marker='s', markersize=4)
    axes[1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('Dice Score', fontsize=12, fontweight='bold')
    axes[1].set_title('Dice Score Curve', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Training curves saved to {save_path}")

In [None]:
print("Loading dataset...")
print(f"Image directory: {config.IMAGE_DIR}")
print(f"Mask directory: {config.MASK_DIR}")

# Get image and mask paths
image_files = sorted(glob.glob(os.path.join(config.IMAGE_DIR, "*.png")))
mask_files = sorted(glob.glob(os.path.join(config.MASK_DIR, "*.png")))

# Verify dataset
assert len(image_files) == len(mask_files), \
    f"Mismatch: {len(image_files)} images, {len(mask_files)} masks"
assert len(image_files) > 0, "No images found! Check your dataset path."

print(f"\nâœ“ Found {len(image_files)} images and {len(mask_files)} masks")

# Display sample filenames
print("\nSample files:")
for i in range(min(3, len(image_files))):
    print(f"  Image: {os.path.basename(image_files[i])}")
    print(f"  Mask:  {os.path.basename(mask_files[i])}")

In [None]:
# Split dataset: 80% train, 10% val, 10% test
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
    image_files, mask_files, test_size=0.2, random_state=config.SEED
)
val_imgs, test_imgs, val_masks, test_masks = train_test_split(
    temp_imgs, temp_masks, test_size=0.5, random_state=config.SEED
)

print(f"\nDataset split:")
print(f"  Training:   {len(train_imgs)} samples ({len(train_imgs)/len(image_files)*100:.1f}%)")
print(f"  Validation: {len(val_imgs)} samples ({len(val_imgs)/len(image_files)*100:.1f}%)")
print(f"  Test:       {len(test_imgs)} samples ({len(test_imgs)/len(image_files)*100:.1f}%)")

In [None]:
# Training transformations (with data augmentation)
train_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transformations (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Mask transformations (same for train/val/test)
mask_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.ToTensor()
])

print("âœ“ Data transformations configured")

In [None]:
# Create datasets
train_dataset = CVCClinicDBDataset(train_imgs, train_masks, train_transform, mask_transform)
val_dataset = CVCClinicDBDataset(val_imgs, val_masks, val_transform, mask_transform)
test_dataset = CVCClinicDBDataset(test_imgs, test_masks, val_transform, mask_transform)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=True, 
    num_workers=2, 
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=2, 
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=2, 
    pin_memory=True
)

print(f"âœ“ DataLoaders created")
print(f"  Training batches:   {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches:       {len(test_loader)}")


In [None]:
print("Initializing DeepLabV3 model with ResNet50 backbone...")

# Load pretrained DeepLabV3
model = deeplabv3_resnet50(pretrained=True)

# Modify classifier for binary segmentation (1 output channel)
model.classifier[4] = nn.Conv2d(256, 1, kernel_size=1)
model.aux_classifier[4] = nn.Conv2d(256, 1, kernel_size=1)

# Move model to device
model = model.to(config.DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"âœ“ Model loaded successfully")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

# Loss function
criterion = DiceBCELoss(alpha=config.DICE_BCE_ALPHA)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    patience=5, 
    factor=0.5, 
    verbose=True
)

print("âœ“ Loss function, optimizer, and scheduler configured")

In [None]:
best_val_dice = 0.0
train_losses, val_losses = [], []
train_dices, val_dices = [], []

print(f"\n{'='*70}")
print(f"Starting training for {config.NUM_EPOCHS} epochs")
print(f"{'='*70}\n")

for epoch in range(config.NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}")
    print("-" * 70)
    
    # Train
    train_loss, train_iou, train_dice = train_one_epoch(
        model, train_loader, criterion, optimizer, config.DEVICE
    )
    
    # Validate
    val_loss, val_iou, val_dice = validate(
        model, val_loader, criterion, config.DEVICE
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_dices.append(train_dice)
    val_dices.append(val_dice)
    
    # Print results
    print(f"\nResults:")
    print(f"  Train â†’ Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}")
    print(f"  Val   â†’ Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}")
    print(f"  Learning Rate: {current_lr:.6f}")
    
    # Save best model
    if val_dice > best_val_dice:
        best_val_dice = val_dice
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_dice': val_dice,
        }, 'best_deeplabv3_model.pth')
        print(f"  âœ“ New best model saved! Val Dice: {best_val_dice:.4f}")

In [None]:
print("\nEvaluating best model on test set...")
print("="*70)

# Load best model
checkpoint = torch.load('best_deeplabv3_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Test
test_loss, test_iou, test_dice = validate(model, test_loader, criterion, config.DEVICE)

print(f"\nðŸ“Š Final Test Results:")
print(f"  Test Loss: {test_loss:.4f}")
print(f"  Test IoU:  {test_iou:.4f}")
print(f"  Test Dice: {test_dice:.4f}")
print(f"\n  Best Val Dice: {best_val_dice:.4f}")

In [None]:
print("\nGenerating prediction visualizations...")
visualize_predictions(model, test_dataset, config.DEVICE, num_samples=5)

In [None]:
print("\nPlotting training curves...")
plot_training_curves(train_losses, val_losses, train_dices, val_dices)

In [None]:
print("TRAINING SUMMARY")
print(f"Dataset: CVC-ClinicDB")
print(f"Model: DeepLabV3 with ResNet50 backbone")
print(f"Total Epochs: {config.NUM_EPOCHS}")
print(f"\nBest Validation Dice: {best_val_dice:.4f}")
print(f"Final Test Dice:      {test_dice:.4f}")
print(f"Final Test IoU:       {test_iou:.4f}")
print("="*70)
print("\nâœ“ All done! Model saved as 'best_deeplabv3_model.pth'")