<a href="https://colab.research.google.com/github/nebuchad-nezzar/Python-Notebooks/blob/main/Transfer_Learningipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

!pip install torch torchvision tqdm matplotlib

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet18  # Using ResNet18 as it's lighter
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Check if GPU is available (important for Colab)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define data transforms
transform = transforms.Compose([
    transforms.Resize(224),  # ResNet expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
def load_dataset():
    # Download and load training set
    full_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )

    # Split into train, validation, and test (60%, 20%, 20%)
    total_size = len(full_dataset)
    train_size = int(0.6 * total_size)
    val_size = int(0.2 * total_size)
    test_size = total_size - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size]
    )

    return train_dataset, val_dataset, test_dataset

# Create data loaders
def create_data_loaders(train_dataset, val_dataset, test_dataset, batch_size=32):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

# Define the model
def load_model():
    # Load pre-trained ResNet18
    model = resnet18(pretrained=True)

    # Modify the final layer for CIFAR-10 (10 classes)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)

    return model.to(device)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5):
    best_val_acc = 0.0

    # Lists to store metrics for plotting
    train_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Training loop with progress bar
        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Calculate average loss for this epoch
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)

        # Validation phase
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        val_accuracies.append(val_accuracy)

        print(f'Epoch {epoch+1}:')
        print(f'Training Loss: {epoch_loss:.3f}')
        print(f'Validation Accuracy: {val_accuracy:.2f}%')

        # Save best model
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')

    # Plot training progress
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies)
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.show()

# Function to show predictions
def show_predictions(model, test_loader, classes, num_images=5):
    model.eval()

    # Get a batch of test images
    dataiter = iter(test_loader)
    images, labels = next(dataiter)
    images = images[:num_images]
    labels = labels[:num_images]

    # Get predictions
    images = images.to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    # Show images and predictions
    fig = plt.figure(figsize=(15, 3))
    for idx in range(num_images):
        ax = plt.subplot(1, num_images, idx + 1)
        img = images[idx].cpu().numpy().transpose((1, 2, 0))
        img = img * 0.5 + 0.5  # Denormalize
        plt.imshow(img)
        color = 'green' if predicted[idx] == labels[idx] else 'red'
        ax.set_title(f'Pred: {classes[predicted[idx]]}\nTrue: {classes[labels[idx]]}',
                    color=color)
        plt.axis('off')
    plt.show()

# Main execution
def main():
    # CIFAR-10 classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
              'dog', 'frog', 'horse', 'ship', 'truck')

    # Load and split dataset
    print("Loading dataset...")
    train_dataset, val_dataset, test_dataset = load_dataset()

    # Create data loaders
    print("Creating data loaders...")
    train_loader, val_loader, test_loader = create_data_loaders(
        train_dataset, val_dataset, test_dataset
    )

    # Initialize model and training components
    print("Initializing model...")
    model = load_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train model
    print("Starting training...")
    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5)

    # Show some predictions
    print("\nShowing predictions...")
    show_predictions(model, test_loader, classes)

if __name__ == "__main__":
    main()

Using device: cpu
Loading dataset...
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 39.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Creating data loaders...
Initializing model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 107MB/s]


Starting training...


Epoch 1/5:   1%|          | 9/938 [01:15<2:09:08,  8.34s/it]


KeyboardInterrupt: 

ResNet Training Pipeline with Anti-Overfitting Techniques

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

class EnhancedResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(EnhancedResNet, self).__init__()
        # Load pretrained ResNet
        self.resnet = resnet18(pretrained=True)

        # Remove the final fully connected layer
        num_features = self.resnet.fc.in_features

        # Add dropout and new classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),  # Add dropout with 0.5 probability
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),  # Add another dropout layer
            nn.Linear(512, num_classes)
        )

        self.resnet.fc = self.classifier

    def forward(self, x):
        return self.resnet(x)

# Enhanced data augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Validation transform (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def train_model_enhanced(model, train_loader, val_loader, criterion, optimizer,
                        scheduler, num_epochs=30, device='cuda'):
    early_stopping = EarlyStopping(patience=5)
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': []
    }

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with dropout enabled (model.train() enables dropout)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Add L2 regularization loss
            l2_lambda = 0.01
            l2_reg = torch.tensor(0.).to(device)
            for param in model.parameters():
                l2_reg += torch.norm(param)
            loss += l2_lambda * l2_reg

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        # Calculate epoch metrics
        epoch_train_loss = train_loss / len(train_loader)
        epoch_val_loss = val_loss / len(val_loader)
        epoch_train_acc = 100. * train_correct / train_total
        epoch_val_acc = 100. * val_correct / val_total

        # Update learning rate based on validation loss
        scheduler.step(epoch_val_loss)

        # Store history
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_acc'].append(epoch_val_acc)

        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%')
        print(f'Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        # Early stopping check
        early_stopping(epoch_val_loss)
        if early_stopping.should_stop:
            print("Early stopping triggered!")
            break

    return history

def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot losses
    ax1.plot(history['train_loss'], label='Training Loss')
    ax1.plot(history['val_loss'], label='Validation Loss')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    # Plot accuracies
    ax2.plot(history['train_acc'], label='Training Accuracy')
    ax2.plot(history['val_acc'], label='Validation Accuracy')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()

    plt.show()

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load datasets with different transforms for train and validation
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transform
    )
    val_size = 5000
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    # Initialize model
    model = EnhancedResNet().to(device)

    # Initialize training components
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    # Train model
    history = train_model_enhanced(
        model, train_loader, val_loader, criterion, optimizer, scheduler
    )

    # Plot training history
    plot_training_history(history)

if __name__ == "__main__":
    main()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:13<00:00, 12.6MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 102MB/s]
Epoch 1/30: 100%|██████████| 1407/1407 [05:26<00:00,  4.31it/s]



Epoch 1/30:
Train Loss: 4.7061, Train Acc: 32.60%
Val Loss: 1.8388, Val Acc: 32.02%
Learning Rate: 0.001000


Epoch 2/30: 100%|██████████| 1407/1407 [05:27<00:00,  4.30it/s]



Epoch 2/30:
Train Loss: 2.9094, Train Acc: 34.60%
Val Loss: 1.9309, Val Acc: 28.82%
Learning Rate: 0.001000


Epoch 3/30: 100%|██████████| 1407/1407 [05:19<00:00,  4.40it/s]



Epoch 3/30:
Train Loss: 2.4537, Train Acc: 38.28%
Val Loss: 1.6342, Val Acc: 39.10%
Learning Rate: 0.001000


Epoch 4/30: 100%|██████████| 1407/1407 [05:17<00:00,  4.43it/s]



Epoch 4/30:
Train Loss: 2.2566, Train Acc: 40.09%
Val Loss: 1.6213, Val Acc: 39.86%
Learning Rate: 0.001000


Epoch 5/30: 100%|██████████| 1407/1407 [05:17<00:00,  4.44it/s]



Epoch 5/30:
Train Loss: 2.1525, Train Acc: 42.11%
Val Loss: 1.5871, Val Acc: 41.78%
Learning Rate: 0.001000


Epoch 6/30: 100%|██████████| 1407/1407 [05:15<00:00,  4.47it/s]



Epoch 6/30:
Train Loss: 2.1093, Train Acc: 44.03%
Val Loss: 1.5521, Val Acc: 43.66%
Learning Rate: 0.001000


Epoch 7/30: 100%|██████████| 1407/1407 [05:13<00:00,  4.49it/s]



Epoch 7/30:
Train Loss: 2.0713, Train Acc: 45.26%
Val Loss: 1.9379, Val Acc: 32.46%
Learning Rate: 0.001000


Epoch 8/30: 100%|██████████| 1407/1407 [05:15<00:00,  4.46it/s]



Epoch 8/30:
Train Loss: 2.0585, Train Acc: 45.85%
Val Loss: 1.5567, Val Acc: 42.76%
Learning Rate: 0.001000


Epoch 9/30: 100%|██████████| 1407/1407 [05:16<00:00,  4.45it/s]



Epoch 9/30:
Train Loss: 2.0510, Train Acc: 46.17%
Val Loss: 1.5250, Val Acc: 44.06%
Learning Rate: 0.001000


Epoch 10/30: 100%|██████████| 1407/1407 [05:15<00:00,  4.46it/s]



Epoch 10/30:
Train Loss: 2.0403, Train Acc: 46.18%
Val Loss: 4.6024, Val Acc: 18.76%
Learning Rate: 0.001000


Epoch 11/30: 100%|██████████| 1407/1407 [05:13<00:00,  4.48it/s]



Epoch 11/30:
Train Loss: 2.0229, Train Acc: 46.92%
Val Loss: 1.6196, Val Acc: 42.10%
Learning Rate: 0.001000


Epoch 12/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 12/30:
Train Loss: 2.0207, Train Acc: 46.99%
Val Loss: 1.9084, Val Acc: 36.14%
Learning Rate: 0.001000


Epoch 13/30: 100%|██████████| 1407/1407 [05:15<00:00,  4.46it/s]



Epoch 13/30:
Train Loss: 2.0281, Train Acc: 47.34%
Val Loss: 1.5871, Val Acc: 40.76%
Learning Rate: 0.000100


Epoch 14/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.48it/s]



Epoch 14/30:
Train Loss: 1.8604, Train Acc: 51.79%
Val Loss: 1.2433, Val Acc: 55.14%
Learning Rate: 0.000100


Epoch 15/30: 100%|██████████| 1407/1407 [05:12<00:00,  4.50it/s]



Epoch 15/30:
Train Loss: 1.7859, Train Acc: 53.78%
Val Loss: 1.2542, Val Acc: 54.34%
Learning Rate: 0.000100


Epoch 16/30: 100%|██████████| 1407/1407 [05:12<00:00,  4.51it/s]



Epoch 16/30:
Train Loss: 1.7531, Train Acc: 54.25%
Val Loss: 1.2396, Val Acc: 55.46%
Learning Rate: 0.000100


Epoch 17/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 17/30:
Train Loss: 1.7310, Train Acc: 54.60%
Val Loss: 1.2032, Val Acc: 56.50%
Learning Rate: 0.000100


Epoch 18/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 18/30:
Train Loss: 1.7213, Train Acc: 54.76%
Val Loss: 1.1888, Val Acc: 57.34%
Learning Rate: 0.000100


Epoch 19/30: 100%|██████████| 1407/1407 [05:15<00:00,  4.46it/s]



Epoch 19/30:
Train Loss: 1.7076, Train Acc: 55.14%
Val Loss: 1.2487, Val Acc: 55.34%
Learning Rate: 0.000100


Epoch 20/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 20/30:
Train Loss: 1.6958, Train Acc: 55.39%
Val Loss: 1.1999, Val Acc: 57.64%
Learning Rate: 0.000100


Epoch 21/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.48it/s]



Epoch 21/30:
Train Loss: 1.6860, Train Acc: 55.27%
Val Loss: 1.1946, Val Acc: 57.08%
Learning Rate: 0.000100


Epoch 22/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 22/30:
Train Loss: 1.6735, Train Acc: 56.04%
Val Loss: 1.1724, Val Acc: 57.70%
Learning Rate: 0.000100


Epoch 23/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 23/30:
Train Loss: 1.6592, Train Acc: 56.30%
Val Loss: 1.2155, Val Acc: 56.12%
Learning Rate: 0.000100


Epoch 24/30: 100%|██████████| 1407/1407 [05:14<00:00,  4.47it/s]



Epoch 24/30:
Train Loss: 1.6518, Train Acc: 56.52%
Val Loss: 1.1706, Val Acc: 57.94%
Learning Rate: 0.000100


Epoch 25/30:  60%|█████▉    | 843/1407 [03:08<02:00,  4.70it/s]

Training Analysis with Stopping Criteria Visualization

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter

class TrainingAnalyzer:
    def __init__(self, history):
        self.history = history
        self.train_loss = history['train_loss']
        self.val_loss = history['val_loss']
        self.train_acc = history['train_acc']
        self.val_acc = history['val_acc']
        self.epochs = range(1, len(self.train_loss) + 1)

    def plot_comprehensive_analysis(self):
        """Create a comprehensive visualization of training metrics"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

        # 1. Loss Curves with Smoothing
        self._plot_smoothed_loss(ax1)

        # 2. Training vs Validation Accuracy
        self._plot_accuracy_comparison(ax2)

        # 3. Generalization Gap
        self._plot_generalization_gap(ax3)

        # 4. Loss Gradient Analysis
        self._plot_loss_gradient(ax4)

        plt.tight_layout()
        return fig

    def _plot_smoothed_loss(self, ax):
        """Plot smoothed loss curves to better visualize trends"""
        window = min(15, len(self.train_loss) // 3)
        if window % 2 == 0:
            window += 1

        smoothed_train = savgol_filter(self.train_loss, window, 3)
        smoothed_val = savgol_filter(self.val_loss, window, 3)

        ax.plot(self.epochs, self.train_loss, 'lightblue', label='Training Loss (Raw)', alpha=0.3)
        ax.plot(self.epochs, self.val_loss, 'lightcoral', label='Validation Loss (Raw)', alpha=0.3)
        ax.plot(self.epochs, smoothed_train, 'blue', label='Training Loss (Smoothed)')
        ax.plot(self.epochs, smoothed_val, 'red', label='Validation Loss (Smoothed)')

        # Find divergence point
        diff = np.array(self.val_loss) - np.array(self.train_loss)
        divergence_point = np.where(diff > np.mean(diff) + np.std(diff))[0]
        if len(divergence_point) > 0:
            div_epoch = divergence_point[0]
            ax.axvline(x=div_epoch, color='gray', linestyle='--', alpha=0.5)
            ax.text(div_epoch + 0.5, max(self.train_loss), f'Divergence at epoch {div_epoch}')

        ax.set_title('Loss Curves (Raw and Smoothed)')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True, alpha=0.3)

    def _plot_accuracy_comparison(self, ax):
        """Plot accuracy curves with optimal stopping point"""
        ax.plot(self.epochs, self.train_acc, 'b-', label='Training Accuracy')
        ax.plot(self.epochs, self.val_acc, 'r-', label='Validation Accuracy')

        # Find optimal accuracy point
        best_val_epoch = np.argmax(self.val_acc) + 1
        ax.axvline(x=best_val_epoch, color='green', linestyle='--', alpha=0.5)
        ax.text(best_val_epoch + 0.5, max(self.train_acc), f'Best validation at epoch {best_val_epoch}')

        ax.set_title('Accuracy Progression')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Accuracy (%)')
        ax.legend()
        ax.grid(True, alpha=0.3)

    def _plot_generalization_gap(self, ax):
        """Plot the generalization gap (difference between train and validation metrics)"""
        acc_gap = np.array(self.train_acc) - np.array(self.val_acc)
        loss_gap = np.array(self.val_loss) - np.array(self.train_loss)

        ax.plot(self.epochs, acc_gap, 'b-', label='Accuracy Gap')
        ax.plot(self.epochs, loss_gap, 'r-', label='Loss Gap')

        # Find point where gap starts growing consistently
        window = 5
        gap_gradient = np.gradient(savgol_filter(acc_gap, window, 3))
        consistent_growth = np.where(gap_gradient > np.mean(gap_gradient) + np.std(gap_gradient))[0]

        if len(consistent_growth) > 0:
            growth_point = consistent_growth[0]
            ax.axvline(x=growth_point, color='red', linestyle='--', alpha=0.5)
            ax.text(growth_point + 0.5, max(acc_gap), f'Gap growth at epoch {growth_point}')

        ax.set_title('Generalization Gap')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Gap (Train - Validation)')
        ax.legend()
        ax.grid(True, alpha=0.3)

    def _plot_loss_gradient(self, ax):
        """Plot loss gradients to identify plateaus and unstable regions"""
        val_gradient = np.gradient(self.val_loss)
        train_gradient = np.gradient(self.train_loss)

        ax.plot(self.epochs, train_gradient, 'b-', label='Training Loss Gradient')
        ax.plot(self.epochs, val_gradient, 'r-', label='Validation Loss Gradient')

        # Find plateau points
        plateau_threshold = 0.001
        plateau_points = np.where(np.abs(val_gradient) < plateau_threshold)[0]

        if len(plateau_points) > 0:
            first_plateau = plateau_points[0]
            ax.axvline(x=first_plateau, color='purple', linestyle='--', alpha=0.5)
            ax.text(first_plateau + 0.5, max(train_gradient), f'First plateau at epoch {first_plateau}')

        ax.set_title('Loss Gradients')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Gradient')
        ax.legend()
        ax.grid(True, alpha=0.3)

    def get_stopping_recommendations(self):
        """Provide stopping recommendations based on multiple criteria"""
        recommendations = []

        # 1. Best validation accuracy
        best_val_epoch = np.argmax(self.val_acc) + 1
        recommendations.append(f"Best validation accuracy achieved at epoch {best_val_epoch}")

        # 2. Generalization gap analysis
        acc_gap = np.array(self.train_acc) - np.array(self.val_acc)
        gap_threshold = np.mean(acc_gap) + np.std(acc_gap)
        overfitting_epochs = np.where(acc_gap > gap_threshold)[0]
        if len(overfitting_epochs) > 0:
            recommendations.append(f"Significant generalization gap detected at epoch {overfitting_epochs[0]}")

        # 3. Loss plateau detection
        val_gradient = np.gradient(self.val_loss)
        plateau_threshold = 0.001
        plateau_points = np.where(np.abs(val_gradient) < plateau_threshold)[0]
        if len(plateau_points) > 0:
            recommendations.append(f"Loss plateau detected at epoch {plateau_points[0]}")

        return recommendations

# Example usage:
def analyze_training(history):
    analyzer = TrainingAnalyzer(history)

    # Plot comprehensive analysis
    analyzer.plot_comprehensive_analysis()
    plt.show()

    # Print recommendations
    print("\nStopping Recommendations:")
    for rec in analyzer.get_stopping_recommendations():
        print(f"- {rec}")

# Add this to your training loop:
history = train_model_enhanced(model, train_loader, val_loader, criterion, optimizer, scheduler)
analyze_training(history)

Class Performance Analysis and Improvement Recommendations

In [None]:
import torch
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
from collections import defaultdict

class ClassPerformanceAnalyzer:
    def __init__(self, model, test_loader, classes, device):
        self.model = model
        self.test_loader = test_loader
        self.classes = classes
        self.device = device
        self.class_correct = defaultdict(int)
        self.class_total = defaultdict(int)
        self.predictions = []
        self.true_labels = []

    def analyze(self):
        """Run complete analysis of class performance"""
        self.model.eval()
        confusion_mat = None

        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs, 1)

                # Store predictions and true labels
                self.predictions.extend(predicted.cpu().numpy())
                self.true_labels.extend(labels.cpu().numpy())

                # Update per-class accuracy
                c = (predicted == labels).squeeze()
                for i in range(len(labels)):
                    label = labels[i]
                    self.class_correct[label] += c[i].item()
                    self.class_total[label] += 1

        return self._generate_performance_report()

    def _generate_performance_report(self):
        """Generate comprehensive performance report"""
        # Calculate per-class accuracy
        class_accuracy = {}
        for i in range(len(self.classes)):
            accuracy = 100 * self.class_correct[i] / self.class_total[i]
            class_accuracy[self.classes[i]] = accuracy

        # Sort classes by performance
        sorted_classes = sorted(class_accuracy.items(), key=lambda x: x[1], reverse=True)

        # Generate confusion matrix
        conf_matrix = confusion_matrix(self.true_labels, self.predictions)

        return {
            'class_accuracy': class_accuracy,
            'sorted_performance': sorted_classes,
            'confusion_matrix': conf_matrix,
            'classification_report': classification_report(
                self.true_labels,
                self.predictions,
                target_names=self.classes
            )
        }

    def plot_performance_analysis(self, results):
        """Create visualization of class performance"""
        fig = plt.figure(figsize=(20, 10))

        # 1. Class Accuracy Bar Plot
        plt.subplot(1, 2, 1)
        accuracies = [acc for _, acc in results['sorted_performance']]
        classes = [cls for cls, _ in results['sorted_performance']]
        colors = ['green' if acc >= 75 else 'orange' if acc >= 60 else 'red' for acc in accuracies]

        plt.bar(classes, accuracies, color=colors)
        plt.title('Per-Class Accuracy')
        plt.xlabel('Classes')
        plt.ylabel('Accuracy (%)')
        plt.xticks(rotation=45)

        # Add horizontal lines for reference
        plt.axhline(y=75, color='g', linestyle='--', alpha=0.3, label='Good (75%)')
        plt.axhline(y=60, color='orange', linestyle='--', alpha=0.3, label='Fair (60%)')
        plt.legend()

        # 2. Confusion Matrix Heatmap
        plt.subplot(1, 2, 2)
        sns.heatmap(results['confusion_matrix'],
                   annot=True,
                   fmt='d',
                   cmap='Blues',
                   xticklabels=self.classes,
                   yticklabels=self.classes)
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')

        plt.tight_layout()
        return fig

    def generate_improvement_recommendations(self, results):
        """Generate specific recommendations for worst performing classes"""
        recommendations = {}
        for class_name, accuracy in results['sorted_performance'][-3:]:  # Bottom 3 classes
            conf_matrix = results['confusion_matrix']
            class_idx = self.classes.index(class_name)

            # Find most common misclassifications
            misclassifications = [
                (self.classes[i], conf_matrix[class_idx][i])
                for i in range(len(self.classes))
                if i != class_idx and conf_matrix[class_idx][i] > 0
            ]
            misclassifications.sort(key=lambda x: x[1], reverse=True)

            recommendations[class_name] = {
                'accuracy': accuracy,
                'common_confusions': misclassifications[:3],
                'specific_recommendations': self._get_specific_recommendations(
                    class_name,
                    misclassifications
                )
            }

        return recommendations

    def _get_specific_recommendations(self, class_name, confusions):
        """Generate specific recommendations based on confusion patterns"""
        recommendations = []

        # Data-related recommendations
        recommendations.append({
            'category': 'Data Enhancement',
            'suggestions': [
                f"Collect more diverse {class_name} images",
                f"Apply targeted augmentation for {class_name} characteristics",
                "Use stratified sampling to balance class representation"
            ]
        })

        # Model-related recommendations
        recommendations.append({
            'category': 'Model Adjustments',
            'suggestions': [
                "Implement class-weighted loss function",
                f"Add attention mechanisms to focus on {class_name} distinctive features",
                "Fine-tune model layers specifically for problematic classes"
            ]
        })

        # Training-related recommendations
        recommendations.append({
            'category': 'Training Strategy',
            'suggestions': [
                "Use curriculum learning starting with easy examples",
                "Implement hard negative mining",
                "Apply mixup augmentation for confused classes"
            ]
        })

        return recommendations

def main():
    # Example usage with your existing model and data
    analyzer = ClassPerformanceAnalyzer(model, test_loader, classes, device)
    results = analyzer.analyze()

    # Plot performance analysis
    analyzer.plot_performance_analysis(results)
    plt.show()

    # Get recommendations for worst performing classes
    recommendations = analyzer.generate_improvement_recommendations(results)

    # Print detailed analysis
    print("\nClass Performance Analysis:")
    print("-" * 50)
    for class_name, accuracy in results['sorted_performance']:
        print(f"{class_name:10s}: {accuracy:.2f}%")

    print("\nDetailed Recommendations for Worst Performing Classes:")
    print("-" * 50)
    for class_name, rec in recommendations.items():
        print(f"\nClass: {class_name} (Accuracy: {rec['accuracy']:.2f}%)")
        print("Common confusions:")
        for confused_class, count in rec['common_confusions']:
            print(f"- Often confused with {confused_class} ({count} instances)")

        print("\nImprovement Recommendations:")
        for category in rec['specific_recommendations']:
            print(f"\n{category['category']}:")
            for suggestion in category['suggestions']:
                print(f"- {suggestion}")

if __name__ == "__main__":
    main()