In [None]:
# Instalar Hugging Face Transformers
!pip install transformers datasets pillow -q

print("✓ Transformers instalado")

In [None]:
from transformers import CvtConfig, CvtModel

print("Verificando modelos CvT en Hugging Face...\n")

cvt_models = [
    'microsoft/cvt-13',    # 20M params - RECOMENDADO
    'microsoft/cvt-21',    # 32M params
    'microsoft/cvt-w24'    # 277M params - MUY GRANDE
]

for model_name in cvt_models:
    try:
        config = CvtConfig.from_pretrained(model_name)
        print(f"✓ {model_name}")
        print(f"  Depth: {config.depth}")
        print(f"  Embed dims: {config.embed_dim}")
        print()
    except Exception as e:
        print(f"✗ {model_name} - Error\n")

print("💡 Recomendado: microsoft/cvt-13")

In [None]:
# Core imports
import os
import sys
import json
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import gc
import time
from datetime import datetime

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau, LinearLR
from torch.cuda.amp import autocast, GradScaler

# Torchvision
from torchvision import transforms
from PIL import Image

# Hugging Face Transformers
from transformers import CvtModel, CvtConfig

# Sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    f1_score, precision_score, recall_score,
    confusion_matrix, classification_report,
    accuracy_score
)

print("✓ Imports loaded successfully (using Hugging Face Transformers)")
print(f"  PyTorch version: {torch.__version__}")

In [None]:
# GPU verification
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")

    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("✓ GPU cache cleared")
else:
    print("⚠️ WARNING: No GPU available, training will be slow!")

In [None]:
# Configuration para CvT con Hugging Face (512x512)
CONFIG = {
    # Paths
    'base_path': '/home/merivadeneira',
    'ddsm_benign_path': '/home/merivadeneira/Masas/DDSM/Benignas/Resized_512',
    'ddsm_malign_path': '/home/merivadeneira/Masas/DDSM/Malignas/Resized_512',
    'inbreast_benign_path': '/home/merivadeneira/Masas/INbreast/Benignas/Resized_512',
    'inbreast_malign_path': '/home/merivadeneira/Masas/INbreast/Malignas/Resized_512',
    'output_dir': '/home/merivadeneira/Outputs/CvT',
    'metrics_dir': '/home/merivadeneira/Metrics/CvT',

    # Model (HUGGING FACE)
    'model_name': 'CvT_TL_Base_HF',
    'pretrained_model': 'microsoft/cvt-13',
    'pretrained': True,
    'input_size': 512,
    'in_channels': 3,
    'num_classes': 2,

    # Fine-tuning strategy
    'freeze_strategy': 'none',

    # Training
    'batch_size': 16,
    'num_epochs': 100,
    'num_folds': 5,
    'early_stopping_patience': 25,
    'min_delta': 1e-4,

    # Optimizer
    'optimizer': 'AdamW',
    'lr_initial': None,
    'lr_min': 1e-7,
    'lr_max': 1e-3,
    'weight_decay': 0.01,
    'betas': (0.9, 0.999),

    # Warmup
    'warmup_epochs': 5,
    'warmup_start_factor': 0.1,

    # Scheduler
    'scheduler': 'ReduceLROnPlateau',
    'scheduler_factor': 0.5,
    'scheduler_patience': 10,
    'scheduler_min_lr': 1e-7,

    # Label smoothing & gradient clipping
    'label_smoothing': 0.1,
    'gradient_clip_norm': 1.0,

    # Data augmentation
    'horizontal_flip': 0.5,
    'vertical_flip': 0.3,
    'rotation_degrees': 20,
    'translate': 0.15,
    'scale': (0.85, 1.15),
    'shear': 10,
    'brightness': 0.2,
    'contrast': 0.2,
    'random_erasing_p': 0.2,

    # Normalization (ImageNet)
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],

    # Mixed Precision & Memory
    'use_amp': True,
    'num_workers': 4,
    'pin_memory': True,

    # Reproducibility
    'seed': 42
}

# Create directories
os.makedirs(CONFIG['output_dir'], exist_ok=True)
os.makedirs(CONFIG['metrics_dir'], exist_ok=True)

# Save configuration
config_path = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_config.json")
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=4)

print(f"✓ Configuration saved to: {config_path}")
print(f"\n📊 Model: {CONFIG['model_name']}")
print(f"🏗️ Base model: {CONFIG['pretrained_model']} (Hugging Face)")
print(f"📦 Batch size: {CONFIG['batch_size']}")
print(f"🔢 Input size: {CONFIG['input_size']}x{CONFIG['input_size']} (SIN RESIZE)")
print(f"🎯 Classes: {CONFIG['num_classes']} (Benign vs Malignant)")
print(f"🔄 Folds: {CONFIG['num_folds']}")
print(f"🔥 Warmup epochs: {CONFIG['warmup_epochs']}")
print(f"⚡ Mixed Precision: {CONFIG['use_amp']}")
print(f"🏷️ Label Smoothing: {CONFIG['label_smoothing']}")
print(f"✂️ Gradient Clipping: {CONFIG['gradient_clip_norm']}")

In [None]:
# Set seed for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])
print(f"✓ Random seed set to {CONFIG['seed']}")

In [None]:
def extract_patient_id(filename, dataset='ddsm'):
    """
    Extract patient ID from filename to prevent data leakage.

    DDSM format: P_00041_LEFT_CC_1.png -> P_00041
    INbreast format: 20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png -> 20586908
    """
    if dataset == 'ddsm':
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"{parts[0]}_{parts[1]}"
    elif dataset == 'inbreast':
        return filename.split('_')[0]

    return filename

# Test
print("Testing patient ID extraction:")
print(f"DDSM: P_00041_LEFT_CC_1.png -> {extract_patient_id('P_00041_LEFT_CC_1.png', 'ddsm')}")
print(f"INbreast: 20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png -> {extract_patient_id('20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png', 'inbreast')}")

In [None]:
class MammographyDataset(Dataset):
    """Custom dataset for mammography images (converts grayscale to RGB)."""

    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image as grayscale
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')

        # Convert to RGB by duplicating the channel
        image = image.convert('RGB')

        # Apply transforms
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

print("✓ Dataset class defined (with grayscale → RGB conversion)")

In [None]:
def get_transforms(train=True):
    """Get data augmentation transforms - mantiene 512x512 original."""

    if train:
        transform = transforms.Compose([
            # SIN RESIZE - mantener 512x512
            transforms.RandomHorizontalFlip(p=CONFIG['horizontal_flip']),
            transforms.RandomVerticalFlip(p=CONFIG['vertical_flip']),
            transforms.RandomRotation(degrees=CONFIG['rotation_degrees']),
            transforms.RandomAffine(
                degrees=0,
                translate=(CONFIG['translate'], CONFIG['translate']),
                scale=CONFIG['scale'],
                shear=CONFIG['shear']
            ),
            transforms.ColorJitter(
                brightness=CONFIG['brightness'],
                contrast=CONFIG['contrast']
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std']),
            transforms.RandomErasing(p=CONFIG['random_erasing_p'])
        ])
    else:
        transform = transforms.Compose([
            # SIN RESIZE - mantener 512x512
            transforms.ToTensor(),
            transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std'])
        ])

    return transform

print("✓ Transform functions defined (mantiene 512x512 original)")

In [None]:
def load_dataset():
    """
    Load all images from DDSM and INbreast datasets.
    Returns: image_paths, labels, patient_ids
    """
    image_paths = []
    labels = []
    patient_ids = []

    datasets = [
        (CONFIG['ddsm_benign_path'], 0, 'ddsm'),
        (CONFIG['ddsm_malign_path'], 1, 'ddsm'),
        (CONFIG['inbreast_benign_path'], 0, 'inbreast'),
        (CONFIG['inbreast_malign_path'], 1, 'inbreast')
    ]

    for path, label, dataset_name in datasets:
        if not os.path.exists(path):
            print(f"⚠️ WARNING: Path not found: {path}")
            continue

        files = [f for f in os.listdir(path) if f.endswith(('.png', '.jpg', '.jpeg'))]

        for filename in files:
            img_path = os.path.join(path, filename)
            patient_id = extract_patient_id(filename, dataset_name)

            image_paths.append(img_path)
            labels.append(label)
            patient_ids.append(patient_id)

    return image_paths, labels, patient_ids

# Load data
print("Loading dataset...")
image_paths, labels, patient_ids = load_dataset()

print(f"\n📊 Dataset loaded:")
print(f"  Total images: {len(image_paths)}")
print(f"  Benign: {labels.count(0)}")
print(f"  Malignant: {labels.count(1)}")
print(f"  Unique patients: {len(set(patient_ids))}")

# Check class balance
class_counts = pd.Series(labels).value_counts()
print(f"\n⚖️ Class distribution:")
for cls, count in class_counts.items():
    percentage = (count / len(labels)) * 100
    cls_name = "Benign" if cls == 0 else "Malignant"
    print(f"  {cls_name}: {count} ({percentage:.1f}%)")

In [None]:
class CvTForImageClassification(nn.Module):
    """
    CvT model for binary classification using Hugging Face.
    Adaptado para imágenes 512x512.
    """

    def __init__(self, model_name='microsoft/cvt-13', num_classes=2, pretrained=True, image_size=512):
        super().__init__()

        # Cargar config y modificar image_size
        config = CvtConfig.from_pretrained(model_name)
        config.image_size = image_size

        if pretrained:
            print(f"Loading pre-trained CvT from {model_name}...")
            print(f"⚠️ Adaptando de 224x224 a {image_size}x{image_size}")
            # ignore_mismatched_sizes permite usar pesos pre-entrenados con diferente tamaño
            self.cvt = CvtModel.from_pretrained(
                model_name,
                config=config,
                ignore_mismatched_sizes=True
            )
        else:
            self.cvt = CvtModel(config)

        # Get hidden size from last stage
        self.hidden_size = self.cvt.config.embed_dim[-1]

        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(0.1),
            nn.Linear(self.hidden_size, num_classes)
        )

        print(f"✓ Model created with {self.hidden_size} hidden size")
        print(f"✓ Input size: {image_size}x{image_size}")

    def forward(self, pixel_values):
        # Get CvT features
        outputs = self.cvt(pixel_values=pixel_values)

        # Use [CLS] token (first token)
        cls_token = outputs.last_hidden_state[:, 0]

        # Classification
        logits = self.classifier(cls_token)

        return logits

print("✓ CvTForImageClassification class defined (soporte para 512x512)")

In [None]:
def create_model(pretrained=True, num_classes=2, freeze_strategy='none'):
    """
    Create CvT model from Hugging Face with 512x512 support.
    """

    model = CvTForImageClassification(
        model_name=CONFIG['pretrained_model'],
        num_classes=num_classes,
        pretrained=pretrained,
        image_size=CONFIG['input_size']
    )

    # Apply freezing strategy
    if freeze_strategy == 'all_except_head':
        for name, param in model.named_parameters():
            if 'classifier' not in name:
                param.requires_grad = False
        print("  ❄️ Frozen all layers except head")

    elif freeze_strategy == 'progressive':
        for name, param in model.named_parameters():
            if 'classifier' not in name:
                param.requires_grad = False
        print("  ❄️ Started with frozen backbone (progressive unfreezing enabled)")

    elif freeze_strategy == 'none':
        for param in model.parameters():
            param.requires_grad = True
        print("  🔥 All layers unfrozen (full fine-tuning)")

    return model

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

print("✓ create_model function defined (512x512)")

In [None]:
class CvTForImageClassification(nn.Module):
    """
    CvT model for binary classification.
    """

    def __init__(self, model_name='microsoft/cvt-13', num_classes=2, pretrained=True, image_size=512):
        super().__init__()

        config = CvtConfig.from_pretrained(model_name)

        if pretrained:
            print(f"Loading pre-trained CvT from {model_name}...")
            self.cvt = CvtModel.from_pretrained(model_name)
        else:
            self.cvt = CvtModel(config)

        # Hidden size del último stage
        self.hidden_size = config.embed_dim[-1]  # 384 para cvt-13
        self.num_classes = num_classes

        # Clasificador simple
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.hidden_size, num_classes)

        print(f"✓ Model created with {self.hidden_size} hidden size")

    def forward(self, pixel_values):
        # Get features
        outputs = self.cvt(pixel_values=pixel_values, return_dict=True)

        # last_hidden_state: [batch, seq_len, hidden_size]
        features = outputs.last_hidden_state

        # Debug: ver el shape real
        # print(f"DEBUG - features shape: {features.shape}")

        # Reshape si es necesario
        batch_size = features.shape[0]

        # Flatten y luego pooling si tiene más dimensiones
        if len(features.shape) == 4:
            # [batch, channels, H, W] -> [batch, channels]
            pooled = features.mean(dim=[2, 3])
        elif len(features.shape) == 3:
            # [batch, seq_len, hidden_size] -> [batch, hidden_size]
            pooled = features.mean(dim=1)
        else:
            # Ya está en el formato correcto
            pooled = features

        # Asegurar que tenga el shape correcto [batch, hidden_size]
        if pooled.shape[-1] != self.hidden_size:
            # Necesitamos hacer más pooling
            pooled = pooled.reshape(batch_size, -1)
            # Proyectar al hidden_size correcto
            if not hasattr(self, 'projection'):
                self.projection = nn.Linear(pooled.shape[-1], self.hidden_size).to(pooled.device)
            pooled = self.projection(pooled)

        # Dropout
        pooled = self.dropout(pooled)

        # Classification: [batch, num_classes]
        logits = self.classifier(pooled)

        return logits

print("✓ CvTForImageClassification class defined")

In [None]:
class LRFinder:
    """Learning Rate Finder using the LR Range Test."""

    def __init__(self, model, optimizer, criterion, device):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device

        # Save initial state
        self.model_state = model.state_dict()
        self.optimizer_state = optimizer.state_dict()

    def range_test(self, train_loader, start_lr=1e-7, end_lr=1e-3, num_iter=100, smooth_f=0.05):
        """Perform LR range test."""

        # Reset model and optimizer
        self.model.load_state_dict(self.model_state)
        self.optimizer.load_state_dict(self.optimizer_state)

        # Calculate LR multiplier
        mult = (end_lr / start_lr) ** (1 / num_iter)
        lr = start_lr

        self.optimizer.param_groups[0]['lr'] = lr

        avg_loss = 0.
        best_loss = float('inf')
        batch_num = 0
        losses = []
        lrs = []

        self.model.train()

        iterator = iter(train_loader)

        for iteration in tqdm(range(num_iter), desc="LR Finder"):
            # Get batch
            try:
                inputs, targets = next(iterator)
            except StopIteration:
                iterator = iter(train_loader)
                inputs, targets = next(iterator)

            batch_num += 1

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # Forward
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)

            # Compute smoothed loss
            avg_loss = smooth_f * loss.item() + (1 - smooth_f) * avg_loss
            smoothed_loss = avg_loss / (1 - (1 - smooth_f) ** batch_num)

            # Stop if loss explodes
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                print(f"\n⚠️ Loss exploded at LR={lr:.2e}")
                break

            # Record best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss

            # Store values
            losses.append(smoothed_loss)
            lrs.append(lr)

            # Backward
            loss.backward()
            self.optimizer.step()

            # Update LR
            lr *= mult
            self.optimizer.param_groups[0]['lr'] = lr

        # Reset model and optimizer
        self.model.load_state_dict(self.model_state)
        self.optimizer.load_state_dict(self.optimizer_state)

        return lrs, losses

    def plot(self, lrs, losses, skip_start=10, skip_end=5):
        """Plot LR range test results."""

        if skip_start >= len(lrs):
            skip_start = 0
        if skip_end >= len(lrs):
            skip_end = 0

        lrs = lrs[skip_start:-skip_end] if skip_end > 0 else lrs[skip_start:]
        losses = losses[skip_start:-skip_end] if skip_end > 0 else losses[skip_start:]

        # Find minimum
        min_idx = np.argmin(losses)
        min_lr = lrs[min_idx]

        # Suggested LR (10x smaller than minimum for fine-tuning)
        suggested_lr = min_lr / 10

        plt.figure(figsize=(10, 6))
        plt.plot(lrs, losses)
        plt.xscale('log')
        plt.xlabel('Learning Rate')
        plt.ylabel('Loss')
        plt.title('Learning Rate Finder (Transfer Learning)')
        plt.axvline(x=min_lr, color='r', linestyle='--', label=f'Min Loss LR: {min_lr:.2e}')
        plt.axvline(x=suggested_lr, color='g', linestyle='--', label=f'Suggested LR: {suggested_lr:.2e}')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Save plot
        plot_path = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_lr_finder.png")
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        plt.show()

        print(f"\n📊 LR Finder Results (Transfer Learning):")
        print(f"  Minimum loss LR: {min_lr:.2e}")
        print(f"  Suggested LR: {suggested_lr:.2e}")
        print(f"  Plot saved to: {plot_path}")

        return suggested_lr

print("✓ LRFinder class defined")

In [None]:
# Run LR Finder
print("\n🔍 Running Learning Rate Finder for Transfer Learning...")
print("This may take a few minutes...\n")

# Create temporary model
temp_model = create_model(
    pretrained=CONFIG['pretrained'],
    num_classes=CONFIG['num_classes'],
    freeze_strategy=CONFIG['freeze_strategy']
).to(device)

# Create temporary dataset (use first 100 images for speed)
temp_indices = list(range(min(100, len(image_paths))))  # Reducido a 100
temp_paths = [image_paths[i] for i in temp_indices]
temp_labels = [labels[i] for i in temp_indices]

print(f"Temp dataset: {len(temp_paths)} images")
print(f"Label distribution: Benign={temp_labels.count(0)}, Malignant={temp_labels.count(1)}")

temp_dataset = MammographyDataset(
    temp_paths,
    temp_labels,
    transform=get_transforms(train=True)
)

temp_loader = DataLoader(
    temp_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,  # Reducido a 0 para debugging
    pin_memory=False
)

# Test one batch
print("\n🧪 Testing one batch...")
try:
    test_batch = next(iter(temp_loader))
    test_images, test_labels = test_batch
    print(f"  Batch images shape: {test_images.shape}")
    print(f"  Batch labels shape: {test_labels.shape}")
    print(f"  Labels dtype: {test_labels.dtype}")
    print(f"  Labels values: {test_labels}")

    # Test forward pass
    test_images = test_images.to(device)
    test_labels = test_labels.to(device)

    with torch.no_grad():
        test_output = temp_model(test_images)
        print(f"  Model output shape: {test_output.shape}")

        # Test loss
        temp_criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG['label_smoothing'])
        test_loss = temp_criterion(test_output, test_labels)
        print(f"  ✓ Loss computation works: {test_loss.item():.4f}")

    del test_images, test_labels, test_output, test_loss

except Exception as e:
    print(f"  ❌ Batch test failed: {e}")
    import traceback
    traceback.print_exc()

    # Clean up and exit
    del temp_model, temp_dataset, temp_loader
    torch.cuda.empty_cache()
    gc.collect()
    raise

# Setup for LR Finder
temp_optimizer = AdamW(
    temp_model.parameters(),
    lr=1e-7,
    weight_decay=CONFIG['weight_decay'],
    betas=CONFIG['betas']
)

temp_criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG['label_smoothing'])

print("\n🚀 Starting LR Finder...")

# Run LR Finder
try:
    lr_finder = LRFinder(temp_model, temp_optimizer, temp_criterion, device)
    lrs, losses = lr_finder.range_test(
        temp_loader,
        start_lr=CONFIG['lr_min'],
        end_lr=CONFIG['lr_max'],
        num_iter=50  # Reducido a 50 para que sea más rápido
    )

    # Plot and get suggested LR
    suggested_lr = lr_finder.plot(lrs, losses)

    # Update config
    CONFIG['lr_initial'] = suggested_lr

    print(f"\n✓ Learning rate set to: {CONFIG['lr_initial']:.2e}")

except Exception as e:
    print(f"\n❌ LR Finder failed: {e}")
    import traceback
    traceback.print_exc()

    # Set default LR
    CONFIG['lr_initial'] = 1e-4
    print(f"\n⚠️ Using default LR: {CONFIG['lr_initial']:.2e}")

# Clean up
del temp_model, temp_dataset, temp_loader, temp_optimizer, temp_criterion
if 'lr_finder' in locals():
    del lr_finder
torch.cuda.empty_cache()
gc.collect()

print(f"\n✓ LR Finder complete")

In [None]:
class EarlyStopping:
    """Early stopping to stop training when validation loss doesn't improve."""

    def __init__(self, patience=15, min_delta=1e-4, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, val_loss, epoch):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_epoch = epoch
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"  EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0

print("✓ EarlyStopping class defined")

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler, device, use_amp=True, clip_norm=None):
    """Train for one epoch with gradient clipping."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc="Training")
    for inputs, targets in pbar:
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        # Mixed precision training
        if use_amp:
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()

            # Gradient clipping
            if clip_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)

            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            # Gradient clipping
            if clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)

            optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_preds = []
    all_targets = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation")
        for inputs, targets in pbar:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc, all_preds, all_targets

def calculate_metrics(y_true, y_pred):
    """Calculate classification metrics."""

    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
    recall = recall_score(y_true, y_pred, average='binary', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='binary', zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    # Specificity
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        specificity = 0

    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'specificity': specificity,
        'confusion_matrix': cm
    }

    return metrics

print("✓ Training functions defined")

In [None]:
def train_with_cross_validation():
    """
    Train model with 10-fold cross validation and warmup scheduler.
    Split by patient to prevent data leakage.
    """

    # Group by patient
    patient_to_indices = defaultdict(list)
    patient_to_label = {}

    for idx, (patient_id, label) in enumerate(zip(patient_ids, labels)):
        patient_to_indices[patient_id].append(idx)
        patient_to_label[patient_id] = label

    unique_patients = list(patient_to_indices.keys())
    patient_labels = [patient_to_label[p] for p in unique_patients]

    print(f"\n📊 Cross Validation Setup:")
    print(f"  Total patients: {len(unique_patients)}")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Folds: {CONFIG['num_folds']}")

    # K-Fold split
    skf = StratifiedKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])

    # Store results
    fold_metrics = []
    fold_histories = []
    fold_cms = []

    # Train each fold
    for fold, (train_patient_idx, val_patient_idx) in enumerate(skf.split(unique_patients, patient_labels)):
        print(f"\n{'='*80}")
        print(f"FOLD {fold + 1}/{CONFIG['num_folds']}")
        print(f"{'='*80}")

        # Get patient IDs for this fold
        train_patients = [unique_patients[i] for i in train_patient_idx]
        val_patients = [unique_patients[i] for i in val_patient_idx]

        # Get image indices
        train_indices = []
        val_indices = []

        for patient in train_patients:
            train_indices.extend(patient_to_indices[patient])

        for patient in val_patients:
            val_indices.extend(patient_to_indices[patient])

        # Create datasets
        train_paths = [image_paths[i] for i in train_indices]
        train_labels = [labels[i] for i in train_indices]
        val_paths = [image_paths[i] for i in val_indices]
        val_labels = [labels[i] for i in val_indices]

        train_dataset = MammographyDataset(train_paths, train_labels, transform=get_transforms(train=True))
        val_dataset = MammographyDataset(val_paths, val_labels, transform=get_transforms(train=False))

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=True,
            num_workers=CONFIG['num_workers'],
            pin_memory=CONFIG['pin_memory']
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=CONFIG['num_workers'],
            pin_memory=CONFIG['pin_memory']
        )

        print(f"\n📊 Fold {fold + 1} Data:")
        print(f"  Train patients: {len(train_patients)}")
        print(f"  Val patients: {len(val_patients)}")
        print(f"  Train images: {len(train_paths)} (Benign: {train_labels.count(0)}, Malignant: {train_labels.count(1)})")
        print(f"  Val images: {len(val_paths)} (Benign: {val_labels.count(0)}, Malignant: {val_labels.count(1)})")

        # Create model
        model = create_model(
            pretrained=CONFIG['pretrained'],
            num_classes=CONFIG['num_classes'],
            freeze_strategy=CONFIG['freeze_strategy']
        ).to(device)

        # Loss with label smoothing
        criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG['label_smoothing'])

        # Optimizer
        optimizer = AdamW(
            model.parameters(),
            lr=CONFIG['lr_initial'],
            weight_decay=CONFIG['weight_decay'],
            betas=CONFIG['betas']
        )

        # Warmup + ReduceLROnPlateau scheduler
        warmup_scheduler = LinearLR(
            optimizer,
            start_factor=CONFIG['warmup_start_factor'],
            total_iters=CONFIG['warmup_epochs']
        )

        main_scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=CONFIG['scheduler_factor'],
            patience=CONFIG['scheduler_patience'],
            min_lr=CONFIG['scheduler_min_lr'],
            verbose=True
        )

        # Early stopping
        early_stopping = EarlyStopping(
            patience=CONFIG['early_stopping_patience'],
            min_delta=CONFIG['min_delta'],
            verbose=True
        )

        # Mixed precision scaler
        scaler = GradScaler() if CONFIG['use_amp'] else None

        # Training history
        history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'lr': []
        }

        best_val_loss = float('inf')
        best_epoch = 0

        # Training loop
        print(f"\n🚀 Starting training with warmup ({CONFIG['warmup_epochs']} epochs)...")
        start_time = time.time()

        for epoch in range(CONFIG['num_epochs']):
            print(f"\n--- Epoch {epoch + 1}/{CONFIG['num_epochs']} ---")

            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, scaler, device,
                CONFIG['use_amp'], CONFIG['gradient_clip_norm']
            )

            # Validate
            val_loss, val_acc, val_preds, val_targets = validate_epoch(
                model, val_loader, criterion, device
            )

            # Update scheduler
            if epoch < CONFIG['warmup_epochs']:
                warmup_scheduler.step()
            else:
                main_scheduler.step(val_loss)

            current_lr = optimizer.param_groups[0]['lr']

            # Save history
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['lr'].append(current_lr)

            # Print epoch summary
            phase = "Warmup" if epoch < CONFIG['warmup_epochs'] else "Main"
            print(f"\n📊 Epoch {epoch + 1} Summary ({phase}):")
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
            print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
            print(f"  LR: {current_lr:.2e}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_epoch = epoch + 1

                model_path = os.path.join(
                    CONFIG['output_dir'],
                    f"{CONFIG['model_name']}_fold{fold}.pth"
                )
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                    'config': CONFIG
                }, model_path)

                print(f"  ✓ Model saved: {model_path}")

            # Early stopping (only after warmup)
            if epoch >= CONFIG['warmup_epochs']:
                early_stopping(val_loss, epoch + 1)
                if early_stopping.early_stop:
                    print(f"\n⏹️ Early stopping triggered at epoch {epoch + 1}")
                    print(f"  Best epoch: {best_epoch} with val_loss: {best_val_loss:.4f}")
                    break

            # Clear cache
            torch.cuda.empty_cache()

        training_time = time.time() - start_time
        print(f"\n⏱️ Training time for fold {fold + 1}: {training_time/60:.2f} minutes")

        # Load best model for evaluation
        model_path = os.path.join(CONFIG['output_dir'], f"{CONFIG['model_name']}_fold{fold}.pth")
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])

        # Final evaluation
        print(f"\n📊 Final evaluation on validation set...")
        _, _, final_preds, final_targets = validate_epoch(model, val_loader, criterion, device)

        # Calculate metrics
        metrics = calculate_metrics(final_targets, final_preds)

        print(f"\n✅ Fold {fold + 1} Results:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  Precision: {metrics['precision']*100:.2f}%")
        print(f"  Recall: {metrics['recall']*100:.2f}%")
        print(f"  F1-Score: {metrics['f1']*100:.2f}%")
        print(f"  Specificity: {metrics['specificity']*100:.2f}%")

        # Store results
        fold_metrics.append(metrics)
        fold_histories.append(history)
        fold_cms.append(metrics['confusion_matrix'])

        # Plot training history
        plot_training_history(history, fold)

        # Plot confusion matrix
        plot_confusion_matrix(metrics['confusion_matrix'], fold)

        # Clean up
        del model, optimizer, train_loader, val_loader
        torch.cuda.empty_cache()
        gc.collect()

    return fold_metrics, fold_histories, fold_cms

print("✓ train_with_cross_validation function defined")

In [None]:
def plot_training_history(history, fold):
    """Plot training history with warmup indicator."""

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'Fold {fold + 1}: Loss (Transfer Learning)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Accuracy
    axes[1].plot([acc*100 for acc in history['train_acc']], label='Train Acc', marker='o')
    axes[1].plot([acc*100 for acc in history['val_acc']], label='Val Acc', marker='s')
    axes[1].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title(f'Fold {fold + 1}: Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # Learning Rate
    axes[2].plot(history['lr'], marker='o')
    axes[2].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title(f'Fold {fold + 1}: Learning Rate (with Warmup)')
    axes[2].set_yscale('log')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()

    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_fold{fold}_history.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"  History plot saved: {plot_path}")

def plot_confusion_matrix(cm, fold, normalize=False):
    """Plot confusion matrix."""

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt='.2f' if normalize else 'd',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Count'}
    )
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title(f'Fold {fold + 1}: Confusion Matrix')

    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_fold{fold}_cm.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"  Confusion matrix saved: {plot_path}")

def save_metrics_summary(fold_metrics):
    """Save metrics summary to CSV."""

    metrics_dict = {
        'Fold': [],
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1-Score': [],
        'Specificity': []
    }

    for fold, metrics in enumerate(fold_metrics):
        metrics_dict['Fold'].append(fold + 1)
        metrics_dict['Accuracy'].append(metrics['accuracy'])
        metrics_dict['Precision'].append(metrics['precision'])
        metrics_dict['Recall'].append(metrics['recall'])
        metrics_dict['F1-Score'].append(metrics['f1'])
        metrics_dict['Specificity'].append(metrics['specificity'])

    # Calculate mean and std
    for key in ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'Specificity']:
        values = metrics_dict[key]
        metrics_dict['Fold'].append('Mean ± Std')
        mean_val = np.mean(values)
        std_val = np.std(values)
        metrics_dict[key].append(f"{mean_val:.4f} ± {std_val:.4f}")
        break

    for key in ['Precision', 'Recall', 'F1-Score', 'Specificity']:
        values = [m for m in metrics_dict[key] if isinstance(m, float)]
        mean_val = np.mean(values)
        std_val = np.std(values)
        metrics_dict[key].append(f"{mean_val:.4f} ± {std_val:.4f}")

    df = pd.DataFrame(metrics_dict)

    csv_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_metrics.csv"
    )
    df.to_csv(csv_path, index=False)

    print(f"\n✅ Metrics saved to: {csv_path}")
    print(f"\n{df.to_string(index=False)}")

    return df

def plot_average_confusion_matrix(fold_cms):
    """Plot average confusion matrix across all folds."""

    avg_cm = np.mean(fold_cms, axis=0)

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        avg_cm,
        annot=True,
        fmt='.1f',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Average Count'}
    )
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Average Confusion Matrix (10-Fold CV - Transfer Learning)')

    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_avg_cm.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\n✅ Average confusion matrix saved: {plot_path}")

print("✓ Visualization functions defined")

In [None]:
# Start training
print("\n" + "="*80)
print("STARTING 10-FOLD CROSS VALIDATION TRAINING (TRANSFER LEARNING - HUGGING FACE)")
print("="*80)

start_time = time.time()

# Train
fold_metrics, fold_histories, fold_cms = train_with_cross_validation()

total_time = time.time() - start_time

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)
print(f"\n⏱️ Total training time: {total_time/3600:.2f} hours")