# RWKV-UNet Training for Custom Segmentation Dataset

This notebook demonstrates how to train RWKV-UNet on a custom binary segmentation dataset with the following structure:
- train/images (.jpg)
- train/masks (.png with 0/255 values)
- val/images (.jpg) 
- val/masks (.png with 0/255 values)
- test/images (.jpg, no masks)

In [None]:
# Import required libraries
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
import logging
from PIL import Image
import argparse
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms

# Import your modules
from datasets.dataset_custom import CustomDataset, CustomTransform
from utils import DiceLoss, test_single_volume
from rwkv_unet import RWKV_UNet

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration and Hyperparameters

In [None]:
# Configuration
class Config:
    # Data paths
    train_images_dir = "data/train/images"
    train_masks_dir = "data/train/masks"
    val_images_dir = "data/val/images" 
    val_masks_dir = "data/val/masks"
    test_images_dir = "data/test/images"
    
    # Model parameters
    img_size = 224
    num_classes = 2  # Binary segmentation (background + foreground)
    in_channels = 1
    
    # Training parameters
    batch_size = 8
    max_epochs = 30
    base_lr = 0.001
    weight_decay = 0.0001
    
    # Paths
    pretrained_path = 'net_B.pth'  # Path to pretrained encoder weights
    output_dir = "outputs/custom_segmentation"
    predictions_dir = "predictions/custom_test"
    
    # Other
    seed = 1234
    
config = Config()

# Create output directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.predictions_dir, exist_ok=True)

## Data Loading and Visualization

In [None]:
# Create datasets
transform = CustomTransform(config.img_size)

train_dataset = CustomDataset(
    images_dir=config.train_images_dir,
    masks_dir=config.train_masks_dir, 
    transform=transform
)

val_dataset = CustomDataset(
    images_dir=config.val_images_dir,
    masks_dir=config.val_masks_dir,
    transform=transform
)

test_dataset = CustomDataset(
    images_dir=config.test_images_dir,
    is_test=True,
    transform=transforms.Compose([
        transforms.Resize((config.img_size, config.img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

In [None]:
# Visualize some training samples
def visualize_samples(dataset, num_samples=4):
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
    
    for i in range(num_samples):
        sample = dataset[i]
        image = sample['image'].squeeze().cpu().numpy()
        mask = sample['label'].cpu().numpy()
        
        # Denormalize image for visualization
        image = (image + 1) / 2  # From [-1,1] to [0,1]
        
        axes[0, i].imshow(image, cmap='gray')
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(mask, cmap='gray')
        axes[1, i].set_title(f'Mask {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_samples(train_dataset)

## Model Initialization

In [None]:
# Set random seeds
torch.manual_seed(config.seed)
np.random.seed(config.seed)

# Initialize model
model = RWKV_UNet(
    in_channels=config.in_channels,
    num_classes=config.num_classes,
    img_size=config.img_size,
    pretrained_path=config.pretrained_path
).to(device)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

# Loss functions
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(config.num_classes)

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=config.base_lr, weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.max_epochs, eta_min=0)

# TensorBoard logging
writer = SummaryWriter(config.output_dir + '/log')

## Training Loop

In [None]:
def train_epoch(model, train_loader, optimizer, ce_loss, dice_loss, device):
    model.train()
    total_loss = 0
    num_batches = len(train_loader)
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss_ce = ce_loss(outputs, labels)
        loss_dice = dice_loss(outputs, labels, softmax=True)
        loss = 0.3 * loss_ce + 0.7 * loss_dice
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})
    
    return total_loss / num_batches

def validate_epoch(model, val_loader, device):
    model.eval()
    total_dice = 0
    num_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            image = batch['image'].to(device)
            label = batch['label'].to(device)
            
            # Simple validation without test_single_volume for now
            outputs = model(image)
            pred = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
            
            # Calculate Dice score for foreground class
            pred_fg = (pred == 1).float()
            label_fg = (label == 1).float()
            
            intersection = torch.sum(pred_fg * label_fg)
            union = torch.sum(pred_fg) + torch.sum(label_fg)
            
            if union > 0:
                dice = (2. * intersection / union).item()
                total_dice += dice
                num_samples += 1
    
    return total_dice / num_samples if num_samples > 0 else 0

In [None]:
# Training loop
best_dice = 0.0
train_losses = []
val_dices = []

for epoch in range(config.max_epochs):
    print(f"\nEpoch {epoch+1}/{config.max_epochs}")
    
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, ce_loss, dice_loss, device)
    train_losses.append(train_loss)
    
    # Validation
    val_dice = validate_epoch(model, val_loader, device)
    val_dices.append(val_dice)
    
    # Learning rate scheduling
    scheduler.step()
    
    # Logging
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Dice/Validation', val_dice, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
    
    print(f"Train Loss: {train_loss:.4f}, Val Dice: {val_dice:.4f}")
    
    # Save best model
    if val_dice > best_dice:
        best_dice = val_dice
        best_model_path = os.path.join(config.output_dir, 'best_model.pth')
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved with Dice: {best_dice:.4f}")
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(config.output_dir, f'epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), checkpoint_path)

# Save final model
final_model_path = os.path.join(config.output_dir, 'final_model.pth')
torch.save(model.state_dict(), final_model_path)

print(f"\nTraining completed! Best Dice score: {best_dice:.4f}")

## Training Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(val_dices)
plt.title('Validation Dice Score')
plt.xlabel('Epoch') 
plt.ylabel('Dice Score')
plt.grid(True)

plt.tight_layout()
plt.show()

# Display some validation predictions
def show_predictions(model, dataset, num_samples=4):
    model.eval()
    fig, axes = plt.subplots(3, num_samples, figsize=(15, 9))
    
    with torch.no_grad():
        for i in range(num_samples):
            sample = dataset[i]
            image = sample['image'].unsqueeze(0).to(device)
            label = sample['label'].numpy()
            
            output = model(image)
            pred = torch.argmax(torch.softmax(output, dim=1), dim=1)
            pred = pred.squeeze().cpu().numpy()
            
            # Denormalize image
            img_show = sample['image'].squeeze().numpy()
            img_show = (img_show + 1) / 2
            
            axes[0, i].imshow(img_show, cmap='gray')
            axes[0, i].set_title(f'Image {i+1}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(label, cmap='gray')
            axes[1, i].set_title(f'Ground Truth {i+1}')
            axes[1, i].axis('off')
            
            axes[2, i].imshow(pred, cmap='gray')
            axes[2, i].set_title(f'Prediction {i+1}')
            axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.show()

show_predictions(model, val_dataset)

## Test Data Prediction

In [None]:
# Load best model for testing
best_model_path = os.path.join(config.output_dir, 'best_model.pth')
model.load_state_dict(torch.load(best_model_path))
model.eval()

print("Generating predictions on test data...")

# Create prediction output directory
test_predictions_dir = os.path.join(config.predictions_dir, 'masks')
os.makedirs(test_predictions_dir, exist_ok=True)

# Process test data
with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Predicting")):
        image = batch['image'].to(device)
        case_name = batch['case_name'][0]
        
        # Get prediction
        output = model(image)
        pred = torch.argmax(torch.softmax(output, dim=1), dim=1)
        pred = pred.squeeze().cpu().numpy()
        
        # Convert prediction to 0/255 format
        pred_mask = (pred * 255).astype(np.uint8)
        
        # Save prediction
        output_filename = case_name.replace('.jpg', '.png')
        output_path = os.path.join(test_predictions_dir, output_filename)
        
        # Save as PIL Image
        pred_img = Image.fromarray(pred_mask, mode='L')
        pred_img.save(output_path)

print(f"Predictions saved to: {test_predictions_dir}")
print(f"Total predictions generated: {len(test_dataset)}")

## Visualize Test Predictions

In [None]:
# Visualize some test predictions
def visualize_test_predictions(num_samples=6):
    fig, axes = plt.subplots(2, num_samples, figsize=(18, 6))
    
    test_files = os.listdir(config.test_images_dir)[:num_samples]
    
    for i, filename in enumerate(test_files):
        # Load original image
        img_path = os.path.join(config.test_images_dir, filename)
        image = Image.open(img_path).convert('RGB')
        
        # Load prediction
        pred_filename = filename.replace('.jpg', '.png')
        pred_path = os.path.join(test_predictions_dir, pred_filename)
        if os.path.exists(pred_path):
            prediction = Image.open(pred_path).convert('L')
            
            axes[0, i].imshow(image)
            axes[0, i].set_title(f'Test Image {i+1}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(prediction, cmap='gray')
            axes[1, i].set_title(f'Prediction {i+1}')
            axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_test_predictions()

## Summary and Results

The training is complete! Here's what was accomplished:

1. **Data Loading**: Successfully loaded custom dataset with train/val/test splits
2. **Model Training**: Trained RWKV-UNet for binary segmentation
3. **Validation**: Monitored performance using Dice score
4. **Test Prediction**: Generated predictions for all test images
5. **Output**: Saved predictions as PNG masks (0/255 format) in the predictions folder

### Key Files Generated:
- `outputs/custom_segmentation/best_model.pth` - Best performing model
- `outputs/custom_segmentation/final_model.pth` - Final model after all epochs
- `predictions/custom_test/masks/` - Test predictions folder

### Next Steps:
- Analyze prediction quality
- Fine-tune hyperparameters if needed
- Apply post-processing if required

In [None]:
# Clean up
writer.close()
print("Training notebook execution completed!")