# CIFAR-10 Image Classification Analysis

This notebook demonstrates how to use the trained CIFAR-10 model for image classification and analysis.

In [None]:
# Import required libraries
import sys
import os
sys.path.append('../api')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from main import ImageClassifier

print("✅ Libraries imported successfully")

In [None]:
# Initialize the classifier
classifier = ImageClassifier()
model_info = classifier.get_model_info()

print("Model Information:")
for key, value in model_info.items():
    print(f"  {key}: {value}")

In [None]:
# Load and display CIFAR-10 sample images
from torchvision import datasets, transforms

# Load CIFAR-10 test dataset
transform = transforms.Compose([
    transforms.ToTensor()
])

test_dataset = datasets.CIFAR10(
    root='../cifar10_data', 
    train=False, 
    download=True, 
    transform=transform
)

# Display sample images
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
fig.suptitle('CIFAR-10 Sample Images', fontsize=16)

for i in range(10):
    row = i // 5
    col = i % 5
    
    image, label = test_dataset[i]
    image_np = image.permute(1, 2, 0).numpy()
    
    axes[row, col].imshow(image_np)
    axes[row, col].set_title(f'{test_dataset.classes[label]}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Test model predictions on sample images
def test_prediction(image_idx):
    """Test model prediction on a specific image from the dataset."""
    image, true_label = test_dataset[image_idx]
    
    # Convert tensor to PIL Image
    image_pil = transforms.ToPILImage()(image)
    
    # Get prediction
    results = classifier.predict(image_pil, top_k=3)
    
    # Display results
    plt.figure(figsize=(10, 4))
    
    # Show image
    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0))
    plt.title(f'True Label: {test_dataset.classes[true_label]}')
    plt.axis('off')
    
    # Show predictions
    plt.subplot(1, 2, 2)
    classes = [pred['class'] for pred in results['top_predictions']]
    confidences = [pred['confidence'] for pred in results['top_predictions']]
    
    bars = plt.barh(classes, confidences)
    plt.xlabel('Confidence')
    plt.title('Top 3 Predictions')
    plt.xlim(0, 1)
    
    # Color the correct prediction green if it's in top 3
    true_class = test_dataset.classes[true_label]
    for i, (bar, class_name) in enumerate(zip(bars, classes)):
        if class_name == true_class:
            bar.set_color('green')
        else:
            bar.set_color('lightblue')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Predicted: {results['predicted_class']} (confidence: {results['confidence']:.3f})")
    print(f"Actual: {true_class}")
    print(f"Correct: {'✅' if results['predicted_class'] == true_class else '❌'}")

# Test on a few sample images
for i in [0, 10, 20, 30, 40]:
    print(f"\n--- Testing Image {i} ---")
    test_prediction(i)

In [None]:
# Analyze model performance on a subset of test data
def analyze_performance(num_samples=100):
    """Analyze model performance on a subset of test data."""
    correct = 0
    total = 0
    class_correct = {class_name: 0 for class_name in test_dataset.classes}
    class_total = {class_name: 0 for class_name in test_dataset.classes}
    
    print(f"Analyzing performance on {num_samples} test samples...")
    
    for i in range(min(num_samples, len(test_dataset))):
        image, true_label = test_dataset[i]
        image_pil = transforms.ToPILImage()(image)
        
        try:
            results = classifier.predict(image_pil)
            predicted_class = results['predicted_class']
            true_class = test_dataset.classes[true_label]
            
            total += 1
            class_total[true_class] += 1
            
            if predicted_class == true_class:
                correct += 1
                class_correct[true_class] += 1
                
        except Exception as e:
            print(f"Error processing image {i}: {e}")
            continue
    
    # Overall accuracy
    overall_accuracy = correct / total if total > 0 else 0
    print(f"\nOverall Accuracy: {overall_accuracy:.3f} ({correct}/{total})")
    
    # Per-class accuracy
    print("\nPer-class Accuracy:")
    for class_name in test_dataset.classes:
        if class_total[class_name] > 0:
            class_acc = class_correct[class_name] / class_total[class_name]
            print(f"  {class_name}: {class_acc:.3f} ({class_correct[class_name]}/{class_total[class_name]})")
    
    return overall_accuracy, class_correct, class_total

# Run performance analysis
accuracy, class_correct, class_total = analyze_performance(100)

## Conclusion

This notebook demonstrates:
1. ✅ Loading the pre-trained CIFAR-10 model
2. ✅ Making predictions on sample images
3. ✅ Visualizing results and model performance
4. ✅ Running in isolated virtual environment with Jupyter

The model is ready for production use through the FastAPI backend!