# Real-Fake Detection Training - CSV-Based Implementation

## Overview
This notebook trains a deep learning model to detect real vs fake (inpainted) images using:
- **CSV-based data splits** (no random splitting)
- **ResNet50** architecture with ImageNet pretrained weights
- **Comprehensive error analysis** on fake images (domain, mask, quality metrics, generative model)
- **Real-time GPU monitoring** during training
- **10+ detailed visualizations** for model performance analysis

## Key Features
- Enhanced progress bar with GPU stats, learning rate, and ETA
- Best model tracking (saves best_model.pth)
- Exhaustive misclassified fake images analysis
- Quality metrics correlation (SSIM, LPIPS, MSE)
- Domain/Mask/Generative Model accuracy breakdowns

## Data
- **Fake images**: ADE20K, CelebAHQ, CityScapes, HumanParsing, OpenImages
- **Generative models**: StableDiffusion v5, StableDiffusion XL, Kandinsky 2.2, OpenJourney
- **Training**: 75K fake + 16K real images (balanced)
- **Validation**: 16K fake + 3K real images (balanced)
- **Test**: 16K fake + 3K real images (balanced)

In [1]:
# Cell 2: Imports
import os
import sys
import time
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as Fimport torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# Import utils from current directory
from utils import setup_device, create_training_folder, calculate_metrics, export_metrics_to_csv

# Try to import pynvml for GPU monitoring
try:
    import pynvml
    pynvml.nvmlInit()
    GPU_MONITORING = True
    print("GPU monitoring enabled (pynvml loaded successfully)")
except:
    GPU_MONITORING = False
    print("GPU monitoring not available (install nvidia-ml-py3 for GPU stats)")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

GPU monitoring not available (install nvidia-ml-py3 for GPU stats)
PyTorch version: 2.5.1+cu121
CUDA available: True


In [None]:
# Cell 3: Hyperparameters (All Customizable)# Training ParametersBATCH_SIZE = 8LEARNING_RATE = 0.001NUM_EPOCHS = 20WEIGHT_DECAY = 1e-4# Scheduler ParametersSCHEDULER_PATIENCE = 2SCHEDULER_FACTOR = 0.5# Data ParametersIMAGE_SIZE = 224USE_AUGMENTATION = TrueNUM_WORKERS = 0# DATA USAGE RATIO (0.0 to 1.0)# Control what percentage of CSV data to use for training/val/testTRAIN_DATA_RATIO = 1.0  # Use 100% of training data (1.0 = all, 0.1 = 10%, etc.)VAL_DATA_RATIO = 1.0    # Use 100% of validation dataTEST_DATA_RATIO = 1.0   # Use 100% of test data# CLASS BALANCING STRATEGY# Choose how to handle class imbalance (more fakes than reals):#   'undersampling': Drop excess fake samples to match real count (loses data)#   'loss_weighting': Use all data but weight loss by inverse class frequency (no data loss, RECOMMENDED)#   'none': Use all data without balancing (model may be biased)CLASS_BALANCE_METHOD = 'loss_weighting'  # Recommended: 'loss_weighting'# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")# RESUME TRAINING FROM CHECKPOINT# Set to True to continue training from a previous best_model.pth# Set to False to start fresh training from ImageNet pretrained weightsRESUME_FROM_CHECKPOINT = False# If RESUME_FROM_CHECKPOINT is True, specify which version to resume from# Leave as None to auto-detect the latest version, or specify version number (e.g., 1, 2, 3)RESUME_VERSION = None  # None = auto-detect latest, or specify: 1, 2, 3, etc.# PathsFAKE_TRAIN_CSV = 'dataset_splits/fake_only_split/fake_train.csv'FAKE_VAL_CSV = 'dataset_splits/fake_only_split/fake_val.csv'FAKE_TEST_CSV = 'dataset_splits/fake_only_split/fake_test.csv'REAL_TRAIN_CSV = 'dataset_splits/real_only_split/real_train.csv'REAL_VAL_CSV = 'dataset_splits/real_only_split/real_val.csv'REAL_TEST_CSV = 'dataset_splits/real_only_split/real_test.csv'# Output folderNOTEBOOK_NAME = 'real_fake_detection_csv'# Random seed for reproducibilityRANDOM_SEED = 42# Set seedstorch.manual_seed(RANDOM_SEED)np.random.seed(RANDOM_SEED)if torch.cuda.is_available():    torch.cuda.manual_seed(RANDOM_SEED)print("="*60)print("HYPERPARAMETERS")print("="*60)print(f"Batch Size: {BATCH_SIZE}")print(f"Learning Rate: {LEARNING_RATE}")print(f"Epochs: {NUM_EPOCHS}")print(f"Weight Decay: {WEIGHT_DECAY}")print(f"Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}")print(f"Augmentation: {USE_AUGMENTATION}")print(f"Scheduler Patience: {SCHEDULER_PATIENCE}")print(f"Scheduler Factor: {SCHEDULER_FACTOR}")print("\n" + "-"*60)print("DATA USAGE RATIOS")print("-"*60)print(f"Train Data Ratio: {TRAIN_DATA_RATIO*100:.1f}% of available data")print(f"Val Data Ratio: {VAL_DATA_RATIO*100:.1f}% of available data")print(f"Test Data Ratio: {TEST_DATA_RATIO*100:.1f}% of available data")print("\n" + "-"*60)print("CLASS BALANCING STRATEGY")print("-"*60)print(f"Method: {CLASS_BALANCE_METHOD}")if CLASS_BALANCE_METHOD == 'loss_weighting':# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")    print("  ‚úÖ Using loss weighting (no data loss, RECOMMENDED)")elif CLASS_BALANCE_METHOD == 'undersampling':# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")    print("  ‚ö†Ô∏è  Using undersampling (data loss)")else:    print("  ‚ö†Ô∏è  No balancing (model may be biased)")print("\n" + "-"*60)print("CHECKPOINT RESUME")print("-"*60)print(f"Resume from checkpoint: {RESUME_FROM_CHECKPOINT}")if RESUME_FROM_CHECKPOINT:    if RESUME_VERSION is None:        print("  üìÇ Will auto-detect latest version")    else:        print(f"  üìÇ Will resume from version: v{RESUME_VERSION}")else:    print("  üÜï Starting fresh training from ImageNet weights")print("="*60)

In [None]:
# Cell 4: GPU Setup and Folder Creation

# Setup device
device = setup_device()

# CRITICAL: Verify CUDA is actually being used
if not torch.cuda.is_available():
    raise RuntimeError("‚ùå CUDA is not available! GPU training is required.")

print(f"‚úÖ Using device: {device}")
print(f"‚úÖ CUDA is available: {torch.cuda.is_available()}")
print(f"‚úÖ Current CUDA device: {torch.cuda.current_device()}")
print(f"‚úÖ Device count: {torch.cuda.device_count()}")

# Create training folder structure with versioning
base_dir, models_dir, data_dir, viz_dir, version = create_training_folder(NOTEBOOK_NAME)

print(f"\n{'='*60}")
print(f"TRAINING SESSION: {NOTEBOOK_NAME} - Version {version}")
print(f"{'='*60}")
print(f"Base directory: {base_dir}")
print(f"Models directory: {models_dir}")
print(f"Data directory: {data_dir}")
print(f"Visualizations directory: {viz_dir}")
print(f"{'='*60}\n")

# Verify directories exist
print("DIRECTORY VERIFICATION:")
print(f"  Models dir exists: {os.path.exists(models_dir)} ‚úÖ" if os.path.exists(models_dir) else f"  Models dir exists: {os.path.exists(models_dir)} ‚ùå")
print(f"  Data dir exists: {os.path.exists(data_dir)} ‚úÖ" if os.path.exists(data_dir) else f"  Data dir exists: {os.path.exists(data_dir)} ‚ùå")
print(f"  Viz dir exists: {os.path.exists(viz_dir)} ‚úÖ" if os.path.exists(viz_dir) else f"  Viz dir exists: {os.path.exists(viz_dir)} ‚ùå")

# Display detailed GPU info
print("\nGPU INFORMATION:")
print(f"  GPU Name: {torch.cuda.get_device_name(0)}")
print(f"  GPU Memory Total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
print(f"  GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"  GPU Memory Cached: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

# GPU monitoring setup
if GPU_MONITORING:
    try:
        gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
        gpu_name = pynvml.nvmlDeviceGetName(gpu_handle)
        print(f"  pynvml GPU Device: {gpu_name}")
        
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
        print(f"  GPU Memory Used: {mem_info.used / 1024**3:.2f} GB / {mem_info.total / 1024**3:.2f} GB")
        print(f"  GPU Memory Free: {mem_info.free / 1024**3:.2f} GB")
    except Exception as e:
        print(f"  pynvml monitoring failed: {e}")
        GPU_MONITORING = False

print("\n" + "="*60)
print("üöÄ GPU VERIFICATION COMPLETE - TRAINING WILL USE GPU")
print("="*60)

In [None]:
# Cell 5: RealFakeCSVDataset Classclass RealFakeCSVDataset(Dataset):    """    Dataset that loads real and fake images from CSV files.    Returns: (image, label, metadata_dict)        Labels:        0 = Real        1 = Fake        Metadata:        - Fake: perturbed_img_id, mask_name, domain, ssim, lpips_score, mse,                model_name, dataset, area_ratio, sem_magnitude        - Real: real_img_id, parent_dataset    """        def __init__(self, fake_csv_path, real_csv_path, transform=None,                  balance_method='loss_weighting', data_ratio=1.0, seed=42):        """        Args:            fake_csv_path: Path to fake images CSV            real_csv_path: Path to real images CSV            transform: Image transformations            balance_method: 'undersampling', 'loss_weighting', or 'none'            data_ratio: Percentage of data to use (0.0 to 1.0)            seed: Random seed for reproducibility        """        self.transform = transform        self.balance_method = balance_method                # Load CSV files        print(f"Loading fake CSV: {fake_csv_path}")        fake_df = pd.read_csv(fake_csv_path)        print(f"  Loaded {len(fake_df):,} fake images from CSV")                print(f"Loading real CSV: {real_csv_path}")        real_df = pd.read_csv(real_csv_path)        print(f"  Loaded {len(real_df):,} real images from CSV")                # Apply data ratio if < 1.0        if data_ratio < 1.0:            print(f"\n‚öôÔ∏è  Applying data ratio: {data_ratio*100:.1f}%")            fake_sample_size = int(len(fake_df) * data_ratio)            real_sample_size = int(len(real_df) * data_ratio)                        fake_df = fake_df.sample(n=fake_sample_size, random_state=seed).reset_index(drop=True)            real_df = real_df.sample(n=real_sample_size, random_state=seed).reset_index(drop=True)                        print(f"  Fake images after ratio: {len(fake_df):,} ({data_ratio*100:.1f}%)")            print(f"  Real images after ratio: {len(real_df):,} ({data_ratio*100:.1f}%)")                # Store original counts for class weighting (BEFORE balancing)        self.num_real = len(real_df)        self.num_fake = len(fake_df)                # Apply balancing strategy        if balance_method == 'undersampling':            min_count = min(len(fake_df), len(real_df))            print(f"\n‚öñÔ∏è  Undersampling to {min_count:,} samples each...")                        fake_df = fake_df.sample(n=min_count, random_state=seed).reset_index(drop=True)            real_df = real_df.sample(n=min_count, random_state=seed).reset_index(drop=True)                        print(f"  Fake images after undersampling: {len(fake_df):,}")            print(f"  Real images after undersampling: {len(real_df):,}")            print(f"  ‚ö†Ô∏è  Data loss: {self.num_fake - len(fake_df):,} fake images dropped")                    elif balance_method == 'loss_weighting':            print(f"\n‚öñÔ∏è  Using loss weighting (keeping all data)")            print(f"  Real: {len(real_df):,} | Fake: {len(fake_df):,}")            print(f"  ‚úÖ No data loss - imbalance handled by loss weights")                    else:  # 'none'            print(f"\n‚ö†Ô∏è  No balancing applied")            print(f"  Real: {len(real_df):,} | Fake: {len(fake_df):,}")            print(f"  Imbalance ratio: {len(fake_df)/len(real_df):.2f}x more fakes")                # Prepare data list        self.data = []                # Add fake images (label = 1)        for idx, row in fake_df.iterrows():            metadata = {                'label_name': 'fake',                'perturbed_img_id': row.get('perturbed_img_id', ''),                'real_img_id': row.get('real_img_id', ''),                'mask_name': row.get('mask_name', ''),                'domain': row.get('domain', ''),                'ssim': row.get('ssim', 0.0),                'lpips_score': row.get('lpips_score', 0.0),                'mse': row.get('mse', 0.0),                'model_name': row.get('model_name', ''),                'dataset': row.get('dataset', ''),                'area_ratio': row.get('area_ratio', 0.0),                'sem_magnitude': row.get('sem_magnitude', 0.0)            }            self.data.append((row['fake_img_path'], 1, metadata))                # Add real images (label = 0)        for idx, row in real_df.iterrows():            metadata = {                'label_name': 'real',                'real_img_id': row.get('real_img_id', ''),                'parent_dataset': row.get('parent_dataset', '')            }            self.data.append((row['real_img_path'], 0, metadata))                print(f"\nüìä Final dataset size: {len(self.data):,} images")        print(f"  Real: {len(real_df):,} ({len(real_df)/len(self.data)*100:.1f}%)")        print(f"  Fake: {len(fake_df):,} ({len(fake_df)/len(self.data)*100:.1f}%)")        def __len__(self):        return len(self.data)        def __getitem__(self, idx):        img_path, label, metadata = self.data[idx]                # Load image        try:            image = Image.open(img_path).convert('RGB')        except Exception as e:            print(f"Error loading image {img_path}: {e}")            # Return a black image as fallback            image = Image.new('RGB', (224, 224), (0, 0, 0))                # Apply transforms        if self.transform:            image = self.transform(image)                return image, label, metadata        def get_class_weights(self):    """    Calculate class weights using direct ratio (most aggressive).    Returns: torch.FloatTensor: Weights for [real, fake] classes    """    imbalance_ratio = self.num_fake / self.num_real    print(f"\n‚öñÔ∏è  Class Imbalance: {imbalance_ratio:.2f}x more fakes")        # Calculate weights    total = self.num_real + self.num_fake    weight_real_v1 = total / (2.0 * self.num_real)  # Standard    weight_fake_v1 = total / (2.0 * self.num_fake)    weight_real_v3 = imbalance_ratio  # Direct ratio (USED)    weight_fake_v3 = 1.0        print(f"  Standard: Real={weight_real_v1:.4f}, Fake={weight_fake_v1:.4f}")    print(f"  Direct:   Real={weight_real_v3:.4f}, Fake={weight_fake_v3:.4f} ‚Üê USING THIS")    return torch.FloatTensor([weight_real_v3, weight_fake_v3])print("RealFakeCSVDataset class defined successfully!")

In [5]:
# Cell 6: Data Transforms

# Training transforms with augmentation
if USE_AUGMENTATION:
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    print("Training transforms: WITH augmentation")
else:
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    print("Training transforms: WITHOUT augmentation")

# Validation and test transforms (no augmentation)
eval_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

print("Validation/Test transforms: Basic resize + normalize")
print(f"Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"Normalization: ImageNet statistics")

Training transforms: WITH augmentation
Validation/Test transforms: Basic resize + normalize
Image size: 224x224
Normalization: ImageNet statistics


In [None]:
# Cell 7: Load Datasets and Create DataLoadersprint("="*60)print("LOADING DATASETS")print("="*60)# Training datasetprint("\nTRAINING SET:")train_dataset = RealFakeCSVDataset(    fake_csv_path=FAKE_TRAIN_CSV,    real_csv_path=REAL_TRAIN_CSV,    transform=train_transform,    balance_method=CLASS_BALANCE_METHOD,# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")    data_ratio=TRAIN_DATA_RATIO,    seed=RANDOM_SEED)# Validation datasetprint("\n" + "="*60)print("VALIDATION SET:")val_dataset = RealFakeCSVDataset(    fake_csv_path=FAKE_VAL_CSV,    real_csv_path=REAL_VAL_CSV,    transform=eval_transform,    balance_method=CLASS_BALANCE_METHOD,# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")    data_ratio=VAL_DATA_RATIO,    seed=RANDOM_SEED)# Test datasetprint("\n" + "="*60)print("TEST SET:")test_dataset = RealFakeCSVDataset(    fake_csv_path=FAKE_TEST_CSV,    real_csv_path=REAL_TEST_CSV,    transform=eval_transform,    balance_method=CLASS_BALANCE_METHOD,# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")    data_ratio=TEST_DATA_RATIO,    seed=RANDOM_SEED)# Calculate class weights if using loss weightingif CLASS_BALANCE_METHOD == 'loss_weighting':# Loss Function TypeLOSS_TYPE = 'crossentropy'  # Options: 'crossentropy', 'focal'FOCAL_ALPHA = 0.25  # Weight for focal loss (only used if LOSS_TYPE="focal")FOCAL_GAMMA = 2.0   # Focusing parameter (only used if LOSS_TYPE="focal")    class_weights = train_dataset.get_class_weights()    print("\n" + "="*60)    print("CLASS WEIGHTS FOR LOSS FUNCTION")    print("="*60)    print(f"Real (class 0) weight: {class_weights[0]:.4f}")    print(f"Fake (class 1) weight: {class_weights[1]:.4f}")    print(f"\nInterpretation:")    print(f"  - Real images get {class_weights[0]:.2f}x weight in loss")    print(f"  - Fake images get {class_weights[1]:.2f}x weight in loss")    print(f"  - Higher weight for minority class compensates for imbalance")    print("="*60)else:    class_weights = None# Custom collate function to handle metadata as list instead of dict batchingdef custom_collate_fn(batch):    """    Custom collate function that keeps metadata as a list of dicts    instead of trying to batch them into a single dict.        Args:        batch: List of (image, label, metadata_dict) tuples        Returns:        images: Batched tensor of images        labels: Batched tensor of labels          metadata: List of metadata dicts (NOT batched)    """    images = torch.stack([item[0] for item in batch])    labels = torch.tensor([item[1] for item in batch])    metadata = [item[2] for item in batch]  # Keep as list!        return images, labels, metadata# Create DataLoaders with custom collate functiontrain_loader = DataLoader(    train_dataset,    batch_size=BATCH_SIZE,    shuffle=True,    num_workers=NUM_WORKERS,    pin_memory=True if torch.cuda.is_available() else False,    collate_fn=custom_collate_fn)val_loader = DataLoader(    val_dataset,    batch_size=BATCH_SIZE,    shuffle=False,    num_workers=NUM_WORKERS,    pin_memory=True if torch.cuda.is_available() else False,    collate_fn=custom_collate_fn)test_loader = DataLoader(    test_dataset,    batch_size=BATCH_SIZE,    shuffle=False,    num_workers=NUM_WORKERS,    pin_memory=True if torch.cuda.is_available() else False,    collate_fn=custom_collate_fn)print("\n" + "="*60)print("DATALOADER SUMMARY")print("="*60)print(f"Train batches: {len(train_loader)}")print(f"Val batches: {len(val_loader)}")print(f"Test batches: {len(test_loader)}")print(f"Batch size: {BATCH_SIZE}")print(f"Num workers: {NUM_WORKERS}")print(f"Custom collate: Enabled (metadata as list)")print("="*60)

In [None]:
# Cell 8: Model Definition

class RealFakeModel(nn.Module):
    """
    ResNet50-based binary classifier for real vs fake image detection.
    
    Architecture:
        - Backbone: ResNet50 pretrained on ImageNet
        - Classifier: Single FC layer (2048 -> 2 classes)
    """
    
    def __init__(self, pretrained=True):
        super(RealFakeModel, self).__init__()
        
        # Load pretrained ResNet50
        self.resnet = models.resnet50(pretrained=pretrained)
        
        # Get number of features from the last layer
        num_features = self.resnet.fc.in_features
        
        # Replace final FC layer for binary classification
        self.resnet.fc = nn.Linear(num_features, 2)
    
    def forward(self, x):
        return self.resnet(x)

# Initialize model
model = RealFakeModel(pretrained=True).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("="*60)
print("MODEL ARCHITECTURE")
print("="*60)
print(f"Model: ResNet50 (pretrained on ImageNet)")
print(f"Output classes: 2 (Real, Fake)")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Device: {device}")
print("="*60)

# Variables to store resume info
resume_checkpoint_path = None
start_epoch = 0
resume_best_val_acc = 0.0

# Check if we should resume from checkpoint
if RESUME_FROM_CHECKPOINT:
    print("\n" + "="*60)
    print("LOADING CHECKPOINT")
    print("="*60)
    
    # Determine which version to load from
    base_results_dir = f"eƒüitim_sonu√ßlarƒ±/{NOTEBOOK_NAME}"
    
    if RESUME_VERSION is None:
        # Auto-detect latest version
        if os.path.exists(base_results_dir):
            versions = []
            for item in os.listdir(os.path.join(base_results_dir, "models")):
                if item.startswith("v") and os.path.isdir(os.path.join(base_results_dir, "models", item)):
                    try:
                        versions.append(int(item[1:]))
                    except:
                        pass
            
            if versions:
                latest_version = max(versions)
                resume_checkpoint_path = os.path.join(base_results_dir, f"models/v{latest_version}/best_model.pth")
                print(f"üîç Auto-detected latest version: v{latest_version}")
            else:
                print("‚ùå No previous versions found. Starting fresh training.")
                RESUME_FROM_CHECKPOINT = False
        else:
            print("‚ùå No previous training directory found. Starting fresh training.")
            RESUME_FROM_CHECKPOINT = False
    else:
        # Use specified version
        resume_checkpoint_path = os.path.join(base_results_dir, f"models/v{RESUME_VERSION}/best_model.pth")
        print(f"üìÇ Using specified version: v{RESUME_VERSION}")
    
    # Load checkpoint if path exists
    if RESUME_FROM_CHECKPOINT and resume_checkpoint_path and os.path.exists(resume_checkpoint_path):
        print(f"üì• Loading checkpoint from: {resume_checkpoint_path}")
        
        try:
            checkpoint = torch.load(resume_checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            
            start_epoch = checkpoint.get('epoch', 0) + 1
            resume_best_val_acc = checkpoint.get('val_acc', 0.0)
            
            print(f"‚úÖ Checkpoint loaded successfully!")
            print(f"   Previous epoch: {checkpoint.get('epoch', 0)}")
            print(f"   Previous val_acc: {resume_best_val_acc:.4f}")
            print(f"   Previous val_loss: {checkpoint.get('val_loss', 0.0):.4f}")
            print(f"   Will continue from epoch: {start_epoch + 1}")
            
        except Exception as e:
            print(f"‚ùå Error loading checkpoint: {e}")
            print("   Starting fresh training instead.")
            RESUME_FROM_CHECKPOINT = False
            start_epoch = 0
            resume_best_val_acc = 0.0
    
    elif RESUME_FROM_CHECKPOINT:
        print(f"‚ùå Checkpoint not found at: {resume_checkpoint_path}")
        print("   Starting fresh training instead.")
        RESUME_FROM_CHECKPOINT = False
        start_epoch = 0
        resume_best_val_acc = 0.0
    
    print("="*60)

if not RESUME_FROM_CHECKPOINT:
    print(f"\nüÜï Starting fresh training from ImageNet pretrained weights")

In [None]:
# Cell 8.5: Focal Loss Implementation (Optional - Better for Class Imbalance)class FocalLoss(nn.Module):    """    Focal Loss for addressing class imbalance.    Formula: FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)    where:        - p_t is the model's estimated probability for the correct class        - alpha_t is the class weight (like in weighted CE)        - gamma is the focusing parameter (default=2.0)            - gamma=0: Focal Loss = CrossEntropyLoss            - gamma>0: Reduces loss for well-classified examples    Key Advantage over Weighted CrossEntropy:        - Weighted CE treats all examples equally regardless of difficulty        - Focal Loss down-weights easy examples and focuses on hard negatives        - For severe imbalance, this prevents the model from being          overwhelmed by easy majority-class examples    Reference: https://arxiv.org/abs/1708.02002    """    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):        """        Args:            alpha: Class weights tensor [weight_class0, weight_class1, ...]                   If None, all classes weighted equally            gamma: Focusing parameter. Higher = more focus on hard examples            reduction: 'mean', 'sum', or 'none'        """        super(FocalLoss, self).__init__()        self.alpha = alpha        self.gamma = gamma        self.reduction = reduction    def forward(self, inputs, targets):        """        Args:            inputs: Model logits [batch_size, num_classes]            targets: Ground truth labels [batch_size]        """        # Get probabilities using softmax        p = torch.softmax(inputs, dim=1)        # Get cross entropy loss (without reduction)        ce_loss = F.cross_entropy(inputs, targets, reduction='none')        # Get probability of the correct class for each sample        p_t = p.gather(1, targets.view(-1, 1)).squeeze(1)        # Calculate focal term: (1 - p_t)^gamma        focal_term = (1 - p_t) ** self.gamma        # Calculate focal loss        focal_loss = focal_term * ce_loss        # Apply class weights if provided        if self.alpha is not None:            alpha_t = self.alpha.gather(0, targets)            focal_loss = alpha_t * focal_loss        # Apply reduction        if self.reduction == 'mean':            return focal_loss.mean()        elif self.reduction == 'sum':            return focal_loss.sum()        else:            return focal_lossprint("FocalLoss class defined successfully!")print("\nFocal Loss advantages for class imbalance:")print("  1. Down-weights easy examples (high confidence)")print("  2. Focuses training on hard examples (low confidence)")print("  3. Prevents majority class from dominating training")print("  4. Combines class weighting (alpha) with difficulty weighting (gamma)")

In [None]:
# Cell 9: Loss, Optimizer, Scheduler, and History# Loss function selection: CrossEntropy vs Focal Lossprint("="*60)print("LOSS FUNCTION CONFIGURATION")print("="*60)print(f"Loss type: {LOSS_TYPE}")print(f"Balance method: {CLASS_BALANCE_METHOD}")if LOSS_TYPE == 'focal':    # Use Focal Loss (better for severe class imbalance)    if CLASS_BALANCE_METHOD == 'loss_weighting' and class_weights is not None:        class_weights_gpu = class_weights.to(device)        criterion = FocalLoss(alpha=class_weights_gpu, gamma=FOCAL_GAMMA)        print(f"\n‚úÖ Using FOCAL LOSS with class weights")        print(f"  Alpha (class weights): Real={class_weights[0]:.4f}, Fake={class_weights[1]:.4f}")        print(f"  Gamma (focusing param): {FOCAL_GAMMA}")        print(f"\n  How Focal Loss works:")        print(f"    1. Easy examples (high confidence) ‚Üí low loss")        print(f"    2. Hard examples (low confidence) ‚Üí high loss")        print(f"    3. Prevents majority class easy examples from dominating")        print(f"    4. Ideal for severe class imbalance (4-5x or more)")    else:        criterion = FocalLoss(alpha=None, gamma=FOCAL_GAMMA)        print(f"\n‚úÖ Using FOCAL LOSS without class weights")        print(f"  Gamma: {FOCAL_GAMMA}")else:    # Use standard CrossEntropy Loss    if CLASS_BALANCE_METHOD == 'loss_weighting' and class_weights is not None:        class_weights_gpu = class_weights.to(device)        criterion = nn.CrossEntropyLoss(weight=class_weights_gpu)        print(f"\n‚úÖ Using CrossEntropyLoss WITH class weights")        print(f"  Real (class 0) weight: {class_weights[0]:.4f}")        print(f"  Fake (class 1) weight: {class_weights[1]:.4f}")        print(f"\n  How it works:")        print(f"    - Loss for each sample multiplied by its class weight")        print(f"    - Minority class (real) errors penalized more heavily")        print(f"    - Prevents model bias toward majority class")    else:        criterion = nn.CrossEntropyLoss()        print(f"\n‚úÖ Using CrossEntropyLoss (no class weights)")        print(f"  ‚ö†Ô∏è  No imbalance handling - may bias toward majority class")print("="*60)# Optimizeroptimizer = optim.Adam(    model.parameters(),    lr=LEARNING_RATE,    weight_decay=WEIGHT_DECAY)# Load optimizer state if resuming from checkpointif RESUME_FROM_CHECKPOINT and resume_checkpoint_path and os.path.exists(resume_checkpoint_path):    try:        checkpoint = torch.load(resume_checkpoint_path, map_location=device)        if 'optimizer_state_dict' in checkpoint:            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])            print(f"\n‚úÖ Optimizer state loaded from checkpoint")    except Exception as e:        print(f"\n‚ö†Ô∏è  Could not load optimizer state: {e}")        print("   Optimizer will start with fresh state")# Learning rate schedulerscheduler = optim.lr_scheduler.ReduceLROnPlateau(    optimizer,    mode='max',    patience=SCHEDULER_PATIENCE,    factor=SCHEDULER_FACTOR,    verbose=True)# Training historyhistory = {    'train_loss': [],    'train_acc': [],    'val_loss': [],    'val_acc': [],    'lr': []}print("\n" + "="*60)print("TRAINING CONFIGURATION")print("="*60)print(f"Optimizer: Adam")print(f"  Learning rate: {LEARNING_RATE}")print(f"  Weight decay: {WEIGHT_DECAY}")print(f"Scheduler: ReduceLROnPlateau")print(f"  Mode: max (validation accuracy)")print(f"  Patience: {SCHEDULER_PATIENCE} epochs")print(f"  Factor: {SCHEDULER_FACTOR}")print(f"Class balancing: {CLASS_BALANCE_METHOD}")if RESUME_FROM_CHECKPOINT:    print(f"Resume training: YES (starting from epoch {start_epoch + 1})")    print(f"Previous best val_acc: {resume_best_val_acc:.4f}")else:    print(f"Resume training: NO (fresh start)")print("="*60)

In [None]:
# Cell 10: GPU Monitoring Helper Functions

def get_gpu_stats():
    """
    Get GPU memory and utilization statistics.
    Returns: (memory_utilization%, gpu_utilization%)
    """
    if not GPU_MONITORING:
        return 0.0, 0.0
    
    try:
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
        utilization = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle)
        
        mem_util = (mem_info.used / mem_info.total) * 100
        gpu_util = utilization.gpu
        
        return mem_util, gpu_util
    except:
        return 0.0, 0.0

def format_time(seconds):
    """
    Format seconds to HH:MM:SS.
    """
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{secs:02d}"

print("GPU monitoring helper functions defined successfully!")

In [None]:
# Cell 11: Training Loop with Enhanced Progress Bar

print("="*60)
print("STARTING TRAINING")
print("="*60)

# CRITICAL: Verify model and data are on GPU
print(f"\nüîç GPU VERIFICATION BEFORE TRAINING:")
print(f"  Model device: {next(model.parameters()).device}")
print(f"  Expected device: {device}")
assert next(model.parameters()).device.type == 'cuda', "‚ùå Model is NOT on GPU!"
print(f"  ‚úÖ Model confirmed on GPU")

# Test a batch to verify GPU is being used
print(f"\nüîç Testing GPU with first batch...")
first_batch = next(iter(train_loader))
test_images, test_labels, _ = first_batch
test_images = test_images.to(device)
test_labels = test_labels.to(device)
print(f"  Batch images device: {test_images.device}")
print(f"  Batch labels device: {test_labels.device}")
assert test_images.device.type == 'cuda', "‚ùå Data is NOT on GPU!"
print(f"  ‚úÖ Data confirmed on GPU")
print(f"  ‚úÖ GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"\n{'='*60}")

# Best model tracking - use resume value if resuming
if RESUME_FROM_CHECKPOINT and resume_best_val_acc > 0:
    best_val_acc = resume_best_val_acc
    best_epoch = start_epoch
    print(f"\nüìÇ Resuming from checkpoint:")
    print(f"   Starting best_val_acc: {best_val_acc:.4f}")
    print(f"   Starting from epoch: {start_epoch + 1}")
else:
    best_val_acc = 0.0
    best_epoch = 0
    print(f"\nüÜï Starting fresh training")

# Training start time
training_start_time = time.time()

for epoch in range(start_epoch, NUM_EPOCHS):
    epoch_start_time = time.time()
    
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    history['lr'].append(current_lr)
    
    print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}] - LR: {current_lr:.6f}")
    print("-" * 60)
    
    # ==================== TRAINING PHASE ====================
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    # Create progress bar for training
    train_pbar = tqdm(train_loader, desc=f"Training", unit="batch")
    
    for batch_idx, (images, labels, metadata) in enumerate(train_pbar):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        train_loss += loss.item()
        
        # Calculate running metrics
        avg_loss = train_loss / (batch_idx + 1)
        current_acc = train_correct / train_total
        
        # Get GPU stats
        mem_util, gpu_util = get_gpu_stats()
        
        # Estimate time remaining
        batches_done = batch_idx + 1
        batches_left = len(train_loader) - batches_done
        batch_time = (time.time() - epoch_start_time) / batches_done
        eta_seconds = batch_time * batches_left
        
        # Update progress bar
        train_pbar.set_postfix({
            'Loss': f'{avg_loss:.4f}',
            'Acc': f'{current_acc:.3f}',
            'LR': f'{current_lr:.6f}',
            'GPU_Mem': f'{mem_util:.0f}%',
            'GPU_Use': f'{gpu_util:.0f}%',
            'ETA': format_time(eta_seconds)
        })
    
    # Calculate epoch training metrics
    epoch_train_loss = train_loss / len(train_loader)
    epoch_train_acc = train_correct / train_total
    
    history['train_loss'].append(epoch_train_loss)
    history['train_acc'].append(epoch_train_acc)
    
    # ==================== VALIDATION PHASE ====================
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    val_pbar = tqdm(val_loader, desc=f"Validation", unit="batch")
    
    with torch.no_grad():
        for batch_idx, (images, labels, metadata) in enumerate(val_pbar):
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            val_loss += loss.item()
            
            # Calculate running metrics
            avg_loss = val_loss / (batch_idx + 1)
            current_acc = val_correct / val_total
            
            # Update progress bar
            val_pbar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'Acc': f'{current_acc:.3f}'
            })
    
    # Calculate epoch validation metrics
    epoch_val_loss = val_loss / len(val_loader)
    epoch_val_acc = val_correct / val_total
    
    history['val_loss'].append(epoch_val_loss)
    history['val_acc'].append(epoch_val_acc)
    
    # Print epoch summary
    epoch_time = time.time() - epoch_start_time
    print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}] Summary:")
    print(f"  Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f}")
    print(f"  Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
    print(f"  Time: {format_time(epoch_time)}")
    print(f"  GPU Memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    
    # Save best model
    if epoch_val_acc > best_val_acc:
        best_val_acc = epoch_val_acc
        best_epoch = epoch + 1
        
        # Save best model checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': epoch_val_acc,
            'val_loss': epoch_val_loss
        }, os.path.join(models_dir, 'best_model.pth'))
        
        print(f"  ‚úÖ NEW BEST MODEL SAVED! Val Acc: {epoch_val_acc:.4f}")
    
    # Always save epoch checkpoint
    torch.save(
        model.state_dict(),
        os.path.join(models_dir, f'model_epoch_{epoch+1}.pth')
    )
    
    # Update learning rate scheduler
    scheduler.step(epoch_val_acc)
    
    print("-" * 60)

# Training complete
total_training_time = time.time() - training_start_time

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"Total training time: {format_time(total_training_time)}")
print(f"Best validation accuracy: {best_val_acc:.4f} (Epoch {best_epoch})")
print(f"Best model saved at: {os.path.join(models_dir, 'best_model.pth')}")
print(f"Final GPU Memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print("="*60)

In [None]:
# Cell 12: Save Training History

# Create history DataFrame using actual number of epochs trained
actual_epochs_trained = len(history['train_loss'])

# Calculate starting epoch number for DataFrame
if RESUME_FROM_CHECKPOINT:
    epoch_start_num = start_epoch + 1
else:
    epoch_start_num = 1

history_df = pd.DataFrame({
    'epoch': list(range(epoch_start_num, epoch_start_num + actual_epochs_trained)),
    'train_loss': history['train_loss'],
    'train_acc': history['train_acc'],
    'val_loss': history['val_loss'],
    'val_acc': history['val_acc'],
    'learning_rate': history['lr']
})

# Save to CSV
history_path = os.path.join(data_dir, 'training_history.csv')
history_df.to_csv(history_path, index=False)

print(f"Training history saved to: {history_path}")
print(f"\nHistory summary:")
print(f"Epochs trained in this session: {actual_epochs_trained}")
if RESUME_FROM_CHECKPOINT:
    print(f"Resumed from epoch: {start_epoch + 1}")
print(f"\nHistory data:")
print(history_df)

In [None]:
# Cell 13: Plot Training Curves

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Use actual epochs from history
epochs = list(range(epoch_start_num, epoch_start_num + actual_epochs_trained))

# Plot 1: Loss curves
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7, 
                label=f'Best Epoch ({best_epoch})')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Plot 2: Accuracy curves
axes[1].plot(epochs, history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
axes[1].plot(epochs, history['val_acc'], 'r-', label='Val Accuracy', linewidth=2)
axes[1].axvline(x=best_epoch, color='red', linestyle='--', linewidth=2, alpha=0.7, 
                label=f'Best Epoch ({best_epoch})')
axes[1].axhline(y=best_val_acc, color='green', linestyle=':', linewidth=1.5, alpha=0.5)
axes[1].text(best_epoch + 0.5, best_val_acc, f'Best: {best_val_acc:.4f}', 
             fontsize=10, color='green', fontweight='bold')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()

# Save plot
plot_path = os.path.join(viz_dir, 'training_curves.png')
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Training curves saved to: {plot_path}")
if RESUME_FROM_CHECKPOINT:
    print(f"Note: Plot shows epochs {epoch_start_num} to {epoch_start_num + actual_epochs_trained - 1} (resumed training)")

In [None]:
# Cell 13.5: Comprehensive Training Visualization

fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# 1. Loss Curves (Top Left)
ax1 = fig.add_subplot(gs[0, 0])
epochs_plot = list(range(epoch_start_num, epoch_start_num + actual_epochs_trained))
ax1.plot(epochs_plot, history['train_loss'], 'b-o', label='Train Loss', linewidth=2, markersize=6)
ax1.plot(epochs_plot, history['val_loss'], 'r-s', label='Val Loss', linewidth=2, markersize=6)
ax1.axvline(x=best_epoch, color='green', linestyle='--', linewidth=2, alpha=0.5, label=f'Best Epoch ({best_epoch})')
ax1.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax1.set_ylabel('Loss', fontsize=11, fontweight='bold')
ax1.set_title('Training & Validation Loss', fontsize=13, fontweight='bold')
ax1.legend(fontsize=9)
ax1.grid(True, alpha=0.3)

# 2. Accuracy Curves (Top Middle)
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_plot, history['train_acc'], 'b-o', label='Train Acc', linewidth=2, markersize=6)
ax2.plot(epochs_plot, history['val_acc'], 'r-s', label='Val Acc', linewidth=2, markersize=6)
ax2.axvline(x=best_epoch, color='green', linestyle='--', linewidth=2, alpha=0.5, label=f'Best Epoch ({best_epoch})')
ax2.axhline(y=best_val_acc, color='orange', linestyle=':', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax2.set_ylabel('Accuracy', fontsize=11, fontweight='bold')
ax2.set_title('Training & Validation Accuracy', fontsize=13, fontweight='bold')
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)

# 3. Learning Rate Schedule (Top Right)
ax3 = fig.add_subplot(gs[0, 2])
ax3.plot(epochs_plot, history['lr'], 'g-o', linewidth=2, markersize=6)
ax3.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax3.set_ylabel('Learning Rate', fontsize=11, fontweight='bold')
ax3.set_title('Learning Rate Schedule', fontsize=13, fontweight='bold')
ax3.set_yscale('log')
ax3.grid(True, alpha=0.3)

# 4. Loss Gap (Train - Val) (Middle Left)
ax4 = fig.add_subplot(gs[1, 0])
loss_gap = [t - v for t, v in zip(history['train_loss'], history['val_loss'])]
colors = ['green' if gap < 0 else 'red' for gap in loss_gap]
ax4.bar(epochs_plot, loss_gap, color=colors, alpha=0.6, edgecolor='black')
ax4.axhline(y=0, color='black', linestyle='-', linewidth=1)
ax4.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax4.set_ylabel('Train Loss - Val Loss', fontsize=11, fontweight='bold')
ax4.set_title('Overfitting Monitor (Loss Gap)', fontsize=13, fontweight='bold')
ax4.grid(True, alpha=0.3, axis='y')

# 5. Accuracy Gap (Val - Train) (Middle Middle)
ax5 = fig.add_subplot(gs[1, 1])
acc_gap = [v - t for t, v in zip(history['train_acc'], history['val_acc'])]
colors = ['green' if gap > 0 else 'red' for gap in acc_gap]
ax5.bar(epochs_plot, acc_gap, color=colors, alpha=0.6, edgecolor='black')
ax5.axhline(y=0, color='black', linestyle='-', linewidth=1)
ax5.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax5.set_ylabel('Val Acc - Train Acc', fontsize=11, fontweight='bold')
ax5.set_title('Generalization Gap', fontsize=13, fontweight='bold')
ax5.grid(True, alpha=0.3, axis='y')

# 6. Epoch Improvement (Middle Right)
ax6 = fig.add_subplot(gs[1, 2])
val_improvements = [0] + [history['val_acc'][i] - history['val_acc'][i-1] for i in range(1, len(history['val_acc']))]
colors = ['green' if imp > 0 else 'red' for imp in val_improvements]
ax6.bar(epochs_plot, val_improvements, color=colors, alpha=0.6, edgecolor='black')
ax6.axhline(y=0, color='black', linestyle='-', linewidth=1)
ax6.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax6.set_ylabel('Validation Acc Improvement', fontsize=11, fontweight='bold')
ax6.set_title('Per-Epoch Validation Improvement', fontsize=13, fontweight='bold')
ax6.grid(True, alpha=0.3, axis='y')

# 7. Training Progress Summary Table (Bottom Left)
ax7 = fig.add_subplot(gs[2, :2])
ax7.axis('off')

summary_data = []
for i, epoch in enumerate(epochs_plot):
    summary_data.append([
        epoch,
        f"{history['train_loss'][i]:.4f}",
        f"{history['train_acc'][i]:.4f}",
        f"{history['val_loss'][i]:.4f}",
        f"{history['val_acc'][i]:.4f}",
        f"{history['lr'][i]:.2e}",
        "‚úÖ BEST" if epoch == best_epoch else ""
    ])

table = ax7.table(cellText=summary_data,
                  colLabels=['Epoch', 'Train Loss', 'Train Acc', 'Val Loss', 'Val Acc', 'LR', 'Best'],
                  cellLoc='center',
                  loc='center',
                  bbox=[0, 0, 1, 1])
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 2)

# Style header
for i in range(7):
    table[(0, i)].set_facecolor('#4CAF50')
    table[(0, i)].set_text_props(weight='bold', color='white')

# Highlight best epoch row
for i, epoch in enumerate(epochs_plot):
    if epoch == best_epoch:
        for j in range(7):
            table[(i+1, j)].set_facecolor('#FFE082')

ax7.set_title('Detailed Training History', fontsize=13, fontweight='bold', pad=20)

# 8. Best Metrics Summary (Bottom Right)
ax8 = fig.add_subplot(gs[2, 2])
ax8.axis('off')

best_metrics = [
    ['Metric', 'Value'],
    ['', ''],
    ['Best Epoch', f'{best_epoch}'],
    ['Best Val Acc', f'{best_val_acc:.4f}'],
    ['Final Train Acc', f'{history["train_acc"][-1]:.4f}'],
    ['Final Val Acc', f'{history["val_acc"][-1]:.4f}'],
    ['Final Train Loss', f'{history["train_loss"][-1]:.4f}'],
    ['Final Val Loss', f'{history["val_loss"][-1]:.4f}'],
    ['Final LR', f'{history["lr"][-1]:.2e}'],
    ['Total Epochs', f'{actual_epochs_trained}']
]

table2 = ax8.table(cellText=best_metrics,
                   cellLoc='left',
                   loc='center',
                   bbox=[0, 0, 1, 1])
table2.auto_set_font_size(False)
table2.set_fontsize(10)
table2.scale(1, 2.5)

# Style header
table2[(0, 0)].set_facecolor('#2196F3')
table2[(0, 0)].set_text_props(weight='bold', color='white')
table2[(0, 1)].set_facecolor('#2196F3')
table2[(0, 1)].set_text_props(weight='bold', color='white')

# Bold metric names
for i in range(2, len(best_metrics)):
    table2[(i, 0)].set_text_props(weight='bold')

ax8.set_title('Best Model Metrics', fontsize=13, fontweight='bold', pad=20)

plt.suptitle(f'Comprehensive Training Analysis - Version {version}', 
             fontsize=16, fontweight='bold', y=0.995)

# Save
comprehensive_path = os.path.join(viz_dir, 'comprehensive_training_analysis.png')
plt.savefig(comprehensive_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Comprehensive training analysis saved to: {comprehensive_path}")

In [None]:
# Cell 14: Load Best Model and Run Test Inference

print("="*60)
print("TEST EVALUATION")
print("="*60)

# Load best model
best_model_path = os.path.join(models_dir, 'best_model.pth')
print(f"\nLoading best model from epoch {best_epoch}")
print(f"Path: {best_model_path}")
print(f"Best validation accuracy: {best_val_acc:.4f}")

checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("\nRunning test inference...")

# Storage for predictions and metadata
all_predictions = []
all_labels = []
all_confidences = []
all_metadata = []

# Run inference
with torch.no_grad():
    test_pbar = tqdm(test_loader, desc="Test Inference", unit="batch")
    
    for images, labels, metadata_batch in test_pbar:
        images = images.to(device)
        
        # Forward pass
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        confidences, predictions = torch.max(probs, dim=1)
        
        # Store results
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_confidences.extend(confidences.cpu().numpy())
        all_metadata.extend(metadata_batch)

# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_confidences = np.array(all_confidences)

# Calculate test accuracy
test_accuracy = accuracy_score(all_labels, all_predictions)

print(f"\n{'='*60}")
print(f"TEST RESULTS")
print(f"{'='*60}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Total test samples: {len(all_labels)}")
print(f"Correct predictions: {(all_predictions == all_labels).sum()}")
print(f"Incorrect predictions: {(all_predictions != all_labels).sum()}")
print(f"{'='*60}")

In [None]:
# Cell 14.5: Train/Val/Test Comprehensive Comparison

# First, we need to get train and val predictions for comparison
print("Running inference on Train and Val sets for comprehensive comparison...")

# Train set inference
model.eval()
train_preds = []
train_labels_list = []
train_confs = []

with torch.no_grad():
    for images, labels, _ in tqdm(train_loader, desc="Train Inference"):
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        confidences, predictions = torch.max(probs, dim=1)
        
        train_preds.extend(predictions.cpu().numpy())
        train_labels_list.extend(labels.numpy())
        train_confs.extend(confidences.cpu().numpy())

train_preds = np.array(train_preds)
train_labels_list = np.array(train_labels_list)
train_confs = np.array(train_confs)

# Val set inference  
val_preds = []
val_labels_list = []
val_confs = []

with torch.no_grad():
    for images, labels, _ in tqdm(val_loader, desc="Val Inference"):
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        confidences, predictions = torch.max(probs, dim=1)
        
        val_preds.extend(predictions.cpu().numpy())
        val_labels_list.extend(labels.numpy())
        val_confs.extend(confidences.cpu().numpy())

val_preds = np.array(val_preds)
val_labels_list = np.array(val_labels_list)
val_confs = np.array(val_confs)

# Calculate metrics
train_acc = accuracy_score(train_labels_list, train_preds)
val_acc = accuracy_score(val_labels_list, val_preds)

# Real/Fake breakdown
train_real_acc = (train_preds[train_labels_list == 0] == 0).mean()
train_fake_acc = (train_preds[train_labels_list == 1] == 1).mean()

val_real_acc = (val_preds[val_labels_list == 0] == 0).mean()
val_fake_acc = (val_preds[val_labels_list == 1] == 1).mean()

test_real_acc = (all_predictions[all_labels == 0] == 0).mean()
test_fake_acc = (all_predictions[all_labels == 1] == 1).mean()

# Create comprehensive comparison plot
fig = plt.figure(figsize=(20, 10))
gs = fig.add_gridspec(2, 4, hspace=0.3, wspace=0.3)

# 1. Overall Accuracy Comparison (Top Left)
ax1 = fig.add_subplot(gs[0, 0])
sets = ['Train', 'Val', 'Test']
accuracies = [train_acc, val_acc, test_accuracy]
colors = ['#2196F3', '#FF9800', '#4CAF50']
bars = ax1.bar(sets, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=2)

for i, (bar, acc) in enumerate(zip(bars, accuracies)):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{acc:.4f}',
             ha='center', va='bottom', fontsize=12, fontweight='bold')

ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
ax1.set_title('Overall Accuracy: Train vs Val vs Test', fontsize=13, fontweight='bold')
ax1.set_ylim([0, 1.1])
ax1.grid(True, alpha=0.3, axis='y')

# 2. Real vs Fake Accuracy by Set (Top Middle-Left)
ax2 = fig.add_subplot(gs[0, 1])
x = np.arange(3)
width = 0.35

real_accs = [train_real_acc, val_real_acc, test_real_acc]
fake_accs = [train_fake_acc, val_fake_acc, test_fake_acc]

bars1 = ax2.bar(x - width/2, real_accs, width, label='Real', color='#03A9F4', alpha=0.8, edgecolor='black')
bars2 = ax2.bar(x + width/2, fake_accs, width, label='Fake', color='#FF5722', alpha=0.8, edgecolor='black')

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                 f'{height:.3f}',
                 ha='center', va='bottom', fontsize=9, fontweight='bold')

ax2.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
ax2.set_title('Real vs Fake Detection Accuracy', fontsize=13, fontweight='bold')
ax2.set_xticks(x)
ax2.set_xticklabels(sets)
ax2.legend(fontsize=10)
ax2.set_ylim([0, 1.15])
ax2.grid(True, alpha=0.3, axis='y')

# 3. Sample Count Comparison (Top Middle-Right)
ax3 = fig.add_subplot(gs[0, 2])
sample_counts = [len(train_labels_list), len(val_labels_list), len(all_labels)]
bars = ax3.bar(sets, sample_counts, color=colors, alpha=0.7, edgecolor='black', linewidth=2)

for bar, count in zip(bars, sample_counts):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height + 100,
             f'{count:,}',
             ha='center', va='bottom', fontsize=11, fontweight='bold')

ax3.set_ylabel('Number of Samples', fontsize=12, fontweight='bold')
ax3.set_title('Dataset Size Comparison', fontsize=13, fontweight='bold')
ax3.grid(True, alpha=0.3, axis='y')

# 4. Confidence Distribution Comparison (Top Right)
ax4 = fig.add_subplot(gs[0, 3])
ax4.hist(train_confs, bins=30, alpha=0.5, label='Train', color='#2196F3', edgecolor='black')
ax4.hist(val_confs, bins=30, alpha=0.5, label='Val', color='#FF9800', edgecolor='black')
ax4.hist(all_confidences, bins=30, alpha=0.5, label='Test', color='#4CAF50', edgecolor='black')

ax4.axvline(train_confs.mean(), color='#2196F3', linestyle='--', linewidth=2, label=f'Train Œº={train_confs.mean():.3f}')
ax4.axvline(val_confs.mean(), color='#FF9800', linestyle='--', linewidth=2, label=f'Val Œº={val_confs.mean():.3f}')
ax4.axvline(all_confidences.mean(), color='#4CAF50', linestyle='--', linewidth=2, label=f'Test Œº={all_confidences.mean():.3f}')

ax4.set_xlabel('Confidence Score', fontsize=12, fontweight='bold')
ax4.set_ylabel('Frequency', fontsize=12, fontweight='bold')
ax4.set_title('Confidence Distribution Comparison', fontsize=13, fontweight='bold')
ax4.legend(fontsize=8)
ax4.grid(True, alpha=0.3)

# 5. Confusion Matrices Side by Side (Bottom)
# Train CM
ax5 = fig.add_subplot(gs[1, 0])
train_cm = confusion_matrix(train_labels_list, train_preds)
sns.heatmap(train_cm, annot=True, fmt='d', cmap='Blues', ax=ax5,
            xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'],
            cbar_kws={'label': 'Count'}, square=True)
ax5.set_title(f'Train Confusion Matrix\n(Acc: {train_acc:.4f})', fontsize=12, fontweight='bold')
ax5.set_xlabel('Predicted', fontsize=11, fontweight='bold')
ax5.set_ylabel('True', fontsize=11, fontweight='bold')

# Val CM
ax6 = fig.add_subplot(gs[1, 1])
val_cm = confusion_matrix(val_labels_list, val_preds)
sns.heatmap(val_cm, annot=True, fmt='d', cmap='Oranges', ax=ax6,
            xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'],
            cbar_kws={'label': 'Count'}, square=True)
ax6.set_title(f'Val Confusion Matrix\n(Acc: {val_acc:.4f})', fontsize=12, fontweight='bold')
ax6.set_xlabel('Predicted', fontsize=11, fontweight='bold')
ax6.set_ylabel('True', fontsize=11, fontweight='bold')

# Test CM
ax7 = fig.add_subplot(gs[1, 2])
test_cm = confusion_matrix(all_labels, all_predictions)
sns.heatmap(test_cm, annot=True, fmt='d', cmap='Greens', ax=ax7,
            xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'],
            cbar_kws={'label': 'Count'}, square=True)
ax7.set_title(f'Test Confusion Matrix\n(Acc: {test_accuracy:.4f})', fontsize=12, fontweight='bold')
ax7.set_xlabel('Predicted', fontsize=11, fontweight='bold')
ax7.set_ylabel('True', fontsize=11, fontweight='bold')

# 6. Summary Statistics Table (Bottom Right)
ax8 = fig.add_subplot(gs[1, 3])
ax8.axis('off')

summary_stats = [
    ['Metric', 'Train', 'Val', 'Test'],
    ['', '', '', ''],
    ['Overall Acc', f'{train_acc:.4f}', f'{val_acc:.4f}', f'{test_accuracy:.4f}'],
    ['Real Acc', f'{train_real_acc:.4f}', f'{val_real_acc:.4f}', f'{test_real_acc:.4f}'],
    ['Fake Acc', f'{train_fake_acc:.4f}', f'{val_fake_acc:.4f}', f'{test_fake_acc:.4f}'],
    ['Samples', f'{len(train_labels_list):,}', f'{len(val_labels_list):,}', f'{len(all_labels):,}'],
    ['Mean Conf', f'{train_confs.mean():.4f}', f'{val_confs.mean():.4f}', f'{all_confidences.mean():.4f}'],
    ['Std Conf', f'{train_confs.std():.4f}', f'{val_confs.std():.4f}', f'{all_confidences.std():.4f}']
]

table = ax8.table(cellText=summary_stats,
                  cellLoc='center',
                  loc='center',
                  bbox=[0, 0, 1, 1])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2.5)

# Style header row
for i in range(4):
    table[(0, i)].set_facecolor('#455A64')
    table[(0, i)].set_text_props(weight='bold', color='white')

# Style metric names
for i in range(2, len(summary_stats)):
    table[(i, 0)].set_facecolor('#ECEFF1')
    table[(i, 0)].set_text_props(weight='bold')

# Highlight best values in each row
for i in range(2, len(summary_stats)):
    values = [float(summary_stats[i][j].replace(',', '')) for j in range(1, 4)]
    best_idx = values.index(max(values))
    table[(i, best_idx + 1)].set_facecolor('#C8E6C9')
    table[(i, best_idx + 1)].set_text_props(weight='bold')

ax8.set_title('Performance Summary', fontsize=13, fontweight='bold', pad=20)

plt.suptitle(f'Train / Validation / Test Comprehensive Comparison - Version {version}', 
             fontsize=16, fontweight='bold', y=0.995)

# Save
comparison_path = os.path.join(viz_dir, 'train_val_test_comparison.png')
plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTrain/Val/Test comparison saved to: {comparison_path}")
print(f"\n{'='*60}")
print("COMPREHENSIVE PERFORMANCE SUMMARY")
print(f"{'='*60}")
print(f"Train Accuracy: {train_acc:.4f} (Real: {train_real_acc:.4f}, Fake: {train_fake_acc:.4f})")
print(f"Val Accuracy:   {val_acc:.4f} (Real: {val_real_acc:.4f}, Fake: {val_fake_acc:.4f})")
print(f"Test Accuracy:  {test_accuracy:.4f} (Real: {test_real_acc:.4f}, Fake: {test_fake_acc:.4f})")
print(f"{'='*60}")

In [None]:
# Cell 15: Build Comprehensive Results DataFrame

print("Building comprehensive results DataFrame...")

# Create base results
results_data = {
    'true_label': ['real' if label == 0 else 'fake' for label in all_labels],
    'predicted_label': ['real' if pred == 0 else 'fake' for pred in all_predictions],
    'confidence': all_confidences,
    'correct': (all_predictions == all_labels).astype(int)
}

# Add metadata fields
# We need to extract metadata from the list of dicts
for key in ['label_name', 'perturbed_img_id', 'real_img_id', 'mask_name', 'domain', 
            'ssim', 'lpips_score', 'mse', 'model_name', 'dataset', 'area_ratio', 
            'sem_magnitude', 'parent_dataset']:
    results_data[key] = [metadata.get(key, None) for metadata in all_metadata]

# Create DataFrame
results_df = pd.DataFrame(results_data)

print(f"Results DataFrame shape: {results_df.shape}")
print(f"\nFirst few rows:")
print(results_df.head())

print(f"\nClass distribution:")
print(results_df['true_label'].value_counts())

print(f"\nPrediction distribution:")
print(results_df['predicted_label'].value_counts())

print(f"\nAccuracy by true label:")
print(results_df.groupby('true_label')['correct'].mean())

In [None]:
# Cell 16: Export Misclassified Fake Images

print("="*60)
print("MISCLASSIFIED FAKE IMAGES ANALYSIS")
print("="*60)

# Filter for misclassified fake images (true=fake, predicted=real)
misclassified_fake = results_df[
    (results_df['true_label'] == 'fake') & 
    (results_df['predicted_label'] == 'real')
].copy()

print(f"\nTotal fake images: {(results_df['true_label'] == 'fake').sum()}")
print(f"Correctly classified fake: {((results_df['true_label'] == 'fake') & (results_df['correct'] == 1)).sum()}")
print(f"Misclassified fake (predicted as real): {len(misclassified_fake)}")
print(f"Fake detection accuracy: {((results_df['true_label'] == 'fake') & (results_df['correct'] == 1)).sum() / (results_df['true_label'] == 'fake').sum():.4f}")

# Select relevant columns for export
export_columns = ['perturbed_img_id', 'real_img_id', 'confidence', 'mask_name', 'domain', 
                  'ssim', 'lpips_score', 'mse', 'model_name', 'dataset', 
                  'area_ratio', 'sem_magnitude']

misclassified_export = misclassified_fake[export_columns].copy()

# Save to CSV
misclassified_path = os.path.join(data_dir, 'misclassified_fake_images.csv')
misclassified_export.to_csv(misclassified_path, index=False)

print(f"\nMisclassified fake images saved to: {misclassified_path}")
print(f"Total rows exported: {len(misclassified_export)}")

# Show sample of misclassified images
if len(misclassified_fake) > 0:
    print(f"\nSample of misclassified fake images:")
    print(misclassified_export.head(10))

In [None]:
# Cell 17.5: Comprehensive Error Analysis Visualization

fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.3)

# 1. Error Distribution by Class (Top Left)
ax1 = fig.add_subplot(gs[0, 0])
error_data = results_df.groupby('true_label')['correct'].agg(['count', 'sum'])
error_data['errors'] = error_data['count'] - error_data['sum']
error_data['error_rate'] = error_data['errors'] / error_data['count']

x = np.arange(len(error_data))
width = 0.35

bars1 = ax1.bar(x - width/2, error_data['sum'], width, label='Correct', color='#4CAF50', alpha=0.8, edgecolor='black')
bars2 = ax1.bar(x + width/2, error_data['errors'], width, label='Errors', color='#F44336', alpha=0.8, edgecolor='black')

# Add values on bars
for bar in bars1:
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height):,}',
             ha='center', va='bottom', fontsize=9, fontweight='bold')

for bar, error_rate in zip(bars2, error_data['error_rate']):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height):,}\n({error_rate*100:.1f}%)',
             ha='center', va='bottom', fontsize=9, fontweight='bold')

ax1.set_ylabel('Count', fontsize=12, fontweight='bold')
ax1.set_title('Error Distribution by Class', fontsize=13, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(error_data.index)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3, axis='y')

# 2. Confidence vs Correctness (Top Middle)
ax2 = fig.add_subplot(gs[0, 1])
correct_mask = results_df['correct'] == 1
incorrect_mask = results_df['correct'] == 0

ax2.scatter(results_df[correct_mask]['confidence'], [1]*correct_mask.sum(), 
           alpha=0.3, s=10, color='green', label=f'Correct ({correct_mask.sum():,})')
ax2.scatter(results_df[incorrect_mask]['confidence'], [0]*incorrect_mask.sum(), 
           alpha=0.3, s=10, color='red', label=f'Incorrect ({incorrect_mask.sum():,})')

ax2.axvline(results_df[correct_mask]['confidence'].mean(), color='green', 
           linestyle='--', linewidth=2, alpha=0.7, label=f'Correct Œº={results_df[correct_mask]["confidence"].mean():.3f}')
ax2.axvline(results_df[incorrect_mask]['confidence'].mean(), color='red', 
           linestyle='--', linewidth=2, alpha=0.7, label=f'Incorrect Œº={results_df[incorrect_mask]["confidence"].mean():.3f}')

ax2.set_xlabel('Prediction Confidence', fontsize=12, fontweight='bold')
ax2.set_ylabel('Correctness', fontsize=12, fontweight='bold')
ax2.set_yticks([0, 1])
ax2.set_yticklabels(['Incorrect', 'Correct'])
ax2.set_title('Confidence vs Correctness', fontsize=13, fontweight='bold')
ax2.legend(fontsize=8, loc='center left')
ax2.grid(True, alpha=0.3)

# 3. Error Rate by Confidence Bins (Top Right)
ax3 = fig.add_subplot(gs[0, 2])
conf_bins = pd.cut(results_df['confidence'], bins=10)
error_by_conf = results_df.groupby(conf_bins)['correct'].agg(['count', 'sum'])
error_by_conf['error_rate'] = 1 - (error_by_conf['sum'] / error_by_conf['count'])
error_by_conf = error_by_conf.dropna()

bin_centers = [interval.mid for interval in error_by_conf.index]
ax3.plot(bin_centers, error_by_conf['error_rate'] * 100, 'o-', linewidth=2, 
        markersize=8, color='#E91E63', markeredgecolor='black', markeredgewidth=1)
ax3.fill_between(bin_centers, 0, error_by_conf['error_rate'] * 100, alpha=0.3, color='#E91E63')

ax3.set_xlabel('Confidence Score', fontsize=12, fontweight='bold')
ax3.set_ylabel('Error Rate (%)', fontsize=12, fontweight='bold')
ax3.set_title('Error Rate by Confidence Level', fontsize=13, fontweight='bold')
ax3.grid(True, alpha=0.3)

# 4. Domain Error Analysis (Middle Left)
ax4 = fig.add_subplot(gs[1, 0])
domain_errors = fake_results.groupby('domain')['correct'].agg(['count', 'sum'])
domain_errors['errors'] = domain_errors['count'] - domain_errors['sum']
domain_errors['error_rate'] = domain_errors['errors'] / domain_errors['count']
domain_errors = domain_errors.sort_values('error_rate', ascending=False)

colors_domain = ['#F44336' if rate > 0.3 else '#FF9800' if rate > 0.15 else '#4CAF50' 
                for rate in domain_errors['error_rate']]

bars = ax4.barh(range(len(domain_errors)), domain_errors['error_rate'] * 100, 
               color=colors_domain, alpha=0.7, edgecolor='black')

for i, (idx, row) in enumerate(domain_errors.iterrows()):
    ax4.text(row['error_rate'] * 100 + 1, i, 
            f"{row['error_rate']*100:.1f}% ({int(row['errors'])}/{int(row['count'])})",
            va='center', fontsize=9, fontweight='bold')

ax4.set_yticks(range(len(domain_errors)))
ax4.set_yticklabels(domain_errors.index)
ax4.set_xlabel('Error Rate (%)', fontsize=12, fontweight='bold')
ax4.set_title('Error Rate by Domain (Fake Images)', fontsize=13, fontweight='bold')
ax4.grid(True, alpha=0.3, axis='x')

# 5. Generative Model Error Analysis (Middle Middle)
ax5 = fig.add_subplot(gs[1, 1])
model_errors = fake_results.groupby('dataset')['correct'].agg(['count', 'sum'])
model_errors['errors'] = model_errors['count'] - model_errors['sum']
model_errors['error_rate'] = model_errors['errors'] / model_errors['count']
model_errors = model_errors.sort_values('error_rate', ascending=False)

colors_model = ['#F44336' if rate > 0.3 else '#FF9800' if rate > 0.15 else '#4CAF50' 
               for rate in model_errors['error_rate']]

bars = ax5.barh(range(len(model_errors)), model_errors['error_rate'] * 100, 
               color=colors_model, alpha=0.7, edgecolor='black')

for i, (idx, row) in enumerate(model_errors.iterrows()):
    ax5.text(row['error_rate'] * 100 + 1, i, 
            f"{row['error_rate']*100:.1f}% ({int(row['errors'])}/{int(row['count'])})",
            va='center', fontsize=9, fontweight='bold')

ax5.set_yticks(range(len(model_errors)))
ax5.set_yticklabels(model_errors.index)
ax5.set_xlabel('Error Rate (%)', fontsize=12, fontweight='bold')
ax5.set_title('Error Rate by Generative Model', fontsize=13, fontweight='bold')
ax5.grid(True, alpha=0.3, axis='x')

# 6. Quality Metrics for Misclassified vs Correct (Middle Right)
ax6 = fig.add_subplot(gs[1, 2])

metrics_to_plot = ['ssim', 'lpips_score', 'mse']
metric_labels = ['SSIM', 'LPIPS', 'MSE']

misclass_fake = fake_results[fake_results['correct'] == 0]
correct_fake = fake_results[fake_results['correct'] == 1]

x_pos = np.arange(len(metrics_to_plot))
width = 0.35

# Normalize metrics to 0-1 range for comparison
def normalize(series):
    return (series - series.min()) / (series.max() - series.min())

misclass_means = [normalize(misclass_fake[m].dropna()).mean() for m in metrics_to_plot]
correct_means = [normalize(correct_fake[m].dropna()).mean() for m in metrics_to_plot]

bars1 = ax6.bar(x_pos - width/2, misclass_means, width, label='Misclassified', 
               color='#F44336', alpha=0.7, edgecolor='black')
bars2 = ax6.bar(x_pos + width/2, correct_means, width, label='Correct', 
               color='#4CAF50', alpha=0.7, edgecolor='black')

# Add values
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax6.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')

ax6.set_ylabel('Normalized Value', fontsize=12, fontweight='bold')
ax6.set_title('Quality Metrics: Misclassified vs Correct', fontsize=13, fontweight='bold')
ax6.set_xticks(x_pos)
ax6.set_xticklabels(metric_labels)
ax6.legend(fontsize=10)
ax6.grid(True, alpha=0.3, axis='y')

# 7. Top 10 Hardest Masks (Bottom Left)
ax7 = fig.add_subplot(gs[2, :2])
mask_errors = fake_results.groupby('mask_name')['correct'].agg(['count', 'sum'])
mask_errors = mask_errors[mask_errors['count'] >= 10]  # At least 10 samples
mask_errors['errors'] = mask_errors['count'] - mask_errors['sum']
mask_errors['error_rate'] = mask_errors['errors'] / mask_errors['count']
hardest_masks = mask_errors.nlargest(10, 'error_rate')

colors_mask = ['#F44336' if rate > 0.5 else '#FF9800' if rate > 0.3 else '#FFC107' 
              for rate in hardest_masks['error_rate']]

bars = ax7.barh(range(len(hardest_masks)), hardest_masks['error_rate'] * 100,
               color=colors_mask, alpha=0.7, edgecolor='black')

for i, (idx, row) in enumerate(hardest_masks.iterrows()):
    ax7.text(row['error_rate'] * 100 + 1, i, 
            f"{row['error_rate']*100:.1f}% ({int(row['errors'])}/{int(row['count'])})",
            va='center', fontsize=9, fontweight='bold')

ax7.set_yticks(range(len(hardest_masks)))
ax7.set_yticklabels([idx[:30] for idx in hardest_masks.index], fontsize=9)
ax7.set_xlabel('Error Rate (%)', fontsize=12, fontweight='bold')
ax7.set_title('Top 10 Hardest Masks (Min 10 samples)', fontsize=13, fontweight='bold')
ax7.grid(True, alpha=0.3, axis='x')

# 8. Error Summary Statistics (Bottom Right)
ax8 = fig.add_subplot(gs[2, 2])
ax8.axis('off')

total_errors = (results_df['correct'] == 0).sum()
real_errors = ((results_df['true_label'] == 'real') & (results_df['correct'] == 0)).sum()
fake_errors = ((results_df['true_label'] == 'fake') & (results_df['correct'] == 0)).sum()

error_stats = [
    ['Error Type', 'Count', '%'],
    ['', '', ''],
    ['Total Errors', f'{total_errors:,}', f'{total_errors/len(results_df)*100:.2f}%'],
    ['Real ‚Üí Fake', f'{real_errors:,}', f'{real_errors/len(results_df)*100:.2f}%'],
    ['Fake ‚Üí Real', f'{fake_errors:,}', f'{fake_errors/len(results_df)*100:.2f}%'],
    ['', '', ''],
    ['Hardest Domain', hardest_masks.index[0][:15], f'{hardest_masks.iloc[0]["error_rate"]*100:.1f}%'],
    ['Hardest Model', model_errors.index[0][:15], f'{model_errors.iloc[0]["error_rate"]*100:.1f}%'],
    ['', '', ''],
    ['Avg Confidence', '', ''],
    ['  Correct', f'{results_df[correct_mask]["confidence"].mean():.4f}', ''],
    ['  Incorrect', f'{results_df[incorrect_mask]["confidence"].mean():.4f}', ''],
]

table = ax8.table(cellText=error_stats,
                  cellLoc='left',
                  loc='center',
                  bbox=[0, 0, 1, 1])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)

# Style header
for i in range(3):
    table[(0, i)].set_facecolor('#D32F2F')
    table[(0, i)].set_text_props(weight='bold', color='white')

# Highlight important rows
for i in [2, 3, 4]:
    table[(i, 0)].set_facecolor('#FFEBEE')
    table[(i, 0)].set_text_props(weight='bold')

ax8.set_title('Error Summary', fontsize=13, fontweight='bold', pad=20)

plt.suptitle(f'Comprehensive Error Analysis - Version {version}', 
             fontsize=16, fontweight='bold', y=0.995)

# Save
error_analysis_path = os.path.join(viz_dir, 'comprehensive_error_analysis.png')
plt.savefig(error_analysis_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Comprehensive error analysis saved to: {error_analysis_path}")
print(f"\n{'='*60}")
print("ERROR ANALYSIS SUMMARY")
print(f"{'='*60}")
print(f"Total Errors: {total_errors:,} ({total_errors/len(results_df)*100:.2f}%)")
print(f"  Real misclassified as Fake: {real_errors:,}")
print(f"  Fake misclassified as Real: {fake_errors:,}")
print(f"\nHardest Domain: {domain_errors.index[0]} (Error Rate: {domain_errors.iloc[0]['error_rate']*100:.2f}%)")
print(f"Hardest Model: {model_errors.index[0]} (Error Rate: {model_errors.iloc[0]['error_rate']*100:.2f}%)")
print(f"{'='*60}")

In [None]:
# Cell 17: Domain/Mask/Model Accuracy Analysis

print("="*60)
print("DETAILED ACCURACY ANALYSIS")
print("="*60)

# Filter for fake images only
fake_results = results_df[results_df['true_label'] == 'fake'].copy()

# 1. Domain-wise accuracy
print("\n1. DOMAIN-WISE ACCURACY:")
print("-" * 60)
domain_accuracy = fake_results.groupby('domain')['correct'].agg([
    ('total_images', 'count'),
    ('correct_predictions', 'sum'),
    ('accuracy', 'mean')
]).sort_values('accuracy', ascending=False)

print(domain_accuracy)

# Save to CSV
domain_path = os.path.join(data_dir, 'domain_accuracy.csv')
domain_accuracy.to_csv(domain_path)
print(f"\nDomain accuracy saved to: {domain_path}")

# 2. Mask-wise accuracy
print("\n2. MASK-WISE ACCURACY (Top 20):")
print("-" * 60)
mask_accuracy = fake_results.groupby('mask_name')['correct'].agg([
    ('total_images', 'count'),
    ('correct_predictions', 'sum'),
    ('accuracy', 'mean')
]).sort_values('accuracy', ascending=False)

print(mask_accuracy.head(20))

# Save to CSV
mask_path = os.path.join(data_dir, 'mask_accuracy.csv')
mask_accuracy.to_csv(mask_path)
print(f"\nMask accuracy saved to: {mask_path}")

# 3. Generative model accuracy
print("\n3. GENERATIVE MODEL ACCURACY:")
print("-" * 60)
model_accuracy = fake_results.groupby('dataset')['correct'].agg([
    ('total_images', 'count'),
    ('correct_predictions', 'sum'),
    ('accuracy', 'mean')
]).sort_values('accuracy', ascending=False)

print(model_accuracy)

# Save to CSV
model_path = os.path.join(data_dir, 'generative_model_accuracy.csv')
model_accuracy.to_csv(model_path)
print(f"\nGenerative model accuracy saved to: {model_path}")

print("\n" + "="*60)

In [None]:
# Cell 18: Standard Metrics (Confusion Matrix, Classification Report)

print("="*60)
print("STANDARD METRICS")
print("="*60)

# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
print("\nConfusion Matrix:")
print(cm)

# Save confusion matrix as CSV
cm_df = pd.DataFrame(
    cm,
    index=['True Real', 'True Fake'],
    columns=['Pred Real', 'Pred Fake']
)
cm_path = os.path.join(data_dir, 'confusion_matrix_counts.csv')
cm_df.to_csv(cm_path)
print(f"\nConfusion matrix saved to: {cm_path}")

# Classification report
print("\n" + "-"*60)
print("Classification Report:")
print("-"*60)
class_names = ['Real', 'Fake']
report = classification_report(
    all_labels, 
    all_predictions, 
    target_names=class_names,
    digits=4
)
print(report)

# Save classification report as CSV
report_dict = classification_report(
    all_labels,
    all_predictions,
    target_names=class_names,
    output_dict=True
)
report_df = pd.DataFrame(report_dict).transpose()
report_path = os.path.join(data_dir, 'classification_report.csv')
report_df.to_csv(report_path)
print(f"\nClassification report saved to: {report_path}")

print("="*60)

In [None]:
# Cell 19: Confusion Matrix Plot

plt.figure(figsize=(8, 6))

# Create heatmap
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues', 
    xticklabels=['Real', 'Fake'],
    yticklabels=['Real', 'Fake'],
    cbar_kws={'label': 'Count'},
    square=True,
    linewidths=1,
    linecolor='gray'
)

plt.title('Confusion Matrix - Real vs Fake Detection', fontsize=14, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=12, fontweight='bold')
plt.ylabel('True Label', fontsize=12, fontweight='bold')

# Add accuracy annotation
accuracy_text = f'Overall Accuracy: {test_accuracy:.4f}'
plt.text(1.0, -0.15, accuracy_text, ha='center', va='top', 
         fontsize=11, fontweight='bold', transform=plt.gca().transAxes)

plt.tight_layout()

# Save plot
cm_plot_path = os.path.join(viz_dir, 'confusion_matrix.png')
plt.savefig(cm_plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Confusion matrix plot saved to: {cm_plot_path}")

In [None]:
# Cell 20: Domain & Mask Accuracy Bar Charts

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Domain accuracy
domain_plot_data = domain_accuracy.sort_values('accuracy', ascending=True)
axes[0].barh(range(len(domain_plot_data)), domain_plot_data['accuracy'], color='steelblue')
axes[0].set_yticks(range(len(domain_plot_data)))
axes[0].set_yticklabels(domain_plot_data.index)
axes[0].set_xlabel('Accuracy', fontsize=12, fontweight='bold')
axes[0].set_title('Detection Accuracy by Domain', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='x')

# Add value labels
for i, (idx, row) in enumerate(domain_plot_data.iterrows()):
    axes[0].text(row['accuracy'] + 0.01, i, f"{row['accuracy']:.3f} ({int(row['total_images'])})", 
                 va='center', fontsize=9)

# Plot 2: Top 15 masks by accuracy
mask_plot_data = mask_accuracy.nlargest(15, 'accuracy').sort_values('accuracy', ascending=True)
axes[1].barh(range(len(mask_plot_data)), mask_plot_data['accuracy'], color='coral')
axes[1].set_yticks(range(len(mask_plot_data)))
axes[1].set_yticklabels(mask_plot_data.index, fontsize=9)
axes[1].set_xlabel('Accuracy', fontsize=12, fontweight='bold')
axes[1].set_title('Detection Accuracy by Mask (Top 15)', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='x')

# Add value labels
for i, (idx, row) in enumerate(mask_plot_data.iterrows()):
    axes[1].text(row['accuracy'] + 0.01, i, f"{row['accuracy']:.3f}", 
                 va='center', fontsize=8)

plt.tight_layout()

# Save plot
domain_mask_path = os.path.join(viz_dir, 'domain_mask_accuracy.png')
plt.savefig(domain_mask_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Domain & Mask accuracy charts saved to: {domain_mask_path}")

In [None]:
# Cell 21: Quality Metrics Scatter Plots

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Prepare data - bin quality metrics and calculate accuracy
def bin_and_calculate_accuracy(df, metric_col, bins=10):
    """Bin a metric and calculate accuracy per bin."""
    df_clean = df.dropna(subset=[metric_col])
    df_clean['bin'] = pd.cut(df_clean[metric_col], bins=bins)
    
    result = df_clean.groupby('bin').agg({
        'correct': ['mean', 'count'],
        metric_col: 'mean'
    })
    
    result.columns = ['accuracy', 'count', 'metric_value']
    return result

# Plot 1: SSIM vs Accuracy
ssim_data = bin_and_calculate_accuracy(fake_results, 'ssim', bins=10)
axes[0].scatter(ssim_data['metric_value'], ssim_data['accuracy'], 
                s=ssim_data['count']*2, alpha=0.6, color='blue')
axes[0].set_xlabel('SSIM (Structural Similarity)', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Detection Accuracy', fontsize=12, fontweight='bold')
axes[0].set_title('Detection Accuracy vs SSIM', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].text(0.05, 0.95, 'Bubble size = sample count', 
             transform=axes[0].transAxes, fontsize=9, va='top')

# Plot 2: LPIPS vs Accuracy
lpips_data = bin_and_calculate_accuracy(fake_results, 'lpips_score', bins=10)
axes[1].scatter(lpips_data['metric_value'], lpips_data['accuracy'], 
                s=lpips_data['count']*2, alpha=0.6, color='green')
axes[1].set_xlabel('LPIPS (Perceptual Distance)', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Detection Accuracy', fontsize=12, fontweight='bold')
axes[1].set_title('Detection Accuracy vs LPIPS', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].text(0.05, 0.95, 'Bubble size = sample count', 
             transform=axes[1].transAxes, fontsize=9, va='top')

# Plot 3: MSE vs Accuracy
mse_data = bin_and_calculate_accuracy(fake_results, 'mse', bins=10)
axes[2].scatter(mse_data['metric_value'], mse_data['accuracy'], 
                s=mse_data['count']*2, alpha=0.6, color='red')
axes[2].set_xlabel('MSE (Mean Squared Error)', fontsize=12, fontweight='bold')
axes[2].set_ylabel('Detection Accuracy', fontsize=12, fontweight='bold')
axes[2].set_title('Detection Accuracy vs MSE', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)
axes[2].text(0.05, 0.95, 'Bubble size = sample count', 
             transform=axes[2].transAxes, fontsize=9, va='top')

plt.tight_layout()

# Save plot
quality_path = os.path.join(viz_dir, 'quality_metrics_scatter.png')
plt.savefig(quality_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Quality metrics scatter plots saved to: {quality_path}")
print("\nInterpretation:")
print("- SSIM: Higher values = more similar to original (harder to detect?)")
print("- LPIPS: Lower values = more perceptually similar (harder to detect?)")
print("- MSE: Lower values = closer to original pixel values (harder to detect?)")

In [None]:
# Cell 22: Confidence Histogram

plt.figure(figsize=(10, 6))

# Split data by correctness
correct_confidences = results_df[results_df['correct'] == 1]['confidence']
incorrect_confidences = results_df[results_df['correct'] == 0]['confidence']

# Plot histograms
plt.hist(correct_confidences, bins=50, alpha=0.6, color='green', 
         label=f'Correct Predictions (n={len(correct_confidences)})', edgecolor='black')
plt.hist(incorrect_confidences, bins=50, alpha=0.6, color='red', 
         label=f'Incorrect Predictions (n={len(incorrect_confidences)})', edgecolor='black')

# Add mean lines
mean_correct = correct_confidences.mean()
mean_incorrect = incorrect_confidences.mean()

plt.axvline(mean_correct, color='darkgreen', linestyle='--', linewidth=2,
            label=f'Mean Correct: {mean_correct:.3f}')
plt.axvline(mean_incorrect, color='darkred', linestyle='--', linewidth=2,
            label=f'Mean Incorrect: {mean_incorrect:.3f}')

plt.xlabel('Confidence Score', fontsize=12, fontweight='bold')
plt.ylabel('Frequency', fontsize=12, fontweight='bold')
plt.title('Prediction Confidence Distribution', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Save plot
confidence_path = os.path.join(viz_dir, 'confidence_histogram.png')
plt.savefig(confidence_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Confidence histogram saved to: {confidence_path}")
print(f"\nConfidence Statistics:")
print(f"  Correct predictions - Mean: {mean_correct:.4f}, Std: {correct_confidences.std():.4f}")
print(f"  Incorrect predictions - Mean: {mean_incorrect:.4f}, Std: {incorrect_confidences.std():.4f}")

In [None]:
# Cell 23: Generative Model Comparison Bar Chart

plt.figure(figsize=(10, 6))

# Sort by accuracy
model_plot_data = model_accuracy.sort_values('accuracy', ascending=True)

# Create bar chart
bars = plt.barh(range(len(model_plot_data)), model_plot_data['accuracy'], color='teal')

# Customize axes
plt.yticks(range(len(model_plot_data)), model_plot_data.index)
plt.xlabel('Detection Accuracy', fontsize=12, fontweight='bold')
plt.ylabel('Generative Model', fontsize=12, fontweight='bold')
plt.title('Fake Detection Accuracy by Generative Model', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='x')

# Add value labels
for i, (idx, row) in enumerate(model_plot_data.iterrows()):
    plt.text(row['accuracy'] + 0.01, i, 
             f"{row['accuracy']:.3f} (n={int(row['total_images'])})", 
             va='center', fontsize=10, fontweight='bold')

plt.xlim(0, 1.1)
plt.tight_layout()

# Save plot
model_comparison_path = os.path.join(viz_dir, 'generative_model_comparison.png')
plt.savefig(model_comparison_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Generative model comparison saved to: {model_comparison_path}")
print("\nInterpretation:")
print("- Higher accuracy = fakes from this model are easier to detect")
print("- Lower accuracy = fakes from this model are more convincing/harder to detect")

In [None]:
# Cell 24: Misclassified Examples Grid (4x4 Images)

# Note: This requires loading actual images from disk
# We'll create a grid showing misclassified fake images with their metadata

if len(misclassified_fake) > 0:
    # Select up to 16 random misclassified examples
    num_examples = min(16, len(misclassified_fake))
    sample_misclassified = misclassified_fake.sample(n=num_examples, random_state=RANDOM_SEED)
    
    # Get image paths from the dataset
    # We need to reconstruct paths from perturbed_img_id
    # This assumes we can access the CSV to get the image path
    fake_test_df = pd.read_csv(FAKE_TEST_CSV)
    
    # Create figure
    fig, axes = plt.subplots(4, 4, figsize=(16, 16))
    axes = axes.flatten()
    
    plotted = 0
    for idx, (_, row) in enumerate(sample_misclassified.iterrows()):
        if plotted >= 16:
            break
            
        # Find the image path
        img_row = fake_test_df[fake_test_df['perturbed_img_id'] == row['perturbed_img_id']]
        
        if len(img_row) > 0:
            img_path = img_row.iloc[0]['fake_img_path']
            
            try:
                # Load and display image
                img = Image.open(img_path).convert('RGB')
                axes[plotted].imshow(img)
                axes[plotted].axis('off')
                
                # Create title with metadata
                title = f"{row['mask_name'][:15]}\n"
                title += f"Domain: {row['domain']}\n"
                title += f"Model: {row['dataset'][:20]}\n"
                title += f"Conf: {row['confidence']:.3f} | SSIM: {row['ssim']:.3f}"
                
                axes[plotted].set_title(title, fontsize=8, color='red')
                plotted += 1
                
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                continue
    
    # Hide unused subplots
    for i in range(plotted, 16):
        axes[i].axis('off')
    
    plt.suptitle('Misclassified Fake Images (Predicted as Real)', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    
    # Save plot
    misclass_grid_path = os.path.join(viz_dir, 'misclassified_examples_grid.png')
    plt.savefig(misclass_grid_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Misclassified examples grid saved to: {misclass_grid_path}")
    print(f"Displayed {plotted} misclassified fake images")
else:
    print("No misclassified fake images to display!")

In [None]:
# Cell 25: Final Summary Report

print("="*80)
print(" " * 25 + "TRAINING SUMMARY REPORT")
print("="*80)

print("\nüìä MODEL INFORMATION")
print("-" * 80)
print(f"Architecture: ResNet50 (pretrained on ImageNet)")
print(f"Training Device: {device}")

print("\nüìÅ DATASET STATISTICS")
print("-" * 80)
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")
print(f"Batch size: {BATCH_SIZE}")

print("\n‚öôÔ∏è TRAINING CONFIGURATION")
print("-" * 80)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Optimizer: Adam")
print(f"Scheduler: ReduceLROnPlateau (patience={SCHEDULER_PATIENCE}, factor={SCHEDULER_FACTOR})")
print(f"Augmentation: {'Enabled' if USE_AUGMENTATION else 'Disabled'}")

print("\nüèÜ TRAINING RESULTS")
print("-" * 80)
print(f"Best epoch: {best_epoch}/{NUM_EPOCHS}")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Final training accuracy: {history['train_acc'][-1]:.4f}")
print(f"Final training loss: {history['train_loss'][-1]:.4f}")
print(f"Total training time: {format_time(total_training_time)}")

print("\nüéØ TEST PERFORMANCE")
print("-" * 80)
print(f"Overall test accuracy: {test_accuracy:.4f}")
print(f"Total test samples: {len(all_labels):,}")
print(f"Correct predictions: {(all_predictions == all_labels).sum():,}")
print(f"Incorrect predictions: {(all_predictions != all_labels).sum():,}")

# Real vs Fake breakdown
real_acc = results_df[results_df['true_label'] == 'real']['correct'].mean()
fake_acc = results_df[results_df['true_label'] == 'fake']['correct'].mean()
print(f"\nReal image detection accuracy: {real_acc:.4f}")
print(f"Fake image detection accuracy: {fake_acc:.4f}")

print("\nüîç ERROR ANALYSIS (FAKE IMAGES)")
print("-" * 80)
print(f"Total fake images in test: {(results_df['true_label'] == 'fake').sum():,}")
print(f"Correctly detected fakes: {((results_df['true_label'] == 'fake') & (results_df['correct'] == 1)).sum():,}")
print(f"Misclassified fakes: {len(misclassified_fake):,}")

if len(domain_accuracy) > 0:
    print(f"\nHardest domain: {domain_accuracy.idxmin()['accuracy']} (Acc: {domain_accuracy['accuracy'].min():.4f})")
    print(f"Easiest domain: {domain_accuracy.idxmax()['accuracy']} (Acc: {domain_accuracy['accuracy'].max():.4f})")

if len(model_accuracy) > 0:
    print(f"\nMost detectable generative model: {model_accuracy.idxmax()['accuracy']} (Acc: {model_accuracy['accuracy'].max():.4f})")
    print(f"Least detectable generative model: {model_accuracy.idxmin()['accuracy']} (Acc: {model_accuracy['accuracy'].min():.4f})")

print("\nüíæ OUTPUT FILES")
print("-" * 80)
print(f"Base directory: {base_dir}")
print(f"\nModels ({len([f for f in os.listdir(models_dir) if f.endswith('.pth')])} files):")
print(f"  - {models_dir}")
print(f"\nData CSVs:")
print(f"  - {data_dir}")
for csv_file in sorted(os.listdir(data_dir)):
    if csv_file.endswith('.csv'):
        print(f"    ‚Ä¢ {csv_file}")

print(f"\nVisualizations:")
print(f"  - {viz_dir}")
for img_file in sorted(os.listdir(viz_dir)):
    if img_file.endswith('.png'):
        print(f"    ‚Ä¢ {img_file}")

print("\n" + "="*80)
print(" " * 30 + "TRAINING COMPLETE!")
print("="*80)