# Facial Expression Recognition (FER) Model Training
## Using FER 2013 Dataset

This notebook builds and trains a Convolutional Neural Network (CNN) for facial expression recognition using the FER 2013 dataset. It's optimized to run on Google Colab with GPU acceleration.

### Emotions Recognized:
- 0: Angry 😠
- 1: Disgust 🤢
- 2: Fear 😨
- 3: Happy 😊
- 4: Sad 😢
- 5: Surprise 😲
- 6: Neutral 😐

## 1. Setup and Environment Check

In [None]:
# Check if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("✅ Running on Google Colab")
except:
    IN_COLAB = False
    print("❌ Not running on Google Colab")

# Check GPU availability
import torch
if torch.cuda.is_available():
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("⚠️ GPU not available, using CPU")
    device = torch.device('cpu')

print(f"Using device: {device}")

## 2. Install Required Packages

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install opencv-python-headless
!pip install matplotlib seaborn
!pip install scikit-learn
!pip install pandas numpy
!pip install pillow
!pip install tqdm

## 3. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from tqdm import tqdm
import os
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✅ All libraries imported successfully")

## 4. Download and Setup FER 2013 Dataset

In [None]:
# Download FER 2013 dataset from Kaggle
if IN_COLAB:
    # Mount Google Drive if needed
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Install Kaggle API
    !pip install kaggle
    
    # You'll need to upload your kaggle.json file or set up Kaggle credentials
    print("Please upload your kaggle.json file to access the FER 2013 dataset")
    print("Or download the dataset manually from: https://www.kaggle.com/datasets/msambare/fer2013")
    
    # Create directories
    os.makedirs('/content/data', exist_ok=True)
    
    # Uncomment the following lines after setting up Kaggle credentials:
    # !kaggle datasets download -d msambare/fer2013 -p /content/data
    # !unzip /content/data/fer2013.zip -d /content/data/
    
    dataset_path = '/content/data/fer2013'
else:
    # Local setup
    dataset_path = './data/fer2013'
    os.makedirs(dataset_path, exist_ok=True)
    print(f"Please download FER 2013 dataset to: {dataset_path}")

print(f"Dataset path: {dataset_path}")

## 5. Custom Dataset Class for FER 2013

In [None]:
class FER2013Dataset(Dataset):
    def __init__(self, data_dir, split='train', transform=None):
        """
        FER 2013 Dataset class
        
        Args:
            data_dir (str): Path to the dataset directory
            split (str): 'train', 'test', or 'validation'
            transform: Data transformations to apply
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        
        # Emotion labels
        self.emotion_labels = {
            'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3,
            'sad': 4, 'surprise': 5, 'neutral': 6
        }
        
        self.label_to_emotion = {v: k for k, v in self.emotion_labels.items()}
        
        # Load image paths and labels
        self.images = []
        self.labels = []
        
        split_dir = os.path.join(data_dir, split)
        
        if os.path.exists(split_dir):
            for emotion in self.emotion_labels.keys():
                emotion_dir = os.path.join(split_dir, emotion)
                if os.path.exists(emotion_dir):
                    for img_file in os.listdir(emotion_dir):
                        if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                            self.images.append(os.path.join(emotion_dir, img_file))
                            self.labels.append(self.emotion_labels[emotion])
        
        print(f"Loaded {len(self.images)} images for {split} split")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load image
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        if image is None:
            # Create a dummy image if loading fails
            image = np.zeros((48, 48), dtype=np.uint8)
        
        # Resize to 48x48 if needed
        if image.shape != (48, 48):
            image = cv2.resize(image, (48, 48))
        
        # Convert to PIL Image for transforms
        image = Image.fromarray(image)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

print("✅ FER2013Dataset class defined")

## 6. Data Transformations and Augmentation

In [None]:
# Define data transformations
train_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

val_test_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

print("✅ Data transformations defined")

## 7. Load Datasets

In [None]:
# Create datasets
try:
    train_dataset = FER2013Dataset(dataset_path, split='train', transform=train_transform)
    val_dataset = FER2013Dataset(dataset_path, split='validation', transform=val_test_transform)
    test_dataset = FER2013Dataset(dataset_path, split='test', transform=val_test_transform)
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please make sure the FER 2013 dataset is properly downloaded and extracted.")
    
    # Create dummy datasets for demonstration
    print("Creating dummy datasets for demonstration...")
    
    class DummyDataset(Dataset):
        def __init__(self, size=1000, transform=None):
            self.size = size
            self.transform = transform
        
        def __len__(self):
            return self.size
        
        def __getitem__(self, idx):
            # Generate random grayscale image
            image = torch.randn(1, 48, 48)
            label = torch.randint(0, 7, (1,)).item()
            return image, label
    
    train_dataset = DummyDataset(size=20000)
    val_dataset = DummyDataset(size=3000)
    test_dataset = DummyDataset(size=3000)

## 8. Create Data Loaders

In [None]:
# Set batch size based on available memory
BATCH_SIZE = 64 if device.type == 'cuda' else 32
NUM_WORKERS = 2 if IN_COLAB else 0

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

print(f"✅ Data loaders created with batch size: {BATCH_SIZE}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 9. Visualize Sample Data

In [None]:
# Emotion labels for visualization
emotion_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
emotion_emojis = ['😠', '🤢', '😨', '😊', '😢', '😲', '😐']

# Visualize some sample images
def visualize_samples(data_loader, num_samples=8):
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.ravel()
    
    for i in range(min(num_samples, len(images))):
        # Convert tensor to numpy and denormalize
        img = images[i].squeeze().numpy()
        img = (img + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'{emotion_names[labels[i]]} {emotion_emojis[labels[i]]}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Sample images from training set:")
visualize_samples(train_loader)

## 10. Define CNN Model Architecture

In [None]:
class EmotionCNN(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.5):
        super(EmotionCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        # Pooling
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Dropout
        self.dropout = nn.Dropout(dropout_rate)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        # First conv block
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        
        # Second conv block
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        
        # Global average pooling
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

# Create model instance
model = EmotionCNN(num_classes=7, dropout_rate=0.5)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✅ Model created and moved to {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Model summary
print("\nModel Architecture:")
print(model)

## 11. Define Loss Function and Optimizer

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

print("✅ Loss function, optimizer, and scheduler defined")
print(f"Loss function: {criterion}")
print(f"Optimizer: {optimizer}")
print(f"Scheduler: {scheduler}")

## 12. Training Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train the model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc='Training')
    
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate_epoch(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc='Validation')
        
        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

print("✅ Training and validation functions defined")

## 13. Training Loop

In [None]:
# Training parameters
NUM_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
BEST_VAL_LOSS = float('inf')
PATIENCE_COUNTER = 0

# Lists to store training history
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

print(f"Starting training for {NUM_EPOCHS} epochs...")
print(f"Device: {device}")
print(f"Batch size: {BATCH_SIZE}")
print("-" * 60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 40)
    
    # Training phase
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validation phase
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate scheduler
    scheduler.step(val_loss)
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Early stopping and model saving
    if val_loss < BEST_VAL_LOSS:
        BEST_VAL_LOSS = val_loss
        PATIENCE_COUNTER = 0
        
        # Save best model
        if IN_COLAB:
            torch.save(model.state_dict(), '/content/best_fer_model.pth')
        else:
            torch.save(model.state_dict(), 'best_fer_model.pth')
        
        print(f"✅ New best model saved! Val Loss: {val_loss:.4f}")
    else:
        PATIENCE_COUNTER += 1
        print(f"⏳ No improvement. Patience: {PATIENCE_COUNTER}/{EARLY_STOPPING_PATIENCE}")
    
    # Early stopping
    if PATIENCE_COUNTER >= EARLY_STOPPING_PATIENCE:
        print(f"\n🛑 Early stopping triggered after {epoch+1} epochs")
        break

print("\n🎉 Training completed!")
print(f"Best validation loss: {BEST_VAL_LOSS:.4f}")

## 14. Plot Training History

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot loss
ax1.plot(train_losses, label='Training Loss', color='blue')
ax1.plot(val_losses, label='Validation Loss', color='red')
ax1.set_title('Model Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Plot accuracy
ax2.plot(train_accuracies, label='Training Accuracy', color='blue')
ax2.plot(val_accuracies, label='Validation Accuracy', color='red')
ax2.set_title('Model Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

# Print final metrics
print(f"Final Training Accuracy: {train_accuracies[-1]:.2f}%")
print(f"Final Validation Accuracy: {val_accuracies[-1]:.2f}%")
print(f"Best Validation Loss: {min(val_losses):.4f}")
print(f"Best Validation Accuracy: {max(val_accuracies):.2f}%")

## 15. Load Best Model and Test

In [None]:
# Load the best model
try:
    if IN_COLAB:
        model.load_state_dict(torch.load('/content/best_fer_model.pth'))
    else:
        model.load_state_dict(torch.load('best_fer_model.pth'))
    print("✅ Best model loaded successfully")
except:
    print("⚠️ Could not load saved model, using current model")

# Test the model
def test_model(model, test_loader, device):
    model.eval()
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Testing')
        
        for data, target in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    return np.array(all_predictions), np.array(all_targets)

# Run test
print("Testing the model...")
test_predictions, test_targets = test_model(model, test_loader, device)

# Calculate test accuracy
test_accuracy = accuracy_score(test_targets, test_predictions)
print(f"\n🎯 Test Accuracy: {test_accuracy*100:.2f}%")

## 16. Detailed Evaluation Metrics

In [None]:
# Classification report
print("📊 Classification Report:")
print("-" * 50)
class_report = classification_report(
    test_targets, 
    test_predictions, 
    target_names=emotion_names,
    digits=4
)
print(class_report)

# Confusion Matrix
cm = confusion_matrix(test_targets, test_predictions)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=[f'{name}\n{emoji}' for name, emoji in zip(emotion_names, emotion_emojis)],
    yticklabels=[f'{name}\n{emoji}' for name, emoji in zip(emotion_names, emotion_emojis)]
)
plt.title('Confusion Matrix - Facial Expression Recognition')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Per-class accuracy
print("\n📈 Per-class Accuracy:")
print("-" * 30)
for i, emotion in enumerate(emotion_names):
    class_mask = test_targets == i
    if np.sum(class_mask) > 0:
        class_acc = np.sum(test_predictions[class_mask] == i) / np.sum(class_mask)
        print(f"{emotion} {emotion_emojis[i]}: {class_acc*100:.2f}%")

## 17. Sample Predictions Visualization

In [None]:
def visualize_predictions(model, test_loader, device, num_samples=12):
    """Visualize sample predictions"""
    model.eval()
    
    # Get a batch of test data
    data_iter = iter(test_loader)
    images, labels = next(data_iter)
    images, labels = images.to(device), labels.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        probabilities = F.softmax(outputs, dim=1)
    
    # Move back to CPU for visualization
    images = images.cpu()
    labels = labels.cpu()
    predictions = predictions.cpu()
    probabilities = probabilities.cpu()
    
    # Create subplot
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.ravel()
    
    for i in range(min(num_samples, len(images))):
        # Denormalize image
        img = images[i].squeeze().numpy()
        img = (img + 1) / 2  # Convert from [-1, 1] to [0, 1]
        
        # Get prediction info
        true_label = labels[i].item()
        pred_label = predictions[i].item()
        confidence = probabilities[i][pred_label].item()
        
        # Set title color based on correctness
        color = 'green' if true_label == pred_label else 'red'
        
        # Plot image
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(
            f'True: {emotion_names[true_label]} {emotion_emojis[true_label]}\n'
            f'Pred: {emotion_names[pred_label]} {emotion_emojis[pred_label]}\n'
            f'Conf: {confidence:.2f}',
            color=color,
            fontsize=10
        )
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Sample predictions (Green = Correct, Red = Incorrect):")
visualize_predictions(model, test_loader, device)

## 18. Save Final Model and Results

In [None]:
# Save the final model
if IN_COLAB:
    model_path = '/content/fer2013_final_model.pth'
    results_path = '/content/training_results.json'
else:
    model_path = 'fer2013_final_model.pth'
    results_path = 'training_results.json'

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': test_accuracy,
    'best_val_loss': BEST_VAL_LOSS,
    'emotion_labels': emotion_names,
    'model_architecture': str(model)
}, model_path)

print(f"✅ Model saved to: {model_path}")

# Save training results
import json

results = {
    'test_accuracy': float(test_accuracy),
    'best_validation_loss': float(BEST_VAL_LOSS),
    'best_validation_accuracy': float(max(val_accuracies)),
    'final_training_accuracy': float(train_accuracies[-1]),
    'final_validation_accuracy': float(val_accuracies[-1]),
    'epochs_trained': len(train_losses),
    'train_losses': [float(x) for x in train_losses],
    'train_accuracies': [float(x) for x in train_accuracies],
    'val_losses': [float(x) for x in val_losses],
    'val_accuracies': [float(x) for x in val_accuracies],
    'emotion_labels': emotion_names,
    'model_parameters': total_params
}

with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"✅ Results saved to: {results_path}")

# Summary
print("\n" + "="*60)
print("🎉 TRAINING SUMMARY")
print("="*60)
print(f"📊 Test Accuracy: {test_accuracy*100:.2f}%")
print(f"📈 Best Validation Accuracy: {max(val_accuracies):.2f}%")
print(f"📉 Best Validation Loss: {BEST_VAL_LOSS:.4f}")
print(f"🕐 Epochs Trained: {len(train_losses)}")
print(f"🔧 Model Parameters: {total_params:,}")
print(f"💾 Model saved to: {model_path}")
print(f"📋 Results saved to: {results_path}")
print("="*60)

## 19. Model Inference Function

In [None]:
def predict_emotion(model, image_path, transform, device):
    """
    Predict emotion from a single image
    
    Args:
        model: Trained emotion recognition model
        image_path: Path to the image file
        transform: Image preprocessing transforms
        device: Device to run inference on
    
    Returns:
        predicted_emotion: Predicted emotion name
        confidence: Prediction confidence
        probabilities: All class probabilities
    """
    model.eval()
    
    # Load and preprocess image
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        raise ValueError(f"Could not load image from {image_path}")
    
    # Resize to 48x48
    image = cv2.resize(image, (48, 48))
    
    # Convert to PIL Image and apply transforms
    image = Image.fromarray(image)
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = F.softmax(output, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    predicted_emotion = emotion_names[predicted.item()]
    confidence_score = confidence.item()
    all_probs = probabilities.cpu().numpy()[0]
    
    return predicted_emotion, confidence_score, all_probs

def display_prediction_results(emotion, confidence, probabilities):
    """
    Display prediction results in a formatted way
    """
    print(f"\n🎯 Predicted Emotion: {emotion} {emotion_emojis[emotion_names.index(emotion)]}")
    print(f"📊 Confidence: {confidence:.4f} ({confidence*100:.2f}%)")
    print("\n📈 All Probabilities:")
    print("-" * 40)
    
    for i, (emotion_name, emoji, prob) in enumerate(zip(emotion_names, emotion_emojis, probabilities)):
        print(f"{emotion_name} {emoji}: {prob:.4f} ({prob*100:.2f}%)")

print("✅ Inference functions defined")
print("\nTo use the model for inference on a new image:")
print("```python")
print("emotion, confidence, probs = predict_emotion(model, 'path/to/image.jpg', val_test_transform, device)")
print("display_prediction_results(emotion, confidence, probs)")
print("```")

## 20. Instructions for Using the Trained Model

### 🚀 How to Use Your Trained FER Model

#### **Loading the Model:**
```python
# Create model instance
model = EmotionCNN(num_classes=7)
model.load_state_dict(torch.load('fer2013_final_model.pth')['model_state_dict'])
model.eval()
```

#### **Making Predictions:**
```python
# For a single image
emotion, confidence, probs = predict_emotion(
    model, 
    'path/to/your/image.jpg', 
    val_test_transform, 
    device
)
display_prediction_results(emotion, confidence, probs)
```

#### **Integration Tips:**
- The model expects 48x48 grayscale images
- Use the same preprocessing transforms as during training
- The model outputs 7 emotion classes: Angry, Disgust, Fear, Happy, Sad, Surprise, Neutral
- For real-time applications, consider using GPU acceleration

#### **Performance Expectations:**
- Test accuracy achieved: **{test_accuracy*100:.2f}%**
- Best validation accuracy: **{max(val_accuracies):.2f}%**
- Model size: **{total_params:,} parameters**

#### **Next Steps:**
1. **Fine-tuning**: Retrain on your specific domain data
2. **Deployment**: Convert to ONNX for production use
3. **Integration**: Combine with face detection for end-to-end emotion recognition
4. **Evaluation**: Test on your specific use case data

---

### 📚 Additional Resources:
- [PyTorch Documentation](https://pytorch.org/docs/)
- [FER 2013 Dataset](https://www.kaggle.com/datasets/msambare/fer2013)
- [OpenCV Face Detection](https://docs.opencv.org/4.x/db/d28/tutorial_cascade_classifier.html)

### 🎉 Congratulations!
You have successfully trained a Facial Expression Recognition model! The model is ready for inference and can be integrated into your applications.