# Mamba-Sea Training Notebook for Custom Segmentation Dataset

This notebook demonstrates how to train the Mamba-Sea model on your custom dataset with the following structure:
- train/images (.jpg)
- train/masks (.png with 0-255 values)
- val/images (.jpg) 
- val/masks (.png with 0-255 values)
- test/images (.jpg - no masks)

After training, the model will predict on test data and save results to a new folder.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration and Hyperparameters

In [None]:
# Configuration
CONFIG = {
    'data_root': './data',  # Path to your data folder
    'image_size': 512,
    'batch_size': 8,
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'num_workers': 4,
    'save_path': './checkpoints',
    'prediction_path': './predictions',
    'model_type': 'VMUNet',  # or 'VMUnet_enhance'
}

# Create directories
os.makedirs(CONFIG['save_path'], exist_ok=True)
os.makedirs(CONFIG['prediction_path'], exist_ok=True)

## Custom Dataset Class

In [None]:
class CustomSegmentationDataset(Dataset):
    def __init__(self, data_root, split='train', image_size=512, augment=True):
        self.data_root = Path(data_root)
        self.split = split
        self.image_size = image_size
        self.augment = augment and split == 'train'
        
        # Get image paths
        if split == 'test':
            self.image_dir = self.data_root / 'test' / 'images'
            self.mask_dir = None
            self.image_paths = list(self.image_dir.glob('*.jpg'))
        else:
            self.image_dir = self.data_root / split / 'images'
            self.mask_dir = self.data_root / split / 'masks'
            self.image_paths = list(self.image_dir.glob('*.jpg'))
        
        print(f"{split} dataset: {len(self.image_paths)} images")
        
        # Define transforms
        if self.augment:
            self.transform = A.Compose([
                A.Resize(image_size, image_size),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
                A.RandomBrightnessContrast(p=0.3),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = A.Compose([
                A.Resize(image_size, image_size),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
            
        self.mask_transform = A.Compose([
            A.Resize(image_size, image_size, interpolation=cv2.INTER_NEAREST),
            ToTensorV2()
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        # Load image
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.split == 'test':
            # For test data, only return image and filename
            transformed = self.transform(image=image)
            return transformed['image'], image_path.name
        
        # Load mask
        mask_name = image_path.stem + '.png'
        mask_path = self.mask_dir / mask_name
        
        if not mask_path.exists():
            raise FileNotFoundError(f"Mask not found: {mask_path}")
            
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        
        # Convert mask to binary (0 or 1)
        mask = (mask > 127).astype(np.uint8)
        
        # Apply transforms
        if self.augment:
            transformed = self.transform(image=image, mask=mask)
            image, mask = transformed['image'], transformed['mask']
        else:
            image_transformed = self.transform(image=image)
            mask_transformed = self.mask_transform(image=mask)
            image, mask = image_transformed['image'], mask_transformed['image']
        
        # Ensure mask is long tensor for CrossEntropyLoss
        mask = mask.long().squeeze(0) if mask.dim() > 2 else mask.long()
        
        return image, mask, image_path.name

## Model Initialization

In [None]:
# Import model factory
import sys
sys.path.append('.')
from model import factory

# Initialize model
model = factory(CONFIG['model_type'], 3, 2)  # 3 input channels, 2 classes (background + foreground)

# Load pretrained weights if available
try:
    model.load_from()
    print("Loaded pretrained weights successfully")
except Exception as e:
    print(f"Could not load pretrained weights: {e}")
    print("Training from scratch")

model = model.to(device)
print(f"Model loaded: {CONFIG['model_type']}")

## Data Loading

In [None]:
# Create datasets
train_dataset = CustomSegmentationDataset(CONFIG['data_root'], 'train', CONFIG['image_size'], augment=True)
val_dataset = CustomSegmentationDataset(CONFIG['data_root'], 'val', CONFIG['image_size'], augment=False)
test_dataset = CustomSegmentationDataset(CONFIG['data_root'], 'test', CONFIG['image_size'], augment=False)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                         num_workers=CONFIG['num_workers'], pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, 
                       num_workers=CONFIG['num_workers'], pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, 
                        num_workers=CONFIG['num_workers'], pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## Loss Functions and Optimizer

In [None]:
# Loss functions
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        target_one_hot = torch.zeros_like(pred)
        target_one_hot.scatter_(1, target.unsqueeze(1), 1)
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

# Combined loss
criterion_ce = nn.CrossEntropyLoss()
criterion_dice = DiceLoss()

# Optimizer
optimizer = AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'])

## Training Functions

In [None]:
def calculate_dice_score(pred, target, smooth=1):
    pred = torch.softmax(pred, dim=1)
    pred = torch.argmax(pred, dim=1)
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.item()

def train_epoch(model, train_loader, optimizer, criterion_ce, criterion_dice, device):
    model.train()
    total_loss = 0
    total_dice = 0
    num_batches = len(train_loader)
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (images, masks, _) in enumerate(pbar):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        
        loss_ce = criterion_ce(outputs, masks)
        loss_dice = criterion_dice(outputs, masks)
        loss = loss_ce + loss_dice
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        dice_score = calculate_dice_score(outputs, masks)
        total_dice += dice_score
        
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Dice': f'{dice_score:.4f}',
            'Avg Loss': f'{total_loss/(batch_idx+1):.4f}',
            'Avg Dice': f'{total_dice/(batch_idx+1):.4f}'
        })
    
    return total_loss / num_batches, total_dice / num_batches

def validate_epoch(model, val_loader, criterion_ce, criterion_dice, device):
    model.eval()
    total_loss = 0
    total_dice = 0
    num_batches = len(val_loader)
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for batch_idx, (images, masks, _) in enumerate(pbar):
            images, masks = images.to(device), masks.to(device)
            
            outputs = model(images)
            
            loss_ce = criterion_ce(outputs, masks)
            loss_dice = criterion_dice(outputs, masks)
            loss = loss_ce + loss_dice
            
            total_loss += loss.item()
            dice_score = calculate_dice_score(outputs, masks)
            total_dice += dice_score
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Dice': f'{dice_score:.4f}',
                'Avg Loss': f'{total_loss/(batch_idx+1):.4f}',
                'Avg Dice': f'{total_dice/(batch_idx+1):.4f}'
            })
    
    return total_loss / num_batches, total_dice / num_batches

## Training Loop

In [None]:
# Training loop
best_dice = 0
train_losses = []
val_losses = []
train_dices = []
val_dices = []

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 50)
    
    # Train
    train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion_ce, criterion_dice, device)
    
    # Validate
    val_loss, val_dice = validate_epoch(model, val_loader, criterion_ce, criterion_dice, device)
    
    # Update learning rate
    scheduler.step()
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_dices.append(train_dice)
    val_dices.append(val_dice)
    
    print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_dice': best_dice,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_dice': train_dice,
            'val_dice': val_dice
        }, os.path.join(CONFIG['save_path'], 'best_model.pth'))
        print(f"New best model saved with Dice: {best_dice:.4f}")
    
    # Save latest model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_dice': best_dice,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_dice': train_dice,
        'val_dice': val_dice
    }, os.path.join(CONFIG['save_path'], 'latest_model.pth'))

print(f"\nTraining completed! Best Dice score: {best_dice:.4f}")

## Plot Training History

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Val Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Dice plot
ax2.plot(train_dices, label='Train Dice')
ax2.plot(val_dices, label='Val Dice')
ax2.set_title('Training and Validation Dice Score')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Dice Score')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['save_path'], 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

## Load Best Model for Prediction

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(CONFIG['save_path'], 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model with Dice score: {checkpoint['best_dice']:.4f}")

model.eval()

## Prediction on Test Data

In [None]:
def predict_and_save(model, test_loader, save_path, device):
    """Predict on test data and save results"""
    os.makedirs(save_path, exist_ok=True)
    
    model.eval()
    with torch.no_grad():
        pbar = tqdm(test_loader, desc='Predicting')
        for images, filenames in pbar:
            images = images.to(device)
            
            outputs = model(images)
            predictions = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(predictions, dim=1)
            
            # Process each image in the batch
            for i, filename in enumerate(filenames):
                pred_mask = predictions[i].cpu().numpy()
                
                # Convert to 0-255 range
                pred_mask = (pred_mask * 255).astype(np.uint8)
                
                # Save prediction
                save_filename = os.path.splitext(filename)[0] + '_pred.png'
                save_filepath = os.path.join(save_path, save_filename)
                
                Image.fromarray(pred_mask).save(save_filepath)
                
            pbar.set_postfix({'Processed': len(filenames)})
    
    print(f"Predictions saved to: {save_path}")

# Run prediction
predict_and_save(model, test_loader, CONFIG['prediction_path'], device)

## Visualization of Results

In [None]:
def visualize_predictions(test_dataset, model, device, num_samples=6):
    """Visualize some test predictions"""
    model.eval()
    
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 8))
    if num_samples == 1:
        axes = axes.reshape(2, 1)
    
    indices = np.random.choice(len(test_dataset), num_samples, replace=False)
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, filename = test_dataset[idx]
            
            # Original image for display (denormalize)
            orig_image = image.clone()
            orig_image = orig_image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            orig_image = orig_image + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            orig_image = torch.clamp(orig_image, 0, 1)
            orig_image = orig_image.permute(1, 2, 0).numpy()
            
            # Prediction
            image_batch = image.unsqueeze(0).to(device)
            output = model(image_batch)
            pred = torch.softmax(output, dim=1)
            pred = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
            
            # Display
            axes[0, i].imshow(orig_image)
            axes[0, i].set_title(f'Original: {filename}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(pred, cmap='gray')
            axes[1, i].set_title('Prediction')
            axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['save_path'], 'prediction_samples.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Visualize some predictions
visualize_predictions(test_dataset, model, device)

## Summary

Training completed successfully! Here's what was accomplished:

1. **Custom Dataset**: Created a dataset class that handles your specific data structure
2. **Model Training**: Trained the Mamba-Sea model with combined CE + Dice loss
3. **Evaluation**: Monitored training with Dice score metrics
4. **Prediction**: Generated predictions on test data and saved them to a new folder
5. **Visualization**: Created plots showing training progress and sample predictions

### Output Files:
- `./checkpoints/best_model.pth` - Best model weights
- `./checkpoints/latest_model.pth` - Latest model weights  
- `./checkpoints/training_history.png` - Training curves
- `./checkpoints/prediction_samples.png` - Sample predictions
- `./predictions/` - Folder containing all test predictions

### Next Steps:
- Experiment with different model configurations
- Try data augmentation techniques
- Tune hyperparameters for better performance
- Evaluate on validation set with additional metrics