# COVID-19 Image Classification - Model Evaluation

This notebook evaluates trained models on the test dataset and generates comprehensive performance metrics.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import json
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix

import config
from dataset import get_data_loaders
from models import get_model
from utils import (
    set_seed, get_device, load_checkpoint,
    plot_confusion_matrix, plot_roc_curves,
    calculate_metrics, save_metrics
)

set_seed(config.RANDOM_SEED)
device = get_device()

print('Setup complete!')

## Load Data

In [None]:
# Load data loaders
train_loader, val_loader, test_loader, class_weights = get_data_loaders(
    config.DATASET_DIR,
    batch_size=32
)

class_names = config.CLASS_NAMES
print(f'Class names: {class_names}')

## Evaluate Models

Load and evaluate each trained model.

In [None]:
def evaluate_model(model_name, checkpoint_path):
    """Evaluate a single model"""
    print(f"\n{'='*80}")
    print(f"Evaluating {model_name.upper()}")
    print(f"{'='*80}")
    
    # Load model
    model = get_model(model_name, num_classes=config.NUM_CLASSES, pretrained=False)
    checkpoint = load_checkpoint(checkpoint_path, model)
    model = model.to(device)
    model.eval()
    
    # Evaluate on test set
    all_labels = []
    all_predictions = []
    all_probabilities = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    accuracy = 100. * correct / total
    
    return {
        'model_name': model_name,
        'accuracy': accuracy,
        'labels': np.array(all_labels),
        'predictions': np.array(all_predictions),
        'probabilities': np.array(all_probabilities),
        'checkpoint': checkpoint
}

In [None]:
# Evaluate all models
results = {}

models_to_evaluate = [
    ('vgg16', os.path.join(config.CHECKPOINT_DIR, 'vgg16_best.pth')),
    ('resnet50', os.path.join(config.CHECKPOINT_DIR, 'resnet50_best.pth')),
    ('densenet121', os.path.join(config.CHECKPOINT_DIR, 'densenet121_best.pth'))
]

for model_name, checkpoint_path in models_to_evaluate:
    if os.path.exists(checkpoint_path):
        results[model_name] = evaluate_model(model_name, checkpoint_path)
    else:
        print(f"Warning: Checkpoint not found for {model_name} at {checkpoint_path}")

## Results Summary

In [None]:
# Create summary table
summary_data = []

for model_name, result in results.items():
    val_acc = result['checkpoint'].get('best_val_acc', 'N/A')
    test_acc = result['accuracy']
    
    summary_data.append({
        'Model': model_name.upper(),
        'Validation Acc (%)': f"{val_acc:.2f}" if isinstance(val_acc, float) else val_acc,
        'Test Acc (%)': f"{test_acc:.2f}"
    })

summary_df = pd.DataFrame(summary_data)
print("\nModel Performance Summary:")
print(summary_df.to_string(index=False))

# Save to CSV
summary_df.to_csv(os.path.join(config.RESULTS_DIR, 'model_summary.csv'), index=False)
print(f"\nSummary saved to {os.path.join(config.RESULTS_DIR, 'model_summary.csv')}")

## Detailed Metrics for Each Model

In [None]:
for model_name, result in results.items():
    print(f"\n{'='*80}")
    print(f"{model_name.upper()} - Detailed Classification Report")
    print(f"{'='*80}")
    
    # Classification report
    report = classification_report(
        result['labels'], 
        result['predictions'], 
        target_names=class_names,
        digits=4
    )
    print(report)
    
    # Save metrics
    metrics = calculate_metrics(result['labels'], result['predictions'], class_names)
    metrics['test_accuracy'] = result['accuracy']
    save_metrics(metrics, os.path.join(config.RESULTS_DIR, f"{model_name}_metrics.json"))

## Confusion Matrices

In [None]:
# Plot confusion matrices for all models
fig, axes = plt.subplots(1, len(results), figsize=(18, 5))

if len(results) == 1:
    axes = [axes]

for idx, (model_name, result) in enumerate(results.items()):
    cm = confusion_matrix(result['labels'], result['predictions'])
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                ax=axes[idx], cbar_kws={'label': 'Count'})
    axes[idx].set_title(f'{model_name.upper()}\nAccuracy: {result["accuracy"]:.2f}%', 
                       fontsize=14, fontweight='bold')
    axes[idx].set_ylabel('True Label', fontsize=12)
    axes[idx].set_xlabel('Predicted Label', fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(config.RESULTS_DIR, 'confusion_matrices.png'), dpi=300, bbox_inches='tight')
plt.show()

## ROC Curves

In [None]:
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

colors = ['blue', 'red', 'green']
n_classes = len(class_names)

for model_name, result in results.items():
    print(f"\nGenerating ROC curves for {model_name.upper()}...")
    
    # Binarize labels
    y_true_bin = label_binarize(result['labels'], classes=range(n_classes))
    y_probs = result['probabilities']
    
    plt.figure(figsize=(10, 8))
    
    for i, (class_name, color) in enumerate(zip(class_names, colors)):
        fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
        roc_auc = auc(fpr, tpr)
        
        plt.plot(fpr, tpr, color=color, linewidth=2,
                label=f'{class_name} (AUC = {roc_auc:.3f})')
    
    plt.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Random Classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title(f'{model_name.upper()} - ROC Curves', fontsize=14, fontweight='bold')
    plt.legend(loc='lower right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plt.savefig(os.path.join(config.RESULTS_DIR, f'{model_name}_roc_curves.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()

## Per-Class Performance Comparison

In [None]:
# Compare per-class F1-scores
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

metrics_to_plot = ['precision', 'recall', 'f1-score']
titles = ['Precision', 'Recall', 'F1-Score']

for metric_idx, (metric, title) in enumerate(zip(metrics_to_plot, titles)):
    data = []
    
    for model_name in results.keys():
        metrics = calculate_metrics(
            results[model_name]['labels'], 
            results[model_name]['predictions'], 
            class_names
        )
        
        for class_name in class_names:
            data.append({
                'Model': model_name.upper(),
                'Class': class_name,
                'Score': metrics['per_class_metrics'][class_name][metric]
            })
    
    df = pd.DataFrame(data)
    
    # Plot grouped bar chart
    x = np.arange(len(class_names))
    width = 0.25
    
    for idx, model_name in enumerate(results.keys()):
        model_data = df[df['Model'] == model_name.upper()]
        axes[metric_idx].bar(x + idx*width, model_data['Score'], width, 
                            label=model_name.upper())
    
    axes[metric_idx].set_xlabel('Class', fontsize=12)
    axes[metric_idx].set_ylabel(title, fontsize=12)
    axes[metric_idx].set_title(f'Per-Class {title}', fontsize=14, fontweight='bold')
    axes[metric_idx].set_xticks(x + width)
    axes[metric_idx].set_xticklabels(class_names, rotation=15, ha='right')
    axes[metric_idx].legend()
    axes[metric_idx].grid(True, alpha=0.3, axis='y')
    axes[metric_idx].set_ylim([0, 1.1])

plt.tight_layout()
plt.savefig(os.path.join(config.RESULTS_DIR, 'per_class_metrics.png'), dpi=300, bbox_inches='tight')
plt.show()

## Conclusion

The evaluation is complete. Check the `results/` directory for saved metrics and visualizations.