# EMCAD Segmentation Model Training

This notebook demonstrates how to train an EMCAD (Efficient Multi-scale Context Aggregation Decoder) model for image segmentation with a custom dataset structured as:

- train/images - Contains training images
- train/masks - Contains corresponding segmentation masks
- test - Contains test images for submission (no masks)

## 1. Import Required Libraries

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from glob import glob

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid

from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

from lib.networks import EMCADNet

## 2. Set Configuration

In [None]:
# Set random seeds for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything()

# Configuration
CONFIG = {
    'data_root': './',  # Update this to your data directory
    'train_image_path': 'train/images',
    'train_mask_path': 'train/masks',
    'test_image_path': 'test',
    'img_size': 512,
    'batch_size': 8,
    'num_workers': 4,
    'lr': 1e-4,
    'epochs': 50,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'num_classes': 1,  # Binary segmentation (foreground vs background)
    'encoder': 'resnet34',  # Use resnet34 as backbone
    'pretrain': True,
    'supervision': 'deep_supervision',  # Options: 'deep_supervision', 'mutation'
}

print(f"Using device: {CONFIG['device']}")

## 3. Create Dataset and DataLoader

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, img_paths, mask_paths=None, transform=None, is_test=False):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.is_test = is_test
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        
        # Load image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if not self.is_test:
            # Load mask
            mask_path = self.mask_paths[idx]
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            # Apply transformations
            if self.transform:
                augmented = self.transform(image=img, mask=mask)
                img = augmented['image']
                mask = augmented['mask']
            
            # Convert mask to tensor and ensure correct shape
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask).float()
            mask = mask.unsqueeze(0)  # Add channel dimension
            
            return {
                'image': img,
                'mask': mask,
                'filename': os.path.basename(img_path)
            }
        else:
            # Test set (no masks)
            if self.transform:
                augmented = self.transform(image=img)
                img = augmented['image']
            
            return {
                'image': img,
                'filename': os.path.basename(img_path)
            }

# Define transformations
train_transform = A.Compose([
    A.Resize(CONFIG['img_size'], CONFIG['img_size']),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(CONFIG['img_size'], CONFIG['img_size']),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

test_transform = A.Compose([
    A.Resize(CONFIG['img_size'], CONFIG['img_size']),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Set up datasets and dataloaders
def get_data_loaders():
    # Get file paths
    train_img_paths = sorted(glob(os.path.join(CONFIG['data_root'], CONFIG['train_image_path'], '*')))
    train_mask_paths = sorted(glob(os.path.join(CONFIG['data_root'], CONFIG['train_mask_path'], '*')))
    test_img_paths = sorted(glob(os.path.join(CONFIG['data_root'], CONFIG['test_image_path'], '*')))
    
    # Verify matching number of images and masks
    assert len(train_img_paths) == len(train_mask_paths), "Number of training images and masks don't match!"
    print(f"Found {len(train_img_paths)} training images and {len(test_img_paths)} test images")
    
    # Split train/val
    val_split = 0.1
    indices = np.arange(len(train_img_paths))
    np.random.shuffle(indices)
    val_size = int(len(train_img_paths) * val_split)
    train_indices = indices[val_size:]
    val_indices = indices[:val_size]
    
    # Create datasets
    train_dataset = SegmentationDataset(
        [train_img_paths[i] for i in train_indices],
        [train_mask_paths[i] for i in train_indices],
        transform=train_transform
    )
    
    val_dataset = SegmentationDataset(
        [train_img_paths[i] for i in val_indices],
        [train_mask_paths[i] for i in val_indices],
        transform=val_transform
    )
    
    test_dataset = SegmentationDataset(
        test_img_paths,
        transform=test_transform,
        is_test=True
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=CONFIG['num_workers']
    )
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = get_data_loaders()

# Visualize some samples
def visualize_batch(dataloader):
    batch = next(iter(dataloader))
    images = batch['image']
    if 'mask' in batch:
        masks = batch['mask']
        fig, axes = plt.subplots(4, 2, figsize=(12, 15))
        for i in range(4):
            if i < len(images):
                # Display image
                img = images[i].permute(1, 2, 0).numpy()
                img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
                img = np.clip(img, 0, 1)
                axes[i, 0].imshow(img)
                axes[i, 0].set_title(f"Image: {batch['filename'][i]}")
                
                # Display mask
                axes[i, 1].imshow(masks[i][0].numpy(), cmap='gray')
                axes[i, 1].set_title("Mask")
        plt.tight_layout()
        plt.show()
    else:
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        axes = axes.flatten()
        for i in range(4):
            if i < len(images):
                img = images[i].permute(1, 2, 0).numpy()
                img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
                img = np.clip(img, 0, 1)
                axes[i].imshow(img)
                axes[i].set_title(f"Test Image: {batch['filename'][i]}")
        plt.tight_layout()
        plt.show()

print("Visualizing training samples:")
visualize_batch(train_loader)

## 4. Define Loss Functions and Metrics

In [None]:
# Define DiceLoss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice

# Define IoU (Jaccard) metric
def iou_score(pred, target):
    pred = torch.sigmoid(pred) > 0.5
    pred = pred.view(-1).float()
    target = target.view(-1).float()
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    return (intersection + 1e-7) / (union + 1e-7)

# Define BCE-Dice combined loss
class BCEDiceLoss(nn.Module):
    def __init__(self, weight_bce=0.5, weight_dice=0.5):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        self.weight_bce = weight_bce
        self.weight_dice = weight_dice
        
    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        dice_loss = self.dice(pred, target)
        loss = self.weight_bce * bce_loss + self.weight_dice * dice_loss
        return loss

# Function to gather all supervision outputs from powerset
def powerset(iterable):
    """Returns the powerset of an iterable except the empty set."""
    from itertools import chain, combinations
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1)))

## 5. Initialize Model

In [None]:
# Initialize the EMCAD model
model = EMCADNet(
    num_classes=CONFIG['num_classes'], 
    encoder=CONFIG['encoder'], 
    pretrain=CONFIG['pretrain'],
    kernel_sizes=[1, 3, 5],
    expansion_factor=2,
    dw_parallel=True,
    add=True
).to(CONFIG['device'])

# Define optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Define loss function
criterion = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)

## 6. Training Function

In [None]:
def train_model():
    best_iou = 0.0
    history = {'train_loss': [], 'val_loss': [], 'val_iou': []}

    for epoch in range(CONFIG['epochs']):
        print(f"Epoch {epoch+1}/{CONFIG['epochs']}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        
        train_pbar = tqdm(train_loader, total=len(train_loader), desc=f"Training")
        for batch in train_pbar:
            images = batch['image'].to(CONFIG['device'])
            masks = batch['mask'].to(CONFIG['device'])
            
            # Forward pass
            outputs = model(images, mode='train')
            
            # Loss calculation based on supervision strategy
            if not isinstance(outputs, list):
                outputs = [outputs]
            
            if epoch == 0 and batch == next(iter(train_loader)):
                n_outs = len(outputs)
                out_idxs = list(np.arange(n_outs))
                if CONFIG['supervision'] == 'mutation':
                    ss = [x for x in powerset(out_idxs)]
                elif CONFIG['supervision'] == 'deep_supervision':
                    ss = [[x] for x in out_idxs]
                else:  # Normal supervision (use only final output)
                    ss = [[-1]]
                print(f"Using supervision strategy: {CONFIG['supervision']}")
                print(f"Output indices: {ss}")
            
            loss = 0.0
            for s in ss:
                if s == []:
                    continue
                    
                iout = 0.0
                for idx in range(len(s)):
                    iout += outputs[s[idx]]
                
                loss += criterion(iout, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': f"{loss.item():.4f}"})
            
        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_iou = 0.0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, total=len(val_loader), desc=f"Validation")
            for batch in val_pbar:
                images = batch['image'].to(CONFIG['device'])
                masks = batch['mask'].to(CONFIG['device'])
                
                # Forward pass
                outputs = model(images, mode='test')
                
                # Use final output for validation metrics
                final_output = outputs[0]
                
                # Calculate loss
                loss = criterion(final_output, masks)
                val_loss += loss.item()
                
                # Calculate IoU
                batch_iou = iou_score(final_output, masks)
                val_iou += batch_iou.item()
                
                val_pbar.set_postfix({'loss': f"{loss.item():.4f}", 'iou': f"{batch_iou.item():.4f}"})
                
        avg_val_loss = val_loss / len(val_loader)
        avg_val_iou = val_iou / len(val_loader)
        history['val_loss'].append(avg_val_loss)
        history['val_iou'].append(avg_val_iou)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val IoU: {avg_val_iou:.4f}, LR: {current_lr:.6f}")
        
        # Save best model
        if avg_val_iou > best_iou:
            best_iou = avg_val_iou
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_iou': best_iou,
            }, 'best_model_emcad.pth')
            print(f"Best model saved with IoU: {best_iou:.4f}")
            
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_iou': avg_val_iou,
            }, f'checkpoint_emcad_epoch{epoch+1}.pth')
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.legend()
    plt.title('Loss History')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['val_iou'], label='Val IoU')
    plt.legend()
    plt.title('IoU History')
    
    plt.tight_layout()
    plt.show()
    
    return history

# Train the model
history = train_model()

## 7. Inference and Kaggle Submission

In [None]:
# Load best model for inference
def load_best_model():
    checkpoint = torch.load('best_model_emcad.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model with validation IoU: {checkpoint['val_iou']:.4f}")
    return model

model = load_best_model()
model.eval()

# Function to create RLE encoding for Kaggle submission
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# Make predictions on test set
def predict_test_set():
    submission = []
    
    with torch.no_grad():
        test_pbar = tqdm(test_loader, total=len(test_loader), desc="Generating predictions")
        for batch in test_pbar:
            images = batch['image'].to(CONFIG['device'])
            filenames = batch['filename']
            
            # Get model predictions
            outputs = model(images, mode='test')
            final_output = outputs[0]
            
            # Apply sigmoid and threshold
            preds = torch.sigmoid(final_output) > 0.5
            
            # Process each image in the batch
            for i, filename in enumerate(filenames):
                # Convert prediction to numpy and resize to original size if needed
                pred = preds[i].squeeze().cpu().numpy().astype(np.uint8)
                
                # Create RLE encoding
                rle = rle_encode(pred)
                
                # Add to submission list
                submission.append([filename, rle])
    
    # Create submission CSV
    import pandas as pd
    submission_df = pd.DataFrame(submission, columns=['id', 'rle_mask'])
    submission_df.to_csv('submission.csv', index=False)
    print("Submission file created: submission.csv")
    
    return submission_df

# Generate predictions and create submission
submission_df = predict_test_set()

# Display a few test predictions
def visualize_test_predictions():
    model.eval()
    with torch.no_grad():
        batch = next(iter(test_loader))
        images = batch['image'].to(CONFIG['device'])
        filenames = batch['filename']
        
        outputs = model(images, mode='test')
        final_output = outputs[0]
        preds = torch.sigmoid(final_output) > 0.5
        
        fig, axes = plt.subplots(len(images), 2, figsize=(12, 4*len(images)))
        if len(images) == 1:
            axes = axes.reshape(1, -1)
            
        for i in range(len(images)):
            # Display original image
            img = images[i].cpu().permute(1, 2, 0).numpy()
            img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
            img = np.clip(img, 0, 1)
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"Image: {filenames[i]}")
            
            # Display prediction
            axes[i, 1].imshow(preds[i].squeeze().cpu().numpy(), cmap='gray')
            axes[i, 1].set_title("Prediction")
            
        plt.tight_layout()
        plt.show()

visualize_test_predictions()

## 8. Conclusion

We've successfully:
1. Loaded and preprocessed the custom segmentation dataset
2. Trained the EMCAD model with different supervision strategies
3. Evaluated the model using IoU metric
4. Generated predictions for the test set and created a submission file for Kaggle

The best model has been saved as `best_model_emcad.pth` and can be used for future inference.