# Advanced Melanoma Segmentation with Self-Supervised Learning + Mask R-CNN

This notebook implements a state-of-the-art deep learning pipeline for melanoma lesion segmentation using:

1. **Self-Supervised Learning (SSL)**: Pre-train a Vision Transformer (ViT) using Masked Autoencoder (MAE) approach
2. **Mask R-CNN Fine-tuning**: Use the SSL backbone with Detectron2 for instance segmentation
3. **Advanced Augmentation**: Albumentations pipeline for robust training

## Pipeline Overview
- Phase 1: Environment Setup & Dependencies
- Phase 2: SSL Pre-training with MAE
- Phase 3: Mask R-CNN Fine-tuning with Custom ViT Backbone
- Phase 4: Evaluation & Model Export


## Phase 1: Environment Setup & Imports


In [None]:
# Core imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

# Detectron2 imports
import detectron2
from detectron2.utils.logger import setup_logger
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data import build_detection_train_loader, build_detection_test_loader
from detectron2.data import detection_utils as utils
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer

# Hugging Face Transformers for MAE
from transformers import ViTMAEForPreTraining, ViTMAEConfig, ViTFeatureExtractor

# Additional libraries
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import json
from pathlib import Path
import random
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Setup logging
setup_logger()

print(f"PyTorch version: {torch.__version__}")
print(f"Detectron2 version: {detectron2.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Configuration and Paths
class Config:
    # Data paths
    DATA_ROOT = Path("data")
    UNLABELED_IMAGES_PATH = DATA_ROOT / "images"  # For SSL pre-training
    TRAIN_IMAGES_PATH = DATA_ROOT / "train" / "images"  # For supervised training
    TRAIN_MASKS_PATH = DATA_ROOT / "train" / "masks"
    VAL_IMAGES_PATH = DATA_ROOT / "val" / "images"
    VAL_MASKS_PATH = DATA_ROOT / "val" / "masks"
    
    # Model paths
    MODELS_PATH = Path("models")
    SSL_BACKBONE_PATH = MODELS_PATH / "ssl_vit_backbone.pth"
    FINAL_MODEL_PATH = MODELS_PATH / "final_lesion_segmenter.pth"
    CONFIG_PATH = MODELS_PATH / "config.yaml"
    
    # Training parameters
    IMAGE_SIZE = (224, 224)  # For ViT
    BATCH_SIZE_SSL = 32  # For MAE pre-training
    BATCH_SIZE_DETECTRON = 4  # For Mask R-CNN training
    NUM_EPOCHS_SSL = 100
    LEARNING_RATE_SSL = 1.5e-4
    LEARNING_RATE_DETECTRON = 1e-4
    MAX_ITER_DETECTRON = 5000
    
    # MAE parameters
    MASK_RATIO = 0.75  # 75% of patches masked
    PATCH_SIZE = 16
    
    # Device
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def __post_init__(self):
        # Create directories if they don't exist
        self.MODELS_PATH.mkdir(exist_ok=True)
        self.DATA_ROOT.mkdir(exist_ok=True)
        for path in [self.UNLABELED_IMAGES_PATH, self.TRAIN_IMAGES_PATH, 
                     self.TRAIN_MASKS_PATH, self.VAL_IMAGES_PATH, self.VAL_MASKS_PATH]:
            path.mkdir(parents=True, exist_ok=True)

config = Config()
print(f"Configuration loaded. Device: {config.DEVICE}")
print(f"Models will be saved to: {config.MODELS_PATH}")
print(f"Data should be placed in: {config.DATA_ROOT}")


## Phase 2: Self-Supervised Learning with MAE

### Dataset for Unlabeled Images


In [None]:
class UnlabeledSkinDataset(Dataset):
    """
    Dataset for unlabeled skin images used in SSL pre-training.
    Loads all images from a directory and applies standard ViT transforms.
    """
    
    def __init__(self, data_dir, image_size=(224, 224)):
        self.data_dir = Path(data_dir)
        self.image_size = image_size
        
        # Find all image files
        self.image_paths = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
            self.image_paths.extend(list(self.data_dir.rglob(ext)))
            self.image_paths.extend(list(self.data_dir.rglob(ext.upper())))
        
        if not self.image_paths:
            raise ValueError(f"No images found in {data_dir}")
        
        print(f"Found {len(self.image_paths)} images for SSL training")
        
        # ViT preprocessing transforms
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])  # ImageNet normalization
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        try:
            # Load image
            image = Image.open(image_path).convert('RGB')
            
            # Apply transforms
            image_tensor = self.transform(image)
            
            return image_tensor
            
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a black image as fallback
            return torch.zeros(3, *self.image_size)

# Example usage (will be used later when data is available)
# ssl_dataset = UnlabeledSkinDataset(config.UNLABELED_IMAGES_PATH)
# ssl_dataloader = DataLoader(ssl_dataset, batch_size=config.BATCH_SIZE_SSL, 
#                           shuffle=True, num_workers=4, pin_memory=True)

print("UnlabeledSkinDataset class defined successfully")


### MAE Pre-training Implementation


In [None]:
class MAETrainer:
    """
    Trainer class for Masked Autoencoder (MAE) pre-training.
    Uses HuggingFace ViTMAEForPreTraining for efficient implementation.
    """
    
    def __init__(self, config):
        self.config = config
        self.device = config.DEVICE
        
        # Initialize MAE model
        self.model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')
        self.model.to(self.device)
        
        # Optimizer and scheduler
        self.optimizer = optim.AdamW(self.model.parameters(), lr=config.LEARNING_RATE_SSL, 
                                   weight_decay=0.05)
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=config.NUM_EPOCHS_SSL
        )
        
        # Loss tracking
        self.train_losses = []
        
    def train_epoch(self, dataloader, epoch):
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx, images in enumerate(dataloader):
            images = images.to(self.device)
            
            # Forward pass through MAE
            outputs = self.model(images)
            loss = outputs.loss
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
            
            # Log progress
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, '
                      f'Loss: {loss.item():.4f}')
        
        avg_loss = epoch_loss / num_batches
        self.train_losses.append(avg_loss)
        
        return avg_loss
    
    def visualize_reconstruction(self, dataloader, epoch, num_samples=4):
        """Visualize MAE reconstruction results"""
        self.model.eval()
        
        with torch.no_grad():
            # Get a batch of images
            images = next(iter(dataloader))[:num_samples].to(self.device)
            
            # Get MAE outputs
            outputs = self.model(images)
            
            # Denormalize images for visualization
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
            std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
            
            original = images * std + mean
            reconstructed = outputs.logits * std + mean
            
            # Clip values to [0, 1]
            original = torch.clamp(original, 0, 1)
            reconstructed = torch.clamp(reconstructed, 0, 1)
            
            # Create visualization
            fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
            
            for i in range(num_samples):
                # Original images
                axes[0, i].imshow(original[i].cpu().permute(1, 2, 0))
                axes[0, i].set_title(f'Original {i+1}')
                axes[0, i].axis('off')
                
                # Reconstructed images
                axes[1, i].imshow(reconstructed[i].cpu().permute(1, 2, 0))
                axes[1, i].set_title(f'Reconstructed {i+1}')
                axes[1, i].axis('off')
            
            plt.suptitle(f'MAE Reconstruction - Epoch {epoch}')
            plt.tight_layout()
            plt.show()
    
    def train(self, dataloader):
        """Main training loop"""
        print(f"Starting MAE pre-training for {self.config.NUM_EPOCHS_SSL} epochs...")
        
        for epoch in range(self.config.NUM_EPOCHS_SSL):
            print(f"\nEpoch {epoch+1}/{self.config.NUM_EPOCHS_SSL}")
            
            # Train for one epoch
            avg_loss = self.train_epoch(dataloader, epoch+1)
            
            # Update learning rate
            self.scheduler.step()
            
            # Visualize reconstructions every 20 epochs
            if (epoch + 1) % 20 == 0:
                self.visualize_reconstruction(dataloader, epoch+1)
            
            print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
        
        print("MAE pre-training completed!")
    
    def save_backbone(self):
        """Save only the ViT encoder weights (backbone)"""
        print("Saving SSL backbone weights...")
        
        # Extract encoder weights
        encoder_state_dict = {}
        for name, param in self.model.named_parameters():
            if not name.startswith('decoder'):  # Only save encoder weights
                encoder_state_dict[name] = param.cpu()
        
        # Save backbone weights
        torch.save(encoder_state_dict, self.config.SSL_BACKBONE_PATH)
        
        # Also save the full model config for reference
        config_dict = {
            'model_name': 'facebook/vit-mae-base',
            'image_size': self.config.IMAGE_SIZE,
            'patch_size': 16,
            'hidden_size': 768,
            'num_hidden_layers': 12,
            'num_attention_heads': 12,
            'intermediate_size': 3072,
            'hidden_dropout_prob': 0.0,
            'attention_probs_dropout_prob': 0.0,
            'initializer_range': 0.02,
            'layer_norm_eps': 1e-12,
        }
        
        with open(self.config.MODELS_PATH / "ssl_config.json", 'w') as f:
            json.dump(config_dict, f, indent=2)
        
        print(f"SSL backbone saved to {self.config.SSL_BACKBONE_PATH}")

print("MAETrainer class defined successfully")


In [None]:
# SSL Training Execution
def run_ssl_training():
    """
    Execute SSL pre-training with MAE.
    This function will be called when unlabeled data is available.
    """
    
    # Check if SSL backbone already exists
    if config.SSL_BACKBONE_PATH.exists():
        print(f"SSL backbone already exists at {config.SSL_BACKBONE_PATH}")
        print("Skipping SSL training. To retrain, delete the existing file first.")
        return
    
    # Check if unlabeled data is available
    if not config.UNLABELED_IMAGES_PATH.exists() or not any(config.UNLABELED_IMAGES_PATH.iterdir()):
        print(f"No unlabeled images found in {config.UNLABELED_IMAGES_PATH}")
        print("Please place unlabeled skin images in the data/images/ directory")
        return
    
    try:
        # Create dataset and dataloader
        print("Creating unlabeled dataset...")
        ssl_dataset = UnlabeledSkinDataset(config.UNLABELED_IMAGES_PATH)
        ssl_dataloader = DataLoader(
            ssl_dataset, 
            batch_size=config.BATCH_SIZE_SSL, 
            shuffle=True, 
            num_workers=2,  # Reduced for Windows compatibility
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        # Initialize trainer
        print("Initializing MAE trainer...")
        trainer = MAETrainer(config)
        
        # Start training
        trainer.train(ssl_dataloader)
        
        # Save backbone weights
        trainer.save_backbone()
        
        # Plot training losses
        plt.figure(figsize=(10, 6))
        plt.plot(trainer.train_losses)
        plt.title('MAE Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.show()
        
        print("SSL pre-training completed successfully!")
        
    except Exception as e:
        print(f"Error during SSL training: {e}")
        print("Make sure you have unlabeled images in the data/images/ directory")

# Uncomment the line below to run SSL training when data is available
# run_ssl_training()

print("SSL training function defined. Call run_ssl_training() when data is ready.")


## Phase 3: Mask R-CNN Fine-tuning with Custom ViT Backbone

### Dataset for Labeled Images (Detectron2 Format)


In [None]:
def get_melanoma_dicts(img_dir, mask_dir):
    """
    Load dataset in Detectron2 format.
    Converts binary masks to COCO format with bounding boxes and segmentation polygons.
    """
    dataset_dicts = []
    
    # Get all image files
    img_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
        img_files.extend(list(Path(img_dir).glob(ext)))
        img_files.extend(list(Path(img_dir).glob(ext.upper())))
    
    for idx, img_path in enumerate(img_files):
        record = {}
        
        # Image information
        record["file_name"] = str(img_path)
        record["image_id"] = idx
        
        # Load image to get dimensions
        image = cv2.imread(str(img_path))
        if image is None:
            continue
            
        height, width = image.shape[:2]
        record["height"] = height
        record["width"] = width
        
        # Load corresponding mask
        mask_path = Path(mask_dir) / (img_path.stem + "_mask" + img_path.suffix)
        if not mask_path.exists():
            # Try alternative naming conventions
            for alt_ext in ['.png', '.jpg', '.jpeg']:
                alt_mask_path = Path(mask_dir) / (img_path.stem + "_mask" + alt_ext)
                if alt_mask_path.exists():
                    mask_path = alt_mask_path
                    break
            else:
                print(f"Warning: No mask found for {img_path}")
                continue
        
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            print(f"Warning: Could not load mask {mask_path}")
            continue
        
        # Find contours in mask
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        objs = []
        for contour in contours:
            # Skip very small contours
            area = cv2.contourArea(contour)
            if area < 100:  # Minimum area threshold
                continue
            
            # Get bounding box
            x, y, w, h = cv2.boundingRect(contour)
            
            # Convert contour to polygon format
            polygon = contour.flatten().tolist()
            
            obj = {
                "bbox": [x, y, x + w, y + h],  # [x1, y1, x2, y2] format
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [polygon],
                "category_id": 0,  # Single class: lesion
                "area": area
            }
            objs.append(obj)
        
        record["annotations"] = objs
        dataset_dicts.append(record)
    
    print(f"Loaded {len(dataset_dicts)} images from {img_dir}")
    return dataset_dicts

def register_datasets():
    """Register datasets with Detectron2"""
    
    # Register training dataset
    DatasetCatalog.register("melanoma_train", 
                          lambda: get_melanoma_dicts(str(config.TRAIN_IMAGES_PATH), 
                                                   str(config.TRAIN_MASKS_PATH)))
    MetadataCatalog.get("melanoma_train").set(thing_classes=["lesion"])
    
    # Register validation dataset
    DatasetCatalog.register("melanoma_val", 
                          lambda: get_melanoma_dicts(str(config.VAL_IMAGES_PATH), 
                                                   str(config.VAL_MASKS_PATH)))
    MetadataCatalog.get("melanoma_val").set(thing_classes=["lesion"])
    
    print("Datasets registered with Detectron2")

# Register datasets
register_datasets()

# Test dataset loading
try:
    train_dicts = get_melanoma_dicts(str(config.TRAIN_IMAGES_PATH), str(config.TRAIN_MASKS_PATH))
    val_dicts = get_melanoma_dicts(str(config.VAL_IMAGES_PATH), str(config.VAL_MASKS_PATH))
    
    print(f"Training samples: {len(train_dicts)}")
    print(f"Validation samples: {len(val_dicts)}")
    
    if train_dicts:
        print(f"Sample training image: {train_dicts[0]['file_name']}")
        print(f"Number of annotations: {len(train_dicts[0]['annotations'])}")
        
except Exception as e:
    print(f"Dataset loading test failed: {e}")
    print("Make sure you have labeled images and masks in data/train/ and data/val/ directories")


In [None]:
# Albumentations augmentation pipeline
def get_augmentation_pipeline():
    """
    Create robust augmentation pipeline for medical images.
    Includes geometric, photometric, and elastic transformations.
    """
    
    # Training augmentations
    train_transform = A.Compose([
        # Geometric transformations
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.1, 
            scale_limit=0.2, 
            rotate_limit=45, 
            p=0.7,
            border_mode=cv2.BORDER_CONSTANT,
            value=0
        ),
        A.RandomResizedCrop(
            height=512, 
            width=512, 
            scale=(0.8, 1.0), 
            ratio=(0.8, 1.2), 
            p=0.5
        ),
        
        # Photometric transformations
        A.RandomBrightnessContrast(
            brightness_limit=0.2, 
            contrast_limit=0.2, 
            p=0.5
        ),
        A.HueSaturationValue(
            hue_shift_limit=20, 
            sat_shift_limit=30, 
            val_shift_limit=20, 
            p=0.3
        ),
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        
        # Advanced augmentations
        A.ElasticTransform(
            alpha=1, 
            sigma=50, 
            alpha_affine=50, 
            p=0.3,
            border_mode=cv2.BORDER_CONSTANT,
            value=0
        ),
        A.GridDistortion(
            num_steps=5, 
            distort_limit=0.1, 
            p=0.2,
            border_mode=cv2.BORDER_CONSTANT,
            value=0
        ),
        A.OpticalDistortion(
            distort_limit=0.1, 
            shift_limit=0.05, 
            p=0.2,
            border_mode=cv2.BORDER_CONSTANT,
            value=0
        ),
        
        # Noise and blur
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
        A.GaussianBlur(blur_limit=(3, 7), p=0.2),
        A.MotionBlur(blur_limit=7, p=0.2),
        
        # Cutout-style augmentation
        A.CoarseDropout(
            max_holes=8, 
            max_height=32, 
            max_width=32, 
            min_holes=1,
            min_height=8,
            min_width=8,
            fill_value=0,
            p=0.3
        ),
        
        # Normalization
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
    
    # Validation transforms (minimal)
    val_transform = A.Compose([
        A.Resize(height=512, width=512),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
    
    return train_transform, val_transform

# Test augmentation pipeline
def visualize_augmentations(image_path, mask_path, num_samples=4):
    """Visualize augmentation effects on a sample image"""
    
    # Load image and mask
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    # Get augmentation pipeline
    train_transform, _ = get_augmentation_pipeline()
    
    # Create figure
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 8))
    
    # Original
    axes[0, 0].imshow(image)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(mask, cmap='gray')
    axes[1, 0].set_title('Original Mask')
    axes[1, 0].axis('off')
    
    # Augmented samples
    for i in range(1, num_samples):
        try:
            # Apply augmentation
            augmented = train_transform(
                image=image, 
                mask=mask,
                bboxes=[[0, 0, image.shape[1], image.shape[0]]],  # Dummy bbox
                class_labels=[0]  # Dummy class
            )
            
            aug_image = augmented['image'].permute(1, 2, 0).numpy()
            aug_mask = augmented['mask'].numpy()
            
            # Denormalize image for display
            aug_image = aug_image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            aug_image = np.clip(aug_image, 0, 1)
            
            axes[0, i].imshow(aug_image)
            axes[0, i].set_title(f'Augmented {i}')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(aug_mask, cmap='gray')
            axes[1, i].set_title(f'Augmented Mask {i}')
            axes[1, i].axis('off')
            
        except Exception as e:
            print(f"Error in augmentation {i}: {e}")
            axes[0, i].text(0.5, 0.5, 'Error', ha='center', va='center')
            axes[1, i].text(0.5, 0.5, 'Error', ha='center', va='center')
    
    plt.tight_layout()
    plt.show()

print("Augmentation pipeline defined successfully")
print("Use visualize_augmentations(image_path, mask_path) to test augmentations")


In [None]:
import fvcore.nn.weight_init as weight_init
from detectron2.layers import Conv2d, ShapeSpec
from detectron2.modeling.backbone import Backbone
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY

class ViTFeatureExtractor(nn.Module):
    """
    Extract multi-scale features from ViT encoder.
    Converts 1D patch tokens back to 2D feature maps at different scales.
    """
    
    def __init__(self, vit_model, feature_layers=[3, 6, 9, 12]):
        super().__init__()
        self.vit_model = vit_model
        self.feature_layers = feature_layers
        self.patch_size = vit_model.config.patch_size
        self.hidden_size = vit_model.config.hidden_size
        
        # Freeze ViT parameters
        for param in self.vit_model.parameters():
            param.requires_grad = False
        
        # Feature projection layers to match Detectron2's expected channels
        self.feature_projections = nn.ModuleList([
            nn.Conv2d(self.hidden_size, 256, 1) for _ in range(len(feature_layers))
        ])
        
        # Initialize projection layers
        for proj in self.feature_projections:
            weight_init.c2_msra_fill(proj)
    
    def forward(self, x):
        """
        Extract multi-scale features from ViT.
        Args:
            x: Input tensor of shape (B, C, H, W)
        Returns:
            List of feature maps at different scales
        """
        B, C, H, W = x.shape
        
        # Reshape to patches and add positional embedding
        patches = self.vit_model.embeddings.patch_embeddings(x)  # (B, num_patches, hidden_size)
        
        # Add CLS token and positional embeddings
        cls_tokens = self.vit_model.embeddings.cls_token.expand(B, -1, -1)
        embeddings = torch.cat((cls_tokens, patches), dim=1)
        embeddings = embeddings + self.vit_model.embeddings.position_embeddings
        
        # Apply transformer blocks and extract features at specified layers
        hidden_states = embeddings
        features = []
        
        for i, block in enumerate(self.vit_model.encoder.layer):
            hidden_states = block(hidden_states)[0]
            
            # Extract features at specified layers (excluding CLS token)
            if i + 1 in self.feature_layers:
                # Remove CLS token and reshape to 2D
                patch_features = hidden_states[:, 1:]  # Remove CLS token
                
                # Calculate patch grid size
                num_patches_per_dim = int(np.sqrt(patch_features.shape[1]))
                patch_features = patch_features.reshape(
                    B, num_patches_per_dim, num_patches_per_dim, self.hidden_size
                )
                patch_features = patch_features.permute(0, 3, 1, 2)  # (B, C, H, W)
                
                # Project to expected channel size
                projected_features = self.feature_projections[self.feature_layers.index(i + 1)](patch_features)
                features.append(projected_features)
        
        return features

class ViTBackbone(Backbone):
    """
    ViT backbone for Detectron2.
    Provides multi-scale features compatible with FPN.
    """
    
    def __init__(self, vit_model, feature_layers=[3, 6, 9, 12]):
        super().__init__()
        self.feature_extractor = ViTFeatureExtractor(vit_model, feature_layers)
        
        # Define feature output shapes
        # Assuming input size of 224x224 and patch size of 16
        self._out_feature_channels = {
            "res3": 256,  # 14x14 features
            "res4": 256,  # 14x14 features  
            "res5": 256,  # 14x14 features
            "res6": 256,  # 14x14 features
        }
        
        self._out_feature_strides = {
            "res3": 16,  # 224/14 = 16
            "res4": 16,
            "res5": 16, 
            "res6": 16,
        }
    
    def forward(self, x):
        """
        Forward pass through ViT backbone.
        Args:
            x: Input tensor of shape (B, C, H, W)
        Returns:
            Dict of feature maps at different scales
        """
        features = self.feature_extractor(x)
        
        return {
            "res3": features[0],
            "res4": features[1], 
            "res5": features[2],
            "res6": features[3],
        }
    
    def output_shape(self):
        return {
            name: ShapeSpec(
                channels=self._out_feature_channels[name], 
                stride=self._out_feature_strides[name]
            )
            for name in self._out_feature_channels.keys()
        }

@BACKBONE_REGISTRY.register()
def build_vit_backbone(cfg, input_shape: ShapeSpec):
    """
    Build ViT backbone for Detectron2.
    This function is called by Detectron2's model builder.
    """
    
    # Check if SSL backbone exists
    ssl_path = config.SSL_BACKBONE_PATH
    if not ssl_path.exists():
        print(f"Warning: SSL backbone not found at {ssl_path}")
        print("Using ImageNet pre-trained ViT instead")
        
        # Use ImageNet pre-trained ViT
        vit_model = timm.create_model('vit_base_patch16_224', pretrained=True)
        
        # Convert to HuggingFace format for consistency
        from transformers import ViTModel, ViTConfig
        
        # Create ViT config
        vit_config = ViTConfig(
            image_size=224,
            patch_size=16,
            num_channels=3,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
        )
        
        # Create model and load weights
        vit_model_hf = ViTModel(vit_config)
        
        # Map timm weights to HuggingFace format
        timm_state = vit_model.state_dict()
        hf_state = {}
        
        # Map embedding layer
        hf_state['embeddings.patch_embeddings.projection.weight'] = timm_state['patch_embed.proj.weight']
        hf_state['embeddings.patch_embeddings.projection.bias'] = timm_state['patch_embed.proj.bias']
        hf_state['embeddings.cls_token'] = timm_state['cls_token']
        hf_state['embeddings.position_embeddings'] = timm_state['pos_embed'][:, 1:]  # Remove class token
        
        # Map transformer layers
        for i in range(12):
            prefix = f'encoder.layer.{i}'
            timm_prefix = f'blocks.{i}'
            
            # Attention
            hf_state[f'{prefix}.attention.attention.query.weight'] = timm_state[f'{timm_prefix}.attn.qkv.weight'][:768]
            hf_state[f'{prefix}.attention.attention.query.bias'] = timm_state[f'{timm_prefix}.attn.qkv.bias'][:768]
            hf_state[f'{prefix}.attention.attention.key.weight'] = timm_state[f'{timm_prefix}.attn.qkv.weight'][768:1536]
            hf_state[f'{prefix}.attention.attention.key.bias'] = timm_state[f'{timm_prefix}.attn.qkv.bias'][768:1536]
            hf_state[f'{prefix}.attention.attention.value.weight'] = timm_state[f'{timm_prefix}.attn.qkv.weight'][1536:]
            hf_state[f'{prefix}.attention.attention.value.bias'] = timm_state[f'{timm_prefix}.attn.qkv.bias'][1536:]
            hf_state[f'{prefix}.attention.output.dense.weight'] = timm_state[f'{timm_prefix}.attn.proj.weight']
            hf_state[f'{prefix}.attention.output.dense.bias'] = timm_state[f'{timm_prefix}.attn.proj.bias']
            
            # Layer norm
            hf_state[f'{prefix}.layernorm_before.weight'] = timm_state[f'{timm_prefix}.norm1.weight']
            hf_state[f'{prefix}.layernorm_before.bias'] = timm_state[f'{timm_prefix}.norm1.bias']
            hf_state[f'{prefix}.layernorm_after.weight'] = timm_state[f'{timm_prefix}.norm2.weight']
            hf_state[f'{prefix}.layernorm_after.bias'] = timm_state[f'{timm_prefix}.norm2.bias']
            
            # MLP
            hf_state[f'{prefix}.intermediate.dense.weight'] = timm_state[f'{timm_prefix}.mlp.fc1.weight']
            hf_state[f'{prefix}.intermediate.dense.bias'] = timm_state[f'{timm_prefix}.mlp.fc1.bias']
            hf_state[f'{prefix}.output.dense.weight'] = timm_state[f'{timm_prefix}.mlp.fc2.weight']
            hf_state[f'{prefix}.output.dense.bias'] = timm_state[f'{timm_prefix}.mlp.fc2.bias']
        
        # Load mapped weights
        vit_model_hf.load_state_dict(hf_state)
        vit_model = vit_model_hf
        
    else:
        print(f"Loading SSL backbone from {ssl_path}")
        
        # Load SSL backbone
        from transformers import ViTModel, ViTConfig
        
        # Load config
        with open(config.MODELS_PATH / "ssl_config.json", 'r') as f:
            config_dict = json.load(f)
        
        # Create ViT model
        vit_config = ViTConfig(**config_dict)
        vit_model = ViTModel(vit_config)
        
        # Load SSL weights
        ssl_weights = torch.load(ssl_path, map_location='cpu')
        vit_model.load_state_dict(ssl_weights)
        
        print("SSL backbone loaded successfully!")
    
    # Create backbone
    backbone = ViTBackbone(vit_model)
    return backbone

print("Custom ViT backbone implementation completed")
print("Backbone registered with Detectron2 as 'build_vit_backbone'")


In [None]:
def setup_detectron2_config():
    """
    Configure Detectron2 for Mask R-CNN training with ViT backbone.
    """
    
    # Get default config
    cfg = get_cfg()
    
    # Start with a standard Mask R-CNN config
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    
    # Replace backbone with our custom ViT backbone
    cfg.MODEL.BACKBONE.NAME = "build_vit_backbone"
    
    # Input configuration
    cfg.INPUT.MIN_SIZE_TRAIN = (512, 640, 704, 768)  # Multi-scale training
    cfg.INPUT.MAX_SIZE_TRAIN = 1024
    cfg.INPUT.MIN_SIZE_TEST = 512
    cfg.INPUT.MAX_SIZE_TEST = 1024
    
    # Pixel normalization (ImageNet stats for ViT)
    cfg.MODEL.PIXEL_MEAN = [123.675, 116.28, 103.53]
    cfg.MODEL.PIXEL_STD = [58.395, 57.12, 57.375]
    
    # Model configuration
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # Single class: lesion
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
    
    # RPN configuration
    cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 2000
    cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN = 1000
    cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 1000
    cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
    
    # Training configuration
    cfg.DATASETS.TRAIN = ("melanoma_train",)
    cfg.DATASETS.TEST = ("melanoma_val",)
    
    cfg.DATALOADER.NUM_WORKERS = 2  # Reduced for Windows compatibility
    cfg.DATALOADER.ASPECT_RATIO_GROUPING = True
    
    # Solver configuration
    cfg.SOLVER.IMS_PER_BATCH = config.BATCH_SIZE_DETECTRON
    cfg.SOLVER.BASE_LR = config.LEARNING_RATE_DETECTRON
    cfg.SOLVER.MAX_ITER = config.MAX_ITER_DETECTRON
    cfg.SOLVER.STEPS = (3000, 4500)  # Learning rate decay steps
    cfg.SOLVER.GAMMA = 0.1
    cfg.SOLVER.WARMUP_ITERS = 500
    cfg.SOLVER.WARMUP_FACTOR = 1.0 / 1000
    cfg.SOLVER.WEIGHT_DECAY = 0.0001
    cfg.SOLVER.CHECKPOINT_PERIOD = 500
    
    # Testing configuration
    cfg.TEST.EVAL_PERIOD = 500
    
    # Output configuration
    cfg.OUTPUT_DIR = str(config.MODELS_PATH / "detectron2_output")
    
    # Create output directory
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    
    return cfg

# Setup configuration
cfg = setup_detectron2_config()

print("Detectron2 configuration setup completed")
print(f"Output directory: {cfg.OUTPUT_DIR}")
print(f"Training iterations: {cfg.SOLVER.MAX_ITER}")
print(f"Batch size: {cfg.SOLVER.IMS_PER_BATCH}")
print(f"Learning rate: {cfg.SOLVER.BASE_LR}")


### Custom Trainer with Augmentation Integration


In [None]:
class AlbumentationsMapper:
    """
    Custom mapper that applies Albumentations augmentations to Detectron2 dataset.
    """
    
    def __init__(self, augmentation_pipeline, is_train=True):
        self.augmentation_pipeline = augmentation_pipeline
        self.is_train = is_train
    
    def __call__(self, dataset_dict):
        """
        Apply augmentations to a single dataset item.
        Args:
            dataset_dict: Dataset item in Detectron2 format
        Returns:
            Augmented dataset item
        """
        dataset_dict = dataset_dict.copy()
        
        # Load image
        image = utils.read_image(dataset_dict["file_name"], format="BGR")
        
        # Convert BGR to RGB for albumentations
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Prepare masks and bboxes for augmentation
        masks = []
        bboxes = []
        class_labels = []
        
        for anno in dataset_dict["annotations"]:
            # Create mask from segmentation
            mask = np.zeros((dataset_dict["height"], dataset_dict["width"]), dtype=np.uint8)
            if "segmentation" in anno:
                for seg in anno["segmentation"]:
                    if len(seg) >= 6:  # At least 3 points for a polygon
                        pts = np.array(seg).reshape(-1, 2).astype(np.int32)
                        cv2.fillPoly(mask, [pts], 255)
            masks.append(mask)
            
            # Convert bbox format
            bbox = anno["bbox"]
            if anno["bbox_mode"] == BoxMode.XYXY_ABS:
                bboxes.append(bbox)
            else:  # XYWH_ABS
                x1, y1, w, h = bbox
                bboxes.append([x1, y1, x1 + w, y1 + h])
            
            class_labels.append(anno["category_id"])
        
        # Apply augmentation
        if self.is_train and self.augmentation_pipeline:
            try:
                augmented = self.augmentation_pipeline(
                    image=image,
                    masks=masks,
                    bboxes=bboxes,
                    class_labels=class_labels
                )
                
                image = augmented["image"]
                masks = augmented["masks"]
                bboxes = augmented["bboxes"]
                class_labels = augmented["class_labels"]
                
                # Convert RGB back to BGR for Detectron2
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                
            except Exception as e:
                print(f"Augmentation error: {e}")
                # Use original data if augmentation fails
                pass
        
        # Convert image to tensor and normalize
        image = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
        
        # Update dataset dict
        dataset_dict["image"] = image
        dataset_dict["height"] = image.shape[1]
        dataset_dict["width"] = image.shape[2]
        
        # Convert masks and bboxes back to Detectron2 format
        instances = []
        for i, (mask, bbox, class_id) in enumerate(zip(masks, bboxes, class_labels)):
            # Create instance from mask
            mask = torch.as_tensor(mask.astype("uint8"))
            
            # Convert bbox to XYXY format
            x1, y1, x2, y2 = bbox
            bbox = torch.as_tensor([x1, y1, x2, y2], dtype=torch.float32)
            
            # Create instance
            instance = {
                "bbox": bbox,
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": mask,
                "category_id": class_id,
            }
            instances.append(instance)
        
        dataset_dict["annotations"] = instances
        
        return dataset_dict

class AugmentationTrainer(DefaultTrainer):
    """
    Custom trainer that integrates Albumentations with Detectron2.
    """
    
    def __init__(self, cfg):
        super().__init__(cfg)
        
        # Get augmentation pipeline
        self.train_transform, _ = get_augmentation_pipeline()
    
    @classmethod
    def build_train_loader(cls, cfg):
        """
        Build training data loader with augmentation.
        """
        mapper = AlbumentationsMapper(
            augmentation_pipeline=cls.get_augmentation_pipeline(),
            is_train=True
        )
        
        return build_detection_train_loader(cfg, mapper=mapper)
    
    @classmethod
    def get_augmentation_pipeline(cls):
        """Get augmentation pipeline for training"""
        train_transform, _ = get_augmentation_pipeline()
        return train_transform
    
    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        Build test data loader without augmentation.
        """
        mapper = AlbumentationsMapper(
            augmentation_pipeline=None,
            is_train=False
        )
        
        return build_detection_test_loader(cfg, dataset_name, mapper=mapper)

# Test trainer initialization
try:
    trainer = AugmentationTrainer(cfg)
    print("Custom trainer initialized successfully")
except Exception as e:
    print(f"Error initializing trainer: {e}")
    print("This is expected if datasets are not available yet")

print("Custom trainer with augmentation integration implemented")


### Training Execution and Evaluation


In [None]:
def run_mask_rcnn_training():
    """
    Execute Mask R-CNN training with custom ViT backbone.
    """
    
    # Check if final model already exists
    if config.FINAL_MODEL_PATH.exists():
        print(f"Final model already exists at {config.FINAL_MODEL_PATH}")
        print("Skipping training. To retrain, delete the existing file first.")
        return
    
    # Check if datasets are available
    train_dicts = get_melanoma_dicts(str(config.TRAIN_IMAGES_PATH), str(config.TRAIN_MASKS_PATH))
    val_dicts = get_melanoma_dicts(str(config.VAL_IMAGES_PATH), str(config.VAL_MASKS_PATH))
    
    if not train_dicts:
        print(f"No training data found in {config.TRAIN_IMAGES_PATH}")
        print("Please place labeled images and masks in the data/train/ directory")
        return
    
    if not val_dicts:
        print(f"No validation data found in {config.VAL_IMAGES_PATH}")
        print("Please place labeled images and masks in the data/val/ directory")
        return
    
    print(f"Starting Mask R-CNN training...")
    print(f"Training samples: {len(train_dicts)}")
    print(f"Validation samples: {len(val_dicts)}")
    
    try:
        # Initialize trainer
        trainer = AugmentationTrainer(cfg)
        
        # Start training
        trainer.resume_or_load(resume=False)
        trainer.train()
        
        print("Training completed successfully!")
        
        # Save final model
        save_final_model(trainer)
        
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()

def save_final_model(trainer):
    """Save the final trained model"""
    print("Saving final model...")
    
    # Get the best checkpoint
    checkpointer = DetectionCheckpointer(trainer.model, save_dir=cfg.OUTPUT_DIR)
    checkpointer.save("final_model")
    
    # Copy to our models directory
    import shutil
    final_checkpoint = Path(cfg.OUTPUT_DIR) / "final_model.pth"
    if final_checkpoint.exists():
        shutil.copy2(final_checkpoint, config.FINAL_MODEL_PATH)
        print(f"Final model saved to {config.FINAL_MODEL_PATH}")
    
    # Save config
    with open(config.CONFIG_PATH, 'w') as f:
        f.write(cfg.dump())
    print(f"Config saved to {config.CONFIG_PATH}")

def evaluate_model():
    """Evaluate the trained model on validation set"""
    
    if not config.FINAL_MODEL_PATH.exists():
        print("No trained model found. Please train the model first.")
        return
    
    print("Evaluating model on validation set...")
    
    try:
        # Load model
        model = build_model(cfg)
        checkpointer = DetectionCheckpointer(model)
        checkpointer.load(str(config.FINAL_MODEL_PATH))
        model.eval()
        
        # Get validation dataset
        val_dataset = DatasetCatalog.get("melanoma_val")
        metadata = MetadataCatalog.get("melanoma_val")
        
        # Run evaluation
        from detectron2.evaluation import COCOEvaluator, inference_on_dataset
        from detectron2.data import build_detection_test_loader
        
        evaluator = COCOEvaluator("melanoma_val", output_dir=cfg.OUTPUT_DIR)
        val_loader = build_detection_test_loader(cfg, "melanoma_val")
        
        results = inference_on_dataset(model, val_loader, evaluator)
        
        print("Evaluation results:")
        print(results)
        
        # Visualize some predictions
        visualize_predictions(model, val_dataset[:4], metadata)
        
    except Exception as e:
        print(f"Error during evaluation: {e}")
        import traceback
        traceback.print_exc()

def visualize_predictions(model, dataset_dicts, metadata, confidence_threshold=0.5):
    """Visualize model predictions on sample images"""
    
    print("Visualizing predictions...")
    
    for d in dataset_dicts:
        # Load image
        img = cv2.imread(d["file_name"])
        
        # Run inference
        with torch.no_grad():
            outputs = model([{"image": torch.as_tensor(img.transpose(2, 0, 1).astype("float32"))}])[0]
        
        # Filter predictions by confidence
        instances = outputs["instances"]
        instances = instances[instances.scores > confidence_threshold]
        
        # Visualize
        v = Visualizer(img[:, :, ::-1], metadata=metadata, scale=1.2)
        out = v.draw_instance_predictions(instances.to("cpu"))
        
        # Display
        plt.figure(figsize=(12, 8))
        plt.imshow(out.get_image())
        plt.title(f"Predictions (confidence > {confidence_threshold})")
        plt.axis('off')
        plt.show()

# Training and evaluation functions
print("Training and evaluation functions defined")
print("Call run_mask_rcnn_training() to start training")
print("Call evaluate_model() to evaluate trained model")


## Usage Instructions

### Complete Training Pipeline

1. **Prepare Data**:
   - Place unlabeled images in `data/images/` for SSL training
   - Place labeled images in `data/train/images/` and masks in `data/train/masks/`
   - Place validation images in `data/val/images/` and masks in `data/val/masks/`

2. **Run SSL Pre-training** (Optional):
   ```python
   run_ssl_training()  # Train MAE backbone on unlabeled data
   ```

3. **Run Mask R-CNN Training**:
   ```python
   run_mask_rcnn_training()  # Train Mask R-CNN with ViT backbone
   ```

4. **Evaluate Model**:
   ```python
   evaluate_model()  # Evaluate on validation set
   ```

### Expected Results

- **SSL Backbone**: `models/ssl_vit_backbone.pth`
- **Final Model**: `models/final_lesion_segmenter.pth`
- **Config**: `models/config.yaml`

The trained model can be used with the standalone `predict.py` script for inference.
