# 🏆 Insurance Fraud Detection Model Training - Step by Step Guide

This notebook provides a comprehensive, educational approach to training a CNN model for insurance fraud detection. Each step is clearly explained with detailed comments so you can understand and modify the training process.

## 📋 What You'll Learn:
1. **Dataset Analysis** - Understanding the fraud vs non-fraud data distribution
2. **Data Preprocessing** - Image transformations and augmentation techniques  
3. **Model Architecture** - EfficientNet-B1 with custom classifier
4. **Loss Functions** - Focal Loss for handling class imbalance
5. **Training Process** - Complete training loop with monitoring
6. **Evaluation Metrics** - Precision, Recall, F1-Score analysis
7. **Model Optimization** - Techniques for better performance

## 🎯 Training Goals:
- **Target Precision**: 87-88% for fraud detection
- **Target Recall**: 85%+ to catch most fraud cases
- **Training Time**: 30-45 minutes on GPU
- **Model Size**: Optimized for deployment

## 📦 Step 1: Import Required Libraries

Let's start by importing all the necessary libraries for our fraud detection model training.

In [None]:
# Core PyTorch libraries for deep learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

# Pre-trained models library
import timm

# Data manipulation and visualization
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
from collections import Counter
import json
import gc
import time

# Machine learning metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.metrics import classification_report

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

print("🔧 All libraries imported successfully!")
print("📱 PyTorch version:", torch.__version__)
print("🎯 Ready to start fraud detection training!")

## 🖥️ Step 2: GPU Setup and Environment Configuration

Let's check if we have GPU available and set up our computing environment for optimal training performance.

In [None]:
# Setup device for training (GPU if available, CPU otherwise)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Clear GPU memory if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()  # Python garbage collection
    
    print("🚀 GPU AVAILABLE!")
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
    print("   ✅ Perfect for fast training!")
else:
    print("⚠️ GPU NOT AVAILABLE - Using CPU")
    print("   Training will be slower but still functional")
    print("   Consider using Google Colab or Kaggle for GPU access")

print(f"\n🎯 Training device: {device}")
print("🔧 Environment ready for fraud detection training!")

## 📁 Step 3: Dataset Path Configuration

Let's define the paths to our fraud detection dataset. The dataset is organized with separate folders for training and testing, each containing 'Fraud' and 'Non-Fraud' subfolders.

In [None]:
# Dataset directory paths
train_dir = r'../Insurance-Fraud-Detection/train'
test_dir = r'../Insurance-Fraud-Detection/test'

# Verify the dataset structure
print("📁 Dataset Structure:")
print("=" * 50)

# Check if directories exist
if os.path.exists(train_dir):
    print(f"✅ Training directory found: {train_dir}")
    
    # Count training samples
    fraud_train_path = os.path.join(train_dir, 'Fraud')
    non_fraud_train_path = os.path.join(train_dir, 'Non-Fraud')
    
    if os.path.exists(fraud_train_path):
        fraud_count = len([f for f in os.listdir(fraud_train_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"   🚨 Fraud samples: {fraud_count:,}")
    
    if os.path.exists(non_fraud_train_path):
        non_fraud_count = len([f for f in os.listdir(non_fraud_train_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"   ✅ Non-fraud samples: {non_fraud_count:,}")
        
        # Calculate class imbalance ratio
        total_train = fraud_count + non_fraud_count
        fraud_ratio = fraud_count / total_train * 100
        print(f"   📊 Fraud ratio: {fraud_ratio:.1f}%")
        print(f"   ⚖️ Class imbalance: {non_fraud_count/fraud_count:.1f}:1 (non-fraud:fraud)")
else:
    print("❌ Training directory not found!")

if os.path.exists(test_dir):
    print(f"✅ Test directory found: {test_dir}")
    
    # Count test samples
    fraud_test_path = os.path.join(test_dir, 'Fraud')
    non_fraud_test_path = os.path.join(test_dir, 'Non-Fraud')
    
    if os.path.exists(fraud_test_path):
        fraud_test_count = len([f for f in os.listdir(fraud_test_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"   🚨 Test fraud samples: {fraud_test_count:,}")
    
    if os.path.exists(non_fraud_test_path):
        non_fraud_test_count = len([f for f in os.listdir(non_fraud_test_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"   ✅ Test non-fraud samples: {non_fraud_test_count:,}")
else:
    print("❌ Test directory not found!")

print("\n🎯 Dataset analysis complete!")

## 🔍 Step 4: Focal Loss Implementation

**Why Focal Loss?** 
- Our dataset has severe class imbalance (25:1 ratio)
- Standard cross-entropy loss gets overwhelmed by easy non-fraud examples
- Focal Loss focuses training on hard-to-classify examples
- Alpha parameter handles class imbalance
- Gamma parameter focuses on hard examples

In [None]:
class FocalLoss(nn.Module):
    """
    🎯 Focal Loss for handling class imbalance in fraud detection
    
    Formula: FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)
    
    Parameters:
    - alpha: Balancing factor for rare class (fraud) vs common class (non-fraud)
    - gamma: Focusing parameter - higher values focus more on hard examples
    - reduction: How to aggregate the loss across batch
    """
    
    def __init__(self, alpha=0.75, gamma=5.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha      # 75% focus on fraud detection
        self.gamma = gamma      # Focus on hard examples
        self.reduction = reduction
        
        print(f"🎯 Focal Loss Configuration:")
        print(f"   α (fraud focus): {alpha:.2f} - Emphasizes fraud class")
        print(f"   γ (hard examples): {gamma:.1f} - Focuses on difficult cases")
        print(f"   📈 This helps with severe class imbalance!")
    
    def forward(self, inputs, targets):
        """
        Forward pass of focal loss
        
        Args:
            inputs: Model predictions (logits) [batch_size, num_classes]
            targets: True labels [batch_size]
        
        Returns:
            focal_loss: Computed focal loss value
        """
        # Calculate standard cross-entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', label_smoothing=0.1)
        
        # Calculate probability of correct class
        pt = torch.exp(-ce_loss)
        
        # Clamp to avoid numerical instability
        pt = torch.clamp(pt, min=1e-8, max=1-1e-8)
        
        # Calculate alpha term (class balancing)
        # If target is 0 (fraud), use alpha; if 1 (non-fraud), use (1-alpha)
        alpha_t = torch.where(targets == 0, self.alpha, 1 - self.alpha)
        
        # Calculate focal weight: (1-pt)^gamma
        focal_weight = (1 - pt) ** self.gamma
        focal_weight = torch.clamp(focal_weight, min=1e-8, max=30.0)
        
        # Final focal loss: alpha * focal_weight * cross_entropy
        focal_loss = alpha_t * focal_weight * ce_loss
        
        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Test the focal loss implementation
print("🧪 Testing Focal Loss...")
focal_loss = FocalLoss(alpha=0.75, gamma=5.0)

# Create dummy data for testing
dummy_logits = torch.randn(4, 2)  # 4 samples, 2 classes
dummy_targets = torch.tensor([0, 1, 0, 1])  # Mix of fraud and non-fraud

test_loss = focal_loss(dummy_logits, dummy_targets)
print(f"✅ Focal loss test successful! Loss value: {test_loss:.4f}")
print("🎯 Ready to use for training!")

## 🏗️ Step 5: Model Architecture - EfficientNet-B1 Based Fraud Detector

**Why EfficientNet-B1?**
- Excellent balance between accuracy and speed
- Pre-trained on ImageNet for strong feature extraction
- Compound scaling for optimal performance
- Perfect size for fraud detection (not too big, not too small)

In [None]:
class FraudDetectionModel(nn.Module):
    """
    🏗️ EfficientNet-B1 based fraud detection model
    
    Architecture:
    1. EfficientNet-B1 backbone (pre-trained on ImageNet)
    2. Custom classifier head with dropout and batch normalization
    3. Optimized for fraud detection with bias initialization
    """
    
    def __init__(self, num_classes=2):
        super(FraudDetectionModel, self).__init__()
        
        print("🏗️ Building Fraud Detection Model...")
        
        # Load pre-trained EfficientNet-B1 backbone
        self.backbone = timm.create_model(
            'efficientnet_b1', 
            pretrained=True,    # Use ImageNet pre-trained weights
            num_classes=0       # Remove the original classifier
        )
        
        # Get the number of features from backbone
        self.num_features = self.backbone.num_features
        print(f"   📐 Backbone features: {self.num_features}")
        
        # Custom classifier head for fraud detection
        self.classifier = nn.Sequential(
            # First layer with high dropout to prevent overfitting
            nn.Dropout(0.4),
            nn.Linear(self.num_features, 256),
            nn.BatchNorm1d(256),    # Batch normalization for stable training
            nn.ReLU(),
            
            # Second layer with moderate dropout
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            # Final classification layer
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
        
        # Initialize bias for better fraud detection
        self._initialize_classifier_bias()
        
        print(f"   🎯 Classifier input features: {self.num_features}")
        print(f"   🎯 Output classes: {num_classes}")
        print("   ✅ Model architecture ready!")
    
    def _initialize_classifier_bias(self):
        """
        🎯 Initialize classifier bias for better fraud detection
        
        This gives the model a slight preference toward detecting fraud,
        which is important given the class imbalance.
        """
        with torch.no_grad():
            # Slight bias toward fraud detection (class 0)
            self.classifier[-1].bias[0] = 1.0   # Fraud class
            self.classifier[-1].bias[1] = -0.5  # Non-fraud class
            
        print("   🎯 Classifier bias initialized for fraud detection")
    
    def forward(self, x):
        """
        Forward pass through the model
        
        Args:
            x: Input images [batch_size, 3, 224, 224]
            
        Returns:
            output: Class logits [batch_size, num_classes]
        """
        # Extract features using EfficientNet backbone
        features = self.backbone(x)
        
        # Classify using custom head
        output = self.classifier(features)
        
        return output
    
    def get_model_info(self):
        """Get detailed information about the model"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'backbone_features': self.num_features,
            'model_size_mb': total_params * 4 / (1024 ** 2)  # Approximate size in MB
        }

# Create and test the model
print("🧪 Creating Fraud Detection Model...")
model = FraudDetectionModel(num_classes=2)

# Get model information
model_info = model.get_model_info()
print(f"\n📊 Model Information:")
print(f"   📐 Total parameters: {model_info['total_parameters']:,}")
print(f"   🎯 Trainable parameters: {model_info['trainable_parameters']:,}")
print(f"   💾 Approximate model size: {model_info['model_size_mb']:.1f} MB")

# Test model with dummy input
print(f"\n🧪 Testing model with dummy input...")
dummy_input = torch.randn(2, 3, 224, 224)  # 2 images, 3 channels, 224x224
dummy_output = model(dummy_input)
print(f"   ✅ Input shape: {dummy_input.shape}")
print(f"   ✅ Output shape: {dummy_output.shape}")
print(f"   ✅ Model test successful!")

# Move model to device (GPU/CPU)
model = model.to(device)
print(f"🚀 Model moved to {device}")
print("🎯 Fraud Detection Model ready for training!")

## 🖼️ Step 6: Data Preprocessing and Augmentation

**Why Data Augmentation?**
- Increases dataset diversity without collecting new images
- Helps model generalize better to unseen fraud cases
- Reduces overfitting by creating variations
- Simulates real-world image variations (rotation, lighting, etc.)

In [None]:
def create_data_transforms():
    """
    🖼️ Create optimized data transformations for fraud detection
    
    Training transforms include augmentation for better generalization
    Validation transforms are minimal for consistent evaluation
    """
    
    # Training transforms with augmentation
    train_transform = transforms.Compose([
        # Resize and crop
        transforms.Resize((240, 240)),      # Slightly larger for cropping
        transforms.CenterCrop(224),         # EfficientNet-B1 input size
        
        # Augmentation techniques
        transforms.RandomHorizontalFlip(p=0.5),     # 50% chance to flip
        transforms.RandomRotation(degrees=10),       # Rotate ±10 degrees
        transforms.ColorJitter(                      # Color variations
            brightness=0.1,    # ±10% brightness
            contrast=0.1,      # ±10% contrast
            saturation=0.1,    # ±10% saturation
            hue=0.05          # ±5% hue
        ),
        
        # Convert to tensor and normalize
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet mean
            std=[0.229, 0.224, 0.225]   # ImageNet std
        )
    ])
    
    # Validation transforms (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((240, 240)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    print("🖼️ Data Transforms Created:")
    print("   📈 Training: Resize → Crop → Augment → Normalize")
    print("   📊 Validation: Resize → Crop → Normalize")
    print("   🎯 Input size: 224x224 (optimized for EfficientNet)")
    
    return train_transform, val_transform

# Create the transforms
train_transform, val_transform = create_data_transforms()

# Test transforms with a dummy image
print("\n🧪 Testing transforms...")
dummy_pil_image = Image.new('RGB', (300, 300), color='red')

# Test training transform
train_tensor = train_transform(dummy_pil_image)
print(f"   ✅ Training transform output: {train_tensor.shape}")
print(f"   📊 Tensor range: [{train_tensor.min():.3f}, {train_tensor.max():.3f}]")

# Test validation transform  
val_tensor = val_transform(dummy_pil_image)
print(f"   ✅ Validation transform output: {val_tensor.shape}")
print("   🎯 Transforms ready for dataset creation!")

## 📊 Step 7: Custom Dataset Class with Smart Balancing

**Key Features:**
- Handles severe class imbalance (200 fraud vs 5000 non-fraud)
- Intelligent sampling to maintain fraud detection performance
- Configurable fraud ratio for optimal training
- Memory-efficient loading

In [None]:
class FraudDataset(Dataset):
    """
    📊 Custom dataset for fraud detection with intelligent class balancing
    
    Features:
    - Loads all fraud samples (never lose fraud data)
    - Smart non-fraud sampling for optimal training balance
    - Configurable fraud ratio
    - Memory-efficient image loading
    """
    
    def __init__(self, data_dir, transform=None, fraud_ratio=0.3, max_samples=None):
        """
        Initialize the fraud detection dataset
        
        Args:
            data_dir: Path to data directory (contains 'Fraud' and 'Non-Fraud' folders)
            transform: Data transformations to apply
            fraud_ratio: Target ratio of fraud samples (0.3 = 30% fraud)
            max_samples: Maximum total samples to load (for memory management)
        """
        self.data_dir = data_dir
        self.transform = transform
        self.image_files = []
        self.labels = []
        
        print(f"📂 Loading dataset from: {data_dir}")
        print(f"🎯 Target fraud ratio: {fraud_ratio:.1%}")
        
        # Define class paths
        fraud_dir = os.path.join(data_dir, 'Fraud')
        non_fraud_dir = os.path.join(data_dir, 'Non-Fraud')
        
        # Load ALL fraud samples (critical - never lose fraud data!)
        fraud_files = [f for f in os.listdir(fraud_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        for img_name in fraud_files:
            self.image_files.append(os.path.join(fraud_dir, img_name))
            self.labels.append(0)  # Fraud = 0
        
        fraud_count = len(fraud_files)
        print(f"   🚨 Loaded fraud samples: {fraud_count}")
        
        # Calculate target non-fraud count based on desired ratio
        # fraud_ratio = fraud_count / (fraud_count + non_fraud_count)
        # Solving for non_fraud_count:
        target_non_fraud = int(fraud_count * (1 - fraud_ratio) / fraud_ratio)
        
        # Apply max_samples limit if specified
        if max_samples:
            available_for_non_fraud = max_samples - fraud_count
            target_non_fraud = min(target_non_fraud, available_for_non_fraud)
        
        print(f"   🎯 Target non-fraud samples: {target_non_fraud}")
        
        # Load non-fraud samples up to target count
        non_fraud_files = [f for f in os.listdir(non_fraud_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        # Shuffle for random sampling
        import random
        random.shuffle(non_fraud_files)
        
        loaded_non_fraud = 0
        for img_name in non_fraud_files:
            if loaded_non_fraud >= target_non_fraud:
                break
                
            self.image_files.append(os.path.join(non_fraud_dir, img_name))
            self.labels.append(1)  # Non-fraud = 1
            loaded_non_fraud += 1
        
        # Calculate actual statistics
        total_samples = len(self.labels)
        actual_fraud_ratio = fraud_count / total_samples
        
        print(f"   ✅ Loaded non-fraud samples: {loaded_non_fraud}")
        print(f"   📊 Total samples: {total_samples}")
        print(f"   📈 Actual fraud ratio: {actual_fraud_ratio:.1%}")
        print(f"   ⚖️ Balance ratio: {loaded_non_fraud/fraud_count:.1f}:1 (non-fraud:fraud)")
        
        # Calculate class distribution
        self.class_counts = Counter(self.labels)
        print(f"   📋 Class distribution: {dict(self.class_counts)}")
    
    def __len__(self):
        """Return total number of samples"""
        return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Get a single sample from the dataset
        
        Args:
            idx: Sample index
            
        Returns:
            image: Transformed image tensor
            label: Class label (0=fraud, 1=non-fraud)
        """
        # Load image
        img_path = self.image_files[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"⚠️ Error loading {img_path}: {e}")
            # Return a black image as fallback
            image = Image.new('RGB', (224, 224), color='black')
        
        label = self.labels[idx]
        
        # Apply transforms if provided
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_class_weights(self):
        """
        Calculate class weights for loss function balancing
        
        Returns:
            weights: List of weights for each class [fraud_weight, non_fraud_weight]
        """
        fraud_count = self.class_counts[0]
        non_fraud_count = self.class_counts[1]
        total_count = fraud_count + non_fraud_count
        
        # Inverse frequency weighting
        fraud_weight = total_count / (2 * fraud_count)
        non_fraud_weight = total_count / (2 * non_fraud_count)
        
        return [fraud_weight, non_fraud_weight]
    
    def visualize_samples(self, num_samples=4):
        """
        Visualize random samples from the dataset
        
        Args:
            num_samples: Number of samples to show
        """
        fig, axes = plt.subplots(2, num_samples//2, figsize=(12, 8))
        axes = axes.flatten()
        
        # Get random indices
        indices = random.sample(range(len(self)), num_samples)
        
        for i, idx in enumerate(indices):
            image, label = self[idx]
            
            # Convert tensor back to PIL for visualization
            if isinstance(image, torch.Tensor):
                # Denormalize
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                image = image * std + mean
                image = torch.clamp(image, 0, 1)
                image = transforms.ToPILImage()(image)
            
            axes[i].imshow(image)
            axes[i].set_title(f"{'🚨 Fraud' if label == 0 else '✅ Non-Fraud'}")
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()

# Test the dataset
print("🧪 Testing FraudDataset...")
test_dataset = FraudDataset(
    data_dir=train_dir,
    transform=train_transform,
    fraud_ratio=0.3,
    max_samples=1000  # Small sample for testing
)

print(f"\n📊 Dataset Test Results:")
print(f"   📏 Dataset length: {len(test_dataset)}")
print(f"   ⚖️ Class weights: {test_dataset.get_class_weights()}")

# Test loading a sample
sample_image, sample_label = test_dataset[0]
print(f"   🖼️ Sample image shape: {sample_image.shape}")
print(f"   🏷️ Sample label: {sample_label} ({'Fraud' if sample_label == 0 else 'Non-Fraud'})")
print("   ✅ Dataset test successful!")

## 🔄 Step 8: Create Data Loaders with Advanced Sampling

**WeightedRandomSampler:**
- Ensures balanced training despite class imbalance
- Gives fraud samples higher probability of being selected
- Creates effective oversampling without data duplication
- Improves fraud detection recall

In [None]:
def create_data_loaders(batch_size=16, fraud_ratio=0.3, num_workers=0):
    """
    🔄 Create optimized data loaders for fraud detection training
    
    Args:
        batch_size: Number of samples per batch
        fraud_ratio: Target fraud ratio in dataset
        num_workers: Number of worker processes for data loading
        
    Returns:
        train_loader: Training data loader with weighted sampling
        test_loader: Test data loader for evaluation
    """
    
    print("🔄 Creating Data Loaders...")
    print(f"   📦 Batch size: {batch_size}")
    print(f"   🎯 Target fraud ratio: {fraud_ratio:.1%}")
    
    # Create datasets
    train_dataset = FraudDataset(
        data_dir=train_dir,
        transform=train_transform,
        fraud_ratio=fraud_ratio
    )
    
    test_dataset = FraudDataset(
        data_dir=test_dir,
        transform=val_transform,
        fraud_ratio=0.5,  # Keep natural test distribution
        max_samples=2000  # Limit test size for faster evaluation
    )
    
    # Calculate sample weights for training
    print(f"\n⚖️ Calculating sample weights for balanced training...")
    
    # Get class weights
    class_weights = train_dataset.get_class_weights()
    fraud_weight = class_weights[0] * 2.0  # Extra emphasis on fraud
    non_fraud_weight = class_weights[1]
    
    print(f"   🚨 Fraud weight: {fraud_weight:.2f}")
    print(f"   ✅ Non-fraud weight: {non_fraud_weight:.2f}")
    
    # Create sample weights list
    sample_weights = []
    for label in train_dataset.labels:
        if label == 0:  # Fraud
            sample_weights.append(fraud_weight)
        else:  # Non-fraud
            sample_weights.append(non_fraud_weight)
    
    # Create weighted sampler for balanced training
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights) * 2,  # 2x oversampling
        replacement=True  # Allow sampling with replacement
    )
    
    print(f"   🔄 Sampler created with {len(sample_weights) * 2:,} samples per epoch")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,  # Use weighted sampler
        num_workers=num_workers,
        pin_memory=True if device.type == 'cuda' else False,
        drop_last=True  # Drop incomplete batches for stable training
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,  # Keep consistent test order
        num_workers=num_workers,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    print(f"\n📊 Data Loader Statistics:")
    print(f"   🚂 Training batches per epoch: {len(train_loader):,}")
    print(f"   🧪 Test batches: {len(test_loader):,}")
    print(f"   📦 Samples per training batch: {batch_size}")
    print(f"   🔄 Total training samples per epoch: {len(train_loader) * batch_size:,}")
    
    return train_loader, test_loader, train_dataset, test_dataset

# Create the data loaders
print("🚀 Creating optimized data loaders...")
train_loader, test_loader, train_dataset, test_dataset = create_data_loaders(
    batch_size=12,  # Optimized for GPU memory
    fraud_ratio=0.3,
    num_workers=0  # Set to 0 for Windows compatibility
)

print("✅ Data loaders created successfully!")
print("🎯 Ready for model training!")

## 📈 Step 9: Training Setup and Optimization

**Optimizer Strategy:**
- **AdamW**: Better weight decay than standard Adam
- **Differential Learning Rates**: Lower LR for pre-trained backbone, higher for classifier
- **ReduceLROnPlateau**: Automatically reduce learning rate when stuck
- **Gradient Clipping**: Prevents gradient explosion

In [None]:
# Initialize loss function
criterion = FocalLoss(alpha=0.75, gamma=5.0)

# Setup optimizer with differential learning rates
optimizer = optim.AdamW([
    # Lower learning rate for pre-trained backbone
    {'params': model.backbone.parameters(), 'lr': 2e-5, 'weight_decay': 0.01},
    # Higher learning rate for new classifier
    {'params': model.classifier.parameters(), 'lr': 1e-3, 'weight_decay': 0.01}
])

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',          # Monitor for maximum metric (like F1-score)
    factor=0.7,          # Reduce LR by 30% when plateau
    patience=2,          # Wait 2 epochs before reducing
    min_lr=1e-7,         # Minimum learning rate
    verbose=True
)

print("📈 Training Setup Complete:")
print(f"   🎯 Loss Function: Focal Loss (α=0.75, γ=5.0)")
print(f"   🔧 Optimizer: AdamW with differential learning rates")
print(f"   📊 Backbone LR: 2e-5 (fine-tuning)")
print(f"   📊 Classifier LR: 1e-3 (learning from scratch)")
print(f"   📉 Scheduler: ReduceLROnPlateau (patience=2)")
print("   ✅ Ready to start training!")

## 📊 Step 10: Evaluation Metrics Functions

Let's create comprehensive evaluation functions to monitor our model's performance during training.

In [None]:
def calculate_metrics(y_true, y_pred, verbose=True):
    """
    📊 Calculate comprehensive metrics for fraud detection
    
    Args:
        y_true: True labels
        y_pred: Predicted labels
        verbose: Whether to print detailed results
        
    Returns:
        metrics: Dictionary containing all calculated metrics
    """
    
    # Basic accuracy
    accuracy = accuracy_score(y_true, y_pred)
    
    # Precision, Recall, F1 for each class
    precision, recall, f1, support = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0
    )
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Extract fraud-specific metrics (class 0)
    fraud_precision = precision[0] if len(precision) > 0 else 0.0
    fraud_recall = recall[0] if len(recall) > 0 else 0.0
    fraud_f1 = f1[0] if len(f1) > 0 else 0.0
    
    # Extract non-fraud metrics (class 1)
    non_fraud_precision = precision[1] if len(precision) > 1 else 0.0
    non_fraud_recall = recall[1] if len(recall) > 1 else 0.0
    non_fraud_f1 = f1[1] if len(f1) > 1 else 0.0
    
    # Custom balanced score (emphasizes fraud detection)
    balance_score = (0.6 * fraud_precision + 0.4 * fraud_recall)
    
    # Create metrics dictionary
    metrics = {
        'accuracy': accuracy,
        'fraud_precision': fraud_precision,
        'fraud_recall': fraud_recall,
        'fraud_f1': fraud_f1,
        'non_fraud_precision': non_fraud_precision,
        'non_fraud_recall': non_fraud_recall,
        'non_fraud_f1': non_fraud_f1,
        'balance_score': balance_score,
        'confusion_matrix': cm
    }
    
    if verbose:
        print(f"📊 Evaluation Metrics:")
        print(f"   🎯 Overall Accuracy: {accuracy*100:.1f}%")
        print(f"   🚨 Fraud Precision: {fraud_precision*100:.1f}%")
        print(f"   🚨 Fraud Recall: {fraud_recall*100:.1f}%")
        print(f"   🚨 Fraud F1-Score: {fraud_f1*100:.1f}%")
        print(f"   ✅ Non-Fraud Precision: {non_fraud_precision*100:.1f}%")
        print(f"   ✅ Non-Fraud Recall: {non_fraud_recall*100:.1f}%")
        print(f"   ⚖️ Balance Score: {balance_score*100:.1f}%")
        print(f"   📋 Confusion Matrix:")
        print(f"      Predicted:  [Fraud] [Non-Fraud]")
        print(f"      Fraud:      [{cm[0,0]:4d}]   [{cm[0,1]:4d}]")
        print(f"      Non-Fraud:  [{cm[1,0]:4d}]   [{cm[1,1]:4d}]")
    
    return metrics

def plot_confusion_matrix(cm, title="Confusion Matrix"):
    """
    📊 Plot confusion matrix with nice visualization
    """
    plt.figure(figsize=(8, 6))
    
    # Create heatmap
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Fraud', 'Non-Fraud'],
                yticklabels=['Fraud', 'Non-Fraud'])
    
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

def plot_training_history(history):
    """
    📈 Plot training history with multiple metrics
    """
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss plot
    axes[0,0].plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    axes[0,0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    axes[0,0].set_title('Training and Validation Loss')
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].set_ylabel('Loss')
    axes[0,0].legend()
    axes[0,0].grid(True)
    
    # Accuracy plot
    axes[0,1].plot(epochs, history['accuracy'], 'g-', label='Accuracy')
    axes[0,1].set_title('Model Accuracy')
    axes[0,1].set_xlabel('Epoch')
    axes[0,1].set_ylabel('Accuracy')
    axes[0,1].legend()
    axes[0,1].grid(True)
    
    # Fraud metrics plot
    axes[1,0].plot(epochs, history['fraud_precision'], 'b-', label='Precision')
    axes[1,0].plot(epochs, history['fraud_recall'], 'r-', label='Recall')
    axes[1,0].plot(epochs, history['fraud_f1'], 'g-', label='F1-Score')
    axes[1,0].set_title('Fraud Detection Metrics')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].set_ylabel('Score')
    axes[1,0].legend()
    axes[1,0].grid(True)
    
    # Balance score plot
    axes[1,1].plot(epochs, history['balance_score'], 'purple', label='Balance Score')
    axes[1,1].set_title('Balance Score (60% Precision + 40% Recall)')
    axes[1,1].set_xlabel('Epoch')
    axes[1,1].set_ylabel('Score')
    axes[1,1].legend()
    axes[1,1].grid(True)
    
    plt.tight_layout()
    plt.show()

print("📊 Evaluation functions created:")
print("   🎯 calculate_metrics() - Comprehensive metric calculation")
print("   📋 plot_confusion_matrix() - Visual confusion matrix")
print("   📈 plot_training_history() - Training progress visualization")
print("   ✅ Ready for training monitoring!")

## 🚀 Step 11: Complete Training Loop

This is the heart of our fraud detection training - a comprehensive training loop with real-time monitoring and automatic best model saving.

In [None]:
# Training configuration
NUM_EPOCHS = 15
TARGET_PRECISION = 87.0  # Target 87% precision
TARGET_RECALL = 85.0     # Target 85% recall

# Initialize training history
history = {
    'train_loss': [],
    'val_loss': [],
    'accuracy': [],
    'fraud_precision': [],
    'fraud_recall': [],
    'fraud_f1': [],
    'balance_score': []
}

# Best model tracking
best_balance_score = 0.0
best_model_state = None
patience_counter = 0

print("🚀 Starting Fraud Detection Training!")
print(f"🎯 Target: {TARGET_PRECISION}% Precision, {TARGET_RECALL}% Recall")
print(f"⏱️ Maximum epochs: {NUM_EPOCHS}")
print("=" * 70)

training_start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()
    
    print(f"\n🔄 EPOCH {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    # =============================================================================
    # TRAINING PHASE
    # =============================================================================
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    print("📚 Training Phase...")
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        # Move data to device
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping to prevent explosion
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        
        # Statistics
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        
        # Progress update every 100 batches
        if (batch_idx + 1) % 100 == 0:
            current_acc = 100.0 * train_correct / train_total
            current_loss = train_loss / (batch_idx + 1)
            progress = (batch_idx + 1) / len(train_loader) * 100
            print(f"   Batch {batch_idx+1:4d}/{len(train_loader)} ({progress:5.1f}%) | "
                  f"Loss: {current_loss:.4f} | Acc: {current_acc:.1f}%")
    
    # Calculate training metrics
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100.0 * train_correct / train_total
    
    # =============================================================================
    # VALIDATION PHASE
    # =============================================================================
    print("🧪 Validation Phase...")
    model.eval()
    val_loss = 0.0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate validation metrics
    avg_val_loss = val_loss / len(test_loader)
    metrics = calculate_metrics(all_labels, all_predictions, verbose=False)
    
    # Extract key metrics
    accuracy = metrics['accuracy'] * 100
    fraud_precision = metrics['fraud_precision'] * 100
    fraud_recall = metrics['fraud_recall'] * 100
    fraud_f1 = metrics['fraud_f1'] * 100
    balance_score = metrics['balance_score'] * 100
    
    # =============================================================================
    # EPOCH SUMMARY
    # =============================================================================
    epoch_time = time.time() - epoch_start_time
    total_time = time.time() - training_start_time
    
    print(f"\n📊 EPOCH {epoch+1} RESULTS:")
    print(f"   ⏱️ Time: {epoch_time:.1f}s | Total: {total_time/60:.1f}min")
    print(f"   📉 Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    print(f"   🎯 Accuracy: {accuracy:.1f}%")
    print(f"   🚨 Fraud Precision: {fraud_precision:.1f}%")
    print(f"   🚨 Fraud Recall: {fraud_recall:.1f}%")
    print(f"   🚨 Fraud F1-Score: {fraud_f1:.1f}%")
    print(f"   ⚖️ Balance Score: {balance_score:.1f}%")
    
    # Update history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['accuracy'].append(accuracy)
    history['fraud_precision'].append(fraud_precision)
    history['fraud_recall'].append(fraud_recall)
    history['fraud_f1'].append(fraud_f1)
    history['balance_score'].append(balance_score)
    
    # =============================================================================
    # MODEL SAVING AND EARLY STOPPING
    # =============================================================================
    
    # Save best model
    if balance_score > best_balance_score and fraud_precision >= 75:
        best_balance_score = balance_score
        best_model_state = model.state_dict().copy()
        torch.save(model.state_dict(), 'best_fraud_model.pth')
        patience_counter = 0
        print(f"   ✅ NEW BEST MODEL! Balance Score: {balance_score:.1f}%")
    else:
        patience_counter += 1
    
    # Check if target achieved
    if fraud_precision >= TARGET_PRECISION and fraud_recall >= TARGET_RECALL:
        print(f"\n🎉 TARGET ACHIEVED!")
        print(f"   🎯 Precision: {fraud_precision:.1f}% (≥{TARGET_PRECISION}%)")
        print(f"   🚨 Recall: {fraud_recall:.1f}% (≥{TARGET_RECALL}%)")
        print(f"   ⏱️ Time to target: {total_time/60:.1f} minutes")
        break
    
    # Learning rate scheduling
    scheduler.step(balance_score)
    
    # Early stopping check
    if patience_counter >= 5:
        print(f"\n⏸️ Early stopping triggered (patience=5)")
        break
    
    print("=" * 70)

# =============================================================================
# TRAINING COMPLETE
# =============================================================================
total_training_time = time.time() - training_start_time

print(f"\n🏁 TRAINING COMPLETE!")
print(f"⏱️ Total training time: {total_training_time/60:.1f} minutes")
print(f"🏆 Best balance score: {best_balance_score:.1f}%")
print(f"📈 Training epochs completed: {len(history['train_loss'])}")

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("✅ Best model loaded for final evaluation")

print("🎯 Training phase complete! Ready for final evaluation.")

## 📊 Step 12: Final Evaluation and Visualization

Let's evaluate our trained model and visualize the training progress and final performance.

In [None]:
# Final comprehensive evaluation
print("🧪 Performing Final Comprehensive Evaluation...")
print("=" * 60)

model.eval()
final_predictions = []
final_labels = []
final_probabilities = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        probabilities = F.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
        
        final_predictions.extend(predicted.cpu().numpy())
        final_labels.extend(labels.cpu().numpy())
        final_probabilities.extend(probabilities.cpu().numpy())

# Calculate final metrics
final_metrics = calculate_metrics(final_labels, final_predictions, verbose=True)

print(f"\n🏆 FINAL MODEL PERFORMANCE:")
print(f"   🎯 Overall Accuracy: {final_metrics['accuracy']*100:.1f}%")
print(f"   🚨 Fraud Precision: {final_metrics['fraud_precision']*100:.1f}%")
print(f"   🚨 Fraud Recall: {final_metrics['fraud_recall']*100:.1f}%")
print(f"   🚨 Fraud F1-Score: {final_metrics['fraud_f1']*100:.1f}%")
print(f"   ⚖️ Balance Score: {final_metrics['balance_score']*100:.1f}%")

# Success evaluation
fraud_precision_final = final_metrics['fraud_precision'] * 100
fraud_recall_final = final_metrics['fraud_recall'] * 100

if fraud_precision_final >= 87 and fraud_recall_final >= 85:
    print("\n🎉 TARGET ACHIEVED! Model ready for deployment!")
    success_level = "EXCELLENT"
elif fraud_precision_final >= 85 and fraud_recall_final >= 80:
    print("\n✅ VERY GOOD performance! Close to target.")
    success_level = "VERY GOOD"
elif fraud_precision_final >= 80:
    print("\n📈 GOOD performance! Consider additional training.")
    success_level = "GOOD"
else:
    print("\n⚠️ Performance below target. Review hyperparameters.")
    success_level = "NEEDS IMPROVEMENT"

print(f"🏅 Performance Level: {success_level}")

In [None]:
# Visualize training history
print("\n📈 Plotting Training History...")
plot_training_history(history)

# Plot confusion matrix
print("\n📋 Plotting Final Confusion Matrix...")
plot_confusion_matrix(final_metrics['confusion_matrix'], 
                     title="Final Model Confusion Matrix")

# Classification report
print("\n📊 Detailed Classification Report:")
print(classification_report(final_labels, final_predictions, 
                          target_names=['Fraud', 'Non-Fraud'], digits=3))

## 💾 Step 13: Model Saving and Export

Save the trained model for deployment and future use.

In [None]:
# Save complete model
print("💾 Saving trained model...")

# Save the complete model (architecture + weights)
torch.save(model, 'final_fraud_detection_model.pth')
print("✅ Complete model saved as 'final_fraud_detection_model.pth'")

# Save only the state dict (more efficient)
torch.save(model.state_dict(), 'final_fraud_model_weights.pth')
print("✅ Model weights saved as 'final_fraud_model_weights.pth'")

# Save training history
import json
with open('training_history.json', 'w') as f:
    # Convert numpy arrays to lists for JSON serialization
    history_for_json = {}
    for key, value in history.items():
        if isinstance(value[0], np.ndarray):
            history_for_json[key] = [float(v) for v in value]
        else:
            history_for_json[key] = value
    json.dump(history_for_json, f, indent=2)
print("✅ Training history saved as 'training_history.json'")

# Save model info
model_info = {
    'model_name': 'EfficientNet-B1 Fraud Detector',
    'input_size': [3, 224, 224],
    'num_classes': 2,
    'final_performance': {
        'accuracy': float(final_metrics['accuracy']),
        'fraud_precision': float(final_metrics['fraud_precision']),
        'fraud_recall': float(final_metrics['fraud_recall']),
        'fraud_f1': float(final_metrics['fraud_f1']),
        'balance_score': float(final_metrics['balance_score'])
    },
    'training_epochs': len(history['train_loss']),
    'training_time_minutes': float(total_training_time / 60),
    'target_achieved': fraud_precision_final >= 87 and fraud_recall_final >= 85
}

with open('model_info.json', 'w') as f:
    json.dump(model_info, f, indent=2)
print("✅ Model info saved as 'model_info.json'")

print(f"\n📁 Files saved:")
print(f"   🧠 final_fraud_detection_model.pth - Complete model")
print(f"   ⚖️ final_fraud_model_weights.pth - Model weights only")
print(f"   📈 training_history.json - Training progress data")
print(f"   📊 model_info.json - Model metadata and performance")

print(f"\n🎯 Model Summary:")
print(f"   🏗️ Architecture: EfficientNet-B1 + Custom Classifier")
print(f"   📐 Parameters: {model_info['final_performance']}")
print(f"   ⏱️ Training Time: {model_info['training_time_minutes']:.1f} minutes")
print(f"   🎯 Target Achieved: {'✅ YES' if model_info['target_achieved'] else '❌ NO'}")

print("\n🎉 Model training and saving complete!")
print("🚀 Your fraud detection model is ready for deployment!")