In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50

# Albumentations for advanced augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Utilities
from sklearn.model_selection import train_test_split

# Display versions
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

In [None]:
class Config:
    
    DATA_PATH = '/kaggle/input/kvasirseg/Kvasir-SEG'
    IMAGES_PATH = os.path.join(DATA_PATH, 'images')
    MASKS_PATH = os.path.join(DATA_PATH, 'masks')
    OUTPUT_PATH = '/kaggle/working'
    
    
    IMG_SIZE = 352              # Input image size 
    BATCH_SIZE = 20             # Batch size for training
    EPOCHS = 100                # Number of training epochs

    BASE_LR = 0.0001           # Initial learning rate
    LR_POWER = 0.9             # Polynomial decay power
        
    ATROUS_RATES = [1, 6, 12, 18]  # Atrous convolution rates
    ASPP_FILTERS = 256              # Number of filters in ASPP
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2
    PIN_MEMORY = True
    
    TRAIN_RATIO = 0.7
    VAL_RATIO = 0.15
    TEST_RATIO = 0.15
    
    RANDOM_SEED = 42

# Create output directory if it doesn't exist
os.makedirs(Config.OUTPUT_PATH, exist_ok=True)

print("âœ“ Configuration loaded successfully")
print(f"  Device: {Config.DEVICE}")
print(f"  Image Size: {Config.IMG_SIZE}x{Config.IMG_SIZE}")
print(f"  Batch Size: {Config.BATCH_SIZE}")
print(f"  Epochs: {Config.EPOCHS}")


In [None]:
class PolypDataset(Dataset):
    """
    Custom Dataset for Polyp Segmentation.
    
    Loads images and corresponding binary masks for polyp detection.
    Applies augmentation transforms during training.
    
    Args:
        image_paths (list): List of paths to input images
        mask_paths (list): List of paths to segmentation masks
        transform: Albumentations transform pipeline
    """
    
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        
        # Verify all files exist
        for img_path, mask_path in zip(image_paths[:5], mask_paths[:5]):
            assert os.path.exists(img_path), f"Image not found: {img_path}"
            assert os.path.exists(mask_path), f"Mask not found: {mask_path}"
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and convert to RGB
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask and binarize
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.uint8)  # Binary threshold
        
        # Apply augmentation pipeline
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # Add channel dimension to mask
        mask = mask.unsqueeze(0).float()
        
        return image, mask

print("âœ“ PolypDataset class defined")

In [None]:
def get_train_transforms():
    """
    Training augmentation pipeline with aggressive augmentation.
    Includes geometric and color transformations to improve generalization.
    """
    return A.Compose([
        A.Resize(Config.IMG_SIZE, Config.IMG_SIZE),
        
        # Geometric augmentations
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.Rotate(limit=15, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
        
        # Color augmentations
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.3),
        
        # Noise
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
        
        # Normalization (ImageNet statistics)
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])


def get_val_transforms():
    """
    Validation/Test augmentation pipeline.
    Only resizing and normalization - no data augmentation.
    """
    return A.Compose([
        A.Resize(Config.IMG_SIZE, Config.IMG_SIZE),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

print("âœ“ Augmentation pipelines defined")
print("  Train: Geometric + Color augmentations + Noise")
print("  Val/Test: Resize + Normalization only")

In [None]:
class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling module.
    
    Captures multi-scale contextual information using parallel atrous convolutions
    with different dilation rates, plus global pooling.
    
    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        atrous_rates (list): List of dilation rates for atrous convolutions
    """
    
    def __init__(self, in_channels, out_channels, atrous_rates):
        super(ASPP, self).__init__()
        
        # 1x1 convolution branch
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Atrous convolution branches (3x3 with different dilation rates)
        self.atrous_convs = nn.ModuleList()
        for rate in atrous_rates:
            self.atrous_convs.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, 
                         padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        
        # Global average pooling branch
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Projection layer to combine all branches
        total_channels = out_channels * (len(atrous_rates) + 2)
        self.project = nn.Sequential(
            nn.Conv2d(total_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
    
    def forward(self, x):
        size = x.shape[2:]
        
        # Apply all parallel branches
        features = [self.conv1(x)]
        
        for atrous_conv in self.atrous_convs:
            features.append(atrous_conv(x))
        
        # Global pooling branch (upsampled to match spatial dimensions)
        global_feat = self.global_pool(x)
        global_feat = F.interpolate(global_feat, size=size, 
                                    mode='bilinear', align_corners=False)
        features.append(global_feat)
        
        # Concatenate all features and project
        features = torch.cat(features, dim=1)
        output = self.project(features)
        
        return output

print("âœ“ ASPP module defined")
print(f"  Atrous rates: {Config.ATROUS_RATES}")
print(f"  Output channels: {Config.ASPP_FILTERS}")

In [None]:
class DeepLabV3Plus(nn.Module):
    """
    DeepLabV3+ architecture for semantic segmentation.
    
    Architecture components:
    - Encoder: ResNet50 backbone with atrous convolutions
    - ASPP: Multi-scale feature extraction
    - Decoder: Combines high-level and low-level features
    - Classifier: Final segmentation head
    
    Args:
        num_classes (int): Number of output classes (1 for binary segmentation)
        backbone (str): Backbone network (default: 'resnet50')
    """
    
    def __init__(self, num_classes=1, backbone='resnet50'):
        super(DeepLabV3Plus, self).__init__()
        
        # ========== Backbone (ResNet50) ==========
        resnet = resnet50(pretrained=True)
        
        # Encoder stages
        self.layer0 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )
        self.layer1 = resnet.layer1  # Output: 256 channels (low-level features)
        self.layer2 = resnet.layer2  # Output: 512 channels
        self.layer3 = resnet.layer3  # Output: 1024 channels
        self.layer4 = resnet.layer4  # Output: 2048 channels (high-level features)
        
        # ========== ASPP Module ==========
        self.aspp = ASPP(2048, Config.ASPP_FILTERS, Config.ATROUS_RATES)
        
        # ========== Low-level Feature Projection ==========
        # Reduce channels from 256 to 48 for efficient concatenation
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(256, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # ========== Decoder ==========
        # Combines ASPP output (256 channels) and low-level features (48 channels)
        self.decoder = nn.Sequential(
            nn.Conv2d(Config.ASPP_FILTERS + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        
        # ========== Segmentation Head ==========
        self.classifier = nn.Conv2d(256, num_classes, 1)
    
    def forward(self, x):
        input_size = x.shape[2:]
        
        # ========== Encoder ==========
        x = self.layer0(x)
        low_level_feat = self.layer1(x)  # Save for decoder skip connection
        x = self.layer2(low_level_feat)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # ========== ASPP ==========
        x = self.aspp(x)
        
        # Upsample to match low-level feature size
        x = F.interpolate(x, size=low_level_feat.shape[2:], 
                         mode='bilinear', align_corners=False)
        
        # ========== Process Low-level Features ==========
        low_level_feat = self.low_level_conv(low_level_feat)
        
        # ========== Decoder ==========
        # Concatenate high-level and low-level features
        x = torch.cat([x, low_level_feat], dim=1)
        x = self.decoder(x)
        
        # ========== Classification ==========
        x = self.classifier(x)
        
        # Upsample to original input size
        x = F.interpolate(x, size=input_size, 
                         mode='bilinear', align_corners=False)
        
        return x

print("âœ“ DeepLabV3+ model defined")
print("  Backbone: ResNet50 (pretrained)")
print("  ASPP with atrous rates:", Config.ATROUS_RATES)
print("  Decoder: Skip connection with low-level features")

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss for segmentation tasks.
    
    Measures overlap between prediction and ground truth.
    Effective for handling class imbalance in medical imaging.
    
    Args:
        smooth (float): Smoothing factor to avoid division by zero
    """
    
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice_coeff = (2. * intersection + self.smooth) / \
                     (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice_coeff


class CombinedLoss(nn.Module):
    """
    Combined BCE and Dice Loss.
    
    Combines Binary Cross-Entropy (for pixel-wise accuracy) and
    Dice Loss (for shape and boundary accuracy).
    
    Args:
        alpha (float): Weight for BCE loss (1-alpha for Dice loss)
    """
    
    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        dice_loss = self.dice(pred, target)
        return self.alpha * bce_loss + (1 - self.alpha) * dice_loss

print("âœ“ Loss functions defined")
print("  Combined Loss = Î± * BCE + (1-Î±) * Dice")
print(f"  Î± = 0.5 (equal weighting)")

In [None]:
def calculate_iou(pred, target, threshold=0.5):
    """
    Calculate Intersection over Union (IoU / Jaccard Index).
    
    Args:
        pred: Model predictions (logits)
        target: Ground truth masks
        threshold: Threshold for binarization
    
    Returns:
        IoU score (float)
    """
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + 1e-7) / (union + 1e-7)
    return iou.item()


def calculate_dice(pred, target, threshold=0.5):
    """
    Calculate Dice Coefficient (F1 Score for segmentation).
    
    Args:
        pred: Model predictions (logits)
        target: Ground truth masks
        threshold: Threshold for binarization
    
    Returns:
        Dice score (float)
    """
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    dice = (2. * intersection + 1e-7) / (pred.sum() + target.sum() + 1e-7)
    return dice.item()

print("âœ“ Evaluation metrics defined")
print("  - IoU (Intersection over Union)")
print("  - Dice Coefficient")

In [None]:
class PolynomialLR:
    """
    Polynomial learning rate decay scheduler.
    
    Implements: lr = base_lr * (1 - iter/max_iter)^power
    
    As described in the DeepLabV3+ paper for stable training.
    
    Args:
        optimizer: PyTorch optimizer
        max_iterations: Total number of training iterations
        power: Polynomial power (default: 0.9)
    """
    
    def __init__(self, optimizer, max_iterations, power=0.9):
        self.optimizer = optimizer
        self.max_iterations = max_iterations
        self.power = power
        self.base_lr = Config.BASE_LR
        self.current_iter = 0
    
    def step(self):
        """Update learning rate and increment iteration counter."""
        lr = self.base_lr * (1 - self.current_iter / self.max_iterations) ** self.power
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        self.current_iter += 1
        return lr
    
    def get_lr(self):
        """Get current learning rate."""
        return self.optimizer.param_groups[0]['lr']

print("âœ“ Polynomial LR scheduler defined")
print(f"  Base LR: {Config.BASE_LR}")
print(f"  Power: {Config.LR_POWER}")

In [None]:
def train_epoch(model, loader, criterion, optimizer, scheduler, device):
    """
    Train the model for one epoch.
    
    Args:
        model: PyTorch model
        loader: Training DataLoader
        criterion: Loss function
        optimizer: PyTorch optimizer
        scheduler: Learning rate scheduler
        device: Device to run on (CPU/GPU)
    
    Returns:
        Tuple of (avg_loss, avg_iou, avg_dice)
    """
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    pbar = tqdm(loader, desc='Training', leave=False)
    
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update learning rate
        current_lr = scheduler.step()
        
        # Calculate metrics
        batch_iou = calculate_iou(outputs, masks)
        batch_dice = calculate_dice(outputs, masks)
        
        # Accumulate metrics
        total_loss += loss.item()
        total_iou += batch_iou
        total_dice += batch_dice
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'iou': f'{batch_iou:.4f}',
            'dice': f'{batch_dice:.4f}',
            'lr': f'{current_lr:.6f}'
        })
    
    # Calculate epoch averages
    avg_loss = total_loss / len(loader)
    avg_iou = total_iou / len(loader)
    avg_dice = total_dice / len(loader)
    
    return avg_loss, avg_iou, avg_dice

print("âœ“ Training function defined")

In [None]:
def validate(model, loader, criterion, device):
    """
    Validate the model.
    
    Args:
        model: PyTorch model
        loader: Validation DataLoader
        criterion: Loss function
        device: Device to run on (CPU/GPU)
    
    Returns:
        Tuple of (avg_loss, avg_iou, avg_dice)
    """
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc='Validation', leave=False)
        
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            batch_iou = calculate_iou(outputs, masks)
            batch_dice = calculate_dice(outputs, masks)
            
            # Accumulate metrics
            total_loss += loss.item()
            total_iou += batch_iou
            total_dice += batch_dice
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'iou': f'{batch_iou:.4f}',
                'dice': f'{batch_dice:.4f}'
            })
    
    # Calculate averages
    avg_loss = total_loss / len(loader)
    avg_iou = total_iou / len(loader)
    avg_dice = total_dice / len(loader)
    
    return avg_loss, avg_iou, avg_dice

print("âœ“ Validation function defined")

In [None]:
def plot_training_history(history, save_path):
    """
    Plot training history with loss, IoU, and Dice metrics.
    
    Args:
        history (dict): Dictionary containing training metrics
        save_path (str): Path to save the plot
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Loss plot
    axes[0, 0].plot(history['train_loss'], label='Train Loss', 
                   linewidth=2, marker='o', markersize=3)
    axes[0, 0].plot(history['val_loss'], label='Val Loss', 
                   linewidth=2, marker='s', markersize=3)
    axes[0, 0].set_title('Loss over Epochs', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch', fontsize=12)
    axes[0, 0].set_ylabel('Loss', fontsize=12)
    axes[0, 0].legend(fontsize=11)
    axes[0, 0].grid(True, alpha=0.3)
    
    # IoU plot
    axes[0, 1].plot(history['train_iou'], label='Train IoU', 
                   linewidth=2, marker='o', markersize=3)
    axes[0, 1].plot(history['val_iou'], label='Val IoU', 
                   linewidth=2, marker='s', markersize=3)
    axes[0, 1].set_title('IoU over Epochs', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch', fontsize=12)
    axes[0, 1].set_ylabel('IoU', fontsize=12)
    axes[0, 1].legend(fontsize=11)
    axes[0, 1].grid(True, alpha=0.3)
    
    # Dice plot
    axes[1, 0].plot(history['train_dice'], label='Train Dice', 
                   linewidth=2, marker='o', markersize=3)
    axes[1, 0].plot(history['val_dice'], label='Val Dice', 
                   linewidth=2, marker='s', markersize=3)
    axes[1, 0].set_title('Dice Coefficient over Epochs', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch', fontsize=12)
    axes[1, 0].set_ylabel('Dice', fontsize=12)
    axes[1, 0].legend(fontsize=11)
    axes[1, 0].grid(True, alpha=0.3)
    
    # Summary table
    axes[1, 1].axis('off')
    best_epoch = np.argmax(history['val_iou'])

In [None]:
def visualize_predictions(model, loader, device, num_samples=5, save_path=None):
    """
    Visualize model predictions alongside ground truth.
    
    Args:
        model: Trained PyTorch model
        loader: DataLoader for test/validation set
        device: Device to run on
        num_samples: Number of samples to visualize
        save_path: Path to save the figure
    """
    model.eval()
    images_list, masks_list, preds_list = [], [], []
    
    # Collect predictions
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs) > 0.5
            
            images_list.append(images.cpu())
            masks_list.append(masks.cpu())
            preds_list.append(preds.cpu())
            
            if len(images_list) * images.size(0) >= num_samples:
                break
    
    # Concatenate and select samples
    images_list = torch.cat(images_list)[:num_samples]
    masks_list = torch.cat(masks_list)[:num_samples]
    preds_list = torch.cat(preds_list)[:num_samples]
    
    # Create visualization
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Denormalize image
        img = images_list[i].permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        # Plot original image
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Original Image', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        # Plot ground truth
        axes[i, 1].imshow(masks_list[i].squeeze(), cmap='gray')
        axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        
        # Plot prediction
        axes[i, 2].imshow(preds_list[i].squeeze(), cmap='gray')
        axes[i, 2].set_title('Prediction', fontsize=12, fontweight='bold')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"âœ“ Predictions saved to {save_path}")
    
    plt.show()

print("âœ“ Visualization functions defined")

In [None]:
print("="*70)
print("DATASET PREPARATION")
print("="*70)

# Set random seeds for reproducibility
torch.manual_seed(Config.RANDOM_SEED)
np.random.seed(Config.RANDOM_SEED)

# Load image and mask paths
print("\nLoading dataset...")
image_paths = sorted(glob(os.path.join(Config.IMAGES_PATH, '*.jpg')))
mask_paths = sorted(glob(os.path.join(Config.MASKS_PATH, '*.jpg')))

print(f"âœ“ Found {len(image_paths)} images")
print(f"âœ“ Found {len(mask_paths)} masks")

# Verify dataset integrity
assert len(image_paths) == len(mask_paths), \
    f"Mismatch: {len(image_paths)} images vs {len(mask_paths)} masks"
assert len(image_paths) > 0, "No images found! Check your dataset path."

# Display sample filenames
print("\nSample files:")
for i in range(min(3, len(image_paths))):
    print(f"  Image {i+1}: {os.path.basename(image_paths[i])}")
    print(f"  Mask  {i+1}: {os.path.basename(mask_paths[i])}")

print("\n" + "="*70)
print("DATASET SPLITTING")
print("="*70)

# First split: separate test set
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
    image_paths, mask_paths, 
    test_size=(1 - Config.TRAIN_RATIO), 
    random_state=Config.RANDOM_SEED
)

# Second split: separate validation and test sets
val_size = Config.TEST_RATIO / (Config.VAL_RATIO + Config.TEST_RATIO)
val_imgs, test_imgs, val_masks, test_masks = train_test_split(
    temp_imgs, temp_masks, 
    test_size=val_size,
    random_state=Config.RANDOM_SEED
)

print(f"\nDataset split:")
print(f"  Training:   {len(train_imgs):4d} samples ({len(train_imgs)/len(image_paths)*100:.1f}%)")
print(f"  Validation: {len(val_imgs):4d} samples ({len(val_imgs)/len(image_paths)*100:.1f}%)")
print(f"  Test:       {len(test_imgs):4d} samples ({len(test_imgs)/len(image_paths)*100:.1f}%)")
print(f"  Total:      {len(image_paths):4d} samples")

In [None]:
# Create datasets
train_dataset = PolypDataset(train_imgs, train_masks, get_train_transforms())
val_dataset = PolypDataset(val_imgs, val_masks, get_val_transforms())
test_dataset = PolypDataset(test_imgs, test_masks, get_val_transforms())

print(f"\nâœ“ Datasets created")
print(f"  Training dataset:   {len(train_dataset)} samples")
print(f"  Validation dataset: {len(val_dataset)} samples")
print(f"  Test dataset:       {len(test_dataset)} samples"

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=True,
    num_workers=Config.NUM_WORKERS, 
    pin_memory=Config.PIN_MEMORY,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=False,
    num_workers=Config.NUM_WORKERS, 
    pin_memory=Config.PIN_MEMORY
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=False,
    num_workers=Config.NUM_WORKERS, 
    pin_memory=Config.PIN_MEMORY
)

print(f"\nâœ“ DataLoaders created")
print(f"  Training batches:   {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches:       {len(test_loader)}")


In [None]:
# Create model
print("\nBuilding DeepLabV3+ model...")
model = DeepLabV3Plus(num_classes=1).to(Config.DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"âœ“ Model created successfully")
print(f"  Device: {Config.DEVICE}")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

# Loss function
criterion = CombinedLoss(alpha=0.5)
print(f"\nâœ“ Loss function: Combined BCE + Dice (Î±=0.5)")

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=Config.BASE_LR)
print(f"âœ“ Optimizer: Adam (lr={Config.BASE_LR})")

# Learning rate scheduler
max_iterations = len(train_loader) * Config.EPOCHS
scheduler = PolynomialLR(optimizer, max_iterations, Config.LR_POWER)
print(f"âœ“ Scheduler: Polynomial LR (power={Config.LR_POWER})")
print(f"  Total iterations: {max_iterations:,}")

# Training history
history = {
    'train_loss': [], 'train_iou': [], 'train_dice': [],
    'val_loss': [], 'val_iou': [], 'val_dice': []
}

best_val_iou = 0.0
best_epoch = 0

print(f"\nStarting training for {Config.EPOCHS} epochs...")
print("="*70)

for epoch in range(Config.EPOCHS):
    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1}/{Config.EPOCHS}")
    print(f"{'='*70}")
    
    # Train
    train_loss, train_iou, train_dice = train_epoch(
        model, train_loader, criterion, optimizer, scheduler, Config.DEVICE
    )
    
    # Validate
    val_loss, val_iou, val_dice = validate(
        model, val_loader, criterion, Config.DEVICE
    )
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_iou'].append(train_iou)
    history['train_dice'].append(train_dice)
    history['val_loss'].append(val_loss)
    history['val_iou'].append(val_iou)
    history['val_dice'].append(val_dice)
    
    # Print epoch summary
    print(f"\nðŸ“Š Epoch {epoch+1} Summary:")
    print(f"   Train â†’ Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Dice: {train_dice:.4f}")
    print(f"   Val   â†’ Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Dice: {val_dice:.4f}")
    
    # Save best model
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        best_epoch = epoch + 1
        
        # Save model checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_iou': val_iou,
            'val_dice': val_dice,
            'history': history
        }
        torch.save(checkpoint, os.path.join(Config.OUTPUT_PATH, 'best_model.pth'))
        print(f"   âœ“ NEW BEST MODEL SAVED! (IoU: {best_val_iou:.4f})")
print(f"Best validation IoU: {best_val_iou:.4f} (Epoch {best_epoch})")

In [None]:
plot_training_history(
    history, 
    os.path.join(Config.OUTPUT_PATH, 'training_history.png')
)