# EffiSegNet Training with EfficientNet V2-S

This notebook demonstrates training the improved EffiSegNet model using EfficientNet V2-S as the backbone for medical image segmentation.

## Dataset Structure
- Train/Image/ (.jpg files)
- Train/Mask/ (.png files, binary 0-255)
- Val/Image/ (.jpg files) 
- Val/Mask/ (.png files, binary 0-255)
- Test/Image/ (.jpg files, no masks)

In [None]:
# Import required libraries
import os
import torch
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2
from pathlib import Path

# Import custom modules
from models.effisegnet import EffiSegNetBN
from datamodule import CustomSegDataset
from network_module import Net

# Set random seed for reproducibility
L.seed_everything(42, workers=True)
torch.set_float32_matmul_precision("medium")

print("Libraries imported successfully!")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")

In [None]:
# Configuration parameters
CONFIG = {
    'model_name': 'efficientnet_v2_s',
    'img_size': (384, 384),
    'batch_size': 4,  # Reduced batch size for V2-S
    'learning_rate': 1e-4,
    'max_epochs': 100,
    'num_workers': 2,  # Reduced for stability
    'ch': 64,
    'pretrained': True,
    'freeze_encoder': False,
    'deep_supervision': False,
    'data_root': './data'  # Update this path to your data directory
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
# Initialize the improved EffiSegNet model with EfficientNet V2-S
model = EffiSegNetBN(
    ch=CONFIG['ch'],
    pretrained=CONFIG['pretrained'],
    freeze_encoder=CONFIG['freeze_encoder'],
    deep_supervision=CONFIG['deep_supervision'],
    model_name=CONFIG['model_name']
)

print(f"Model initialized: {CONFIG['model_name']}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Initialize dataset
dataset = CustomSegDataset(
    batch_size=CONFIG['batch_size'],
    root_dir=CONFIG['data_root'],
    num_workers=CONFIG['num_workers'],
    img_size=CONFIG['img_size']
)

# Setup dataset
dataset.setup()

print(f"Dataset initialized:")
print(f"  Train samples: {len(dataset.train_set)}")
print(f"  Validation samples: {len(dataset.val_set)}")
print(f"  Test samples: {len(dataset.test_set)}")

In [None]:
# Visualize sample data
def visualize_samples(dataset, num_samples=4):
    train_loader = dataset.train_dataloader()
    batch = next(iter(train_loader))
    images, masks = batch
    
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 8))
    
    for i in range(min(num_samples, len(images))):
        # Denormalize image for visualization
        img = images[i].permute(1, 2, 0).cpu().numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        mask = masks[i][0].cpu().numpy()
        
        axes[0, i].imshow(img)
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(mask, cmap='gray')
        axes[1, i].set_title(f'Mask {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Only visualize if we have training data
if len(dataset.train_set) > 0:
    visualize_samples(dataset)
else:
    print("No training data found. Please check your data directory structure.")

In [None]:
# Initialize training components
from monai.losses import DiceCELoss
from lightning.pytorch import loggers
import torch.optim as optim

# Loss function
criterion = DiceCELoss(include_background=False, sigmoid=True)

# Create scheduler class that will be instantiated later
class CosineScheduler:
    def __init__(self, T_max):
        self.T_max = T_max
    
    def __call__(self, optimizer):
        return optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=self.T_max, 
            eta_min=1e-6
        )

scheduler_class = CosineScheduler(CONFIG['max_epochs'])

# Initialize network module
net = Net(
    model=model,
    criterion=criterion,
    optimizer=optim.AdamW,
    lr=CONFIG['learning_rate'],
    scheduler=scheduler_class
)

# Logger
logger = loggers.TensorBoardLogger("logs/", name=f"{CONFIG['model_name']}_training")

print("Training components initialized!")

In [None]:
# Initialize trainer
trainer = L.Trainer(
    accelerator='auto',
    max_epochs=CONFIG['max_epochs'],
    log_every_n_steps=10,
    logger=logger,
    deterministic='warn',
    enable_checkpointing=True,
    default_root_dir='./checkpoints',
    gradient_clip_val=1.0,  # Add gradient clipping for stability
    accumulate_grad_batches=1  # Gradient accumulation if needed
)

print("Trainer initialized!")
print(f"Training will run for {CONFIG['max_epochs']} epochs")

In [None]:
# Start training
print("Starting training...")
try:
    trainer.fit(net, dataset)
    print("Training completed successfully!")
except Exception as e:
    print(f"Training failed with error: {e}")
    print("Please check your data directory structure and paths.")

In [None]:
# Evaluate on validation set
print("Running validation...")
val_results = trainer.validate(net, dataset)
print("Validation results:", val_results)

In [None]:
# Test the model (if test set has ground truth)
# Uncomment if your test set has ground truth masks
# print("Running test...")
# test_results = trainer.test(net, dataset)
# print("Test results:", test_results)

In [None]:
# Prediction on test set and save results
import torch.nn.functional as F

def predict_and_save_test_results(model, dataset, output_dir="./predictions"):
    """
    Predict on test set and save mask predictions
    """
    os.makedirs(output_dir, exist_ok=True)
    
    model.eval()
    test_loader = dataset.test_dataloader()
    
    with torch.no_grad():
        for batch_idx, (images, image_paths) in enumerate(test_loader):
            if torch.cuda.is_available():
                images = images.cuda()
                model = model.cuda()
            
            # Forward pass
            if hasattr(model.model, 'deep_supervision') and model.model.deep_supervision:
                logits, _ = model(images)
            else:
                logits = model(images)
            
            # Convert to probabilities and then to binary masks
            probs = torch.sigmoid(logits)
            masks = (probs > 0.5).float()
            
            # Save each prediction
            for i, img_path in enumerate(image_paths):
                # Get original filename without extension
                filename = Path(img_path).stem
                
                # Resize mask back to original size if needed
                mask = masks[i, 0].cpu().numpy()
                mask = (mask * 255).astype(np.uint8)
                
                # Save mask
                save_path = os.path.join(output_dir, f"{filename}_pred.png")
                cv2.imwrite(save_path, mask)
                
                if batch_idx == 0 and i < 4:  # Visualize first few predictions
                    plt.figure(figsize=(10, 5))
                    
                    # Original image
                    img = images[i].permute(1, 2, 0).cpu().numpy()
                    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                    img = np.clip(img, 0, 1)
                    
                    plt.subplot(1, 2, 1)
                    plt.imshow(img)
                    plt.title(f'Original Image: {filename}')
                    plt.axis('off')
                    
                    plt.subplot(1, 2, 2)
                    plt.imshow(mask, cmap='gray')
                    plt.title(f'Predicted Mask: {filename}')
                    plt.axis('off')
                    
                    plt.tight_layout()
                    plt.show()
    
    print(f"Predictions saved to: {output_dir}")

# Run prediction
predict_and_save_test_results(net, dataset)

## Training Summary

The training process is now complete! Here's what was accomplished:

1. **Model Architecture**: Used improved EffiSegNet with EfficientNet V2-S backbone
2. **Training**: Trained on your custom dataset with Train/Val splits
3. **Evaluation**: Validated the model performance
4. **Prediction**: Generated predictions for test images and saved them to `./predictions/` folder

### Next Steps:
- Check the `./predictions/` folder for your test set predictions
- Review training logs in TensorBoard: `tensorboard --logdir logs/`
- Analyze model performance metrics
- Fine-tune hyperparameters if needed

### Model Performance Metrics:
- Dice Score: Measures overlap between predicted and ground truth masks
- IoU (Intersection over Union): Another overlap metric
- Precision/Recall: Classification metrics for segmentation
- F1 Score: Harmonic mean of precision and recall