In [None]:
# Core imports
import os
import sys
import json
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import gc
import time
from datetime import datetime

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler  # Mixed precision

# Torchvision
from torchvision import transforms
from PIL import Image

# Sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    f1_score, precision_score, recall_score,
    confusion_matrix, classification_report,
    accuracy_score
)

print("✓ Imports loaded successfully")

In [None]:
# GPU verification
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")

    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("✓ GPU cache cleared")
else:
    print("⚠️ WARNING: No GPU available, training will be slow!")

In [None]:
# Configuration
CONFIG = {
    # Paths
    'base_path': '/home/merivadeneira',
    'ddsm_benign_path': '/home/merivadeneira/Masas/DDSM/Benignas/Resized_512',
    'ddsm_malign_path': '/home/merivadeneira/Masas/DDSM/Malignas/Resized_512',
    'inbreast_benign_path': '/home/merivadeneira/Masas/INbreast/Benignas/Resized_512',
    'inbreast_malign_path': '/home/merivadeneira/Masas/INbreast/Malignas/Resized_512',
    'output_dir': '/home/merivadeneira/Outputs/CvT',
    'metrics_dir': '/home/merivadeneira/Metrics/CvT',

    # Model
    'model_name': 'CvT_0_Base',
    'input_size': 512,
    'in_channels': 1,  # Grayscale
    'num_classes': 2,  # Binary: Benign vs Malignant

    # CvT Architecture (CvT-13)
    'stages': [
        {'embed_dim': 64, 'depth': 1, 'num_heads': 1, 'kernel_size': 7, 'stride': 4, 'padding': 2},
        {'embed_dim': 192, 'depth': 2, 'num_heads': 3, 'kernel_size': 3, 'stride': 2, 'padding': 1},
        {'embed_dim': 384, 'depth': 10, 'num_heads': 6, 'kernel_size': 3, 'stride': 2, 'padding': 1}
    ],
    'mlp_ratio': 4.0,
    'qkv_bias': True,
    'drop_rate': 0.0,
    'attn_drop_rate': 0.0,
    'drop_path_rate': 0.1,

    # Training
    'batch_size': 16,  # Will reduce to 8 if OOM
    'num_epochs': 100,
    'num_folds': 5,
    'early_stopping_patience': 25,
    'min_delta': 1e-4,

    # Optimizer
    'optimizer': 'AdamW',
    'lr_initial': None,  # Will be set by LR Finder
    'lr_min': 1e-7,
    'lr_max': 1e-2,
    'weight_decay': 0.01,
    'betas': (0.9, 0.999),

    # Scheduler
    'scheduler': 'ReduceLROnPlateau',
    'scheduler_factor': 0.5,
    'scheduler_patience': 10,
    'scheduler_min_lr': 1e-7,

    # Data augmentation
    'horizontal_flip': 0.5,
    'rotation_degrees': 15,
    'translate': 0.1,
    'scale': (0.9, 1.1),
    'shear': 10,
    'brightness': 0.1,
    'random_erasing_p': 0.15,

    # Normalization (grayscale)
    'mean': [0.5],
    'std': [0.5],

    # Mixed Precision & Memory
    'use_amp': True,  # Automatic Mixed Precision
    'gradient_checkpointing': True,
    'num_workers': 4,
    'pin_memory': True,

    # Reproducibility
    'seed': 42
}

# Create output directories
os.makedirs(CONFIG['output_dir'], exist_ok=True)
os.makedirs(CONFIG['metrics_dir'], exist_ok=True)

# Save configuration
config_path = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_config.json")
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=4)

print(f"✓ Configuration saved to: {config_path}")
print(f"\n Model: {CONFIG['model_name']}")
print(f" Batch size: {CONFIG['batch_size']}")
print(f" Input size: {CONFIG['input_size']}x{CONFIG['input_size']}")
print(f" Classes: {CONFIG['num_classes']} (Benign vs Malignant)")
print(f" Folds: {CONFIG['num_folds']}")
print(f" Mixed Precision: {CONFIG['use_amp']}")
print(f" Gradient Checkpointing: {CONFIG['gradient_checkpointing']}")

In [None]:
# Set seed for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])
print(f"✓ Random seed set to {CONFIG['seed']}")

In [None]:
def extract_patient_id(filename, dataset='ddsm'):
    """
    Extract patient ID from filename to prevent data leakage.

    DDSM format: P_00041_LEFT_CC_1.png -> P_00041
    INbreast format: 20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png -> 20586908
    """
    if dataset == 'ddsm':
        # Extract P_XXXXX
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"{parts[0]}_{parts[1]}"  # P_00041
    elif dataset == 'inbreast':
        # Extract first number (patient ID)
        return filename.split('_')[0]

    return filename  # Fallback

# Test
print("Testing patient ID extraction:")
print(f"DDSM: P_00041_LEFT_CC_1.png -> {extract_patient_id('P_00041_LEFT_CC_1.png', 'ddsm')}")
print(f"INbreast: 20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png -> {extract_patient_id('20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png', 'inbreast')}")

In [None]:
class MammographyDataset(Dataset):
    """Custom dataset for mammography images."""

    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Grayscale

        # Apply transforms
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

print("✓ Dataset class defined")

In [None]:
def get_transforms(train=True):
    """Get data augmentation transforms."""

    if train:
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=CONFIG['horizontal_flip']),
            transforms.RandomRotation(degrees=CONFIG['rotation_degrees']),
            transforms.RandomAffine(
                degrees=0,
                translate=(CONFIG['translate'], CONFIG['translate']),
                scale=CONFIG['scale'],
                shear=CONFIG['shear']
            ),
            transforms.ColorJitter(brightness=CONFIG['brightness']),
            transforms.ToTensor(),
            transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std']),
            transforms.RandomErasing(p=CONFIG['random_erasing_p'])
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std'])
        ])

    return transform

print("✓ Transform functions defined")

In [None]:
def load_dataset():
    """
    Load all images from DDSM and INbreast datasets.
    Returns: image_paths, labels, patient_ids
    """
    image_paths = []
    labels = []
    patient_ids = []

    datasets = [
        (CONFIG['ddsm_benign_path'], 0, 'ddsm'),
        (CONFIG['ddsm_malign_path'], 1, 'ddsm'),
        (CONFIG['inbreast_benign_path'], 0, 'inbreast'),
        (CONFIG['inbreast_malign_path'], 1, 'inbreast')
    ]

    for path, label, dataset_name in datasets:
        if not os.path.exists(path):
            print(f"⚠️ WARNING: Path not found: {path}")
            continue

        files = [f for f in os.listdir(path) if f.endswith(('.png', '.jpg', '.jpeg'))]

        for filename in files:
            img_path = os.path.join(path, filename)
            patient_id = extract_patient_id(filename, dataset_name)

            image_paths.append(img_path)
            labels.append(label)
            patient_ids.append(patient_id)

    return image_paths, labels, patient_ids

# Load data
print("Loading dataset...")
image_paths, labels, patient_ids = load_dataset()

print(f"\n📊 Dataset loaded:")
print(f"  Total images: {len(image_paths)}")
print(f"  Benign: {labels.count(0)}")
print(f"  Malignant: {labels.count(1)}")
print(f"  Unique patients: {len(set(patient_ids))}")

# Check class balance
class_counts = pd.Series(labels).value_counts()
print(f"\n⚖️ Class distribution:")
for cls, count in class_counts.items():
    percentage = (count / len(labels)) * 100
    cls_name = "Benign" if cls == 0 else "Malignant"
    print(f"  {cls_name}: {count} ({percentage:.1f}%)")

In [None]:
class ConvEmbedding(nn.Module):
    """Convolutional token embedding for CvT."""

    def __init__(self, in_channels, embed_dim, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, embed_dim, H', W']
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, H'*W', embed_dim]
        x = self.norm(x)
        return x, (H, W)

print("✓ ConvEmbedding defined")

In [None]:
class ConvPositionEncoding(nn.Module):
    """Convolutional position encoding (CPE) for CvT."""

    def __init__(self, dim, kernel_size=3):
        super().__init__()
        self.proj = nn.Conv2d(
            dim, dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=dim  # Depthwise
        )

    def forward(self, x, size):
        # x: [B, N, C]
        H, W = size
        B, N, C = x.shape

        # Reshape to image format
        cnn_feat = x.transpose(1, 2).view(B, C, H, W)
        x = self.proj(cnn_feat) + cnn_feat
        x = x.flatten(2).transpose(1, 2)

        return x

print("✓ ConvPositionEncoding defined")

In [None]:
class Attention(nn.Module):
    """Multi-head self-attention with convolutional projection."""

    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

print("✓ Attention defined")

In [None]:
class Mlp(nn.Module):
    """MLP module."""

    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

print("✓ Mlp defined")

In [None]:
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample."""

    def __init__(self, drop_prob=0.):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x

        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

print("✓ DropPath defined")

In [None]:
class CvTBlock(nn.Module):
    """CvT Transformer block."""

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
                 drop=0., attn_drop=0., drop_path=0.):
        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)

    def forward(self, x, size):
        # Attention with residual
        x = x + self.drop_path(self.attn(self.norm1(x)))

        # MLP with residual
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

print("✓ CvTBlock defined")

In [None]:
class CvTStage(nn.Module):
    """A stage in CvT architecture."""

    def __init__(self, in_channels, embed_dim, depth, num_heads, mlp_ratio,
                 qkv_bias, drop_rate, attn_drop_rate, drop_path_rate,
                 kernel_size=3, stride=2, padding=1):
        super().__init__()

        # Convolutional embedding
        self.patch_embed = ConvEmbedding(
            in_channels, embed_dim,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )

        # Convolutional position encoding
        self.pos_embed = ConvPositionEncoding(embed_dim, kernel_size=3)

        # Transformer blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            CvTBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i]
            )
            for i in range(depth)
        ])

    def forward(self, x):
        # Convolutional embedding
        x, size = self.patch_embed(x)

        # Add positional encoding
        x = self.pos_embed(x, size)

        # Apply transformer blocks
        for blk in self.blocks:
            x = blk(x, size)

        return x, size

print("✓ CvTStage defined")

In [None]:
class CvT(nn.Module):
    """Convolutional Vision Transformer (CvT) for image classification."""

    def __init__(self, in_channels=1, num_classes=2, stages=None,
                 mlp_ratio=4., qkv_bias=True, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0.1):
        super().__init__()

        if stages is None:
            # Default CvT-13 configuration
            stages = [
                {'embed_dim': 64, 'depth': 1, 'num_heads': 1, 'kernel_size': 7, 'stride': 4, 'padding': 2},
                {'embed_dim': 192, 'depth': 2, 'num_heads': 3, 'kernel_size': 3, 'stride': 2, 'padding': 1},
                {'embed_dim': 384, 'depth': 10, 'num_heads': 6, 'kernel_size': 3, 'stride': 2, 'padding': 1}
            ]

        self.num_stages = len(stages)

        # Build stages
        self.stages = nn.ModuleList()
        for i, stage_config in enumerate(stages):
            stage_in_channels = in_channels if i == 0 else stages[i-1]['embed_dim']

            stage = CvTStage(
                in_channels=stage_in_channels,
                embed_dim=stage_config['embed_dim'],
                depth=stage_config['depth'],
                num_heads=stage_config['num_heads'],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop_rate=drop_rate,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=drop_path_rate,
                kernel_size=stage_config['kernel_size'],
                stride=stage_config['stride'],
                padding=stage_config['padding']
            )
            self.stages.append(stage)

        # Classification head
        self.norm = nn.LayerNorm(stages[-1]['embed_dim'])
        self.head = nn.Linear(stages[-1]['embed_dim'], num_classes)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # x: [B, 1, 512, 512]

        # Pass through stages
        for i, stage in enumerate(self.stages):
            if i == 0:
                x, size = stage(x)
            else:
                # Reshape back to image format for next stage
                B, N, C = x.shape
                H, W = size
                x = x.transpose(1, 2).reshape(B, C, H, W)
                x, size = stage(x)

        # Global average pooling
        x = self.norm(x)
        x = x.mean(dim=1)  # [B, embed_dim]

        # Classification
        x = self.head(x)

        return x

print("✓ CvT model defined")

In [None]:
# Test model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Create model
model = CvT(
    in_channels=CONFIG['in_channels'],
    num_classes=CONFIG['num_classes'],
    stages=CONFIG['stages'],
    mlp_ratio=CONFIG['mlp_ratio'],
    qkv_bias=CONFIG['qkv_bias'],
    drop_rate=CONFIG['drop_rate'],
    attn_drop_rate=CONFIG['attn_drop_rate'],
    drop_path_rate=CONFIG['drop_path_rate']
).to(device)

# Test forward pass
with torch.no_grad():
    dummy_input = torch.randn(2, 1, 512, 512).to(device)
    output = model(dummy_input)
    print(f"\n✓ Model test passed")
    print(f"  Input shape: {dummy_input.shape}")
    print(f"  Output shape: {output.shape}")
    print(f"  Parameters: {count_parameters(model):,}")

# Clean up
del model, dummy_input, output
torch.cuda.empty_cache()
gc.collect()

In [None]:
class LRFinder:
    """Learning Rate Finder using the LR Range Test."""

    def __init__(self, model, optimizer, criterion, device):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device

        # Save initial state
        self.model_state = model.state_dict()
        self.optimizer_state = optimizer.state_dict()

    def range_test(self, train_loader, start_lr=1e-7, end_lr=1, num_iter=100, smooth_f=0.05):
        """Perform LR range test."""

        # Reset model and optimizer
        self.model.load_state_dict(self.model_state)
        self.optimizer.load_state_dict(self.optimizer_state)

        # Calculate LR multiplier
        mult = (end_lr / start_lr) ** (1 / num_iter)
        lr = start_lr

        self.optimizer.param_groups[0]['lr'] = lr

        avg_loss = 0.
        best_loss = float('inf')
        batch_num = 0
        losses = []
        lrs = []

        self.model.train()

        iterator = iter(train_loader)

        for iteration in tqdm(range(num_iter), desc="LR Finder"):
            # Get batch
            try:
                inputs, targets = next(iterator)
            except StopIteration:
                iterator = iter(train_loader)
                inputs, targets = next(iterator)

            batch_num += 1

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # Forward
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)

            # Compute smoothed loss
            avg_loss = smooth_f * loss.item() + (1 - smooth_f) * avg_loss
            smoothed_loss = avg_loss / (1 - (1 - smooth_f) ** batch_num)

            # Stop if loss explodes
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                print(f"\n⚠️ Loss exploded at LR={lr:.2e}")
                break

            # Record best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss

            # Store values
            losses.append(smoothed_loss)
            lrs.append(lr)

            # Backward
            loss.backward()
            self.optimizer.step()

            # Update LR
            lr *= mult
            self.optimizer.param_groups[0]['lr'] = lr

        # Reset model and optimizer
        self.model.load_state_dict(self.model_state)
        self.optimizer.load_state_dict(self.optimizer_state)

        return lrs, losses

    def plot(self, lrs, losses, skip_start=10, skip_end=5):
        """Plot LR range test results."""

        if skip_start >= len(lrs):
            skip_start = 0
        if skip_end >= len(lrs):
            skip_end = 0

        lrs = lrs[skip_start:-skip_end] if skip_end > 0 else lrs[skip_start:]
        losses = losses[skip_start:-skip_end] if skip_end > 0 else losses[skip_start:]

        # Find minimum
        min_idx = np.argmin(losses)
        min_lr = lrs[min_idx]

        # Suggested LR (10x smaller than minimum)
        suggested_lr = min_lr / 10

        plt.figure(figsize=(10, 6))
        plt.plot(lrs, losses)
        plt.xscale('log')
        plt.xlabel('Learning Rate')
        plt.ylabel('Loss')
        plt.title('Learning Rate Finder')
        plt.axvline(x=min_lr, color='r', linestyle='--', label=f'Min Loss LR: {min_lr:.2e}')
        plt.axvline(x=suggested_lr, color='g', linestyle='--', label=f'Suggested LR: {suggested_lr:.2e}')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Save plot
        plot_path = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_lr_finder.png")
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        plt.show()

        print(f"\n📊 LR Finder Results:")
        print(f"  Minimum loss LR: {min_lr:.2e}")
        print(f"  Suggested LR: {suggested_lr:.2e}")
        print(f"  Plot saved to: {plot_path}")

        return suggested_lr

print("✓ LRFinder class defined")

In [None]:
# Run LR Finder
print("\n🔍 Running Learning Rate Finder...")
print("This may take a few minutes...\n")

# Create temporary model and dataloader for LR finding
temp_model = CvT(
    in_channels=CONFIG['in_channels'],
    num_classes=CONFIG['num_classes'],
    stages=CONFIG['stages'],
    mlp_ratio=CONFIG['mlp_ratio'],
    qkv_bias=CONFIG['qkv_bias'],
    drop_rate=CONFIG['drop_rate'],
    attn_drop_rate=CONFIG['attn_drop_rate'],
    drop_path_rate=CONFIG['drop_path_rate']
).to(device)

# Create temporary dataset (use first 500 images for speed)
temp_indices = list(range(min(500, len(image_paths))))
temp_paths = [image_paths[i] for i in temp_indices]
temp_labels = [labels[i] for i in temp_indices]

temp_dataset = MammographyDataset(
    temp_paths,
    temp_labels,
    transform=get_transforms(train=True)
)

temp_loader = DataLoader(
    temp_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

# Setup for LR Finder
temp_optimizer = AdamW(
    temp_model.parameters(),
    lr=1e-7,
    weight_decay=CONFIG['weight_decay'],
    betas=CONFIG['betas']
)
temp_criterion = nn.CrossEntropyLoss()

# Run LR Finder
lr_finder = LRFinder(temp_model, temp_optimizer, temp_criterion, device)
lrs, losses = lr_finder.range_test(
    temp_loader,
    start_lr=CONFIG['lr_min'],
    end_lr=CONFIG['lr_max'],
    num_iter=100
)

# Plot and get suggested LR
suggested_lr = lr_finder.plot(lrs, losses)

# Update config
CONFIG['lr_initial'] = suggested_lr

# Clean up
del temp_model, temp_dataset, temp_loader, temp_optimizer, temp_criterion, lr_finder
torch.cuda.empty_cache()
gc.collect()

print(f"\n✓ Learning rate set to: {CONFIG['lr_initial']:.2e}")

In [None]:
class EarlyStopping:
    """Early stopping to stop training when validation loss doesn't improve."""

    def __init__(self, patience=25, min_delta=1e-4, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, val_loss, epoch):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_epoch = epoch
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"  EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0

print("✓ EarlyStopping class defined")

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler, device, use_amp=True):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc="Training")
    for inputs, targets in pbar:
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        # Mixed precision training
        if use_amp:
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc

print("✓ train_epoch function defined")

In [None]:
def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_preds = []
    all_targets = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation")
        for inputs, targets in pbar:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc, all_preds, all_targets

print("✓ validate_epoch function defined")

In [None]:
def calculate_metrics(y_true, y_pred):
    """Calculate classification metrics."""

    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
    recall = recall_score(y_true, y_pred, average='binary', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='binary', zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    # Specificity (TN / (TN + FP))
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        specificity = 0

    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'specificity': specificity,
        'confusion_matrix': cm
    }

    return metrics

print("✓ calculate_metrics function defined")

In [None]:
def train_with_cross_validation():
    """
    Train model with 10-fold cross validation.
    Split by patient to prevent data leakage.
    """

    # Group by patient
    patient_to_indices = defaultdict(list)
    patient_to_label = {}

    for idx, (patient_id, label) in enumerate(zip(patient_ids, labels)):
        patient_to_indices[patient_id].append(idx)
        patient_to_label[patient_id] = label

    unique_patients = list(patient_to_indices.keys())
    patient_labels = [patient_to_label[p] for p in unique_patients]

    print(f"\n📊 Cross Validation Setup:")
    print(f"  Total patients: {len(unique_patients)}")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Folds: {CONFIG['num_folds']}")

    # K-Fold split
    skf = StratifiedKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])

    # Store results
    fold_metrics = []
    fold_histories = []
    fold_cms = []

    # Train each fold
    for fold, (train_patient_idx, val_patient_idx) in enumerate(skf.split(unique_patients, patient_labels)):
        print(f"\n{'='*80}")
        print(f"FOLD {fold + 1}/{CONFIG['num_folds']}")
        print(f"{'='*80}")

        # Get patient IDs for this fold
        train_patients = [unique_patients[i] for i in train_patient_idx]
        val_patients = [unique_patients[i] for i in val_patient_idx]

        # Get image indices
        train_indices = []
        val_indices = []

        for patient in train_patients:
            train_indices.extend(patient_to_indices[patient])

        for patient in val_patients:
            val_indices.extend(patient_to_indices[patient])

        # Create datasets
        train_paths = [image_paths[i] for i in train_indices]
        train_labels = [labels[i] for i in train_indices]
        val_paths = [image_paths[i] for i in val_indices]
        val_labels = [labels[i] for i in val_indices]

        train_dataset = MammographyDataset(train_paths, train_labels, transform=get_transforms(train=True))
        val_dataset = MammographyDataset(val_paths, val_labels, transform=get_transforms(train=False))

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=True,
            num_workers=CONFIG['num_workers'],
            pin_memory=CONFIG['pin_memory']
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=CONFIG['num_workers'],
            pin_memory=CONFIG['pin_memory']
        )

        print(f"\n📊 Fold {fold + 1} Data:")
        print(f"  Train patients: {len(train_patients)}")
        print(f"  Val patients: {len(val_patients)}")
        print(f"  Train images: {len(train_paths)} (Benign: {train_labels.count(0)}, Malignant: {train_labels.count(1)})")
        print(f"  Val images: {len(val_paths)} (Benign: {val_labels.count(0)}, Malignant: {val_labels.count(1)})")

        # Create model
        model = CvT(
            in_channels=CONFIG['in_channels'],
            num_classes=CONFIG['num_classes'],
            stages=CONFIG['stages'],
            mlp_ratio=CONFIG['mlp_ratio'],
            qkv_bias=CONFIG['qkv_bias'],
            drop_rate=CONFIG['drop_rate'],
            attn_drop_rate=CONFIG['attn_drop_rate'],
            drop_path_rate=CONFIG['drop_path_rate']
        ).to(device)

        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(
            model.parameters(),
            lr=CONFIG['lr_initial'],
            weight_decay=CONFIG['weight_decay'],
            betas=CONFIG['betas']
        )

        # Scheduler
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=CONFIG['scheduler_factor'],
            patience=CONFIG['scheduler_patience'],
            min_lr=CONFIG['scheduler_min_lr'],
            verbose=True
        )

        # Early stopping
        early_stopping = EarlyStopping(
            patience=CONFIG['early_stopping_patience'],
            min_delta=CONFIG['min_delta'],
            verbose=True
        )

        # Mixed precision scaler
        scaler = GradScaler() if CONFIG['use_amp'] else None

        # Training history
        history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'lr': []
        }

        best_val_loss = float('inf')
        best_epoch = 0

        # Training loop
        print(f"\n🚀 Starting training...")
        start_time = time.time()

        for epoch in range(CONFIG['num_epochs']):
            print(f"\n--- Epoch {epoch + 1}/{CONFIG['num_epochs']} ---")

            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, scaler, device, CONFIG['use_amp']
            )

            # Validate
            val_loss, val_acc, val_preds, val_targets = validate_epoch(
                model, val_loader, criterion, device
            )

            # Update scheduler
            scheduler.step(val_loss)
            current_lr = optimizer.param_groups[0]['lr']

            # Save history
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['lr'].append(current_lr)

            # Print epoch summary
            print(f"\n📊 Epoch {epoch + 1} Summary:")
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
            print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
            print(f"  LR: {current_lr:.2e}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_epoch = epoch + 1

                model_path = os.path.join(
                    CONFIG['output_dir'],
                    f"{CONFIG['model_name']}_fold{fold}.pth"
                )
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                    'config': CONFIG
                }, model_path)

                print(f"  ✓ Model saved: {model_path}")

            # Early stopping
            early_stopping(val_loss, epoch + 1)
            if early_stopping.early_stop:
                print(f"\n⏹️ Early stopping triggered at epoch {epoch + 1}")
                print(f"  Best epoch: {best_epoch} with val_loss: {best_val_loss:.4f}")
                break

            # Clear cache
            torch.cuda.empty_cache()

        training_time = time.time() - start_time
        print(f"\n⏱️ Training time for fold {fold + 1}: {training_time/60:.2f} minutes")

        # Load best model for evaluation
        model_path = os.path.join(CONFIG['output_dir'], f"{CONFIG['model_name']}_fold{fold}.pth")
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])

        # Final evaluation
        print(f"\n📊 Final evaluation on validation set...")
        _, _, final_preds, final_targets = validate_epoch(model, val_loader, criterion, device)

        # Calculate metrics
        metrics = calculate_metrics(final_targets, final_preds)

        print(f"\n✅ Fold {fold + 1} Results:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  Precision: {metrics['precision']*100:.2f}%")
        print(f"  Recall: {metrics['recall']*100:.2f}%")
        print(f"  F1-Score: {metrics['f1']*100:.2f}%")
        print(f"  Specificity: {metrics['specificity']*100:.2f}%")

        # Store results
        fold_metrics.append(metrics)
        fold_histories.append(history)
        fold_cms.append(metrics['confusion_matrix'])

        # Plot training history
        plot_training_history(history, fold)

        # Plot confusion matrix
        plot_confusion_matrix(metrics['confusion_matrix'], fold)

        # Clean up
        del model, optimizer, scheduler, train_loader, val_loader
        torch.cuda.empty_cache()
        gc.collect()

    return fold_metrics, fold_histories, fold_cms

print("✓ train_with_cross_validation function defined")

In [None]:
def plot_training_history(history, fold):
    """Plot training history."""

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'Fold {fold + 1}: Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Accuracy
    axes[1].plot([acc*100 for acc in history['train_acc']], label='Train Acc', marker='o')
    axes[1].plot([acc*100 for acc in history['val_acc']], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title(f'Fold {fold + 1}: Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # Learning Rate
    axes[2].plot(history['lr'], marker='o')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title(f'Fold {fold + 1}: Learning Rate')
    axes[2].set_yscale('log')
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()

    # Save
    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_fold{fold}_history.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"  History plot saved: {plot_path}")

print("✓ plot_training_history function defined")

In [None]:
def plot_confusion_matrix(cm, fold, normalize=False):
    """Plot confusion matrix."""

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt='.2f' if normalize else 'd',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Count'}
    )
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title(f'Fold {fold + 1}: Confusion Matrix')

    # Save
    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_fold{fold}_cm.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"  Confusion matrix saved: {plot_path}")

print("✓ plot_confusion_matrix function defined")

In [None]:
def save_metrics_summary(fold_metrics):
    """Save metrics summary to CSV."""

    # Extract metrics
    metrics_dict = {
        'Fold': [],
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1-Score': [],
        'Specificity': []
    }

    for fold, metrics in enumerate(fold_metrics):
        metrics_dict['Fold'].append(fold + 1)
        metrics_dict['Accuracy'].append(metrics['accuracy'])
        metrics_dict['Precision'].append(metrics['precision'])
        metrics_dict['Recall'].append(metrics['recall'])
        metrics_dict['F1-Score'].append(metrics['f1'])
        metrics_dict['Specificity'].append(metrics['specificity'])

    # Calculate mean and std
    for key in ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'Specificity']:
        values = metrics_dict[key]
        metrics_dict['Fold'].append('Mean ± Std')
        mean_val = np.mean(values)
        std_val = np.std(values)
        metrics_dict[key].append(f"{mean_val:.4f} ± {std_val:.4f}")
        break

    # Fill rest of the summary row
    for key in ['Precision', 'Recall', 'F1-Score', 'Specificity']:
        values = [m for m in metrics_dict[key] if isinstance(m, float)]
        mean_val = np.mean(values)
        std_val = np.std(values)
        metrics_dict[key].append(f"{mean_val:.4f} ± {std_val:.4f}")

    # Create DataFrame
    df = pd.DataFrame(metrics_dict)

    # Save to CSV
    csv_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_metrics.csv"
    )
    df.to_csv(csv_path, index=False)

    print(f"\n✅ Metrics saved to: {csv_path}")
    print(f"\n{df.to_string(index=False)}")

    return df

print("✓ save_metrics_summary function defined")

In [None]:
def plot_average_confusion_matrix(fold_cms):
    """Plot average confusion matrix across all folds."""

    # Average confusion matrix
    avg_cm = np.mean(fold_cms, axis=0)

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        avg_cm,
        annot=True,
        fmt='.1f',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Average Count'}
    )
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Average Confusion Matrix (10-Fold CV)')

    # Save
    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_avg_cm.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\n✅ Average confusion matrix saved: {plot_path}")

print("✓ plot_average_confusion_matrix function defined")

In [None]:
# Start training
print("\n" + "="*80)
print("STARTING 10-FOLD CROSS VALIDATION TRAINING")
print("="*80)

start_time = time.time()

# Train
fold_metrics, fold_histories, fold_cms = train_with_cross_validation()

total_time = time.time() - start_time

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)
print(f"\n⏱️ Total training time: {total_time/3600:.2f} hours")

In [None]:
# Save metrics summary
print("\n📊 Saving final metrics summary...")
metrics_df = save_metrics_summary(fold_metrics)

# Plot average confusion matrix
print("\n📊 Plotting average confusion matrix...")
plot_average_confusion_matrix(fold_cms)

print("\n" + "="*80)
print("✅ ALL RESULTS SAVED")
print("="*80)
print(f"\n📁 Output directory: {CONFIG['output_dir']}")
print(f"📊 Metrics directory: {CONFIG['metrics_dir']}")
print(f"\n🎉 Training complete!")

In [None]:
# Final summary
print("\n" + "="*80)
print("FINAL SUMMARY")
print("="*80)

# Calculate overall statistics
accuracies = [m['accuracy'] for m in fold_metrics]
precisions = [m['precision'] for m in fold_metrics]
recalls = [m['recall'] for m in fold_metrics]
f1s = [m['f1'] for m in fold_metrics]
specificities = [m['specificity'] for m in fold_metrics]

print(f"\n Model: {CONFIG['model_name']}")
print(f" Architecture: CvT-13 (3-stage hierarchical)")
print(f" Parameters: ~20M")
print(f" Dataset: DDSM + INbreast")
print(f" Images: {len(image_paths)} total")
print(f" Patients: {len(set(patient_ids))} unique")
print(f" Folds: {CONFIG['num_folds']}")

print(f"\n Performance (Mean ± Std):")
print(f"  Accuracy:    {np.mean(accuracies)*100:.2f}% ± {np.std(accuracies)*100:.2f}%")
print(f"  Precision:   {np.mean(precisions)*100:.2f}% ± {np.std(precisions)*100:.2f}%")
print(f"  Recall:      {np.mean(recalls)*100:.2f}% ± {np.std(recalls)*100:.2f}%")
print(f"  F1-Score:    {np.mean(f1s)*100:.2f}% ± {np.std(f1s)*100:.2f}%")
print(f"  Specificity: {np.mean(specificities)*100:.2f}% ± {np.std(specificities)*100:.2f}%")

print(f"\n⏱ Training time: {total_time/3600:.2f} hours")
print(f" Models saved: {CONFIG['num_folds']} folds")
print(f" Metrics saved: {CONFIG['metrics_dir']}")

print("\n CvT_0_Base training completed successfully!")