# MambaU-Lite Training for Image Segmentation

This notebook trains the MambaU-Lite model for a segmentation task using the following data structure:
- train/images (.jpg)
- train/masks (.png binary 0-255)
- val_images (.jpg)
- val_masks (.png binary 0-255)
- test/images (test data without masks)

After training, it predicts on test data and saves the results to a new folder.

In [None]:
# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from glob import glob
from tqdm.notebook import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Import the model and metrics
from models.mamba_ulite import ULite
from metric import dice_score, iou_score, DiceLoss, dice_tversky_loss

## Data Preparation

First, we'll create a dataset class to load images and masks from the specified directory structure.

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, img_paths, mask_paths=None, transform=None, test_mode=False):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.test_mode = test_mode
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Resize to model input size (adjust based on your requirements)
        image = image.resize((256, 256))
        
        # Convert to tensor and normalize
        image = np.array(image).astype(np.float32) / 255.0
        image = torch.from_numpy(image.transpose(2, 0, 1))
        
        if self.test_mode:
            return image, os.path.basename(img_path)
        
        # Load mask
        mask_path = self.mask_paths[idx]
        mask = Image.open(mask_path).convert('L')
        mask = mask.resize((256, 256))
        
        # Convert mask to binary (0 or 1)
        mask = np.array(mask)
        mask = (mask > 128).astype(np.float32)  # Convert from 0-255 to 0-1
        mask = torch.from_numpy(mask).unsqueeze(0)  # Add channel dimension
        
        if self.transform:
            # Apply data augmentation (you can expand this)
            if np.random.random() > 0.5:
                image = torch.flip(image, dims=[2])  # Horizontal flip
                mask = torch.flip(mask, dims=[2])
                
            if np.random.random() > 0.5:
                image = torch.flip(image, dims=[1])  # Vertical flip
                mask = torch.flip(mask, dims=[1])
        
        return image, mask

In [None]:
# Set up data paths
data_root = input("Enter path to data directory: ")  # User will input the data root directory

# Training data
train_img_paths = sorted(glob(os.path.join(data_root, 'train/images/*.jpg')))
train_mask_paths = sorted(glob(os.path.join(data_root, 'train/masks/*.png')))

# Validation data
val_img_paths = sorted(glob(os.path.join(data_root, 'val_images/*.jpg')))
val_mask_paths = sorted(glob(os.path.join(data_root, 'val_masks/*.png')))

# Test data
test_img_paths = sorted(glob(os.path.join(data_root, 'test/images/*.jpg')))

print(f"Train images: {len(train_img_paths)}")
print(f"Train masks: {len(train_mask_paths)}")
print(f"Validation images: {len(val_img_paths)}")
print(f"Validation masks: {len(val_mask_paths)}")
print(f"Test images: {len(test_img_paths)}")

# Verify paths match for training and validation
assert len(train_img_paths) == len(train_mask_paths), "Number of training images and masks don't match"
assert len(val_img_paths) == len(val_mask_paths), "Number of validation images and masks don't match"

# Create datasets
train_dataset = SegmentationDataset(train_img_paths, train_mask_paths, transform=True)
val_dataset = SegmentationDataset(val_img_paths, val_mask_paths)
test_dataset = SegmentationDataset(test_img_paths, test_mode=True)

# Create data loaders
batch_size = 4  # Adjust based on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4)

# Show a sample from the training dataset
sample_idx = np.random.randint(0, len(train_dataset))
sample_img, sample_mask = train_dataset[sample_idx]

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(sample_img.permute(1, 2, 0))
plt.title("Sample Image")
plt.subplot(1, 2, 2)
plt.imshow(sample_mask.squeeze(), cmap='gray')
plt.title("Sample Mask")
plt.show()

## Model Definition

Now we'll set up the model for training using PyTorch Lightning.

In [None]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.model = ULite()
        self.learning_rate = learning_rate
        
    def forward(self, x):
        return self.model(x)
    
    def _step(self, batch):
        images, masks = batch
        # The model now handles tensor format conversion internally
        # No need to manually permute here
        outputs = self.model(images)
        loss = dice_tversky_loss(outputs, masks)
        dice = dice_score(outputs, masks)
        iou = iou_score(outputs, masks)
        return loss, dice, iou
    
    def training_step(self, batch, batch_idx):
        loss, dice, iou = self._step(batch)
        metrics = {"loss": loss, "train_dice": dice, "train_iou": iou}
        self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, dice, iou = self._step(batch)
        metrics = {"val_loss": loss, "val_dice": dice, "val_iou": iou}
        self.log_dict(metrics, prog_bar=True)
        return metrics
    
    def test_step(self, batch, batch_idx):
        loss, dice, iou = self._step(batch)
        metrics = {"test_loss": loss, "test_dice": dice, "test_iou": iou}
        self.log_dict(metrics, prog_bar=True)
        return metrics
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="max", factor=0.5, patience=5, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_dice"
        }

## Training the Model

Now we'll train the model using PyTorch Lightning.

In [None]:
# Set up model
model = SegmentationModel(learning_rate=1e-3)

# Set up callbacks
checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename='mamba_ulite-{epoch:02d}-{val_dice:.4f}',
    monitor='val_dice',
    mode='max',
    save_top_k=1,
    verbose=True,
)

early_stop_callback = EarlyStopping(
    monitor='val_dice',
    patience=10,
    mode='max',
    verbose=True
)

# Set up trainer
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, early_stop_callback],
    precision=16,  # Use mixed precision for faster training
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    log_every_n_steps=10,
)

# Train the model
trainer.fit(model, train_loader, val_loader)

# Print best model path
print(f"Best model path: {checkpoint_callback.best_model_path}")
print(f"Best validation dice score: {checkpoint_callback.best_model_score:.4f}")

## Prediction on Test Data

Now we'll predict on the test data and save the results to a new folder.

In [None]:
# Load the best model
best_model_path = checkpoint_callback.best_model_path
model = SegmentationModel.load_from_checkpoint(best_model_path)
model.eval()
model.to(device)

# Create output directory
output_dir = os.path.join(os.getcwd(), 'predictions')
os.makedirs(output_dir, exist_ok=True)

# Predict on test data
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Predicting"):
        images, filenames = batch
        
        # No need to permute - model handles format conversion
        
        # Move to device
        images = images.to(device)
        
        # Predict
        outputs = model(images)
        
        # Convert to binary mask
        predictions = torch.sigmoid(outputs) > 0.5
        
        # Save predictions
        for i, filename in enumerate(filenames):
            # Convert prediction to image
            pred_mask = predictions[i].cpu().numpy().squeeze().astype(np.uint8) * 255
            
            # Save prediction
            output_path = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_pred.png")
            cv2.imwrite(output_path, pred_mask)

print(f"Predictions saved to {output_dir}")

## Visualize Some Predictions

Let's visualize a few test predictions alongside the original images.

In [None]:
# Get a few test images and their predictions
test_samples = min(5, len(test_img_paths))
plt.figure(figsize=(15, test_samples*5))

for i in range(test_samples):
    # Load original image
    img_path = test_img_paths[i]
    img_name = os.path.basename(img_path)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Load prediction
    pred_path = os.path.join(output_dir, f"{os.path.splitext(img_name)[0]}_pred.png")
    pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
    
    # Display
    plt.subplot(test_samples, 2, i*2+1)
    plt.imshow(img)
    plt.title(f"Original Image: {img_name}")
    plt.axis('off')
    
    plt.subplot(test_samples, 2, i*2+2)
    plt.imshow(pred, cmap='gray')
    plt.title(f"Prediction")
    plt.axis('off')

plt.tight_layout()
plt.show()

## Summary

In this notebook, we:
1. Set up a dataset to load images and masks from the specified directory structure
2. Created a PyTorch Lightning module for the MambaU-Lite model
3. Trained the model on the training data and validated it
4. Predicted on the test data and saved the results to a new folder
5. Visualized some test predictions

The predictions are saved in the `predictions` folder with the naming convention `{original_filename}_pred.png`.