In [1]:
# Add these at the top of your notebook cell
import matplotlib
matplotlib.use('inline')  # or 'Agg'
import matplotlib.pyplot as plt
plt.ion()  # Turn on interactive mode

# Also add this to force display
%matplotlib inline
import torchvision.transforms.v2 as T

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from pathlib import Path
import tifffile
from typing import Tuple, Optional, List, Dict, Iterator
import logging
import re
from collections import defaultdict

class IMCDataset(Dataset):
    """Dataset for IMC (Imaging Mass Cytometry) data for tissue type and condition prediction"""
    
    def __init__(
        self,
        data_dir: str,
        transform=None,
        target_transform=None,
        use_mask: bool = False,
        channels: Optional[List[int]] = None,
        image_size: Optional[Tuple[int, int]] = None,
        normalize: bool = True,
        arcsinh_transform: bool = True,
        cofactor: float = 5.0,
        classification_task: str = "condition",  # "condition" (Benign vs Malignant) - "tissue" and "both" removed since we only have prostate data
        split: str = "train"
    ):
        """
        Args:
            data_dir: Directory containing ROI folders
            transform: Optional transform to be applied on images
            target_transform: Optional transform to be applied on labels
            use_mask: Whether to use the mask.tiff files
            channels: List of channel indices to use (if None, uses all)
            image_size: Size to resize images to (height, width)
            normalize: Whether to normalize channels
            arcsinh_transform: Whether to apply arcsinh transformation
            cofactor: Cofactor for arcsinh transformation
            classification_task: What to predict - "tissue", "condition", or "both"
        """
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.target_transform = target_transform
        self.use_mask = use_mask
        self.channels = channels
        self.image_size = image_size
        self.normalize = normalize
        self.arcsinh_transform = arcsinh_transform
        self.cofactor = cofactor
        self.classification_task = classification_task
        
        # Scan directory and parse folder names
        self.samples = self._scan_directory()
        
        # Create label mappings based on classification task
        self._create_label_mappings()
        
        logging.info(f"Loaded {len(self.samples)} samples")
        logging.info(f"Classification task: {self.classification_task}")
        logging.info(f"Labels: {self.unique_labels}")

        # Set transforms if not provided
        if transform is None and image_size is not None:
            train_transform, val_transform = self.get_imc_transforms(image_size)
            self.transform = train_transform if split == "train" else val_transform
        else:
            self.transform = transform
        
    def _scan_directory(self):
        """Scan directory for ROI folders and parse information"""
        samples = []
        
        for roi_folder in self.data_dir.glob("ROI*"):
            if not roi_folder.is_dir():
                continue
                
            # Parse folder name: ROI###_TISSUE_[Benign_]TMA###[N]
            folder_name = roi_folder.name
            
            # Extract ROI number
            roi_match = re.match(r"ROI(\d+)", folder_name)
            if not roi_match:
                continue
            roi_num = int(roi_match.group(1))
            
            # Extract tissue type
            parts = folder_name.split('_')
            if len(parts) < 3:
                continue
            tissue = parts[1]  # PROSTATE, LIVER, KIDNEY
            
            # Determine if benign or malignant
            if "Benign" in folder_name:
                condition = "Benign"
            else:
                condition = "Malignant"
            
            # Check if image files exist
            imc_path = roi_folder / "input" / "imc"
            if not imc_path.exists():
                continue
                
            # Look for .ome.tiff file
            ome_files = list(imc_path.glob("*.ome.tiff"))
            if not ome_files:
                continue
            image_file = ome_files[0]
            
            # Look for mask file if needed
            mask_file = None
            if self.use_mask:
                mask_files = list(imc_path.glob("*_mask.tiff"))
                if mask_files:
                    mask_file = mask_files[0]
            
            samples.append({
                'roi_num': roi_num,
                'tissue': tissue,
                'condition': condition,
                'folder_name': folder_name,
                'image_path': image_file,
                'mask_path': mask_file,
                'roi_folder': roi_folder
            })
        
        return samples
    
    def _create_label_mappings(self):
        """Create label mappings based on classification task"""
        if self.classification_task == "tissue":
            labels = [sample['tissue'] for sample in self.samples]
        elif self.classification_task == "condition":
            labels = [sample['condition'] for sample in self.samples]
        elif self.classification_task == "both":
            labels = [f"{sample['tissue']}_{sample['condition']}" for sample in self.samples]
        else:
            raise ValueError("classification_task must be 'tissue', 'condition', or 'both'")
        
        self.unique_labels = sorted(list(set(labels)))
        self.label_to_idx = {label: idx for idx, label in enumerate(self.unique_labels)}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        image = self._load_image(sample['image_path'])
        
        # Load mask if needed
        mask = None
        if self.use_mask and sample['mask_path']:
            mask = self._load_mask(sample['mask_path'])
        
        # Get label based on classification task
        if self.classification_task == "tissue":
            label = sample['tissue']
        elif self.classification_task == "condition":
            label = sample['condition']
        elif self.classification_task == "both":
            label = f"{sample['tissue']}_{sample['condition']}"
        
        label_idx = self.label_to_idx[label]
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            label_idx = self.target_transform(label_idx)
        
        if mask is not None:
            return image, label_idx, mask
        else:
            return image, label_idx
    
    def _load_image(self, img_path: Path) -> torch.Tensor:
        """Load and preprocess IMC image"""
        # Load OME-TIFF image
        img = tifffile.imread(str(img_path))
        
        # Ensure image is float32
        img = img.astype(np.float32)
        
        # Handle different image shapes
        if img.ndim == 2:
            # Single channel image
            img = img[np.newaxis, ...]  # Add channel dimension
        elif img.ndim == 3:
            # Multi-channel image - OME-TIFF usually has channels first
            # If channels are last, transpose
            if img.shape[2] < img.shape[0]:
                img = np.transpose(img, (2, 0, 1))
        
        # Select specific channels if specified
        if self.channels is not None:
            img = img[self.channels]
        
        # Apply arcsinh transformation (common for mass cytometry data)
        if self.arcsinh_transform:
            img = np.arcsinh(img / self.cofactor)
        
        # Normalize each channel
        if self.normalize:
            for c in range(img.shape[0]):
                channel = img[c]
                if channel.std() > 0:
                    img[c] = (channel - channel.mean()) / channel.std()
        
        # Resize if needed
        if self.image_size is not None and self.image_size != img.shape[1:]:
            img = torch.nn.functional.interpolate(
                torch.from_numpy(img).unsqueeze(0),
                size=self.image_size,
                mode='bilinear',
                align_corners=False
            ).squeeze(0).numpy()
        
        return torch.from_numpy(img)
    
    def get_imc_transforms(self, image_size: Tuple[int, int]):
        """
        Get transforms for IMC data
        
        Args:
            image_size: Target image size (height, width)
        
        Returns:
            train_transform, val_transform
        """
        
        # Training transforms with augmentation
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation([-90,90]),
            transforms.RandomAffine(degrees=[-45,45], translate=[0.05, 0.05]),
            transforms.Resize(image_size, antialias=True),
        ])
        
        # Validation transforms (no augmentation)
        val_transform = transforms.Compose([
            transforms.Resize(image_size, antialias=True),
        ])
        
        return train_transform, val_transform
    
    def _load_mask(self, mask_path: Path) -> torch.Tensor:
        """Load segmentation mask"""
        mask = tifffile.imread(str(mask_path))
        return torch.from_numpy(mask.astype(np.int64))
    
    def get_class_weights(self, indices: Optional[List[int]] = None):
        """Calculate class weights for imbalanced datasets"""
       
        if indices is not None and len(indices) > 0:
                # Convert to list if it's a numpy array or tensor
                if hasattr(indices, 'tolist'):
                    indices = indices.tolist()
                elif isinstance(indices, (np.ndarray, torch.Tensor)):
                    indices = indices.tolist()
        
                samples_to_use = [self.samples[i] for i in indices]
        else:
                samples_to_use = self.samples


        
        if self.classification_task == "tissue":
            labels = [sample['tissue'] for sample in samples_to_use]
        elif self.classification_task == "condition":
            labels = [sample['condition'] for sample in samples_to_use]
        elif self.classification_task == "both":
            labels = [f"{sample['tissue']}_{sample['condition']}" for sample in samples_to_use]
        
        from collections import Counter
        label_counts = Counter(labels)
        total_samples = len(samples_to_use)
        
        weights = []
        for label in self.unique_labels:
            if label in label_counts:
                weight = total_samples / (len(self.unique_labels) * label_counts[label])
            else:
                weight = 0.0  # Class not present in this fold
            weights.append(weight)
        
        return torch.FloatTensor(weights)
    
    

def create_kfold_dataloaders(
    data_dir: str,
    k_folds: int = 5,
    batch_size: int = 32,
    num_workers: int = 4,
    pin_memory: bool = True,
    random_seed: int = 17025,
    **dataset_kwargs
) -> Tuple[Iterator[Tuple[DataLoader, DataLoader, Dict]], Dict]:
    """
    Create k-fold cross-validation dataloaders
    
    Args:
        data_dir: Directory containing ROI folders
        k_folds: Number of folds for cross-validation
        batch_size: Batch size for dataloaders
        num_workers: Number of workers for dataloaders
        pin_memory: Whether to pin memory
        random_seed: Random seed for reproducibility
        **dataset_kwargs: Additional arguments for IMCDataset
    
    Returns:
        Iterator yielding (train_loader, val_loader, fold_info) for each fold
        Overall dataset info dictionary
    """
    
    # Create full dataset
    full_dataset = IMCDataset(
        data_dir=data_dir,
        **dataset_kwargs
    )
    
    # Check if we have any samples
    if len(full_dataset) == 0:
        raise ValueError("No valid samples found")
    
    # Validate k_folds parameter
    if k_folds < 2:
        raise ValueError("k_folds must be at least 2")
    if k_folds > len(full_dataset):
        raise ValueError(f"k_folds ({k_folds}) cannot be greater than number of samples ({len(full_dataset)})")
    
    # Get labels for stratification
    if full_dataset.classification_task == "tissue":
        labels = [sample['tissue'] for sample in full_dataset.samples]
    elif full_dataset.classification_task == "condition":
        labels = [sample['condition'] for sample in full_dataset.samples]
    else:
        labels = [f"{sample['tissue']}_{sample['condition']}" for sample in full_dataset.samples]
    
    # Check class distribution
    from collections import Counter
    class_counts = Counter(labels)
    min_class_count = min(class_counts.values())
    
    print(f" Class distribution: {dict(class_counts)}")
    
    if min_class_count < k_folds:
        print(f" Warning: Smallest class has only {min_class_count} samples, but k_folds={k_folds}")
        print("   Consider reducing k_folds or this may cause stratification issues")
    
    # Create stratified k-fold splits
    try:
        from sklearn.model_selection import StratifiedKFold
        
        skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=random_seed)
        indices = np.array(range(len(full_dataset)))
        labels_array = np.array(labels)
        
        fold_splits = list(skf.split(indices, labels_array))
        print(f" Using stratified {k_folds}-fold cross-validation")
        
    except (ImportError, ValueError) as e:
        # Fall back to regular k-fold if sklearn not available or stratification fails
        print(f" Stratified split failed ({e}), using regular {k_folds}-fold split")
        
        # Manual k-fold implementation
        torch.manual_seed(random_seed)
        indices = torch.randperm(len(full_dataset)).tolist()
        fold_size = len(full_dataset) // k_folds
        
        fold_splits = []
        for i in range(k_folds):
            start_idx = i * fold_size
            end_idx = (i + 1) * fold_size if i < k_folds - 1 else len(full_dataset)
            
            val_indices = indices[start_idx:end_idx]
            train_indices = indices[:start_idx] + indices[end_idx:]
            
            fold_splits.append((train_indices, val_indices))
    
    # Overall dataset info
    overall_info = {
        'num_classes': len(full_dataset.unique_labels),
        'class_names': full_dataset.unique_labels,
        'total_samples': len(full_dataset),
        'classification_task': full_dataset.classification_task,
        'k_folds': k_folds
    }
    
    def fold_generator():
        """Generator that yields train/val loaders for each fold"""
        
        for fold_idx, (train_indices, val_indices) in enumerate(fold_splits):
            print(f"\n Fold {fold_idx + 1}/{k_folds}")
            
            # Create datasets for this fold with appropriate transforms
            train_dataset_kwargs = {**dataset_kwargs, 'split': 'train'}
            val_dataset_kwargs = {**dataset_kwargs, 'split': 'val'}
            
            # Remove 'split' from kwargs if it was passed in originally
            train_dataset_kwargs.pop('split', None)
            val_dataset_kwargs.pop('split', None)
            
            train_dataset_full = IMCDataset(data_dir=data_dir, split='train', **train_dataset_kwargs)
            val_dataset_full = IMCDataset(data_dir=data_dir, split='val', **val_dataset_kwargs)
            
            # Create subsets for this fold
            train_dataset = torch.utils.data.Subset(train_dataset_full, train_indices)
            val_dataset = torch.utils.data.Subset(val_dataset_full, val_indices)
            
            # Create dataloaders
            train_loader = DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=num_workers,
                pin_memory=pin_memory,
                drop_last=True
            )
            
            val_loader = DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=pin_memory,
                drop_last=False
            )
            
            # Calculate class weights for this fold
            class_weights = full_dataset.get_class_weights(train_indices)
            
            # Get class distribution for this fold
            train_labels = [labels[i] for i in train_indices]
            val_labels = [labels[i] for i in val_indices]
            
            from collections import Counter
            train_dist = Counter(train_labels)
            val_dist = Counter(val_labels)
            
            fold_info = {
                'fold_idx': fold_idx,
                'train_size': len(train_indices),
                'val_size': len(val_indices),
                'class_weights': class_weights,
                'train_distribution': dict(train_dist),
                'val_distribution': dict(val_dist)
            }
            
            # Print fold summary
            print(f"   Train samples: {fold_info['train_size']}")
            print(f"   Val samples: {fold_info['val_size']}")
            print(f"   Train distribution: {fold_info['train_distribution']}")
            print(f"   Val distribution: {fold_info['val_distribution']}")
            
            yield train_loader, val_loader, fold_info
    
    # Print overall summary
    print(f"📊 K-Fold Dataset Summary:")
    print(f"   Classification task: {overall_info['classification_task']}")
    print(f"   Classes: {overall_info['class_names']}")
    print(f"   Total samples: {overall_info['total_samples']}")
    print(f"   Number of folds: {overall_info['k_folds']}")
    
    return fold_generator(), overall_info


def create_single_fold_dataloaders(
    data_dir: str,
    fold_idx: int,
    k_folds: int = 5,
    batch_size: int = 32,
    num_workers: int = 4,
    pin_memory: bool = True,
    random_seed: int = 17025,
    **dataset_kwargs
) -> Tuple[DataLoader, DataLoader, Dict]:
    """
    Create dataloaders for a specific fold (useful for parallel processing)
    
    Args:
        data_dir: Directory containing ROI folders
        fold_idx: Which fold to create (0-indexed)
        k_folds: Total number of folds
        **other args: Same as create_kfold_dataloaders
    
    Returns:
        train_loader, val_loader, fold_info for the specified fold
    """
    
    fold_generator, overall_info = create_kfold_dataloaders(
        data_dir=data_dir,
        k_folds=k_folds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        random_seed=random_seed,
        **dataset_kwargs
    )
    
    # Get the specific fold
    for i, (train_loader, val_loader, fold_info) in enumerate(fold_generator):
        if i == fold_idx:
            return train_loader, val_loader, fold_info
    
    raise ValueError(f"Fold {fold_idx} not found (max fold index: {k_folds-1})")



In [4]:
from sympy import Ge
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.ops import DropBlock2d
import torchvision.transforms as transforms
import torchvision.datasets as dataset
import torch.nn.functional as F
import torchvision.models as models
from typing import List, Tuple, Optional


class EfficientNetMarkerEncoder(nn.Module):
    """EfficientNet-B3 based marker encoder for individual IMC marker channels"""
    def __init__(self, feature_dim: int = 128, pretrained: bool = True, dropout_rate: float = 0.3):
        super(EfficientNetMarkerEncoder, self).__init__()
        self.feature_dim = feature_dim
        
        # Load the pretrained EfficientNet-B3 model
        if pretrained:
            weights = models.EfficientNet_B3_Weights.DEFAULT
        else:
            weights = None
            
        self.backbone = models.efficientnet_b3(weights=weights)
        
        # EfficientNet-B3 output features from the last layer before classifier
        # EfficientNet classifier is a Sequential with a Dropout and Linear layer
        backbone_out_features = self.backbone.classifier[1].in_features  # 1792 for EfficientNet-B4
        
        # Modify the first layer to accept a single channel input
        original_first_conv = self.backbone.features[0][0]
        self.backbone.features[0][0] = nn.Conv2d(
            in_channels=1,
            out_channels=original_first_conv.out_channels,
            kernel_size=original_first_conv.kernel_size,
            stride=original_first_conv.stride,
            padding=original_first_conv.padding,
            bias=False
        )
        
        # Remove final classification layers but keep the avgpool and flatten
        self.backbone.classifier = nn.Identity()
        
        # Store the avgpool and flatten operations explicitly
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

        # Adding custom feature head
        self.feature_head = nn.Sequential(
            nn.LayerNorm(backbone_out_features),
            nn.Linear(backbone_out_features, 512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, feature_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        # x shape: [batch_size, 1, H, W]
        # Extract features from backbone (before classifier)
        features = self.backbone.features(x)  # [batch_size, 1792, H', W']
        features = self.avgpool(features)  # [batch_size, 1792, 1, 1]
        features = self.flatten(features)  # [batch_size, 1792]
        features = self.feature_head(features)
        return features


class SimpleFusionHead(nn.Module):
    """Simple fusion head using average pooling instead of attention"""
    def __init__(self, feature_dim=128, num_markers=62, num_classes=2, dropout_rate=0.5):
        super(SimpleFusionHead, self).__init__()
        self.feature_dim = feature_dim
        self.num_markers = num_markers
        self.num_classes = num_classes
        
        # Simple classifier with more dropout for regularization
        self.classifier = nn.Sequential(
            nn.LayerNorm(feature_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(feature_dim, 256),
            nn.GELU(),                        
            nn.Dropout(dropout_rate),
            nn.Linear(256, 64),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, num_classes)
        )

    def forward(self, marker_features):
        # marker_features = [batch_size, num_markers, feature_dim]
        # Simple average pooling across markers
        pooled_features = torch.mean(marker_features, dim=1)  # [batch_size, feature_dim]
        
        # Classification
        logits = self.classifier(pooled_features)
        
        return {
            'logits': logits,
            'pooled_features': pooled_features
        }


class SimpleIMCClassifier(nn.Module):
    """Simplified IMC Classifier with EfficientNet-B3 backbone"""
    def __init__(
            self,
            num_markers=62,
            feature_dim=128,
            num_classes=2,
            pretrained=True,
            dropout_rate=0.5,
            shared_backbone=True
    ):
        super(SimpleIMCClassifier, self).__init__()
        self.num_markers = num_markers
        self.feature_dim = feature_dim
        self.shared_backbone = shared_backbone

        # Create the encoder
        if shared_backbone:
            self.marker_encoder = EfficientNetMarkerEncoder(
                feature_dim=feature_dim,
                pretrained=pretrained,
                dropout_rate=dropout_rate
            )
            print(f"Using EfficientNet-B3 backbone (pretrained={pretrained})")
        else:
            # For separate backbones (not implemented)
            pass
        
        # Simple Fusion head
        self.fusion_head = SimpleFusionHead(
            feature_dim=feature_dim,
            num_markers=num_markers,
            num_classes=num_classes,
            dropout_rate=dropout_rate
        )

    def forward(self, imc_data):
        # imc_data shape: [batch_size, num_markers, H, W]
        batch_size, num_markers, H, W = imc_data.shape
        assert num_markers == self.num_markers, f"Expected {self.num_markers}, got {num_markers}"
        
        # Extract features from each marker using EfficientNet-B3
        marker_features = []
        for i in range(num_markers):
            # Get single marker data: [batch_size, 1, H, W]
            marker_data = imc_data[:, i:i+1, :, :]
            if self.shared_backbone:
                features = self.marker_encoder(marker_data)
            else:
                pass  # Not implemented
            marker_features.append(features)

        # Stack features: [batch_size, num_markers, feature_dim]
        marker_features = torch.stack(marker_features, dim=1)

        # Fusion and classification
        output = self.fusion_head(marker_features)
        return output

    def freeze_backbones(self):
        """Freeze EfficientNet-B3 backbones and only train Fusion Head"""
        if self.shared_backbone:
            for param in self.marker_encoder.backbone.parameters():
                param.requires_grad = False
        print("Frozen backbone - only training fusion head")

    def unfreeze_backbones(self):
        """Unfreeze EfficientNet-B3 backbone for fine tuning"""
        if self.shared_backbone:
            for param in self.marker_encoder.backbone.parameters():
                param.requires_grad = True
        print("Unfrozen EfficientNet-B3 backbones - training full model")


def create_simple_model(
        num_markers=62,
        num_classes=2,
        feature_dim=128,
        pretrained=True,
        shared_backbone=True,
        dropout_rate=0.5,
        device='cuda' if torch.cuda.is_available() else 'mps'
):
    """Create simplified model with EfficientNet-B3 backbone"""
    if not shared_backbone:
        print(f"separate backbones not implemented, using shared")
        shared_backbone = True

    model = SimpleIMCClassifier(
        num_markers=num_markers,
        num_classes=num_classes,
        feature_dim=feature_dim,
        pretrained=pretrained,
        shared_backbone=shared_backbone,
        dropout_rate=dropout_rate
    )

    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Simple Model Summary")
    print(f"Markers: {num_markers}, Backbone: EfficientNet-B3, pretrained: {pretrained}")
    print(f"Feature_dim: {feature_dim}, classes: {num_classes}, device: {device}")
    print(f"Dropout rate: {dropout_rate}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    return model

device = 'cuda' if torch.cuda.is_available() else 'mps'
# Create simplified model
model = create_simple_model(
    num_markers=62,
    num_classes=2,  # Update this based on your dataset_info
    pretrained=True,
    shared_backbone=True,
    feature_dim=64,  # Reduced feature dimension
    dropout_rate=0.6,  # Higher dropout for regularization
    device=device
)

Using EfficientNet-B3 backbone (pretrained=True)
Simple Model Summary
Markers: 62, Backbone: EfficientNet-B3, pretrained: True
Feature_dim: 64, classes: 2, device: cuda
Dropout rate: 0.6
Total parameters: 11,551,706
Trainable parameters: 11,551,706


## Training Script

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
from pathlib import Path
import json
import time
from tqdm import tqdm
import logging




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import numpy as np
import json
import pandas as pd

# First, let's fix your IMCTrainer class (there were some bugs)
class IMCTrainer:
    def __init__(
            self,
            model,
            train_loader,
            val_loader,
            fold_info,  # Added fold_info parameter
            dataset_info,
            device='cuda' if torch.cuda.is_available() else 'mps',
            learning_rate=1e-3,  # Fixed: was 1e3 (too high!)
            weight_decay=1e-4,   # Fixed: was 1e4 (too high!)
            use_class_weights=True,
            use_mixed_precision=True,
            save_dir='./checkpoints'
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.fold_info = fold_info  # Store fold information
        self.dataset_info = dataset_info
        self.device = device
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        
        # Loss function
        if use_class_weights and 'class_weights' in fold_info:
            class_weights = fold_info['class_weights'].to(device)
            self.criterion = nn.CrossEntropyLoss(weight=class_weights)
            print(f"Using class weights: {class_weights}")
        else:
            self.criterion = nn.CrossEntropyLoss()

        # Optimizer
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,

        )
        
        # Mixed precision training for memory efficiency
        self.use_mixed_precision = use_mixed_precision
        if use_mixed_precision:
            self.scaler = torch.amp.GradScaler("cuda")
            print("Using mixed precision training") 
        
        # Training history
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': [],
            'learning_rates': []
        }
        self.best_val_acc = 0.0
        self.best_model_path = None

    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0  # Fixed: was missing
        all_preds = []
        all_labels = []
        
        pbar = tqdm(self.train_loader, desc="Training")
        for batch_idx, (images, labels) in enumerate(pbar):
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()
        
            if self.use_mixed_precision:
                with torch.amp.autocast("cuda"):
                    output = self.model(images)
                    loss = self.criterion(output['logits'], labels)

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                output = self.model(images)
                loss = self.criterion(output['logits'], labels)
                loss.backward()
                self.optimizer.step()
            
            running_loss += loss.item()

            # Calculate predictions
            _, preds = torch.max(output['logits'], 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Update progress bar
            pbar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'Acc': f"{accuracy_score(all_labels, all_preds):.4f}"
            })
            
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = accuracy_score(all_labels, all_preds)
        return epoch_loss, epoch_acc

    def validate_epoch(self):
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc="Validation")
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                
                if self.use_mixed_precision:
                    with torch.amp.autocast("cuda"):
                        output = self.model(images)
                        loss = self.criterion(output['logits'], labels)
                else:
                    output = self.model(images)
                    loss = self.criterion(output['logits'], labels)
                
                running_loss += loss.item()
                
                # Calculate predictions
                _, preds = torch.max(output['logits'], 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                pbar.set_postfix({'Loss': f"{loss.item():.4f}"})
        
        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = accuracy_score(all_labels, all_preds)
        
        # Calculate detailed metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, all_preds, average='weighted'
        )
        
        return {
            'loss': epoch_loss,
            'accuracy': epoch_acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'predictions': all_preds,
            'labels': all_labels
        }
    
    def train(self, num_epochs=50, freeze_backbone_epochs=5):
        """Full training loop"""
        print(f"Starting training for {num_epochs} epochs")
        print(f"First {freeze_backbone_epochs} epochs with frozen backbone")

        # Phase 1: Training only the fusion head first with frozen backbone
        if freeze_backbone_epochs > 0:
            print(f"Training fusion head only {freeze_backbone_epochs} epochs")
            self.model.freeze_backbones()

            for epoch in range(freeze_backbone_epochs):
                print(f"\nEpoch {epoch+1}/{freeze_backbone_epochs}")

                # Train 
                train_loss, train_acc = self.train_epoch()

                # Validate
                val_results = self.validate_epoch()

                # Update history
                self.history['train_loss'].append(train_loss)
                self.history['val_loss'].append(val_results['loss'])
                self.history['train_acc'].append(train_acc)
                self.history['val_acc'].append(val_results['accuracy'])
                self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])

                print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
                print(f"Val Loss: {val_results['loss']:.4f}, Val Acc: {val_results['accuracy']:.4f}")
                print(f"Val F1: {val_results['f1']:.4f}")
            
                # Save best model
                if val_results['accuracy'] > self.best_val_acc:
                    self.best_val_acc = val_results['accuracy']
                    self.save_checkpoint(epoch, f"best_frozen_epoch_{epoch+1}.pth", val_results)
            
                # Learning rate scheduling
                self.scheduler.step(val_results['loss'])

        # Phase 2: Unfreeze and fine tune complete model
        remaining_epochs = num_epochs - freeze_backbone_epochs
        if remaining_epochs > 0:
            print(f"Phase 2: fine tuning entire model ({remaining_epochs} epochs)")
            self.model.unfreeze_backbones()

            # Reduce lr for fine tuning
            for param_group in self.optimizer.param_groups:
                param_group['lr'] *= 0.1
            print(f"🔽 Reduced learning rate to {self.optimizer.param_groups[0]['lr']}")
        
            for epoch in range(freeze_backbone_epochs, num_epochs):
                print(f"\nEpoch {epoch+1}/{num_epochs}")
                
                # Train
                train_loss, train_acc = self.train_epoch()
                
                # Validate
                val_results = self.validate_epoch()
                
                # Update history
                self.history['train_loss'].append(train_loss)
                self.history['val_loss'].append(val_results['loss'])
                self.history['train_acc'].append(train_acc)
                self.history['val_acc'].append(val_results['accuracy'])
                self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])
                
                print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
                print(f"Val Loss: {val_results['loss']:.4f}, Val Acc: {val_results['accuracy']:.4f}")
                print(f"Val F1: {val_results['f1']:.4f}")
                
                # Save best model
                if val_results['accuracy'] > self.best_val_acc:
                    self.best_val_acc = val_results['accuracy']
                    self.save_checkpoint(epoch, f"best_finetuned_epoch_{epoch+1}.pth", val_results)
                
                # Learning rate scheduling
                self.scheduler.step(val_results['loss'])
                
                # Early stopping check
                if self.optimizer.param_groups[0]['lr'] < 1e-7:
                    print("💤 Learning rate too small, stopping training")
                    break
                
        print(f"Training completed! Best validation accuracy: {self.best_val_acc:.4f}")
        return self.history
    
    def save_checkpoint(self, epoch, filename, val_results=None):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc,
            'history': self.history,
            'dataset_info': self.dataset_info,
            'fold_info': self.fold_info
        }
        
        if val_results:
            checkpoint['val_results'] = val_results
        
        checkpoint_path = self.save_dir / filename
        torch.save(checkpoint, checkpoint_path)
        self.best_model_path = checkpoint_path
        print(f"💾 Saved checkpoint: {checkpoint_path}")


# New K-Fold Training Manager
class KFoldIMCTrainer:
    """Manager for K-Fold cross-validation training"""
    
    def __init__(
        self,
        model_factory,  # Function that creates a fresh model for each fold
        fold_generator,
        dataset_info,
        device='cuda' if torch.cuda.is_available() else 'mps',
        save_dir='./kfold_results',
        **trainer_kwargs  # Additional arguments for IMCTrainer
    ):
        self.model_factory = model_factory
        self.fold_generator = fold_generator
        self.dataset_info = dataset_info
        self.device = device
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.trainer_kwargs = trainer_kwargs
        
        # Results storage
        self.fold_results = []
        self.cv_summary = {}
        
    def train_all_folds(self, num_epochs=50, freeze_backbone_epochs=5):
        """Train all folds in the cross-validation"""
        
        print(f"🔄 Starting {self.dataset_info['k_folds']}-Fold Cross-Validation")
        print(f"📊 Dataset: {self.dataset_info['total_samples']} samples, {self.dataset_info['num_classes']} classes")
        print(f"🏷️ Classes: {self.dataset_info['class_names']}")
        print("="*80)
        
        fold_results = []
        
        for fold_idx, (train_loader, val_loader, fold_info) in enumerate(self.fold_generator):
            print(f"\n🎯 FOLD {fold_idx + 1}/{self.dataset_info['k_folds']}")
            print(f"   Train samples: {fold_info['train_size']}")
            print(f"   Val samples: {fold_info['val_size']}")
            print(f"   Train distribution: {fold_info['train_distribution']}")
            print(f"   Val distribution: {fold_info['val_distribution']}")
            print("-"*60)
            
            # Create fresh model for this fold
            model = self.model_factory().to(self.device)
            
            # Create fold-specific save directory
            fold_save_dir = self.save_dir / f"fold_{fold_idx + 1}"
            fold_save_dir.mkdir(exist_ok=True)
            
            # Create trainer for this fold
            trainer = IMCTrainer(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                fold_info=fold_info,
                dataset_info=self.dataset_info,
                device=self.device,
                save_dir=fold_save_dir,
                **self.trainer_kwargs
            )
            
            # Train this fold
            history = trainer.train(
                num_epochs=num_epochs,
                freeze_backbone_epochs=freeze_backbone_epochs
            )
            
            # Store fold results
            fold_result = {
                'fold_idx': fold_idx,
                'best_val_acc': trainer.best_val_acc,
                'history': history,
                'fold_info': fold_info,
                'best_model_path': trainer.best_model_path
            }
            fold_results.append(fold_result)
            
            print(f"✅ Fold {fold_idx + 1} completed - Best Val Acc: {trainer.best_val_acc:.4f}")
            
            # Clean up GPU memory
            del model, trainer
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        self.fold_results = fold_results
        self.compute_cv_summary()
        self.save_cv_results()
        self.plot_cv_results()
        
        return self.cv_summary
    
    def compute_cv_summary(self):
        """Compute cross-validation summary statistics"""
        
        # Extract metrics from all folds
        val_accuracies = [fold['best_val_acc'] for fold in self.fold_results]
        final_train_accs = [fold['history']['train_acc'][-1] for fold in self.fold_results]
        final_val_accs = [fold['history']['val_acc'][-1] for fold in self.fold_results]
        
        # Summary statistics
        self.cv_summary = {
            'num_folds': len(self.fold_results),
            'best_val_accuracies': val_accuracies,
            'final_train_accuracies': final_train_accs,
            'final_val_accuracies': final_val_accs,
            
            # Mean and std of best validation accuracies
            'mean_val_acc': np.mean(val_accuracies),
            'std_val_acc': np.std(val_accuracies),
            'min_val_acc': np.min(val_accuracies),
            'max_val_acc': np.max(val_accuracies),
            
            # Overall performance
            'mean_final_train_acc': np.mean(final_train_accs),
            'mean_final_val_acc': np.mean(final_val_accs),
            
            # Best fold
            'best_fold_idx': np.argmax(val_accuracies),
            'best_fold_acc': np.max(val_accuracies)
        }
        
        # Print summary
        print("\n" + "="*80)
        print("📈 CROSS-VALIDATION SUMMARY")
        print("="*80)
        print(f"Mean Validation Accuracy: {self.cv_summary['mean_val_acc']:.4f} ± {self.cv_summary['std_val_acc']:.4f}")
        print(f"Best Fold: {self.cv_summary['best_fold_idx'] + 1} (Acc: {self.cv_summary['best_fold_acc']:.4f})")
        print(f"Range: [{self.cv_summary['min_val_acc']:.4f}, {self.cv_summary['max_val_acc']:.4f}]")
        
        print(f"\nFold-by-fold results:")
        for i, acc in enumerate(val_accuracies):
            print(f"  Fold {i+1}: {acc:.4f}")
    
    def save_cv_results(self):
        """Save cross-validation results"""
        
        # Save summary as JSON
        summary_path = self.save_dir / 'cv_summary.json'
        with open(summary_path, 'w') as f:
            # Convert numpy types to native Python types for JSON serialization
            json_summary = {}
            for key, value in self.cv_summary.items():
                if isinstance(value, np.ndarray):
                    json_summary[key] = value.tolist()
                elif isinstance(value, (np.integer, np.floating)):
                    json_summary[key] = value.item()
                else:
                    json_summary[key] = value
            json.dump(json_summary, f, indent=2)
        
        # Save detailed results
        detailed_results = {
            'dataset_info': self.dataset_info,
            'cv_summary': self.cv_summary,
            'fold_results': self.fold_results
        }
        
        results_path = self.save_dir / 'detailed_cv_results.pth'
        torch.save(detailed_results, results_path)
        
        print(f"💾 Results saved to {self.save_dir}")
    
    def plot_cv_results(self):
        """Plot cross-validation results"""
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        # 1. Validation accuracy across folds
        fold_numbers = range(1, len(self.fold_results) + 1)
        val_accs = self.cv_summary['best_val_accuracies']
        
        ax1.bar(fold_numbers, val_accs, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.axhline(y=self.cv_summary['mean_val_acc'], color='red', linestyle='--', 
                   label=f"Mean: {self.cv_summary['mean_val_acc']:.4f}")
        ax1.set_xlabel('Fold')
        ax1.set_ylabel('Best Validation Accuracy')
        ax1.set_title('Validation Accuracy Across Folds')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. Training curves for all folds
        for i, fold_result in enumerate(self.fold_results):
            epochs = range(1, len(fold_result['history']['val_acc']) + 1)
            ax2.plot(epochs, fold_result['history']['val_acc'], 
                    alpha=0.7, label=f'Fold {i+1}')
        
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Validation Accuracy')
        ax2.set_title('Validation Accuracy Curves')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # 3. Distribution of final accuracies
        ax3.hist(val_accs, bins=min(10, len(val_accs)), alpha=0.7, color='lightgreen', edgecolor='black')
        ax3.axvline(x=self.cv_summary['mean_val_acc'], color='red', linestyle='--',
                   label=f"Mean: {self.cv_summary['mean_val_acc']:.4f}")
        ax3.set_xlabel('Validation Accuracy')
        ax3.set_ylabel('Frequency')
        ax3.set_title('Distribution of Best Validation Accuracies')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # 4. Summary statistics
        ax4.axis('off')
        summary_text = f"""
Cross-Validation Summary

Number of Folds: {self.cv_summary['num_folds']}
Mean Val Accuracy: {self.cv_summary['mean_val_acc']:.4f} ± {self.cv_summary['std_val_acc']:.4f}

Best Fold: {self.cv_summary['best_fold_idx'] + 1}
Best Accuracy: {self.cv_summary['best_fold_acc']:.4f}

Min Accuracy: {self.cv_summary['min_val_acc']:.4f}
Max Accuracy: {self.cv_summary['max_val_acc']:.4f}

Dataset: {self.dataset_info['total_samples']} samples
Classes: {', '.join(self.dataset_info['class_names'])}
        """
        ax4.text(0.1, 0.9, summary_text, transform=ax4.transAxes, fontsize=12,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        plt.tight_layout()
        plt.savefig(self.save_dir / 'cv_results.png', dpi=300, bbox_inches='tight')
        plt.show()


# Example usage function
def train_kfold_imc_model(
    data_dir,
    model_factory,
    k_folds=5,
    num_epochs=50,
    freeze_backbone_epochs=5,
    batch_size=16,
    learning_rate=1e-3,
    weight_decay=1e-4,
    **dataset_kwargs
):
    """
    Complete K-Fold training pipeline
    
    Args:
        data_dir: Path to IMC data
        model_factory: Function that creates a fresh model (e.g., lambda: create_simple_model())
        k_folds: Number of folds
        num_epochs: Training epochs per fold
        freeze_backbone_epochs: Epochs to train with frozen backbone
        batch_size: Batch size
        learning_rate: Learning rate
        weight_decay: Weight decay
        **dataset_kwargs: Additional arguments for IMCDataset
    """
    
    # Import your dataloader functions here
   
    
    # Create k-fold dataloaders
    fold_generator, dataset_info = create_kfold_dataloaders(
        data_dir=data_dir,
        k_folds=k_folds,
        batch_size=batch_size,
        **dataset_kwargs
    )
    
    # Create K-Fold trainer
    kfold_trainer = KFoldIMCTrainer(
        model_factory=model_factory,
        fold_generator=fold_generator,
        dataset_info=dataset_info,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        save_dir=f'./kfold_results_{k_folds}fold'
    )
    
    # Train all folds
    cv_summary = kfold_trainer.train_all_folds(
        num_epochs=num_epochs,
        freeze_backbone_epochs=freeze_backbone_epochs
    )
    
    return cv_summary, kfold_trainer

In [8]:
import torch
import torch.nn as nn
import numpy as np
import json
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix

device = 'cuda' if torch.cuda.is_available() else 'mps'
print(f"Using {device}")

# Training configuration
config = {
    'batch_size': 3,
    'image_size': (256, 256),
    'k_folds': 5,
    'num_epochs': 30,
    'freeze_backbone_epochs': 5,
    'learning_rate': 1e-3,  # Fixed from your original
    'weight_decay': 1e-4,   # Fixed from your original
    'feature_dim': 128,
    'dropout_rate': 0.6,
    'num_workers': 4,
    'random_seed': 17025
}

print("Loading data for k-fold cross-validation...")

# Create k-fold dataloaders
fold_generator, dataset_info = create_kfold_dataloaders(
    data_dir='./data',
    k_folds=config['k_folds'],
    classification_task='condition',
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    arcsinh_transform=True,
    cofactor=5.0,
    num_workers=config['num_workers'],
    random_seed=config['random_seed']
)

print(f" Dataset Info:")
print(f"   Classes: {dataset_info['class_names']}")
print(f"   Total samples: {dataset_info['total_samples']}")
print(f"   K-folds: {dataset_info['k_folds']}")

# Storage for cross-validation results
cv_results = {
    'fold_results': [],
    'fold_histories': [],
    'best_models': [],
    'dataset_info': dataset_info,
    'config': config
}

# Aggregate metrics across folds
aggregate_metrics = defaultdict(list)

print(f"\n Starting {config['k_folds']}-Fold Cross-Validation...")

# Train on each fold
for fold_idx, (train_loader, val_loader, fold_info) in enumerate(fold_generator):
    print(f"\n{'='*60}")
    print(f" FOLD {fold_idx + 1}/{config['k_folds']}")
    print(f"{'='*60}")
    print(f"Train samples: {fold_info['train_size']}")
    print(f"Val samples: {fold_info['val_size']}")
    print(f"Train distribution: {fold_info['train_distribution']}")
    print(f"Val distribution: {fold_info['val_distribution']}")
    
    # Create model for this fold
    model = create_simple_model(  # Fixed function name
        num_markers=62,
        num_classes=dataset_info['num_classes'],
        pretrained=True,
        shared_backbone=True,
        feature_dim=config['feature_dim'],
        dropout_rate=config['dropout_rate'],
        device=device
    )
    
    # Create fold-specific save directory
    fold_save_dir = f'./checkpoints/fold_{fold_idx + 1}'
    Path(fold_save_dir).mkdir(parents=True, exist_ok=True)
    
    # Create trainer for this fold
    trainer = IMCTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        fold_info=fold_info,  # Pass fold_info directly
        dataset_info=dataset_info,
        device=device,
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        use_class_weights=True,
        use_mixed_precision=True,
        save_dir=fold_save_dir
    )
    
    # Train model for this fold
    print(f" Starting training for fold {fold_idx + 1}...")
    try:
        history = trainer.train(
            num_epochs=config['num_epochs'],
            freeze_backbone_epochs=config['freeze_backbone_epochs']
        )
        
        # Store fold results
        fold_result = {
            'fold_idx': fold_idx,
            'best_val_acc': trainer.best_val_acc,
            'final_train_acc': history['train_acc'][-1] if history['train_acc'] else 0,
            'final_val_acc': history['val_acc'][-1] if history['val_acc'] else 0,
            'final_train_loss': history['train_loss'][-1] if history['train_loss'] else float('inf'),
            'final_val_loss': history['val_loss'][-1] if history['val_loss'] else float('inf'),
            'best_val_loss': min(history['val_loss']) if history['val_loss'] else float('inf'),
            'fold_info': fold_info,
            'best_model_path': str(trainer.best_model_path) if trainer.best_model_path else None,
            'convergence_epoch': len(history['val_acc']) if history['val_acc'] else 0
        }
        
        # Store results
        cv_results['fold_results'].append(fold_result)
        cv_results['fold_histories'].append(history)
        cv_results['best_models'].append(trainer.best_model_path)
        
        # Aggregate metrics
        aggregate_metrics['val_acc'].append(trainer.best_val_acc)
        aggregate_metrics['final_train_acc'].append(fold_result['final_train_acc'])
        aggregate_metrics['final_val_acc'].append(fold_result['final_val_acc'])
        aggregate_metrics['final_train_loss'].append(fold_result['final_train_loss'])
        aggregate_metrics['final_val_loss'].append(fold_result['final_val_loss'])
        aggregate_metrics['best_val_loss'].append(fold_result['best_val_loss'])
        aggregate_metrics['convergence_epochs'].append(fold_result['convergence_epoch'])
        
        # Plot training history for this fold
        trainer.plot_training_history()
        
        # Save fold-specific results
        fold_results_path = Path(fold_save_dir) / 'fold_results.json'
        
        with open(fold_results_path, 'w') as f:
            # Convert tensors and Paths to serializable format
            serializable_result = {}
            for k, v in fold_result.items():
                if isinstance(v, torch.Tensor):
                    serializable_result[k] = v.tolist()
                elif isinstance(v, Path):
                    serializable_result[k] = str(v)
                elif k == 'fold_info':
                    # Handle fold_info specially
                    serializable_fold_info = {}
                    for fk, fv in v.items():
                        if isinstance(fv, torch.Tensor):
                            serializable_fold_info[fk] = fv.tolist()
                        else:
                            serializable_fold_info[fk] = fv
                    serializable_result[k] = serializable_fold_info
                else:
                    serializable_result[k] = v
            json.dump(serializable_result, f, indent=2)
        
        print(f" Fold {fold_idx + 1} completed!")
        print(f"   Best Val Accuracy: {trainer.best_val_acc:.4f}")
        print(f"   Final Train Accuracy: {fold_result['final_train_acc']:.4f}")
        print(f"   Final Val Accuracy: {fold_result['final_val_acc']:.4f}")
        print(f"   Convergence Epoch: {fold_result['convergence_epoch']}")
        print(f"   Results saved to: {fold_results_path}")
        
    except Exception as e:
        print(f" Error in fold {fold_idx + 1}: {str(e)}")
        # Store failed fold info
        failed_result = {
            'fold_idx': fold_idx,
            'error': str(e),
            'fold_info': fold_info
        }
        cv_results['fold_results'].append(failed_result)
        continue
    
    finally:
        # Clean up model to free memory
        del model, trainer
        if device == 'cuda':
            torch.cuda.empty_cache()

# Check if we have any successful folds
successful_folds = [r for r in cv_results['fold_results'] if 'best_val_acc' in r]

if not successful_folds:
    print(" No folds completed successfully!")
    exit(1)

print(f"\n{'='*60}")
print(f" CROSS-VALIDATION RESULTS SUMMARY")
print(f"{'='*60}")
print(f"Successful folds: {len(successful_folds)}/{config['k_folds']}")

# Calculate means and standard deviations
def safe_mean_std(values):
    """Calculate mean and std, handling empty lists"""
    if not values:
        return 0.0, 0.0
    return np.mean(values), np.std(values)

mean_val_acc, std_val_acc = safe_mean_std(aggregate_metrics['val_acc'])
mean_train_acc, std_train_acc = safe_mean_std(aggregate_metrics['final_train_acc'])
mean_final_val_acc, std_final_val_acc = safe_mean_std(aggregate_metrics['final_val_acc'])
mean_val_loss, std_val_loss = safe_mean_std(aggregate_metrics['best_val_loss'])
mean_convergence, std_convergence = safe_mean_std(aggregate_metrics['convergence_epochs'])

print(f"\n📊 Performance Metrics:")
print(f"Best Validation Accuracy: {mean_val_acc:.4f} ± {std_val_acc:.4f}")
print(f"Final Training Accuracy: {mean_train_acc:.4f} ± {std_train_acc:.4f}")
print(f"Final Validation Accuracy: {mean_final_val_acc:.4f} ± {std_final_val_acc:.4f}")
print(f"Best Validation Loss: {mean_val_loss:.4f} ± {std_val_loss:.4f}")
print(f"Convergence Epochs: {mean_convergence:.1f} ± {std_convergence:.1f}")

print(f"\n📋 Per-fold results:")
for result in successful_folds:
    i = result['fold_idx']
    print(f"  Fold {i+1}: Val Acc = {result['best_val_acc']:.4f}, "
          f"Final Val Acc = {result['final_val_acc']:.4f}, "
          f"Epochs = {result['convergence_epoch']}")

# Find best performing fold
if aggregate_metrics['val_acc']:
    best_fold_idx = np.argmax(aggregate_metrics['val_acc'])
    best_fold = successful_folds[best_fold_idx]
    print(f"\n Best performing fold: Fold {best_fold['fold_idx'] + 1}")
    print(f"   Best validation accuracy: {best_fold['best_val_acc']:.4f}")
    print(f"   Model path: {best_fold['best_model_path']}")

def plot_comprehensive_cv_results(cv_results, aggregate_metrics, config):
    """Plot comprehensive cross-validation results"""
    
    # Filter successful results
    successful_results = [r for r in cv_results['fold_results'] if 'best_val_acc' in r]
    successful_histories = [cv_results['fold_histories'][r['fold_idx']] for r in successful_results]
    
    if not successful_results:
        print("No successful results to plot!")
        return
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f'{config["k_folds"]}-Fold Cross-Validation Results', fontsize=16, fontweight='bold')
    
    # 1. Validation accuracy per fold
    fold_nums = [r['fold_idx'] + 1 for r in successful_results]
    val_accs = [r['best_val_acc'] for r in successful_results]
    
    axes[0, 0].bar([f"Fold {i}" for i in fold_nums], val_accs, alpha=0.7, color='skyblue')
    if val_accs:
        axes[0, 0].axhline(y=np.mean(val_accs), color='red', 
                           linestyle='--', label=f"Mean: {np.mean(val_accs):.3f}")
    axes[0, 0].set_title('Best Validation Accuracy by Fold')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # 2. Training curves for all folds
    for i, history in enumerate(successful_histories):
        if history['val_acc']:
            epochs = range(1, len(history['val_acc']) + 1)
            fold_idx = successful_results[i]['fold_idx']
            axes[0, 1].plot(epochs, history['val_acc'], alpha=0.7, label=f'Fold {fold_idx+1}')
    axes[0, 1].set_title('Validation Accuracy Curves')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Loss curves for all folds
    for i, history in enumerate(successful_histories):
        if history['val_loss']:
            epochs = range(1, len(history['val_loss']) + 1)
            fold_idx = successful_results[i]['fold_idx']
            axes[0, 2].plot(epochs, history['val_loss'], alpha=0.7, label=f'Fold {fold_idx+1}')
    axes[0, 2].set_title('Validation Loss Curves')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Loss')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # 4. Box plot of metrics
    metrics_data = []
    metrics_labels = []
    
    if aggregate_metrics['val_acc']:
        metrics_data.append(aggregate_metrics['val_acc'])
        metrics_labels.append('Best Val Acc')
    
    if aggregate_metrics['final_val_acc']:
        metrics_data.append(aggregate_metrics['final_val_acc'])
        metrics_labels.append('Final Val Acc')
    
    if aggregate_metrics['final_train_acc']:
        metrics_data.append(aggregate_metrics['final_train_acc'])
        metrics_labels.append('Final Train Acc')
    
    if metrics_data:
        axes[1, 0].boxplot(metrics_data, labels=metrics_labels)
        axes[1, 0].set_title('Accuracy Distribution')
        axes[1, 0].set_ylabel('Accuracy')
        axes[1, 0].grid(True, alpha=0.3)
        axes[1, 0].tick_params(axis='x', rotation=45)
    
    # 5. Convergence analysis
    if aggregate_metrics['convergence_epochs']:
        convergence_epochs = aggregate_metrics['convergence_epochs']
        axes[1, 1].bar([f"Fold {i+1}" for i in fold_nums], 
                       [successful_results[i]['convergence_epoch'] for i in range(len(successful_results))],
                       alpha=0.7, color='lightgreen')
        axes[1, 1].axhline(y=np.mean(convergence_epochs), color='red', 
                           linestyle='--', label=f"Mean: {np.mean(convergence_epochs):.1f}")
        axes[1, 1].set_title('Training Epochs per Fold')
        axes[1, 1].set_ylabel('Epochs')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].tick_params(axis='x', rotation=45)
    
    # 6. Performance summary
    axes[1, 2].axis('off')
    summary_text = f"""
Cross-Validation Summary

Successful Folds: {len(successful_results)}/{config['k_folds']}
Dataset: {cv_results['dataset_info']['total_samples']} samples
Classes: {', '.join(cv_results['dataset_info']['class_names'])}

Mean Val Accuracy: {mean_val_acc:.4f} ± {std_val_acc:.4f}
Mean Train Accuracy: {mean_train_acc:.4f} ± {std_train_acc:.4f}

Best Fold: {best_fold['fold_idx'] + 1 if 'best_fold' in locals() else 'N/A'}
Best Accuracy: {best_fold['best_val_acc']:.4f if 'best_fold' in locals() else 'N/A'}

Configuration:
- Epochs: {config['num_epochs']}
- Frozen Epochs: {config['freeze_backbone_epochs']}
- Batch Size: {config['batch_size']}
- Learning Rate: {config['learning_rate']}
    """
    
    axes[1, 2].text(0.05, 0.95, summary_text, transform=axes[1, 2].transAxes, 
                    fontsize=10, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    plt.tight_layout()
    
    # Save the plot
    plot_path = Path('./checkpoints') / 'cv_comprehensive_results.png'
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f" Cross-validation plots saved to: {plot_path}")
    plt.show()

# Create and show the comprehensive plot
if successful_folds:
    plot_comprehensive_cv_results(cv_results, aggregate_metrics, config)

# Save comprehensive cross-validation results
final_cv_results = {
    'summary_statistics': {
        'successful_folds': len(successful_folds),
        'total_folds': config['k_folds'],
        'mean_val_acc': mean_val_acc,
        'std_val_acc': std_val_acc,
        'mean_train_acc': mean_train_acc,
        'std_train_acc': std_train_acc,
        'mean_final_val_acc': mean_final_val_acc,
        'std_final_val_acc': std_final_val_acc,
        'mean_val_loss': mean_val_loss,
        'std_val_loss': std_val_loss,
        'mean_convergence_epochs': mean_convergence,
        'std_convergence_epochs': std_convergence,
    },
    'per_fold_results': cv_results['fold_results'],
    'aggregate_metrics': {k: v for k, v in aggregate_metrics.items()},
    'dataset_info': {k: v.tolist() if isinstance(v, torch.Tensor) else v 
                     for k, v in dataset_info.items()},
    'config': config
}

# Add best fold info if available
if 'best_fold' in locals():
    final_cv_results['summary_statistics']['best_fold'] = best_fold['fold_idx'] + 1
    final_cv_results['summary_statistics']['best_fold_acc'] = best_fold['best_val_acc']
    final_cv_results['best_model_path'] = best_fold['best_model_path']

# Save final results
final_results_path = Path('./checkpoints') / 'cv_final_results.json'
final_results_path.parent.mkdir(parents=True, exist_ok=True)

with open(final_results_path, 'w') as f:
    json.dump(final_cv_results, f, indent=2)

print(f"\n Final cross-validation results saved to: {final_results_path}")

if 'best_fold' in locals():
    print(f"Best model from Fold {best_fold['fold_idx'] + 1}: {best_fold['best_model_path']}")
    print(f" Overall performance: {mean_val_acc:.4f} ± {std_val_acc:.4f}")
    
    # Instructions for using the best model
    print(f"\n To use the best model for inference:")
    print(f"```python")
    print(f"import torch")
    print(f"from your_model_module import create_simple_model")
    print(f"")
    print(f"# Load the best model")
    print(f"device = '{device}'")
    print(f"model = create_simple_model(num_markers=62, num_classes={dataset_info['num_classes']}, device=device)")
    print(f"checkpoint = torch.load('{best_fold['best_model_path']}', map_location=device)")
    print(f"model.load_state_dict(checkpoint['model_state_dict'])")
    print(f"model.eval()")
    print(f"```")
else:
    print(" No successful folds completed!")

print(f"\n K-Fold Cross-Validation Complete!")
print(f"Check the './checkpoints' directory for detailed results and model files.")

Using cuda
Loading data for k-fold cross-validation...
 Class distribution: {'Malignant': 55, 'Benign': 8}
 Using stratified 5-fold cross-validation
📊 K-Fold Dataset Summary:
   Classification task: condition
   Classes: ['Benign', 'Malignant']
   Total samples: 63
   Number of folds: 5
 Dataset Info:
   Classes: ['Benign', 'Malignant']
   Total samples: 63
   K-folds: 5

 Starting 5-Fold Cross-Validation...

 Fold 1/5
   Train samples: 50
   Val samples: 13
   Train distribution: {'Malignant': 44, 'Benign': 6}
   Val distribution: {'Benign': 2, 'Malignant': 11}

 FOLD 1/5
Train samples: 50
Val samples: 13
Train distribution: {'Malignant': 44, 'Benign': 6}
Val distribution: {'Benign': 2, 'Malignant': 11}
Using EfficientNet-B3 backbone (pretrained=True)
Simple Model Summary
Markers: 62, Backbone: EfficientNet-B3, pretrained: True
Feature_dim: 128, classes: 2, device: cuda
Dropout rate: 0.6
Total parameters: 11,601,050
Trainable parameters: 11,601,050
Using class weights: tensor([4.1667,

  with autocast():
  with autocast():
  with autocast():
  with autocast():
  with autocast():
Training:  25%|██▌       | 4/16 [00:14<00:42,  3.57s/it, Loss=0.4409, Acc=0.7500]


KeyboardInterrupt: 

In [None]:
# Quick fix - add this after create_kfold_dataloaders
print("🔧 Checking and fixing sample data types...")

# Get the full dataset to inspect it
full_dataset = IMCDataset(
    data_dir='./data',
    classification_task='condition'
)

# Debug the first few samples
for i, sample in enumerate(full_dataset.samples[:3]):
    print(f"Sample {i}:")
    print(f"  tissue: {type(sample['tissue'])} = {sample['tissue']}")
    print(f"  condition: {type(sample['condition'])} = {sample['condition']}")

# Fix any array issues
for sample in full_dataset.samples:
    if isinstance(sample['tissue'], (list, tuple, np.ndarray)):
        sample['tissue'] = str(sample['tissue'][0]) if len(sample['tissue']) > 0 else "unknown"
    if isinstance(sample['condition'], (list, tuple, np.ndarray)):
        sample['condition'] = str(sample['condition'][0]) if len(sample['condition']) > 0 else "unknown"
    
    # Ensure they're strings
    sample['tissue'] = str(sample['tissue'])
    sample['condition'] = str(sample['condition'])

print("✅ Fixed sample data types")

🔧 Checking and fixing sample data types...
Sample 0:
  tissue: <class 'str'> = LIVER
  condition: <class 'str'> = Malignant
Sample 1:
  tissue: <class 'str'> = PROSTATE
  condition: <class 'str'> = Malignant
Sample 2:
  tissue: <class 'str'> = PROSTATE
  condition: <class 'str'> = Malignant
✅ Fixed sample data types
