# U-KAN Training for Custom Segmentation Dataset

This notebook demonstrates how to train U-KAN for medical image segmentation using your custom dataset structure:
- train/images (.jpg)
- train/masks (.png)
- val/images (.jpg) 
- val/masks (.png)
- test/images (.jpg) - for prediction only

## Setup and Imports

In [None]:
import os
import sys
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from collections import OrderedDict
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import cv2

# Albumentations for data augmentation
from albumentations.augmentations import transforms
from albumentations.augmentations import geometric
from albumentations.core.composition import Compose
from albumentations import RandomRotate90, Resize

# Model imports
import archs
import losses
from custom_dataset import CustomDataset
from metrics import iou_score, indicators
from utils import AverageMeter, str2bool
from tensorboardX import SummaryWriter

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration and Hyperparameters

In [None]:
# Configuration
config = {
    # Data paths
    'train_img_dir': './data/train/images',
    'train_mask_dir': './data/train/masks', 
    'val_img_dir': './data/val/images',
    'val_mask_dir': './data/val/masks',
    'test_img_dir': './data/test/images',
    'output_dir': './outputs',
    'experiment_name': 'custom_ukan_experiment',
    
    # Model parameters
    'arch': 'UKAN',
    'num_classes': 1,
    'input_channels': 3,
    'deep_supervision': False,
    'input_w': 256,
    'input_h': 256,
    'embed_dims': [256, 320, 512],
    'no_kan': False,  # Set to True to use MLP instead of KAN
    
    # Training parameters
    'epochs': 100,
    'batch_size': 8,
    'lr': 1e-4,
    'kan_lr': 1e-2,
    'weight_decay': 1e-4,
    'kan_weight_decay': 1e-4,
    'num_workers': 4,
    
    # Loss and optimizer
    'loss': 'BCEDiceLoss',
    'optimizer': 'Adam',
    
    # Scheduler
    'scheduler': 'CosineAnnealingLR',
    'min_lr': 1e-5,
    
    # Early stopping
    'early_stopping': 20,
    'save_best_only': True
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## Utility Functions

In [None]:
def seed_everything(seed=42):
    """Set seeds for reproducibility"""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def create_dir(path):
    """Create directory if it doesn't exist"""
    os.makedirs(path, exist_ok=True)

def save_config(config, save_path):
    """Save configuration to yaml file"""
    with open(save_path, 'w') as f:
        yaml.dump(config, f)

# Set seed for reproducibility
seed_everything(42)

# Create output directory
exp_dir = os.path.join(config['output_dir'], config['experiment_name'])
create_dir(exp_dir)

# Save configuration
save_config(config, os.path.join(exp_dir, 'config.yaml'))
print(f"Experiment directory: {exp_dir}")

## Data Loading and Augmentation

In [None]:
# Define data transforms
train_transform = Compose([
    RandomRotate90(),
    geometric.transforms.Flip(),
    Resize(config['input_h'], config['input_w']),
    transforms.Normalize(),
])

val_transform = Compose([
    Resize(config['input_h'], config['input_w']),
    transforms.Normalize(),
])

# Create datasets
train_dataset = CustomDataset(
    img_dir=config['train_img_dir'],
    mask_dir=config['train_mask_dir'],
    transform=train_transform,
    is_test=False
)

val_dataset = CustomDataset(
    img_dir=config['val_img_dir'],
    mask_dir=config['val_mask_dir'],
    transform=val_transform,
    is_test=False
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    drop_last=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    drop_last=False
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## Model, Loss, and Optimizer Setup

In [None]:
# Create model
model = archs.UKAN(
    num_classes=config['num_classes'],
    input_channels=config['input_channels'],
    deep_supervision=config['deep_supervision'],
    embed_dims=config['embed_dims'],
    no_kan=config['no_kan']
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Define loss function
if config['loss'] == 'BCEWithLogitsLoss':
    criterion = nn.BCEWithLogitsLoss().to(device)
else:
    criterion = getattr(losses, config['loss'])().to(device)

# Setup optimizer with different learning rates for KAN and other parameters
param_groups = []
for name, param in model.named_parameters():
    if 'layer' in name.lower() and 'fc' in name.lower():  # KAN layers
        param_groups.append({
            'params': param, 
            'lr': config['kan_lr'], 
            'weight_decay': config['kan_weight_decay']
        })
    else:  # Other parameters
        param_groups.append({
            'params': param, 
            'lr': config['lr'], 
            'weight_decay': config['weight_decay']
        })

optimizer = optim.Adam(param_groups)

# Setup scheduler
if config['scheduler'] == 'CosineAnnealingLR':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['epochs'], eta_min=config['min_lr']
    )
else:
    scheduler = None

print("Model, loss, optimizer, and scheduler initialized!")

## Training Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    
    losses = AverageMeter()
    ious = AverageMeter()
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (images, masks, meta) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        if config['deep_supervision']:
            outputs = model(images)
            loss = 0
            for output in outputs:
                loss += criterion(output, masks)
            loss /= len(outputs)
            iou, dice, _ = iou_score(outputs[-1], masks)
        else:
            outputs = model(images)
            loss = criterion(outputs, masks)
            iou, dice, _ = iou_score(outputs, masks)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update metrics
        losses.update(loss.item(), images.size(0))
        ious.update(iou, images.size(0))
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{losses.avg:.4f}',
            'IoU': f'{ious.avg:.4f}'
        })
    
    return {'loss': losses.avg, 'iou': ious.avg}

def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    
    losses = AverageMeter()
    ious = AverageMeter()
    dices = AverageMeter()
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for batch_idx, (images, masks, meta) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            if config['deep_supervision']:
                outputs = model(images)
                loss = 0
                for output in outputs:
                    loss += criterion(output, masks)
                loss /= len(outputs)
                iou, dice, _ = iou_score(outputs[-1], masks)
            else:
                outputs = model(images)
                loss = criterion(outputs, masks)
                iou, dice, _ = iou_score(outputs, masks)
            
            # Update metrics
            losses.update(loss.item(), images.size(0))
            ious.update(iou, images.size(0))
            dices.update(dice, images.size(0))
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{losses.avg:.4f}',
                'IoU': f'{ious.avg:.4f}',
                'Dice': f'{dices.avg:.4f}'
            })
    
    return {'loss': losses.avg, 'iou': ious.avg, 'dice': dices.avg}

## Training Loop

In [None]:
# Initialize training variables
best_iou = 0.0
best_dice = 0.0
patience_counter = 0
train_history = {'epoch': [], 'train_loss': [], 'train_iou': [], 'val_loss': [], 'val_iou': [], 'val_dice': []}

# Setup tensorboard writer
writer = SummaryWriter(os.path.join(exp_dir, 'tensorboard'))

print("Starting training...")
print("="*50)

for epoch in range(config['epochs']):
    print(f"Epoch [{epoch+1}/{config['epochs']}]")
    
    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_metrics = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    if scheduler is not None:
        scheduler.step()
    
    # Log metrics
    current_lr = optimizer.param_groups[0]['lr']
    print(f"LR: {current_lr:.6f} | "
          f"Train Loss: {train_metrics['loss']:.4f} | "
          f"Train IoU: {train_metrics['iou']:.4f} | "
          f"Val Loss: {val_metrics['loss']:.4f} | "
          f"Val IoU: {val_metrics['iou']:.4f} | "
          f"Val Dice: {val_metrics['dice']:.4f}")
    
    # Save to history
    train_history['epoch'].append(epoch + 1)
    train_history['train_loss'].append(train_metrics['loss'])
    train_history['train_iou'].append(train_metrics['iou'])
    train_history['val_loss'].append(val_metrics['loss'])
    train_history['val_iou'].append(val_metrics['iou'])
    train_history['val_dice'].append(val_metrics['dice'])
    
    # Write to tensorboard
    writer.add_scalar('Train/Loss', train_metrics['loss'], epoch)
    writer.add_scalar('Train/IoU', train_metrics['iou'], epoch)
    writer.add_scalar('Val/Loss', val_metrics['loss'], epoch)
    writer.add_scalar('Val/IoU', val_metrics['iou'], epoch)
    writer.add_scalar('Val/Dice', val_metrics['dice'], epoch)
    writer.add_scalar('Learning_Rate', current_lr, epoch)
    
    # Save best model
    if val_metrics['iou'] > best_iou:
        best_iou = val_metrics['iou']
        best_dice = val_metrics['dice']
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_iou': best_iou,
            'best_dice': best_dice,
            'config': config
        }, os.path.join(exp_dir, 'best_model.pth'))
        
        print(f"★ New best model saved! IoU: {best_iou:.4f}, Dice: {best_dice:.4f}")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= config['early_stopping']:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break
    
    print("-" * 50)

# Save final model
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_history': train_history,
    'config': config
}, os.path.join(exp_dir, 'final_model.pth'))

# Save training history
pd.DataFrame(train_history).to_csv(os.path.join(exp_dir, 'training_history.csv'), index=False)

writer.close()
print(f"Training completed! Best IoU: {best_iou:.4f}, Best Dice: {best_dice:.4f}")

## Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss plot
axes[0, 0].plot(train_history['epoch'], train_history['train_loss'], label='Train Loss', color='blue')
axes[0, 0].plot(train_history['epoch'], train_history['val_loss'], label='Val Loss', color='red')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# IoU plot
axes[0, 1].plot(train_history['epoch'], train_history['train_iou'], label='Train IoU', color='blue')
axes[0, 1].plot(train_history['epoch'], train_history['val_iou'], label='Val IoU', color='red')
axes[0, 1].set_title('Training and Validation IoU')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('IoU')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Dice plot
axes[1, 0].plot(train_history['epoch'], train_history['val_dice'], label='Val Dice', color='green')
axes[1, 0].set_title('Validation Dice Score')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Dice')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Combined metrics
axes[1, 1].plot(train_history['epoch'], train_history['val_iou'], label='Val IoU', color='red')
axes[1, 1].plot(train_history['epoch'], train_history['val_dice'], label='Val Dice', color='green')
axes[1, 1].set_title('Validation Metrics')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(exp_dir, 'training_curves.png'), dpi=300, bbox_inches='tight')
plt.show()

# Print final results
print("="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"Best Validation IoU: {best_iou:.4f}")
print(f"Best Validation Dice: {best_dice:.4f}")
print(f"Total Epochs: {len(train_history['epoch'])}")
print(f"Experiment Directory: {exp_dir}")

## Test Data Prediction

In [None]:
def predict_test_data(model_path, test_img_dir, output_dir, config):
    """Predict on test data and save results"""
    
    # Load best model
    checkpoint = torch.load(model_path, map_location=device)
    model = archs.UKAN(
        num_classes=config['num_classes'],
        input_channels=config['input_channels'],
        deep_supervision=config['deep_supervision'],
        embed_dims=config['embed_dims'],
        no_kan=config['no_kan']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    # Create test dataset
    test_dataset = CustomDataset(
        img_dir=test_img_dir,
        mask_dir=None,
        transform=val_transform,
        is_test=True
    )
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        drop_last=False
    )
    
    # Create output directory
    create_dir(output_dir)
    
    print(f"Predicting on {len(test_dataset)} test images...")
    
    with torch.no_grad():
        for batch_idx, (images, _, meta) in enumerate(tqdm(test_loader, desc='Predicting')):
            images = images.to(device)
            
            # Forward pass
            if config['deep_supervision']:
                outputs = model(images)
                output = outputs[-1]
            else:
                output = model(images)
            
            # Apply sigmoid and threshold
            pred = torch.sigmoid(output).cpu().numpy()
            pred = (pred >= 0.5).astype(np.uint8)
            
            # Save prediction
            img_id = meta['img_id'][0]
            pred_mask = pred[0, 0]  # Remove batch and channel dimensions
            pred_mask = pred_mask * 255  # Convert to 0-255 range
            
            # Save as PNG
            output_path = os.path.join(output_dir, f"{img_id}_pred.png")
            Image.fromarray(pred_mask, mode='L').save(output_path)
    
    print(f"Predictions saved to: {output_dir}")

# Run prediction on test data
test_output_dir = os.path.join(exp_dir, 'test_predictions')
best_model_path = os.path.join(exp_dir, 'best_model.pth')

if os.path.exists(config['test_img_dir']) and os.path.exists(best_model_path):
    predict_test_data(best_model_path, config['test_img_dir'], test_output_dir, config)
else:
    print("Test images directory or best model not found. Skipping prediction.")
    if not os.path.exists(config['test_img_dir']):
        print(f"Test directory not found: {config['test_img_dir']}")
    if not os.path.exists(best_model_path):
        print(f"Best model not found: {best_model_path}")

## Visualization of Results

In [None]:
def visualize_predictions(val_img_dir, val_mask_dir, model_path, config, num_samples=5):
    """Visualize predictions on validation data"""
    
    # Load model
    checkpoint = torch.load(model_path, map_location=device)
    model = archs.UKAN(
        num_classes=config['num_classes'],
        input_channels=config['input_channels'],
        deep_supervision=config['deep_supervision'],
        embed_dims=config['embed_dims'],
        no_kan=config['no_kan']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    # Create validation dataset
    val_vis_dataset = CustomDataset(
        img_dir=val_img_dir,
        mask_dir=val_mask_dir,
        transform=val_transform,
        is_test=False
    )
    
    # Randomly sample images
    indices = random.sample(range(len(val_vis_dataset)), min(num_samples, len(val_vis_dataset)))
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, mask, meta = val_vis_dataset[idx]
            img_id = meta['img_id']
            
            # Add batch dimension and predict
            image_batch = image.unsqueeze(0).to(device)
            
            if config['deep_supervision']:
                outputs = model(image_batch)
                output = outputs[-1]
            else:
                output = model(image_batch)
            
            pred = torch.sigmoid(output).cpu().numpy()[0, 0]
            pred_binary = (pred >= 0.5).astype(np.uint8)
            
            # Convert for visualization
            img_vis = image.permute(1, 2, 0).numpy()
            img_vis = (img_vis - img_vis.min()) / (img_vis.max() - img_vis.min())
            
            mask_vis = mask[0].numpy()
            
            # Plot
            axes[i, 0].imshow(img_vis)
            axes[i, 0].set_title(f'Original Image\n{img_id}')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(mask_vis, cmap='gray')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred_binary, cmap='gray')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, 'validation_predictions.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Visualize some validation predictions
if os.path.exists(best_model_path):
    visualize_predictions(
        config['val_img_dir'], 
        config['val_mask_dir'], 
        best_model_path, 
        config, 
        num_samples=3
    )
else:
    print("Best model not found for visualization.")

## Summary and Next Steps

The training is now complete! Here's what was accomplished:

1. **Model Training**: Trained U-KAN on your custom dataset
2. **Validation**: Monitored performance on validation set
3. **Best Model**: Saved the best performing model based on IoU score
4. **Test Prediction**: Generated predictions for test images
5. **Visualization**: Created training curves and sample predictions

### Output Files:
- `best_model.pth`: Best performing model weights
- `final_model.pth`: Final model weights
- `training_history.csv`: Training metrics log
- `training_curves.png`: Training progress visualization
- `test_predictions/`: Folder containing test predictions
- `validation_predictions.png`: Sample validation results
- `config.yaml`: Training configuration

### To use the trained model:
1. Load the best model using `torch.load('best_model.pth')`
2. Use the prediction function to segment new images
3. Adjust threshold (0.5) based on your specific needs

### Tips for better performance:
- Increase training epochs if loss is still decreasing
- Experiment with different learning rates
- Try data augmentation techniques
- Use larger input resolution if computationally feasible
- Consider ensemble methods for final predictions