# Multimodal Emotion Recognition with Late Fusion

This notebook demonstrates how to combine trained Facial Expression Recognition (FER) and Textual Emotion Recognition (TER) models using late fusion to create a powerful multimodal emotion recognition system.

## Overview

- **FER Model**: CNN-based facial expression recognition trained on FER2013 dataset
- **TER Model**: DistilBERT-based textual emotion recognition trained on emotion datasets
- **Fusion Strategy**: Late fusion combining predictions from both modalities
- **Target**: Improved emotion recognition accuracy through multimodal learning

## Emotions Recognized:
- **Angry** 😠
- **Disgust** 🤢  
- **Fear** 😨
- **Happy** 😊
- **Sad** 😢
- **Surprise** 😲
- **Neutral** 😐

The notebook is optimized to run on Google Colab with GPU acceleration for efficient training and inference.

## 1. Environment Setup and Google Drive Mount

Set up the environment, check for GPU availability, and mount Google Drive to access pre-trained models.

In [None]:
# Check if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("✅ Running on Google Colab")
    
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    print("✅ Google Drive mounted successfully")
    
    # Set paths for Google Drive
    FER_MODEL_PATH = "/content/drive/MyDrive/FER_Model_Data/models"
    TER_MODEL_PATH = "/content/drive/MyDrive/TER_Models/ter_distilbert_model"
    FUSION_MODEL_PATH = "/content/drive/MyDrive/Fusion_Models"
    
except ImportError:
    IN_COLAB = False
    print("❌ Not running on Google Colab")
    
    # Set local paths
    FER_MODEL_PATH = "./fer_models"
    TER_MODEL_PATH = "./ter_models"
    FUSION_MODEL_PATH = "./fusion_models"

# Check CUDA availability
import torch
if torch.cuda.is_available():
    print(f"✅ CUDA available: {torch.version.cuda}")
    print(f"✅ GPU device: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("⚠️ CUDA not available, using CPU")
    device = torch.device('cpu')

print(f"🎯 Using device: {device}")

# Create fusion model directory
import os
os.makedirs(FUSION_MODEL_PATH, exist_ok=True)

print(f"\n📁 Model paths configured:")
print(f"   FER models: {FER_MODEL_PATH}")
print(f"   TER models: {TER_MODEL_PATH}")
print(f"   Fusion models: {FUSION_MODEL_PATH}")

# Constants for emotions (Ekman's basic emotions)
EMOTION_LABELS = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
NUM_CLASSES = len(EMOTION_LABELS)
EMOTION_NAMES = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
EMOTION_EMOJIS = ['😠', '🤢', '😨', '😊', '😢', '😲', '😐']

print(f"\n🎭 Emotion classes: {NUM_CLASSES}")
print(f"   Labels: {EMOTION_LABELS}")

## 2. Import Required Libraries

Import all necessary libraries for loading models, creating fusion architecture, and training.

In [None]:
# Install required packages if in Colab
if IN_COLAB:
    !pip install transformers torch torchvision torchaudio opencv-python-headless -q
    !pip install matplotlib seaborn scikit-learn pandas numpy tqdm pillow -q
    print("✅ Packages installed in Colab")

# Core libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

# Computer Vision
import torchvision.transforms as transforms
import cv2
from PIL import Image

# NLP and Transformers
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

# Data handling and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.preprocessing import LabelEncoder

# Utilities
import os
import json
import pickle
import random
import warnings
from tqdm.auto import tqdm
from datetime import datetime
import re

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

# Suppress warnings
warnings.filterwarnings('ignore')

print("✅ All libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

# Set matplotlib style
plt.style.use('default')
sns.set_palette("husl")

## 3. Load Pre-trained FER Model

Load the trained CNN-based Facial Expression Recognition model from Google Drive.

In [None]:
# Define the FER CNN architecture (must match the trained model)
class EmotionCNN(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.5):
        super(EmotionCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        # Pooling
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Dropout
        self.dropout = nn.Dropout(dropout_rate)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        # First conv block
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        
        # Second conv block
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        
        # Global average pooling
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

# Load the FER model
print("Loading FER model...")

try:
    # Create model instance
    fer_model = EmotionCNN(num_classes=NUM_CLASSES, dropout_rate=0.5)
    
    # Try different possible model file names
    fer_model_files = [
        'fer2013_final_model.pth',
        'best_fer_model.pth', 
        'emotion_cnn_model.pth'
    ]
    
    fer_model_loaded = False
    for model_file in fer_model_files:
        fer_model_path = os.path.join(FER_MODEL_PATH, model_file)
        if os.path.exists(fer_model_path):
            try:
                checkpoint = torch.load(fer_model_path, map_location=device)
                
                # Handle different checkpoint formats
                if isinstance(checkpoint, dict):
                    if 'model_state_dict' in checkpoint:
                        fer_model.load_state_dict(checkpoint['model_state_dict'])
                    elif 'state_dict' in checkpoint:
                        fer_model.load_state_dict(checkpoint['state_dict'])
                    else:
                        fer_model.load_state_dict(checkpoint)
                else:
                    fer_model.load_state_dict(checkpoint)
                
                fer_model.to(device)
                fer_model.eval()
                fer_model_loaded = True
                print(f"✅ FER model loaded successfully from: {model_file}")
                break
                
            except Exception as e:
                print(f"⚠️ Failed to load {model_file}: {e}")
                continue
    
    if not fer_model_loaded:
        print("❌ Could not load FER model from any expected file")
        print(f"Expected files in {FER_MODEL_PATH}:")
        for f in fer_model_files:
            print(f"  - {f}")
        
        # Create a dummy model for demonstration
        print("Creating dummy FER model for demonstration...")
        fer_model = EmotionCNN(num_classes=NUM_CLASSES)
        fer_model.to(device)
        fer_model.eval()
        print("⚠️ Using randomly initialized FER model")

except Exception as e:
    print(f"❌ Error loading FER model: {e}")
    # Create dummy model
    fer_model = EmotionCNN(num_classes=NUM_CLASSES)
    fer_model.to(device)
    fer_model.eval()
    print("⚠️ Using randomly initialized FER model")

# Define FER preprocessing transforms
fer_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

print(f"✅ FER model ready on {device}")
print(f"Model parameters: {sum(p.numel() for p in fer_model.parameters()):,}")

# Test FER model with dummy input
try:
    dummy_image = torch.randn(1, 1, 48, 48).to(device)
    with torch.no_grad():
        fer_output = fer_model(dummy_image)
    print(f"✅ FER model test successful - Output shape: {fer_output.shape}")
except Exception as e:
    print(f"❌ FER model test failed: {e}")

## 4. Load Pre-trained TER Model

Load the trained DistilBERT-based Textual Emotion Recognition model from Google Drive.

In [None]:
# Load the TER model (DistilBERT)
print("Loading TER model...")

try:
    # Check if TER model directory exists
    if os.path.exists(TER_MODEL_PATH):
        print(f"📁 TER model directory found: {TER_MODEL_PATH}")
        
        # Load DistilBERT model
        ter_model = DistilBertForSequenceClassification.from_pretrained(
            TER_MODEL_PATH,
            num_labels=NUM_CLASSES,
            output_attentions=False,
            output_hidden_states=False
        )
        
        # Load tokenizer
        ter_tokenizer = DistilBertTokenizer.from_pretrained(TER_MODEL_PATH)
        
        # Load label encoder
        label_encoder_path = os.path.join(TER_MODEL_PATH, 'label_encoder.pkl')
        if os.path.exists(label_encoder_path):
            with open(label_encoder_path, 'rb') as f:
                ter_label_encoder = pickle.load(f)
            print("✅ Label encoder loaded successfully")
        else:
            # Create default label encoder
            ter_label_encoder = LabelEncoder()
            ter_label_encoder.fit(EMOTION_LABELS)
            print("⚠️ Created default label encoder")
        
        # Load training config
        config_path = os.path.join(TER_MODEL_PATH, 'training_config.pkl')
        if os.path.exists(config_path):
            with open(config_path, 'rb') as f:
                ter_config = pickle.load(f)
            print("✅ Training config loaded successfully")
        else:
            ter_config = {
                'max_length': 128,
                'num_classes': NUM_CLASSES,
                'emotion_labels': EMOTION_LABELS
            }
            print("⚠️ Created default config")
        
        ter_model.to(device)
        ter_model.eval()
        
        print(f"✅ TER model loaded successfully")
        print(f"Model config: {ter_config.get('max_length', 128)} max tokens")
        
    else:
        print(f"❌ TER model directory not found: {TER_MODEL_PATH}")
        raise FileNotFoundError("TER model not found")
        
except Exception as e:
    print(f"❌ Error loading TER model: {e}")
    print("Creating dummy TER model for demonstration...")
    
    # Create dummy TER model
    ter_model = DistilBertForSequenceClassification.from_pretrained(
        'distilbert-base-uncased',
        num_labels=NUM_CLASSES,
        output_attentions=False,
        output_hidden_states=False
    )
    ter_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    
    # Create default label encoder
    ter_label_encoder = LabelEncoder()
    ter_label_encoder.fit(EMOTION_LABELS)
    
    ter_config = {
        'max_length': 128,
        'num_classes': NUM_CLASSES,
        'emotion_labels': EMOTION_LABELS
    }
    
    ter_model.to(device)
    ter_model.eval()
    print("⚠️ Using pre-trained DistilBERT (not fine-tuned)")

# Text preprocessing function
def clean_text(text):
    """Clean and preprocess text data"""
    text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9\s\.\!\?\,\;\:]', '', text)
    text = ' '.join(text.split())
    return text

# TER prediction function
def predict_ter(text, max_length=None):
    """Predict emotion from text using TER model"""
    if max_length is None:
        max_length = ter_config.get('max_length', 128)
    
    # Clean text
    cleaned_text = clean_text(text)
    
    # Tokenize
    encoding = ter_tokenizer(
        cleaned_text,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors='pt'
    )
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Make prediction
    with torch.no_grad():
        outputs = ter_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    
    return logits

print(f"✅ TER model ready on {device}")
print(f"Model parameters: {sum(p.numel() for p in ter_model.parameters()):,}")

# Test TER model with dummy input
try:
    test_text = "I am feeling happy today!"
    ter_output = predict_ter(test_text)
    print(f"✅ TER model test successful - Output shape: {ter_output.shape}")
    
    # Show prediction
    probabilities = F.softmax(ter_output, dim=1)
    predicted_class = torch.argmax(ter_output, dim=1).item()
    confidence = probabilities[0][predicted_class].item()
    predicted_emotion = ter_label_encoder.inverse_transform([predicted_class])[0]
    
    print(f"Test prediction: '{test_text}' -> {predicted_emotion} (confidence: {confidence:.3f})")
    
except Exception as e:
    print(f"❌ TER model test failed: {e}")

## 5. Define Late Fusion Architecture

Create a late fusion neural network that combines predictions from both FER and TER models.

In [None]:
class LateFusionModel(nn.Module):
    """
    Late Fusion Model that combines FER and TER predictions
    
    This model takes the output logits from both FER and TER models
    and learns to fuse them for improved emotion recognition.
    """
    
    def __init__(self, num_classes=7, fusion_type='weighted', hidden_dim=128):
        super(LateFusionModel, self).__init__()
        
        self.num_classes = num_classes
        self.fusion_type = fusion_type
        self.hidden_dim = hidden_dim
        
        if fusion_type == 'weighted':
            # Learnable weights for each modality
            self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
            
        elif fusion_type == 'mlp':
            # MLP fusion network
            self.fusion_mlp = nn.Sequential(
                nn.Linear(num_classes * 2, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(hidden_dim // 2, num_classes)
            )
            
        elif fusion_type == 'attention':
            # Attention-based fusion
            self.attention = nn.MultiheadAttention(
                embed_dim=num_classes,
                num_heads=1,
                batch_first=True
            )
            self.output_projection = nn.Linear(num_classes, num_classes)
            
        elif fusion_type == 'bilinear':
            # Bilinear fusion
            self.bilinear = nn.Bilinear(num_classes, num_classes, hidden_dim)
            self.output_layer = nn.Linear(hidden_dim, num_classes)
            
        else:  # simple averaging
            pass
    
    def forward(self, fer_logits, ter_logits):
        """
        Forward pass for late fusion
        
        Args:
            fer_logits: Logits from FER model [batch_size, num_classes]
            ter_logits: Logits from TER model [batch_size, num_classes]
            
        Returns:
            fused_logits: Combined logits [batch_size, num_classes]
            fusion_weights: Attention weights (if applicable)
        """
        
        if self.fusion_type == 'simple':
            # Simple averaging
            fused_logits = (fer_logits + ter_logits) / 2
            fusion_weights = torch.tensor([0.5, 0.5]).to(fer_logits.device)
            
        elif self.fusion_type == 'weighted':
            # Learnable weighted combination
            weights = F.softmax(self.fusion_weights, dim=0)
            fused_logits = weights[0] * fer_logits + weights[1] * ter_logits
            fusion_weights = weights
            
        elif self.fusion_type == 'mlp':
            # MLP-based fusion
            concatenated = torch.cat([fer_logits, ter_logits], dim=1)
            fused_logits = self.fusion_mlp(concatenated)
            fusion_weights = None
            
        elif self.fusion_type == 'attention':
            # Attention-based fusion
            # Stack the logits for attention
            stacked_logits = torch.stack([fer_logits, ter_logits], dim=1)  # [batch, 2, num_classes]
            
            # Apply self-attention
            attended_logits, attention_weights = self.attention(
                stacked_logits, stacked_logits, stacked_logits
            )
            
            # Weighted sum based on attention
            fused_logits = torch.sum(attended_logits, dim=1)  # [batch, num_classes]
            fused_logits = self.output_projection(fused_logits)
            fusion_weights = attention_weights.mean(dim=1)  # Average across heads
            
        elif self.fusion_type == 'bilinear':
            # Bilinear fusion
            bilinear_output = self.bilinear(fer_logits, ter_logits)
            fused_logits = self.output_layer(F.relu(bilinear_output))
            fusion_weights = None
            
        return fused_logits, fusion_weights

class MultiModalEmotionRecognizer(nn.Module):
    """
    Complete multimodal emotion recognition system
    """
    
    def __init__(self, fer_model, ter_model, ter_tokenizer, fusion_type='weighted'):
        super(MultiModalEmotionRecognizer, self).__init__()
        
        self.fer_model = fer_model
        self.ter_model = ter_model
        self.ter_tokenizer = ter_tokenizer
        
        # Freeze pre-trained models (optional)
        self.freeze_pretrained = True
        if self.freeze_pretrained:
            for param in self.fer_model.parameters():
                param.requires_grad = False
            for param in self.ter_model.parameters():
                param.requires_grad = False
        
        # Late fusion module
        self.fusion_model = LateFusionModel(
            num_classes=NUM_CLASSES,
            fusion_type=fusion_type
        )
        
    def forward(self, images, input_ids, attention_mask):
        """
        Forward pass for multimodal emotion recognition
        
        Args:
            images: Image tensor [batch_size, 1, 48, 48]
            input_ids: Tokenized text [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            fused_logits: Final emotion predictions
            fer_logits: FER model predictions
            ter_logits: TER model predictions
            fusion_weights: Fusion weights (if applicable)
        """
        
        # Get FER predictions
        fer_logits = self.fer_model(images)
        
        # Get TER predictions
        ter_outputs = self.ter_model(input_ids=input_ids, attention_mask=attention_mask)
        ter_logits = ter_outputs.logits
        
        # Fuse predictions
        fused_logits, fusion_weights = self.fusion_model(fer_logits, ter_logits)
        
        return fused_logits, fer_logits, ter_logits, fusion_weights

# Create fusion models with different strategies
print("Creating late fusion models...")

fusion_strategies = ['simple', 'weighted', 'mlp', 'attention', 'bilinear']
fusion_models = {}

for strategy in fusion_strategies:
    try:
        model = MultiModalEmotionRecognizer(
            fer_model=fer_model,
            ter_model=ter_model,
            ter_tokenizer=ter_tokenizer,
            fusion_type=strategy
        )
        model.to(device)
        fusion_models[strategy] = model
        
        # Count trainable parameters
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        
        print(f"✅ {strategy.upper()} fusion model created")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")
        
    except Exception as e:
        print(f"❌ Failed to create {strategy} fusion model: {e}")

# Select default model for training (weighted fusion)
default_fusion_model = fusion_models.get('weighted', list(fusion_models.values())[0])
print(f"\n🎯 Using {default_fusion_model.fusion_model.fusion_type} fusion as default model")

# Test multimodal model
print("\n🧪 Testing multimodal model...")
try:
    # Create dummy inputs
    dummy_images = torch.randn(2, 1, 48, 48).to(device)
    dummy_texts = ["I am feeling happy!", "This is terrible and scary."]
    
    # Tokenize text
    encoding = ter_tokenizer(
        dummy_texts,
        truncation=True,
        padding=True,
        max_length=ter_config.get('max_length', 128),
        return_tensors='pt'
    )
    dummy_input_ids = encoding['input_ids'].to(device)
    dummy_attention_mask = encoding['attention_mask'].to(device)
    
    # Test forward pass
    with torch.no_grad():
        fused_logits, fer_logits, ter_logits, fusion_weights = default_fusion_model(
            dummy_images, dummy_input_ids, dummy_attention_mask
        )
    
    print(f"✅ Multimodal test successful!")
    print(f"   Input shapes: Images {dummy_images.shape}, Text {dummy_input_ids.shape}")
    print(f"   FER output: {fer_logits.shape}")
    print(f"   TER output: {ter_logits.shape}")
    print(f"   Fused output: {fused_logits.shape}")
    if fusion_weights is not None:
        print(f"   Fusion weights: {fusion_weights}")
    
except Exception as e:
    print(f"❌ Multimodal test failed: {e}")

## 6. Create Combined Dataset Class

Implement a custom PyTorch dataset class for multimodal emotion recognition training.

In [None]:
class MultiModalEmotionDataset(Dataset):
    """
    Dataset class for multimodal emotion recognition
    Combines images and text for the same emotion labels
    """
    
    def __init__(self, image_texts_pairs, labels, tokenizer, image_transform=None, max_length=128):
        """
        Args:
            image_texts_pairs: List of tuples (image_path_or_array, text)
            labels: List of emotion labels (numeric)
            tokenizer: Text tokenizer
            image_transform: Image preprocessing transforms
            max_length: Maximum text sequence length
        """
        self.data = image_texts_pairs
        self.labels = labels
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_length = max_length
        
        assert len(self.data) == len(self.labels), "Data and labels must have same length"
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_data, text = self.data[idx]
        label = self.labels[idx]
        
        # Process image
        if isinstance(image_data, str):
            # Load image from path
            image = cv2.imread(image_data, cv2.IMREAD_GRAYSCALE)
            if image is None:
                # Create dummy image if loading fails
                image = np.random.randint(0, 255, (48, 48), dtype=np.uint8)
        else:
            # Use provided image array
            image = image_data
        
        # Ensure image is 48x48
        if image.shape != (48, 48):
            image = cv2.resize(image, (48, 48))
        
        # Convert to PIL and apply transforms
        image = Image.fromarray(image)
        if self.image_transform:
            image = self.image_transform(image)
        else:
            # Default transform
            image = transforms.ToTensor()(image)
            image = transforms.Normalize(mean=[0.5], std=[0.5])(image)
        
        # Process text
        cleaned_text = clean_text(text)
        
        # Tokenize
        encoding = self.tokenizer(
            cleaned_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long),
            'text': text  # Keep original text for reference
        }

# Create synthetic multimodal dataset for demonstration
def create_synthetic_multimodal_dataset(num_samples=1000):
    """
    Create a synthetic dataset with paired images and text for each emotion
    """
    print("Creating synthetic multimodal dataset...")
    
    # Emotion-specific text templates
    emotion_texts = {
        0: [  # angry
            "I am so furious about this situation",
            "This makes me incredibly angry and upset",
            "I can't believe how infuriating this is",
            "I'm absolutely livid right now",
            "This is making me so mad and frustrated"
        ],
        1: [  # disgust
            "This is absolutely disgusting and revolting",
            "I feel sick looking at this gross thing",
            "How repugnant and vile can something be",
            "This disgusting sight is making me nauseous",
            "I'm completely repulsed by this awful thing"
        ],
        2: [  # fear
            "I'm terrified and scared of what might happen",
            "This frightening situation fills me with dread",
            "I'm so anxious and worried about this",
            "This scary scenario is making me panic",
            "I feel overwhelming fear and terror"
        ],
        3: [  # happy
            "I am so happy and joyful today",
            "This wonderful news fills me with happiness",
            "I'm feeling absolutely delighted and cheerful",
            "This amazing thing brings me so much joy",
            "I'm in such a great mood and feeling fantastic"
        ],
        4: [  # sad
            "I feel so sad and heartbroken about this",
            "This tragic news is making me deeply sorrowful",
            "I'm feeling incredibly melancholy and blue",
            "This depressing situation brings tears to my eyes",
            "I'm overwhelmed with sadness and grief"
        ],
        5: [  # surprise
            "Wow, I never expected this amazing surprise",
            "I'm absolutely shocked and astonished by this",
            "What an incredible and unexpected revelation",
            "I'm completely stunned by this surprising news",
            "This unexpected turn of events is so surprising"
        ],
        6: [  # neutral
            "This is a normal part of the daily routine",
            "The situation is neither good nor bad",
            "I have mixed feelings about this outcome",
            "This is just a regular occurrence nothing special",
            "I feel indifferent about this particular matter"
        ]
    }
    
    image_text_pairs = []
    labels = []
    
    # Generate samples for each emotion
    samples_per_emotion = num_samples // NUM_CLASSES
    
    for emotion_id in range(NUM_CLASSES):
        for i in range(samples_per_emotion):
            # Create synthetic face image with emotion-specific features
            image = create_synthetic_face_image(emotion_id)
            
            # Select random text for this emotion
            text = random.choice(emotion_texts[emotion_id])
            
            # Add some variation to the text
            if random.random() < 0.3:
                text = f"I really think that {text.lower()}"
            elif random.random() < 0.3:
                text = f"{text} It's quite overwhelming."
            
            image_text_pairs.append((image, text))
            labels.append(emotion_id)
    
    print(f"✅ Created {len(image_text_pairs)} multimodal samples")
    print(f"   Samples per emotion: {samples_per_emotion}")
    
    return image_text_pairs, labels

def create_synthetic_face_image(emotion_id, size=(48, 48)):
    """
    Create a synthetic face-like image with emotion-specific features
    """
    np.random.seed(emotion_id * 100 + random.randint(0, 99))
    
    # Create base face
    image = np.random.randint(80, 180, size, dtype=np.uint8)
    
    # Add face features
    center_x, center_y = size[1] // 2, size[0] // 2
    
    # Eyes
    eye1_x, eye1_y = center_x - 8, center_y - 6
    eye2_x, eye2_y = center_x + 8, center_y - 6
    cv2.circle(image, (eye1_x, eye1_y), 3, 50, -1)
    cv2.circle(image, (eye2_x, eye2_y), 3, 50, -1)
    
    # Nose
    cv2.circle(image, (center_x, center_y + 2), 2, 100, -1)
    
    # Emotion-specific mouth
    mouth_y = center_y + 10
    
    if emotion_id == 0:  # angry - frown
        cv2.ellipse(image, (center_x, mouth_y + 3), (8, 4), 0, 180, 360, 60, -1)
    elif emotion_id == 1:  # disgust - wavy mouth
        cv2.line(image, (center_x - 6, mouth_y), (center_x + 6, mouth_y), 70, 2)
    elif emotion_id == 2:  # fear - open mouth
        cv2.ellipse(image, (center_x, mouth_y), (4, 6), 0, 0, 360, 40, -1)
    elif emotion_id == 3:  # happy - smile
        cv2.ellipse(image, (center_x, mouth_y - 2), (8, 4), 0, 0, 180, 70, -1)
    elif emotion_id == 4:  # sad - frown
        cv2.ellipse(image, (center_x, mouth_y + 4), (8, 4), 0, 180, 360, 60, -1)
    elif emotion_id == 5:  # surprise - open mouth
        cv2.ellipse(image, (center_x, mouth_y), (5, 8), 0, 0, 360, 30, -1)
    else:  # neutral - straight line
        cv2.line(image, (center_x - 6, mouth_y), (center_x + 6, mouth_y), 80, 2)
    
    return image

# Create the synthetic dataset
dataset_pairs, dataset_labels = create_synthetic_multimodal_dataset(num_samples=2000)

# Create train/validation split
train_size = int(0.8 * len(dataset_pairs))
val_size = len(dataset_pairs) - train_size

train_pairs = dataset_pairs[:train_size]
train_labels = dataset_labels[:train_size]
val_pairs = dataset_pairs[train_size:]
val_labels = dataset_labels[train_size:]

print(f"\n📊 Dataset splits:")
print(f"   Training: {len(train_pairs)} samples")
print(f"   Validation: {len(val_pairs)} samples")

# Create dataset instances
train_dataset = MultiModalEmotionDataset(
    image_texts_pairs=train_pairs,
    labels=train_labels,
    tokenizer=ter_tokenizer,
    image_transform=fer_transform,
    max_length=ter_config.get('max_length', 128)
)

val_dataset = MultiModalEmotionDataset(
    image_texts_pairs=val_pairs,
    labels=val_labels,
    tokenizer=ter_tokenizer,
    image_transform=fer_transform,
    max_length=ter_config.get('max_length', 128)
)

print(f"✅ Multimodal datasets created successfully")

# Test dataset
print("\n🧪 Testing dataset...")
try:
    sample = train_dataset[0]
    print(f"✅ Dataset test successful!")
    print(f"   Image shape: {sample['image'].shape}")
    print(f"   Input IDs shape: {sample['input_ids'].shape}")
    print(f"   Attention mask shape: {sample['attention_mask'].shape}")
    print(f"   Label: {sample['label']} ({EMOTION_NAMES[sample['label']]})")
    print(f"   Text: '{sample['text']}'")
    
except Exception as e:
    print(f"❌ Dataset test failed: {e}")

## 7. Data Loading and Training Setup

Now we'll create data loaders and set up the training configuration for our fusion model.

In [None]:
# Data loading configuration
BATCH_SIZE = 16  # Adjust based on GPU memory
NUM_WORKERS = 2 if device.type == 'cuda' else 0

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

print(f"📊 Data loaders created:")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")
print(f"   Batch size: {BATCH_SIZE}")

# Training configuration
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
FUSION_STRATEGY = 'mlp'  # Options: 'weighted', 'mlp', 'attention', 'bilinear', 'simple'

print(f"\n⚙️ Training configuration:")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Fusion strategy: {FUSION_STRATEGY}")

# Initialize fusion model
fusion_model = LateFusionModel(
    fer_output_size=7,
    ter_output_size=7,
    num_classes=NUM_CLASSES,
    fusion_strategy=FUSION_STRATEGY,
    device=device
).to(device)

print(f"✅ Fusion model initialized with {FUSION_STRATEGY} strategy")

# Initialize multimodal recognizer
multimodal_recognizer = MultiModalEmotionRecognizer(
    fer_model=fer_model,
    ter_model=ter_model,
    fusion_model=fusion_model,
    tokenizer=ter_tokenizer,
    label_encoder=ter_label_encoder,
    device=device
)

print(f"✅ Multimodal recognizer initialized")

# Optimizer and loss function
optimizer = torch.optim.Adam(fusion_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

print(f"✅ Optimizer, loss function, and scheduler initialized")

# Test data loader
print("\n🧪 Testing data loader...")
try:
    sample_batch = next(iter(train_loader))
    print(f"✅ Data loader test successful!")
    print(f"   Batch image shape: {sample_batch['image'].shape}")
    print(f"   Batch input_ids shape: {sample_batch['input_ids'].shape}")
    print(f"   Batch attention_mask shape: {sample_batch['attention_mask'].shape}")
    print(f"   Batch labels shape: {sample_batch['label'].shape}")
    print(f"   Sample emotions: {[EMOTION_NAMES[label.item()] for label in sample_batch['label'][:5]]}")
    
except Exception as e:
    print(f"❌ Data loader test failed: {e}")

## 8. Training Loop

Let's train our fusion model to learn optimal weights for combining FER and TER predictions.

In [None]:
# Data loading configuration
BATCH_SIZE = 16  # Adjust based on GPU memory
NUM_WORKERS = 2 if device.type == 'cuda' else 0

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False
)

print(f"📊 Data loaders created:")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")
print(f"   Batch size: {BATCH_SIZE}")

# Training configuration
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
FUSION_STRATEGY = 'mlp'  # Options: 'weighted', 'mlp', 'attention', 'bilinear', 'simple'

print(f"\n⚙️ Training configuration:")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Fusion strategy: {FUSION_STRATEGY}")

# Initialize multimodal recognizer (includes fusion model internally)
multimodal_recognizer = MultiModalEmotionRecognizer(
    fer_model=fer_model,
    ter_model=ter_model,
    ter_tokenizer=ter_tokenizer,
    fusion_type=FUSION_STRATEGY
).to(device)

print(f"✅ Multimodal recognizer initialized with {FUSION_STRATEGY} fusion strategy")

# Optimizer and loss function (train only the fusion model parameters)
optimizer = torch.optim.Adam(multimodal_recognizer.fusion_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

print(f"✅ Optimizer, loss function, and scheduler initialized")

# Test data loader
print("\n🧪 Testing data loader...")
try:
    sample_batch = next(iter(train_loader))
    print(f"✅ Data loader test successful!")
    print(f"   Batch image shape: {sample_batch['image'].shape}")
    print(f"   Batch input_ids shape: {sample_batch['input_ids'].shape}")
    print(f"   Batch attention_mask shape: {sample_batch['attention_mask'].shape}")
    print(f"   Batch labels shape: {sample_batch['label'].shape}")
    print(f"   Sample emotions: {[EMOTION_NAMES[label.item()] for label in sample_batch['label'][:5]]}")
    
except Exception as e:
    print(f"❌ Data loader test failed: {e}")

## 9. Visualization and Evaluation

Let's visualize the training progress and evaluate our fusion model's performance.

In [None]:
# Plot training history
def plot_training_history(history):
    """
    Plot training and validation loss and accuracy
    """
    fig, ((ax1, ax2)) = plt.subplots(1, 2, figsize=(15, 5))
    
    epochs = range(1, len(history['train_losses']) + 1)
    
    # Plot loss
    ax1.plot(epochs, history['train_losses'], 'bo-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, history['val_losses'], 'ro-', label='Validation Loss', linewidth=2)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracy
    ax2.plot(epochs, history['train_accuracies'], 'bo-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, history['val_accuracies'], 'ro-', label='Validation Accuracy', linewidth=2)
    ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    print(f"📈 Final Training Metrics:")
    print(f"   Training Loss: {history['train_losses'][-1]:.4f}")
    print(f"   Training Accuracy: {history['train_accuracies'][-1]:.2f}%")
    print(f"   Validation Loss: {history['val_losses'][-1]:.4f}")
    print(f"   Validation Accuracy: {history['val_accuracies'][-1]:.2f}%")

plot_training_history(training_history)

# Comprehensive evaluation function
def evaluate_fusion_model(multimodal_recognizer, val_loader, device):
    """
    Comprehensive evaluation of the fusion model
    """
    multimodal_recognizer.fusion_model.eval()
    
    all_fusion_preds = []
    all_fer_preds = []
    all_ter_preds = []
    all_labels = []
    all_fusion_probs = []
    all_fer_probs = []
    all_ter_probs = []
    
    print("🔍 Evaluating fusion model...")
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            # Move batch to device
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Get individual model predictions
            fer_outputs = multimodal_recognizer.fer_model(images)
            fer_probs = F.softmax(fer_outputs, dim=1)
            
            ter_outputs = multimodal_recognizer.ter_model(
                input_ids=input_ids,
                attention_mask=attention_mask
            ).logits
            ter_probs = F.softmax(ter_outputs, dim=1)
            
            # Get fusion predictions
            fusion_outputs = multimodal_recognizer.fusion_model(fer_probs, ter_probs)
            fusion_probs = F.softmax(fusion_outputs, dim=1)
            
            # Store predictions
            _, fer_pred = torch.max(fer_probs, 1)
            _, ter_pred = torch.max(ter_probs, 1)
            _, fusion_pred = torch.max(fusion_probs, 1)
            
            all_fer_preds.extend(fer_pred.cpu().numpy())
            all_ter_preds.extend(ter_pred.cpu().numpy())
            all_fusion_preds.extend(fusion_pred.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            all_fer_probs.extend(fer_probs.cpu().numpy())
            all_ter_probs.extend(ter_probs.cpu().numpy())
            all_fusion_probs.extend(fusion_probs.cpu().numpy())
    
    # Calculate accuracies
    fer_accuracy = accuracy_score(all_labels, all_fer_preds) * 100
    ter_accuracy = accuracy_score(all_labels, all_ter_preds) * 100
    fusion_accuracy = accuracy_score(all_labels, all_fusion_preds) * 100
    
    print(f"\n📊 Model Comparison:")
    print(f"   FER Only Accuracy: {fer_accuracy:.2f}%")
    print(f"   TER Only Accuracy: {ter_accuracy:.2f}%")
    print(f"   Fusion Accuracy: {fusion_accuracy:.2f}%")
    print(f"   Improvement over FER: {fusion_accuracy - fer_accuracy:+.2f}%")
    print(f"   Improvement over TER: {fusion_accuracy - ter_accuracy:+.2f}%")
    
    return {
        'fer_preds': all_fer_preds,
        'ter_preds': all_ter_preds,
        'fusion_preds': all_fusion_preds,
        'labels': all_labels,
        'fer_probs': all_fer_probs,
        'ter_probs': all_ter_probs,
        'fusion_probs': all_fusion_probs,
        'fer_accuracy': fer_accuracy,
        'ter_accuracy': ter_accuracy,
        'fusion_accuracy': fusion_accuracy
    }

# Run comprehensive evaluation
eval_results = evaluate_fusion_model(multimodal_recognizer, val_loader, device)

# Plot confusion matrices
def plot_confusion_matrices(eval_results):
    """
    Plot confusion matrices for all three models
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    models = [
        ('FER Only', eval_results['fer_preds']),
        ('TER Only', eval_results['ter_preds']),
        ('Fusion', eval_results['fusion_preds'])
    ]
    
    for idx, (model_name, preds) in enumerate(models):
        cm = confusion_matrix(eval_results['labels'], preds)
        
        # Normalize confusion matrix
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        # Plot
        sns.heatmap(
            cm_normalized,
            annot=True,
            fmt='.2f',
            cmap='Blues',
            xticklabels=EMOTION_NAMES,
            yticklabels=EMOTION_NAMES,
            ax=axes[idx]
        )
        
        axes[idx].set_title(f'{model_name} Confusion Matrix', fontweight='bold')
        axes[idx].set_xlabel('Predicted')
        axes[idx].set_ylabel('Actual')
        
        # Rotate labels for better readability
        axes[idx].tick_params(axis='x', rotation=45)
        axes[idx].tick_params(axis='y', rotation=0)
    
    plt.tight_layout()
    plt.show()

plot_confusion_matrices(eval_results)

# Detailed classification reports
from sklearn.metrics import classification_report

print(f"\n📋 Detailed Classification Reports:")
print(f"\n{'='*20} FER Only {'='*20}")
print(classification_report(
    eval_results['labels'], 
    eval_results['fer_preds'], 
    target_names=EMOTION_NAMES,
    zero_division=0
))

print(f"\n{'='*20} TER Only {'='*20}")
print(classification_report(
    eval_results['labels'], 
    eval_results['ter_preds'], 
    target_names=EMOTION_NAMES,
    zero_division=0
))

print(f"\n{'='*20} Fusion Model {'='*20}")
print(classification_report(
    eval_results['labels'], 
    eval_results['fusion_preds'], 
    target_names=EMOTION_NAMES,
    zero_division=0
))

## 10. Sample Predictions

Let's see how our fusion model performs on some sample inputs and compare with individual models.

In [None]:
# Function to make predictions on sample data
def predict_sample_multimodal(
    multimodal_recognizer, 
    image, 
    text, 
    device,
    show_details=True
):
    """
    Make prediction on a single image-text pair and show detailed results
    """
    multimodal_recognizer.fusion_model.eval()
    
    # Preprocess image
    if isinstance(image, np.ndarray):
        if image.shape != (48, 48):
            image = cv2.resize(image, (48, 48))
        pil_image = Image.fromarray(image)
    else:
        pil_image = image
    
    # Apply image transforms
    image_tensor = fer_transform(pil_image).unsqueeze(0).to(device)
    
    # Preprocess text
    cleaned_text = clean_text(text)
    encoding = ter_tokenizer(
        cleaned_text,
        truncation=True,
        padding='max_length',
        max_length=ter_config.get('max_length', 128),
        return_tensors='pt'
    )
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        # Individual model predictions
        fer_outputs = multimodal_recognizer.fer_model(image_tensor)
        fer_probs = F.softmax(fer_outputs, dim=1).cpu().numpy()[0]
        
        ter_outputs = multimodal_recognizer.ter_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).logits
        ter_probs = F.softmax(ter_outputs, dim=1).cpu().numpy()[0]
        
        # Fusion prediction
        fer_probs_tensor = torch.tensor(fer_probs).unsqueeze(0).to(device)
        ter_probs_tensor = torch.tensor(ter_probs).unsqueeze(0).to(device)
        fusion_outputs = multimodal_recognizer.fusion_model(fer_probs_tensor, ter_probs_tensor)
        fusion_probs = F.softmax(fusion_outputs, dim=1).cpu().numpy()[0]
    
    # Get predictions
    fer_pred = np.argmax(fer_probs)
    ter_pred = np.argmax(ter_probs)
    fusion_pred = np.argmax(fusion_probs)
    
    results = {
        'fer_prediction': fer_pred,
        'ter_prediction': ter_pred,
        'fusion_prediction': fusion_pred,
        'fer_probs': fer_probs,
        'ter_probs': ter_probs,
        'fusion_probs': fusion_probs,
        'fer_emotion': EMOTION_NAMES[fer_pred],
        'ter_emotion': EMOTION_NAMES[ter_pred],
        'fusion_emotion': EMOTION_NAMES[fusion_pred],
        'text': text,
        'cleaned_text': cleaned_text
    }
    
    if show_details:
        print(f"🔍 Multimodal Prediction Results:")
        print(f"   Text: '{text}'")
        print(f"   FER Prediction: {results['fer_emotion']} (confidence: {fer_probs[fer_pred]:.3f})")
        print(f"   TER Prediction: {results['ter_emotion']} (confidence: {ter_probs[ter_pred]:.3f})")
        print(f"   Fusion Prediction: {results['fusion_emotion']} (confidence: {fusion_probs[fusion_pred]:.3f})")
        
        # Show if fusion differs from individual models
        if fusion_pred != fer_pred or fusion_pred != ter_pred:
            print(f"   ⚠️  Fusion disagrees with individual models!")
    
    return results

# Test with some validation samples
print("🧪 Testing with validation samples:")
print("="*50)

# Get a few samples from validation set
sample_indices = [0, 10, 50, 100, 200]

for i, idx in enumerate(sample_indices):
    if idx < len(val_dataset):
        sample = val_dataset[idx]
        
        # Convert image tensor back to PIL for display
        image_np = sample['image'].cpu().numpy()
        if len(image_np.shape) == 3:
            image_np = image_np[0]  # Remove channel dimension if present
        image_np = ((image_np + 1) * 127.5).astype(np.uint8)  # Denormalize
        
        print(f"\n📷 Sample {i+1}:")
        print(f"   True Label: {EMOTION_NAMES[sample['label']]}")
        
        results = predict_sample_multimodal(
            multimodal_recognizer,
            image_np,
            sample['text'],
            device,
            show_details=True
        )
        
        # Check if prediction is correct
        correct_fer = results['fer_prediction'] == sample['label']
        correct_ter = results['ter_prediction'] == sample['label']
        correct_fusion = results['fusion_prediction'] == sample['label']
        
        print(f"   ✅ Correct predictions: FER={correct_fer}, TER={correct_ter}, Fusion={correct_fusion}")

# Create custom test examples
print(f"\n\n🎭 Testing with custom examples:")
print("="*50)

custom_examples = [
    {
        'emotion_id': 3,  # happy
        'text': "I'm so excited about this wonderful news! This makes me incredibly happy and joyful!",
        'description': "Happy emotion - positive text"
    },
    {
        'emotion_id': 0,  # angry
        'text': "This is absolutely infuriating and makes me so angry! I can't stand this situation!",
        'description': "Angry emotion - negative text"
    },
    {
        'emotion_id': 4,  # sad
        'text': "I feel so heartbroken and sad about what happened. This is truly devastating.",
        'description': "Sad emotion - melancholic text"
    },
    {
        'emotion_id': 5,  # surprise
        'text': "Wow! I never expected this amazing surprise! I'm completely shocked and astonished!",
        'description': "Surprise emotion - unexpected text"
    }
]

for i, example in enumerate(custom_examples):
    print(f"\n🎪 Custom Example {i+1}: {example['description']}")
    
    # Create synthetic image for this emotion
    synthetic_image = create_synthetic_face_image(example['emotion_id'])
    
    results = predict_sample_multimodal(
        multimodal_recognizer,
        synthetic_image,
        example['text'],
        device,
        show_details=True
    )
    
    # Check consistency
    expected_emotion = EMOTION_NAMES[example['emotion_id']]
    print(f"   Expected: {expected_emotion}")
    
    if results['fusion_emotion'] == expected_emotion:
        print(f"   ✅ Fusion prediction matches expected emotion!")
    else:
        print(f"   ❌ Fusion prediction differs from expected emotion")

# Visualize probability distributions for a sample
def plot_prediction_comparison(results):
    """
    Plot probability distributions for FER, TER, and Fusion models
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    x = np.arange(len(EMOTION_NAMES))
    width = 0.25
    
    bars1 = ax.bar(x - width, results['fer_probs'], width, label='FER Only', alpha=0.8)
    bars2 = ax.bar(x, results['ter_probs'], width, label='TER Only', alpha=0.8)
    bars3 = ax.bar(x + width, results['fusion_probs'], width, label='Fusion', alpha=0.8)
    
    ax.set_xlabel('Emotions')
    ax.set_ylabel('Probability')
    ax.set_title(f'Prediction Comparison\\nText: "{results["text"][:50]}..."')
    ax.set_xticks(x)
    ax.set_xticklabels(EMOTION_NAMES, rotation=45)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Highlight the predictions
    fer_pred_idx = np.argmax(results['fer_probs'])
    ter_pred_idx = np.argmax(results['ter_probs'])
    fusion_pred_idx = np.argmax(results['fusion_probs'])
    
    bars1[fer_pred_idx].set_edgecolor('red')
    bars1[fer_pred_idx].set_linewidth(3)
    bars2[ter_pred_idx].set_edgecolor('red')
    bars2[ter_pred_idx].set_linewidth(3)
    bars3[fusion_pred_idx].set_edgecolor('red')
    bars3[fusion_pred_idx].set_linewidth(3)
    
    plt.tight_layout()
    plt.show()

# Plot for one of the custom examples
if len(custom_examples) > 0:
    print(f"\\n📊 Probability distribution visualization:")
    example = custom_examples[0]
    synthetic_image = create_synthetic_face_image(example['emotion_id'])
    results = predict_sample_multimodal(
        multimodal_recognizer,
        synthetic_image,
        example['text'],
        device,
        show_details=False
    )
    plot_prediction_comparison(results)

## 11. Save Trained Fusion Model

Let's save our trained fusion model to Google Drive for future use.

In [None]:
# Save the trained fusion model and related components
def save_fusion_model(
    fusion_model, 
    multimodal_recognizer, 
    training_history, 
    eval_results,
    fusion_strategy,
    save_dir='/content/drive/MyDrive/emotion_models/'
):
    """
    Save the trained fusion model and all related components
    """
    import os
    import json
    import pickle
    from datetime import datetime
    
    # Create timestamp for unique naming
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name = f"multimodal_fusion_{fusion_strategy}_{timestamp}"
    
    # Create save directory
    full_save_dir = os.path.join(save_dir, model_name)
    os.makedirs(full_save_dir, exist_ok=True)
    
    print(f"💾 Saving fusion model to: {full_save_dir}")
    
    # 1. Save fusion model state dict
    fusion_model_path = os.path.join(full_save_dir, 'fusion_model.pth')
    torch.save(fusion_model.state_dict(), fusion_model_path)
    print(f"   ✅ Fusion model saved: fusion_model.pth")
    
    # 2. Save complete multimodal recognizer (for easy loading)
    recognizer_path = os.path.join(full_save_dir, 'multimodal_recognizer.pth')
    torch.save({
        'fusion_model_state_dict': fusion_model.state_dict(),
        'fusion_strategy': fusion_strategy,
        'fer_output_size': 7,
        'ter_output_size': 7,
        'num_classes': NUM_CLASSES,
        'emotion_names': EMOTION_NAMES
    }, recognizer_path)
    print(f"   ✅ Complete recognizer saved: multimodal_recognizer.pth")
    
    # 3. Save training configuration
    config = {
        'fusion_strategy': fusion_strategy,
        'learning_rate': LEARNING_RATE,
        'num_epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'num_classes': NUM_CLASSES,
        'emotion_names': EMOTION_NAMES,
        'fer_output_size': 7,
        'ter_output_size': 7,
        'max_length': ter_config.get('max_length', 128),
        'timestamp': timestamp,
        'final_metrics': {
            'fer_accuracy': eval_results['fer_accuracy'],
            'ter_accuracy': eval_results['ter_accuracy'],
            'fusion_accuracy': eval_results['fusion_accuracy'],
            'improvement_over_fer': eval_results['fusion_accuracy'] - eval_results['fer_accuracy'],
            'improvement_over_ter': eval_results['fusion_accuracy'] - eval_results['ter_accuracy']
        }
    }\n    \n    config_path = os.path.join(full_save_dir, 'config.json')
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    print(f"   ✅ Configuration saved: config.json")
    
    # 4. Save training history
    history_path = os.path.join(full_save_dir, 'training_history.pkl')
    with open(history_path, 'wb') as f:
        pickle.dump(training_history, f)
    print(f"   ✅ Training history saved: training_history.pkl")
    
    # 5. Save evaluation results
    eval_path = os.path.join(full_save_dir, 'evaluation_results.pkl')
    with open(eval_path, 'wb') as f:
        pickle.dump(eval_results, f)
    print(f"   ✅ Evaluation results saved: evaluation_results.pkl")
    
    # 6. Save fusion model architecture details
    if hasattr(fusion_model, 'get_fusion_weights'):
        try:
            weights = fusion_model.get_fusion_weights()
            weights_path = os.path.join(full_save_dir, 'fusion_weights.json')
            with open(weights_path, 'w') as f:
                json.dump(weights, f, indent=2)
            print(f"   ✅ Fusion weights saved: fusion_weights.json")
        except:
            print(f"   ⚠️  Could not save fusion weights")
    
    # 7. Create a README file with model information
    readme_content = f"""# Multimodal Emotion Recognition Fusion Model

## Model Information
- **Model Name**: {model_name}
- **Fusion Strategy**: {fusion_strategy}
- **Training Date**: {timestamp}
- **Number of Classes**: {NUM_CLASSES}
- **Emotion Classes**: {', '.join(EMOTION_NAMES)}

## Performance Metrics
- **FER Only Accuracy**: {eval_results['fer_accuracy']:.2f}%
- **TER Only Accuracy**: {eval_results['ter_accuracy']:.2f}%
- **Fusion Accuracy**: {eval_results['fusion_accuracy']:.2f}%
- **Improvement over FER**: {eval_results['fusion_accuracy'] - eval_results['fer_accuracy']:+.2f}%
- **Improvement over TER**: {eval_results['fusion_accuracy'] - eval_results['ter_accuracy']:+.2f}%

## Training Configuration
- **Learning Rate**: {LEARNING_RATE}
- **Epochs**: {NUM_EPOCHS}
- **Batch Size**: {BATCH_SIZE}
- **Final Training Accuracy**: {training_history['train_accuracies'][-1]:.2f}%
- **Final Validation Accuracy**: {training_history['val_accuracies'][-1]:.2f}%

## Files Included
- `fusion_model.pth`: Fusion model state dict
- `multimodal_recognizer.pth`: Complete model for easy loading
- `config.json`: Training and model configuration
- `training_history.pkl`: Training loss and accuracy history
- `evaluation_results.pkl`: Detailed evaluation results
- `fusion_weights.json`: Learned fusion weights (if available)
- `README.md`: This file

## Usage
To load this model:

```python
# Load the complete recognizer
checkpoint = torch.load('multimodal_recognizer.pth')
fusion_model = LateFusionModel(
    fer_output_size=checkpoint['fer_output_size'],
    ter_output_size=checkpoint['ter_output_size'],
    num_classes=checkpoint['num_classes'],
    fusion_strategy=checkpoint['fusion_strategy'],
    device=device
)
fusion_model.load_state_dict(checkpoint['fusion_model_state_dict'])
```

Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
"""
    
    readme_path = os.path.join(full_save_dir, 'README.md')
    with open(readme_path, 'w') as f:
        f.write(readme_content)
    print(f"   ✅ README created: README.md")
    
    print(f"\\n🎉 Model successfully saved to: {full_save_dir}")
    print(f"   Total files: {len(os.listdir(full_save_dir))}")
    
    return full_save_dir, model_name

# Save the trained model
try:
    save_path, model_name = save_fusion_model(
        fusion_model=fusion_model,
        multimodal_recognizer=multimodal_recognizer,
        training_history=training_history,
        eval_results=eval_results,
        fusion_strategy=FUSION_STRATEGY
    )
    
    print(f"\\n📋 Model Details:")
    print(f"   Name: {model_name}")
    print(f"   Strategy: {FUSION_STRATEGY}")
    print(f"   Final Validation Accuracy: {training_history['val_accuracies'][-1]:.2f}%")
    print(f"   Fusion Improvement: {eval_results['fusion_accuracy'] - max(eval_results['fer_accuracy'], eval_results['ter_accuracy']):+.2f}%")
    
except Exception as e:
    print(f"❌ Error saving model: {e}")
    print("   Attempting to save to local directory...")
    
    # Fallback: save to local directory
    try:
        save_path, model_name = save_fusion_model(
            fusion_model=fusion_model,
            multimodal_recognizer=multimodal_recognizer,
            training_history=training_history,
            eval_results=eval_results,
            fusion_strategy=FUSION_STRATEGY,
            save_dir='./saved_models/'
        )
        print(f"✅ Model saved locally to: {save_path}")
    except Exception as e2:
        print(f"❌ Failed to save model: {e2}")

## 12. Conclusion and Next Steps

### 🎯 Summary

We have successfully created a **multimodal late fusion emotion recognition system** that combines:

1. **FER Model**: CNN-based facial expression recognition trained on FER2013
2. **TER Model**: DistilBERT-based textual emotion recognition  
3. **Fusion Model**: Late fusion with multiple strategies (weighted, MLP, attention, bilinear)

### 📊 Key Achievements

- ✅ **Modular Architecture**: Easy to swap individual models or fusion strategies
- ✅ **Robust Loading**: Handles missing models with fallback options
- ✅ **Multiple Fusion Strategies**: Supports weighted, MLP, attention, and bilinear fusion
- ✅ **Comprehensive Evaluation**: Detailed metrics and visualizations
- ✅ **Google Colab Ready**: Works seamlessly in Colab environment
- ✅ **Reproducible Results**: Proper random seed management

### 🚀 Potential Improvements

1. **Real Data Integration**:
   - Replace synthetic data with real multimodal emotion datasets
   - Use datasets like MELD, IEMOCAP, or CMU-MOSEI

2. **Advanced Fusion Techniques**:
   - Early fusion at feature level
   - Attention-based temporal fusion for video sequences
   - Transformer-based multimodal fusion

3. **Model Enhancements**:
   - Fine-tune pre-trained models end-to-end
   - Add more modalities (audio, physiological signals)
   - Implement ensemble methods

4. **Production Readiness**:
   - Add real-time inference capabilities
   - Optimize for mobile/edge deployment
   - Implement confidence thresholding

### 💡 Usage Tips

- **Experiment with different fusion strategies** using the `FUSION_STRATEGY` parameter
- **Adjust learning rates and epochs** based on your specific dataset
- **Monitor for overfitting** using the validation metrics
- **Use ensemble of different fusion strategies** for better performance

### 🔄 How to Extend

To use this notebook with your own models:

1. Replace the FER model loading section with your trained CNN
2. Replace the TER model loading section with your trained text classifier  
3. Modify `NUM_CLASSES` and `EMOTION_NAMES` as needed
4. Update the dataset class for your specific data format
5. Adjust fusion model architecture if needed

**Happy emotion recognition! 🎭✨**