In [None]:
import os
import random
import numpy as np
import time
import logging
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import GradScaler, autocast

import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models.feature_extraction import create_feature_extractor

from timm.models.vision_transformer import VisionTransformer
from timm.data.mixup import Mixup
from timm.data.auto_augment import rand_augment_transform, auto_augment_transform
from timm.loss import LabelSmoothingCrossEntropy

import sklearn.model_selection as skms
from PIL import Image

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_seed()

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger()

# Custom Dataset with caching capability
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, cache_images=False):
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.cache_images = cache_images
        
        # Get all class directories
        self.class_dirs = [d for d in self.img_dir.iterdir() if d.is_dir()]
        self.classes = sorted([d.name for d in self.class_dirs])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        
        # Get all image files
        self.img_paths = []
        self.targets = []
        
        # Supported image extensions
        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
        
        # Iterate through each class directory
        for class_dir in self.class_dirs:
            class_idx = self.class_to_idx[class_dir.name]
            
            # Get all images in this class directory
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in valid_extensions:
                    self.img_paths.append(img_path)
                    self.targets.append(class_idx)
        
        # Create (path, label) pairs
        self.samples = list(zip(self.img_paths, self.targets))
        
        logger.info(f"Found {len(self.samples)} images in {len(self.classes)} classes")
        for cls, idx in self.class_to_idx.items():
            count = self.targets.count(idx)
            logger.info(f"Class {cls}: {count} images")
        
        # Image cache to speed up training
        self.image_cache = {}
        if self.cache_images:
            logger.info("Caching images for faster training...")
            for idx in range(len(self.samples)):
                self._load_image(idx)
            logger.info(f"Cached {len(self.image_cache)} images")
    
    def _load_image(self, idx):
        img_path, _ = self.samples[idx]
        if self.cache_images and str(img_path) in self.image_cache:
            return self.image_cache[str(img_path)]
        
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.cache_images:
                self.image_cache[str(img_path)] = image
                
            return image
        except Exception as e:
            logger.error(f"Error loading image {img_path}: {e}")
            # Return a blank image as fallback
            return Image.new('RGB', (224, 224), color='gray')
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image = self._load_image(idx)
        label = self.samples[idx][1]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# Advanced data augmentation pipeline
def get_transforms(img_size=224):
    # Basic transformations
    basic_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Advanced training augmentations
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.6, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        # RandAugment from timm
        rand_augment_transform('rand-m9-n3-mstd0.5', {'img_size': img_size}),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.25)
    ])
    
    return train_transform, basic_transform

# ResNet-DeiT Hybrid Model
class ResNetDeiTHybrid(nn.Module):
    def __init__(self, num_classes, resnet_pretrained=True, embed_dim=384, depth=8, 
                 num_heads=6, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        
        # 1. ResNet feature extractor
        weights = ResNet18_Weights.DEFAULT if resnet_pretrained else None
        resnet = resnet18(weights=weights)
        
        # Remove the final layer and pooling
        self.resnet_features = create_feature_extractor(
            resnet, 
            return_nodes={'layer4': 'features'}
        )
        
        # 2. Feature processing to prepare for transformer
        # ResNet18 outputs features with shape [B, 512, 7, 7]
        feature_dim = 512
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(feature_dim, embed_dim, kernel_size=1),
            nn.BatchNorm2d(embed_dim),
            nn.GELU()
        )
        
        # 3. Vision Transformer part (DeiT inspired)
        # The number of patches depends on the output size of ResNet
        # For ResNet18 with 224x224 input, feature map is 7x7 = 49 patches
        self.num_patches = 49  # 7x7 feature map
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=embed_dim * 4,
                dropout=dropout,
                activation='gelu',
                batch_first=True
            ),
            num_layers=depth
        )
        
        # 4. MLP Head with dropout
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, num_classes)
        )
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        # Initialize transformer weights
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embedding, std=0.02)
        
        # Initialize MLP head
        for m in self.fc.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # 1. Extract features using ResNet
        features = self.resnet_features(x)['features']  # [B, 512, 7, 7]
        
        # 2. Convert to patch embeddings
        patch_embeddings = self.to_patch_embedding(features)  # [B, embed_dim, 7, 7]
        patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)  # [B, 49, embed_dim]
        
        # 3. Add class token and positional embedding
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, patch_embeddings), dim=1)  # [B, 50, embed_dim]
        x = x + self.pos_embedding
        
        # 4. Apply transformer
        x = self.transformer(x)
        
        # 5. Use the class token for classification
        x = self.norm(x[:, 0])
        
        # 6. MLP head
        x = self.fc(x)
        
        return x

# Knowledge Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, outputs, labels, teacher_outputs):
        # Hard loss - CrossEntropy with true labels
        hard_loss = self.ce_loss(outputs, labels)
        
        # Soft loss - KL Divergence with teacher's predictions
        soft_loss = F.kl_div(
            F.log_softmax(outputs / self.temperature, dim=1),
            F.softmax(teacher_outputs / self.temperature, dim=1),
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # Combined loss
        return (1 - self.alpha) * hard_loss + self.alpha * soft_loss

# Training function
def train_one_epoch(model, train_loader, optimizer, criterion, device, epoch, 
                    teacher_model=None, mixup_fn=None, scaler=None, distill_criterion=None):
    model.train()
    if teacher_model is not None:
        teacher_model.eval()
    
    total_loss = 0
    correct = 0
    total = 0
    batch_time = AverageMeter()
    end = time.time()
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Apply mixup if available
        if mixup_fn is not None:
            inputs, targets_a, targets_b, lam = mixup_fn(inputs, targets)
            
        # Use automatic mixed precision if available
        if scaler is not None:
            with autocast():
                outputs = model(inputs)
                
                if teacher_model is not None:
                    with torch.no_grad():
                        teacher_outputs = teacher_model(inputs)
                    loss = distill_criterion(outputs, targets, teacher_outputs)
                elif mixup_fn is not None:
                    loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
                else:
                    loss = criterion(outputs, targets)
                    
            # Update with gradient scaling
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # Standard update
            outputs = model(inputs)
            
            if teacher_model is not None:
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)
                loss = distill_criterion(outputs, targets, teacher_outputs)
            elif mixup_fn is not None:
                loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else:
                loss = criterion(outputs, targets)
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        
        # Calculate accuracy (not accurate with mixup but useful for monitoring)
        if mixup_fn is None:
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Print progress
        if (batch_idx + 1) % 10 == 0:
            acc = 100. * correct / total if mixup_fn is None else -1
            logger.info(f'Epoch: {epoch} | Batch: {batch_idx+1}/{len(train_loader)} | '
                       f'Loss: {total_loss/(batch_idx+1):.4f} | Acc: {acc:.2f}% | '
                       f'Time: {batch_time.avg:.3f}s/batch')
    
    return total_loss / len(train_loader)

# Evaluation function
def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    acc = 100. * correct / total
    avg_loss = total_loss / len(val_loader)
    
    return avg_loss, acc

# Utility class for measurement
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Progressive unfreezing - helper function
def unfreeze_model_parts(model, epoch, total_epochs, layers_to_unfreeze):
    if epoch in layers_to_unfreeze:
        layer_to_unfreeze = layers_to_unfreeze[epoch]
        logger.info(f"Unfreezing {layer_to_unfreeze}")
        
        if layer_to_unfreeze == 'layer3':
            for param in model.resnet_features.layer3.parameters():
                param.requires_grad = True
        elif layer_to_unfreeze == 'layer2':
            for param in model.resnet_features.layer2.parameters():
                param.requires_grad = True
        elif layer_to_unfreeze == 'layer1':
            for param in model.resnet_features.layer1.parameters():
                param.requires_grad = True
        elif layer_to_unfreeze == 'transformer_last':
            # Unfreeze last 2 transformer layers
            transformer_layers = list(model.transformer.children())
            for layer in transformer_layers[-2:]:
                for param in layer.parameters():
                    param.requires_grad = True
        elif layer_to_unfreeze == 'transformer_all':
            # Unfreeze all transformer layers
            for param in model.transformer.parameters():
                param.requires_grad = True
        elif layer_to_unfreeze == 'all':
            # Unfreeze everything
            for param in model.parameters():
                param.requires_grad = True

# Main training function with k-fold cross-validation
def train_model_with_kfold(
    data_dir,
    num_classes,
    img_size=224,
    batch_size=32,
    num_epochs=30,
    learning_rate=1e-4,
    weight_decay=1e-4,
    k_folds=5,
    embed_dim=384,
    transformer_depth=6,
    num_heads=6,
    dropout=0.2,
    use_mixup=True,
    mixup_alpha=0.2,
    cutmix_alpha=0.2,
    label_smoothing=0.1,
    use_amp=True,
    use_distillation=False,
    distillation_alpha=0.5,
    cache_images=True
):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Get data transforms
    train_transform, val_transform = get_transforms(img_size)
    
    # Load the full dataset
    full_dataset = CustomImageDataset(
        img_dir=data_dir,
        transform=None,  # We'll apply transforms later
        cache_images=cache_images
    )
    
    # Get class weights for imbalanced dataset
    class_counts = np.zeros(num_classes)
    for _, label in full_dataset.samples:
        class_counts[label] += 1
    class_weights = torch.FloatTensor(1.0 / class_counts)
    class_weights = class_weights / class_weights.sum() * num_classes
    logger.info(f"Class weights: {class_weights}")
    
    # Create indices for k-fold split
    indices = list(range(len(full_dataset)))
    
    # Use stratified k-fold to maintain class distribution
    labels = [label for _, label in full_dataset.samples]
    skf = skms.StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
    
    # Setup mixup function if needed
    mixup_fn = None
    if use_mixup:
        mixup_fn = Mixup(
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            num_classes=num_classes
        )
    
    # Loss function with label smoothing
    criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
    
    # Keep track of best models across folds
    best_models = []
    fold_accuracies = []
    
    # Train with k-fold cross-validation
    for fold, (train_idx, val_idx) in enumerate(skf.split(indices, labels)):
        logger.info(f"Starting fold {fold+1}/{k_folds}")
        
        # Create datasets for this fold
        train_subsampler = SubsetRandomSampler(train_idx)
        val_subsampler = SubsetRandomSampler(val_idx)
        
        # Create data loaders with transforms
        train_loader = DataLoader(
            dataset=CustomImageDataset(
                img_dir=data_dir,
                transform=train_transform,
                cache_images=cache_images
            ),
            batch_size=batch_size,
            sampler=train_subsampler,
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            dataset=CustomImageDataset(
                img_dir=data_dir,
                transform=val_transform,
                cache_images=cache_images
            ),
            batch_size=batch_size,
            sampler=val_subsampler,
            num_workers=4,
            pin_memory=True
        )
        
        # Initialize model
        model = ResNetDeiTHybrid(
            num_classes=num_classes,
            resnet_pretrained=True,
            embed_dim=embed_dim,
            depth=transformer_depth,
            num_heads=num_heads,
            dropout=dropout
        )
        
        # Initially freeze most of the network for transfer learning
        # Start by only training the classification head and positional embeddings
        for param in model.parameters():
            param.requires_grad = False
            
        # Unfreeze specific parts
        for param in model.fc.parameters():
            param.requires_grad = True
        model.cls_token.requires_grad = True
        model.pos_embedding.requires_grad = True
        
        # Define which parts to unfreeze and when
        layers_to_unfreeze = {
            5: 'transformer_last',  # Unfreeze last transformer layers at epoch 5
            10: 'transformer_all',  # Unfreeze all transformer at epoch 10
            15: 'layer3',           # Unfreeze layer3 of ResNet at epoch 15
            20: 'layer2',           # Unfreeze layer2 of ResNet at epoch 20
            25: 'all'               # Unfreeze everything at epoch 25
        }
        
        # Initialize teacher model for knowledge distillation if needed
        teacher_model = None
        distill_criterion = None
        if use_distillation:
            # Use a bigger model as teacher (like a pre-trained ViT)
            teacher_model = torchvision.models.vit_b_16(weights="DEFAULT")
            teacher_model.heads[0] = nn.Linear(teacher_model.heads[0].in_features, num_classes)
            teacher_model = teacher_model.to(device)
            teacher_model.eval()
            
            # Distillation loss
            distill_criterion = DistillationLoss(alpha=distillation_alpha)
        
        # Move model to device
        model = model.to(device)
        
        # Optimizer with weight decay
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        # Learning rate scheduler
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=5,  # Restart every 5 epochs
            T_mult=2,  # Double the restart interval after each restart
            eta_min=learning_rate / 100  # Min learning rate
        )
        
        # Initialize gradient scaler for mixed precision training
        scaler = GradScaler() if use_amp else None
        
        # Early stopping setup
        best_acc = 0
        patience = 5
        patience_counter = 0
        best_model_state = None
        
        # Training loop
        for epoch in range(num_epochs):
            # Progressively unfreeze layers
            unfreeze_model_parts(model, epoch, num_epochs, layers_to_unfreeze)
            
            # Train for one epoch
            train_loss = train_one_epoch(
                model, train_loader, optimizer, criterion, device, epoch,
                teacher_model, mixup_fn, scaler, distill_criterion
            )
            
            # Evaluate on validation set
            val_loss, val_acc = evaluate(model, val_loader, criterion, device)
            
            # Update learning rate
            scheduler.step()
            
            # Log progress
            logger.info(f"Fold {fold+1} | Epoch {epoch+1}/{num_epochs} | "
                       f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
                       f"Val Acc: {val_acc:.2f}%")
            
            # Check if this is the best model
            if val_acc > best_acc:
                best_acc = val_acc
                patience_counter = 0
                best_model_state = model.state_dict().copy()
                logger.info(f"New best model with accuracy: {best_acc:.2f}%")
            else:
                patience_counter += 1
                
            # Check for early stopping
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered after epoch {epoch+1}")
                break
        
        # Save best model for this fold
        model_save_path = f"model_fold_{fold+1}_acc_{best_acc:.2f}.pth"
        torch.save(best_model_state, model_save_path)
        
        # Restore best model and evaluate
        model.load_state_dict(best_model_state)
        final_val_loss, final_val_acc = evaluate(model, val_loader, criterion, device)
        logger.info(f"Fold {fold+1} final results - Val Loss: {final_val_loss:.4f} | Val Acc: {final_val_acc:.2f}%")
        
        best_models.append((model, final_val_acc))
        fold_accuracies.append(final_val_acc)
    
    # Report overall results
    mean_acc = sum(fold_accuracies) / len(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    logger.info(f"K-fold cross-validation results:")
    logger.info(f"Mean accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
    
    # Return best model across all folds
    best_fold = np.argmax(fold_accuracies)
    best_model, best_acc = best_models[best_fold]
    logger.info(f"Best model from fold {best_fold+1} with accuracy: {best_acc:.2f}%")
    
    return best_model

# Inference function for the trained model
def predict_with_model(model, image_path, transform, device):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image)
        _, predicted = outputs.max(1)
    
    return predicted.item()

# Example usage
if __name__ == "__main__":
    # Configuration
    config = {
        "data_dir": "data",  # Update with your dataset path
        "num_classes": 3,  # Update with your number of classes
        "img_size": 224,
        "batch_size": 16,
        "num_epochs": 30,
        "learning_rate": 5e-5,
        "weight_decay": 1e-4,
        "k_folds": 5,
        "embed_dim": 384,
        "transformer_depth": 6,
        "num_heads": 6,
        "dropout": 0.2,
        "use_mixup": True,
        "mixup_alpha": 0.2,
        "cutmix_alpha": 0.2,
        "label_smoothing": 0.1,
        "use_amp": True,
        "use_distillation": True,
        "distillation_alpha": 0.5,
        "cache_images": True
    }
    
    # Train model
    best_model = train_model_with_kfold(**config)
    
    # Save final model
    torch.save(best_model.state_dict(), "best_resnet_deit_hybrid_model.pth")
    
    logger.info("Training completed and best model saved!")

2025-03-26 10:25:32,884 - INFO - Using device: mps
2025-03-26 10:25:32,888 - INFO - Found 1525 images in 3 classes
2025-03-26 10:25:32,889 - INFO - Class Healthy: 407 images
2025-03-26 10:25:32,889 - INFO - Class Monkeypox: 563 images
2025-03-26 10:25:32,889 - INFO - Class Other: 555 images
2025-03-26 10:25:32,889 - INFO - Caching images for faster training...
2025-03-26 10:25:35,116 - INFO - Cached 1525 images
2025-03-26 10:25:35,152 - INFO - Class weights: tensor([1.2214, 0.8829, 0.8957])
2025-03-26 10:25:35,162 - INFO - Starting fold 1/5
2025-03-26 10:25:35,165 - INFO - Found 1525 images in 3 classes
2025-03-26 10:25:35,166 - INFO - Class Healthy: 407 images
2025-03-26 10:25:35,166 - INFO - Class Monkeypox: 563 images
2025-03-26 10:25:35,166 - INFO - Class Other: 555 images
2025-03-26 10:25:35,166 - INFO - Caching images for faster training...
2025-03-26 10:25:36,131 - INFO - Cached 1525 images
2025-03-26 10:25:36,246 - INFO - Found 1525 images in 3 classes
2025-03-26 10:25:36,246 -