# DuAT Model Training for Medical Image Segmentation

This notebook trains the DuAT model on your custom dataset with the following structure:
- train/images (.jpg)
- train/masks (.png with 0-255 values)
- val/images (.jpg)
- val/masks (.png with 0-255 values)
- test/images (.jpg, no masks)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import os
import argparse
from datetime import datetime
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
from lib.DuAT import DuAT
from utils.utils import clip_gradient, AvgMeter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Dataset Configuration
Set your dataset paths here:

In [None]:
# Dataset paths - Update these paths according to your data location
TRAIN_IMAGES_PATH = "path/to/your/train/images"
TRAIN_MASKS_PATH = "path/to/your/train/masks"
VAL_IMAGES_PATH = "path/to/your/val/images"
VAL_MASKS_PATH = "path/to/your/val/masks"
TEST_IMAGES_PATH = "path/to/your/test/images"
OUTPUT_PATH = "predictions"

# Training parameters
BATCH_SIZE = 8
IMAGE_SIZE = 352
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100
CLIP_GRADIENT = 0.5

# Create output directory
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs("model_checkpoints", exist_ok=True)

## Custom Dataset Class

In [None]:
class CustomSegmentationDataset(Dataset):
    def __init__(self, images_path, masks_path, image_size=352, is_training=True):
        self.images_path = images_path
        self.masks_path = masks_path
        self.image_size = image_size
        self.is_training = is_training
        
        # Get all image files
        self.image_files = [f for f in os.listdir(images_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.image_files.sort()
        
        # Define transforms
        if is_training:
            self.transform = A.Compose([
                A.Resize(image_size, image_size, interpolation=cv2.INTER_NEAREST),
                A.HorizontalFlip(p=0.3),
                A.VerticalFlip(p=0.3),
                A.RandomBrightnessContrast(p=0.2),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = A.Compose([
                A.Resize(image_size, image_size, interpolation=cv2.INTER_NEAREST),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_path, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask_name = img_name.replace('.jpg', '.png').replace('.jpeg', '.png')
        mask_path = os.path.join(self.masks_path, mask_name)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = mask.astype(np.float32) / 255.0  # Normalize to 0-1
        
        # Apply transforms
        transformed = self.transform(image=image, mask=mask)
        
        return transformed['image'], transformed['mask'].unsqueeze(0)

class TestDataset(Dataset):
    def __init__(self, images_path, image_size=352):
        self.images_path = images_path
        self.image_size = image_size
        
        # Get all image files
        self.image_files = [f for f in os.listdir(images_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.image_files.sort()
        
        # Define transforms
        self.transform = A.Compose([
            A.Resize(image_size, image_size, interpolation=cv2.INTER_NEAREST),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_path, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Store original size for later use
        original_size = image.shape[:2]
        
        # Apply transforms
        transformed = self.transform(image=image)
        
        return transformed['image'], img_name, original_size

## Loss Function

In [None]:
def structure_loss(pred, mask):
    """
    Structure loss function combining weighted BCE and weighted IoU
    """
    weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
    wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask) * weit).sum(dim=(2, 3))
    union = ((pred + mask) * weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1) / (union - inter + 1)

    return (wbce + wiou).mean()

def dice_coefficient(pred, target, smooth=1):
    """
    Calculate Dice coefficient
    """
    pred = torch.sigmoid(pred)
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

## Validation Function

In [None]:
def validate_model(model, val_loader, device):
    """
    Validate the model and return average dice score
    """
    model.eval()
    total_dice = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            output1, output2 = model(images)
            
            # Combine outputs
            combined_output = output1 + output2
            
            # Calculate dice coefficient
            dice = dice_coefficient(combined_output, masks)
            total_dice += dice.item() * images.size(0)
            total_samples += images.size(0)
    
    avg_dice = total_dice / total_samples
    return avg_dice

## Model Initialization and Data Loading

In [None]:
# Initialize model
model = DuAT().to(device)
print("Model initialized successfully!")

# Create datasets
train_dataset = CustomSegmentationDataset(
    TRAIN_IMAGES_PATH, TRAIN_MASKS_PATH, 
    image_size=IMAGE_SIZE, is_training=True
)

val_dataset = CustomSegmentationDataset(
    VAL_IMAGES_PATH, VAL_MASKS_PATH, 
    image_size=IMAGE_SIZE, is_training=False
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, 
    shuffle=True, num_workers=4, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, 
    shuffle=False, num_workers=4, pin_memory=True
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

## Training Loop

In [None]:
# Training history
train_losses = []
val_dice_scores = []
best_dice = 0.0

print("Starting training...")
print("=" * 50)

for epoch in range(NUM_EPOCHS):
    model.train()
    
    # Training metrics
    loss_meter = AvgMeter()
    
    for batch_idx, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output1, output2 = model(images)
        
        # Calculate losses
        loss1 = structure_loss(output1, masks)
        loss2 = structure_loss(output2, masks)
        total_loss = loss1 + loss2
        
        # Backward pass
        total_loss.backward()
        clip_gradient(optimizer, CLIP_GRADIENT)
        optimizer.step()
        
        # Update metrics
        loss_meter.update(total_loss.item(), images.size(0))
        
        # Print progress
        if (batch_idx + 1) % 20 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss_meter.avg:.4f}')
    
    # Validation
    val_dice = validate_model(model, val_loader, device)
    
    # Update learning rate
    scheduler.step()
    
    # Save metrics
    train_losses.append(loss_meter.avg)
    val_dice_scores.append(val_dice)
    
    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] - Train Loss: {loss_meter.avg:.4f}, Val Dice: {val_dice:.4f}')
    
    # Save best model
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), 'model_checkpoints/best_model.pth')
        print(f'New best model saved! Dice: {best_dice:.4f}')
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'model_checkpoints/checkpoint_epoch_{epoch+1}.pth')
    
    print("-" * 50)

print("Training completed!")
print(f"Best validation Dice score: {best_dice:.4f}")

## Plot Training History

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(val_dice_scores)
plt.title('Validation Dice Score')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()

## Test Data Prediction

In [None]:
def predict_test_data(model, test_loader, output_path, device):
    """
    Predict on test data and save results
    """
    model.eval()
    
    print("Starting prediction on test data...")
    
    with torch.no_grad():
        for batch_idx, (images, filenames, original_sizes) in enumerate(test_loader):
            images = images.to(device)
            
            # Forward pass
            output1, output2 = model(images)
            
            # Combine outputs and apply sigmoid
            combined_output = torch.sigmoid(output1 + output2)
            
            # Process each image in the batch
            for i in range(images.size(0)):
                pred_mask = combined_output[i, 0].cpu().numpy()
                filename = filenames[i]
                original_h, original_w = original_sizes[0][i].item(), original_sizes[1][i].item()
                
                # Resize prediction to original size
                pred_mask_resized = cv2.resize(pred_mask, (original_w, original_h), interpolation=cv2.INTER_CUBIC)
                
                # Convert to 0-255 range
                pred_mask_uint8 = (pred_mask_resized * 255).astype(np.uint8)
                
                # Save prediction
                output_filename = filename.replace('.jpg', '.png').replace('.jpeg', '.png')
                output_filepath = os.path.join(output_path, output_filename)
                cv2.imwrite(output_filepath, pred_mask_uint8)
            
            if (batch_idx + 1) % 10 == 0:
                print(f'Processed {batch_idx + 1}/{len(test_loader)} batches')
    
    print(f"Predictions saved to: {output_path}")

# Load best model for testing
model.load_state_dict(torch.load('model_checkpoints/best_model.pth'))

# Create test dataset and loader
test_dataset = TestDataset(TEST_IMAGES_PATH, image_size=IMAGE_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Test samples: {len(test_dataset)}")

# Run prediction
predict_test_data(model, test_loader, OUTPUT_PATH, device)

## Visualize Sample Predictions

In [None]:
def visualize_predictions(model, val_loader, device, num_samples=4):
    """
    Visualize some predictions on validation data
    """
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    sample_count = 0
    
    with torch.no_grad():
        for images, masks in val_loader:
            if sample_count >= num_samples:
                break
                
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            output1, output2 = model(images)
            predictions = torch.sigmoid(output1 + output2)
            
            for i in range(min(images.size(0), num_samples - sample_count)):
                # Convert tensors to numpy
                image = images[i].cpu().permute(1, 2, 0).numpy()
                # Denormalize image
                image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                image = np.clip(image, 0, 1)
                
                mask = masks[i, 0].cpu().numpy()
                pred = predictions[i, 0].cpu().numpy()
                
                # Plot
                axes[sample_count, 0].imshow(image)
                axes[sample_count, 0].set_title('Original Image')
                axes[sample_count, 0].axis('off')
                
                axes[sample_count, 1].imshow(mask, cmap='gray')
                axes[sample_count, 1].set_title('Ground Truth')
                axes[sample_count, 1].axis('off')
                
                axes[sample_count, 2].imshow(pred, cmap='gray')
                axes[sample_count, 2].set_title('Prediction')
                axes[sample_count, 2].axis('off')
                
                sample_count += 1
                if sample_count >= num_samples:
                    break
    
    plt.tight_layout()
    plt.savefig('sample_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()

# Visualize some predictions
visualize_predictions(model, val_loader, device, num_samples=4)

## Summary
Model training is complete! Here's what was accomplished:

1. **Data Loading**: Custom dataset classes for your specific data structure
2. **Training**: Full training loop with validation monitoring
3. **Model Saving**: Best model saved based on validation Dice score
4. **Testing**: Predictions generated for test data and saved to output folder
5. **Visualization**: Training history plots and sample predictions

### Files Generated:
- `model_checkpoints/best_model.pth`: Best performing model
- `model_checkpoints/checkpoint_epoch_*.pth`: Regular checkpoints
- `predictions/`: Folder containing test predictions
- `training_history.png`: Training loss and validation metrics
- `sample_predictions.png`: Visual comparison of predictions

### Next Steps:
- Evaluate predictions using your preferred metrics
- Fine-tune hyperparameters if needed
- Apply post-processing to predictions if required