# EMCAD Training for Binary Segmentation

This notebook demonstrates how to use EMCAD architecture for binary segmentation (0/255 masks).

## Setup and Imports

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import time
import cv2
from glob import glob

# Import necessary modules from EMCAD
from lib.networks import EMCADNet
from utils.utils import DiceLoss, AvgMeter

# For reproducibility
seed = 2222
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Dataset Class for Binary Segmentation

In [None]:
class BinarySegmentationDataset(Dataset):
    def __init__(self, image_root, mask_root=None, trainsize=224, is_train=True, augmentations=True):
        self.trainsize = trainsize
        self.is_train = is_train
        self.augmentations = augmentations
        
        # Get image paths
        self.images = sorted(glob(os.path.join(image_root, '*')))
        
        # Get mask paths if this is training data
        if mask_root is not None:
            self.masks = sorted(glob(os.path.join(mask_root, '*')))
            assert len(self.images) == len(self.masks), "Number of images and masks don't match"
        else:
            self.masks = None
        
        print(f"Found {len(self.images)} images")
        if self.masks:
            print(f"Found {len(self.masks)} masks")
        
        # Define transforms
        if self.augmentations and is_train:
            print('Using data augmentations')
            self.img_transform = transforms.Compose([
                transforms.RandomRotation(30),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.Resize((self.trainsize, self.trainsize)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])
            
            self.mask_transform = transforms.Compose([
                transforms.RandomRotation(30),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.Resize((self.trainsize, self.trainsize)),
                transforms.ToTensor()
            ])
        else:
            print('No data augmentation')
            self.img_transform = transforms.Compose([
                transforms.Resize((self.trainsize, self.trainsize)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])
            
            self.mask_transform = transforms.Compose([
                transforms.Resize((self.trainsize, self.trainsize)),
                transforms.ToTensor()
            ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        # Load image
        img_path = self.images[index]
        img = Image.open(img_path).convert('RGB')
        
        # Apply transforms with same seed for synchronized transforms
        if self.masks is not None:
            # Load mask
            mask_path = self.masks[index]
            mask = Image.open(mask_path).convert('L')  # Convert to grayscale
            
            # Apply the same transformation to both image and mask
            seed = np.random.randint(2147483647)
            
            random.seed(seed)
            torch.manual_seed(seed)
            img = self.img_transform(img)
            
            random.seed(seed)
            torch.manual_seed(seed)
            mask = self.mask_transform(mask)
            
            # Convert mask to binary - threshold at 0.5 (128/255)
            mask = (mask > 0.5).float()
            
            return img, mask, os.path.basename(img_path)
        else:
            # For test set without masks
            img = self.img_transform(img)
            return img, os.path.basename(img_path)

## Configuration

In [None]:
# Define parameters
img_size = 224  # EMCAD default is 224
batch_size = 8
num_classes = 2  # Binary segmentation: background and foreground
data_dir = '/path/to/your/data'  # Replace with your data path
train_img_dir = os.path.join(data_dir, 'train/images')
train_mask_dir = os.path.join(data_dir, 'train/masks')
test_img_dir = os.path.join(data_dir, 'test')
result_dir = './results'
model_dir = './model_checkpoints'

os.makedirs(result_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Training parameters
num_epochs = 100
learning_rate = 1e-4

# Model parameters
encoder = 'pvt_v2_b2'  # Default EMCAD encoder
kernel_sizes = [1, 3, 5]  # Default EMCAD kernel sizes
expansion_factor = 2
lgag_ks = 3
activation_mscb = 'relu6'
supervision = 'mutation'
pretrained_dir = './pretrained_pth/pvt/'

# Split data into train and validation
all_images = sorted(glob(os.path.join(train_img_dir, '*')))
all_masks = sorted(glob(os.path.join(train_mask_dir, '*')))

# Use 85% data for training, 15% for validation
train_size = int(0.85 * len(all_images))
val_size = len(all_images) - train_size

train_img_list = all_images[:train_size]
val_img_list = all_images[train_size:]
train_mask_list = all_masks[:train_size]
val_mask_list = all_masks[train_size:]

print(f"Training samples: {len(train_img_list)}")
print(f"Validation samples: {len(val_img_list)}")

# Create temporary directories for train/val split
os.makedirs('./tmp/train/images', exist_ok=True)
os.makedirs('./tmp/train/masks', exist_ok=True)
os.makedirs('./tmp/val/images', exist_ok=True)
os.makedirs('./tmp/val/masks', exist_ok=True)

# Copy files to temporary directories (or create symlinks)
import shutil
# Uncomment below code if you want to physically copy files

'''
for img_path in train_img_list:
    shutil.copy(img_path, os.path.join('./tmp/train/images', os.path.basename(img_path)))

for mask_path in train_mask_list:
    shutil.copy(mask_path, os.path.join('./tmp/train/masks', os.path.basename(mask_path)))
    
for img_path in val_img_list:
    shutil.copy(img_path, os.path.join('./tmp/val/images', os.path.basename(img_path)))
    
for mask_path in val_mask_list:
    shutil.copy(mask_path, os.path.join('./tmp/val/masks', os.path.basename(mask_path)))
'''

## Create Data Loaders

In [None]:
# Create datasets
train_dataset = BinarySegmentationDataset(
    image_root='./tmp/train/images',
    mask_root='./tmp/train/masks',
    trainsize=img_size,
    is_train=True,
    augmentations=True
)

val_dataset = BinarySegmentationDataset(
    image_root='./tmp/val/images',
    mask_root='./tmp/val/masks',
    trainsize=img_size,
    is_train=False,
    augmentations=False
)

test_dataset = BinarySegmentationDataset(
    image_root=test_img_dir,
    mask_root=None,
    trainsize=img_size,
    is_train=False,
    augmentations=False
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Visualization Functions

In [None]:
def visualize_batch(images, masks, predictions=None, num_samples=4):
    """Visualize a batch of images, masks and predictions"""
    plt.figure(figsize=(15, 5 * min(num_samples, len(images))))
    
    for i in range(min(num_samples, len(images))):
        # Convert tensors to numpy arrays
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        mask = masks[i].cpu().squeeze().numpy()
        
        # Display original image
        plt.subplot(min(num_samples, len(images)), 3, i*3 + 1)
        plt.imshow(img)
        plt.title(f"Image {i+1}")
        plt.axis('off')
        
        # Display ground truth mask
        plt.subplot(min(num_samples, len(images)), 3, i*3 + 2)
        plt.imshow(mask, cmap='gray')
        plt.title(f"Ground Truth {i+1}")
        plt.axis('off')
        
        if predictions is not None:
            pred = predictions[i].cpu().squeeze().numpy()
            
            # Display prediction
            plt.subplot(min(num_samples, len(images)), 3, i*3 + 3)
            plt.imshow(pred, cmap='gray')
            plt.title(f"Prediction {i+1}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Test visualization with a batch from validation set
val_batch = next(iter(val_loader))
visualize_batch(val_batch[0], val_batch[1])

## Initialize the EMCAD Model

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize EMCAD model for binary segmentation
model = EMCADNet(
    num_classes=num_classes,
    kernel_sizes=kernel_sizes,
    expansion_factor=expansion_factor,
    dw_parallel=True,  # Default setting
    add=True,  # Default setting: use addition instead of concatenation
    lgag_ks=lgag_ks,
    activation=activation_mscb,
    encoder=encoder,
    pretrain=True,  # Use pretrained weights
    pretrained_dir=pretrained_dir
)

# Move model to device
model = model.to(device)

# Define loss functions
dice_loss = DiceLoss(num_classes)
ce_loss = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.1, patience=10, verbose=True
)

print("Model initialized successfully.")

## Training Loop

In [None]:
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    dice_meter = AvgMeter()
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for images, masks, _ in progress_bar:
        images = images.to(device)
        masks = masks.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, mode='train')
        
        # Handle multiple outputs from the model
        if isinstance(outputs, list):
            loss = 0.0
            w_ce, w_dice = 0.3, 0.7  # Loss weights
            
            for output in outputs:
                # Cross entropy loss - masks should be long type with class indices
                targets = masks.squeeze(1).long()
                loss_ce = ce_loss(output, targets)
                loss_dice = dice_loss(output, targets, softmax=True)
                loss += w_ce * loss_ce + w_dice * loss_dice
                
            loss = loss / len(outputs)  # Average over outputs
        else:
            # Single output
            targets = masks.squeeze(1).long()
            loss_ce = ce_loss(outputs, targets)
            loss_dice = dice_loss(outputs, targets, softmax=True)
            loss = w_ce * loss_ce + w_dice * loss_dice
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Get the final predictions for calculating metrics
        if isinstance(outputs, list):
            final_output = outputs[-1]
        else:
            final_output = outputs
        
        # Get the prediction
        probs = torch.softmax(final_output, dim=1)
        preds = torch.argmax(probs, dim=1).unsqueeze(1)
        
        # Calculate Dice score
        dice = 1 - dice_loss._dice_loss(preds.float(), masks)
        dice_meter.update(dice.item())
        
        # Update progress bar
        progress_bar.set_postfix(loss=loss.item(), dice=dice_meter.avg)
    
    return total_loss / len(train_loader), dice_meter.avg

def validate(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    dice_meter = AvgMeter()
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validation")
        
        for images, masks, _ in progress_bar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Get the final predictions if multiple outputs
            if isinstance(outputs, list):
                outputs = outputs[-1]
            
            # Calculate loss
            targets = masks.squeeze(1).long()
            loss_ce = ce_loss(outputs, targets)
            loss_dice = dice_loss(outputs, targets, softmax=True)
            loss = 0.3 * loss_ce + 0.7 * loss_dice
            
            total_loss += loss.item()
            
            # Get the prediction
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1).unsqueeze(1)
            
            # Calculate Dice score
            dice = 1 - dice_loss._dice_loss(preds.float(), masks)
            dice_meter.update(dice.item())
            
            # Update progress bar
            progress_bar.set_postfix(loss=loss.item(), dice=dice_meter.avg)
    
    return total_loss / len(val_loader), dice_meter.avg

In [None]:
# Training loop
best_dice = 0.0
train_losses = []
val_losses = []
train_dices = []
val_dices = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    # Train for one epoch
    train_loss, train_dice = train_epoch(model, train_loader, optimizer, device)
    train_losses.append(train_loss)
    train_dices.append(train_dice)
    
    # Validate
    val_loss, val_dice = validate(model, val_loader, device)
    val_losses.append(val_loss)
    val_dices.append(val_dice)
    
    print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")
    
    # Update learning rate based on validation dice score
    scheduler.step(val_dice)
    
    # Save model if validation dice improves
    if val_dice > best_dice:
        best_dice = val_dice
        print(f"New best validation Dice: {best_dice:.4f}")
        # Save model
        torch.save(model.state_dict(), os.path.join(model_dir, 'best_model.pth'))
        print("Saved best model checkpoint")
    
    # Save model at each epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_dice': train_dice,
        'val_dice': val_dice,
    }, os.path.join(model_dir, f'checkpoint_epoch_{epoch+1}.pth'))

# Plot training and validation metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_dices, label='Train Dice')
plt.plot(val_dices, label='Val Dice')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(result_dir, 'training_metrics.png'))
plt.show()

## Evaluate on Validation Set

In [None]:
# Load best model
best_model_path = os.path.join(model_dir, 'best_model.pth')
model.load_state_dict(torch.load(best_model_path))
model.eval()

# Get some random validation samples
val_iter = iter(val_loader)
val_batch = next(val_iter)
images, masks, _ = val_batch
images = images.to(device)

# Get predictions
with torch.no_grad():
    outputs = model(images)
    if isinstance(outputs, list):
        outputs = outputs[-1]
    
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1).unsqueeze(1)

# Visualize results
visualize_batch(images, masks, preds)

# Calculate metrics on the entire validation set
val_loss, val_dice = validate(model, val_loader, device)
print(f"Validation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}")

## Generate Predictions for Test Set

In [None]:
def generate_predictions(model, test_loader, device, output_dir):
    """Generate and save predictions for the test set"""
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    with torch.no_grad():
        for images, filenames in tqdm(test_loader, desc="Generating predictions"):
            images = images.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Get the final predictions if multiple outputs
            if isinstance(outputs, list):
                outputs = outputs[-1]
            
            # Generate binary predictions
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            # Save each prediction
            for i, filename in enumerate(filenames):
                pred_mask = preds[i].cpu().numpy() * 255  # Convert back to 0-255 range
                cv2.imwrite(os.path.join(output_dir, filename), pred_mask.astype(np.uint8))

# Create directory for test predictions
test_output_dir = os.path.join(result_dir, 'test_predictions')
os.makedirs(test_output_dir, exist_ok=True)

# Generate predictions
generate_predictions(model, test_loader, device, test_output_dir)
print(f"Test predictions saved to {test_output_dir}")

# Prepare for Kaggle submission if needed
def prepare_kaggle_submission(test_output_dir, submission_file):
    """Prepare submission file for Kaggle"""
    # This function depends on the specific Kaggle competition format
    # You may need to adjust this based on your competition's requirements
    
    # Example: Create a simple CSV with filename and predicted class
    import csv
    
    with open(submission_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['id', 'predicted'])  # Header
        
        for filename in os.listdir(test_output_dir):
            # Read prediction
            pred = cv2.imread(os.path.join(test_output_dir, filename), cv2.IMREAD_GRAYSCALE)
            
            # Process according to submission requirements
            # This is just an example - adjust based on your specific needs
            pred_value = 1 if np.mean(pred) > 127 else 0
            
            # Write to CSV
            writer.writerow([filename.split('.')[0], pred_value])

# Create Kaggle submission file
submission_file = os.path.join(result_dir, 'submission.csv')
# Uncomment the line below if you want to create the submission file
# prepare_kaggle_submission(test_output_dir, submission_file)

## Clean Up

If you created temporary directories, you might want to clean them up.

In [None]:
# Clean up temporary directories
import shutil

# Uncomment if you want to remove the temporary directories
# shutil.rmtree('./tmp', ignore_errors=True)

## Conclusion

You have successfully:
1. Prepared your data for training with EMCAD
2. Trained an EMCAD model for binary segmentation
3. Evaluated the model on validation data
4. Generated predictions for test data that can be submitted to Kaggle

To improve results, consider:
- Experimenting with different data augmentations
- Trying different hyperparameters (learning rate, batch size)
- Using different encoders (e.g., pvt_v2_b0, resnet34, etc.)
- Implementing cross-validation to get more stable results