In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

class EnsembleModel(nn.Module):
    def __init__(self, facial_model, speech_model, num_classes=8, fusion_method='weighted'):
        super(EnsembleModel, self).__init__()
        
        # Load pre-trained models
        self.facial_model = facial_model
        self.speech_model = speech_model
        
        # Freeze the base models (optional)
        for param in self.facial_model.parameters():
            param.requires_grad = False
        for param in self.speech_model.parameters():
            param.requires_grad = False
            
        # Fusion method
        self.fusion_method = fusion_method
        
        if fusion_method == 'concat':
            # Input size will be sum of the output sizes of both models
            # Assuming facial_model outputs 8 classes and speech_model outputs 8 classes
            self.fusion_layer = nn.Linear(num_classes * 2, num_classes)
        elif fusion_method == 'weighted':
            # Learn weights for each modality
            self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
        elif fusion_method == 'attention':
            # Attention-based fusion
            self.attention = nn.Sequential(
                nn.Linear(num_classes * 2, 64),
                nn.ReLU(),
                nn.Linear(64, 2),
                nn.Softmax(dim=1)
            )
            
    def forward(self, facial_input, speech_input):
        # Get predictions from individual models
        facial_output = self.facial_model(facial_input)
        speech_output = self.speech_model(speech_input)
        
        # Apply fusion method
        if self.fusion_method == 'concat':
            # Concatenate outputs and pass through fusion layer
            combined = torch.cat((facial_output, speech_output), dim=1)
            return self.fusion_layer(combined)
        
        elif self.fusion_method == 'weighted':
            # Apply learned weights
            alpha = torch.sigmoid(self.alpha)  # Constrain between 0 and 1
            return alpha * facial_output + (1 - alpha) * speech_output
        
        elif self.fusion_method == 'attention':
            # Concatenate for attention computation
            combined = torch.cat((facial_output, speech_output), dim=1)
            weights = self.attention(combined)
            
            # Apply attention weights
            weighted_facial = weights[:, 0].unsqueeze(1) * facial_output
            weighted_speech = weights[:, 1].unsqueeze(1) * speech_output
            
            return weighted_facial + weighted_speech
        
        elif self.fusion_method == 'max':
            # Take maximum confidence for each class
            return torch.max(facial_output, speech_output)
        
        else:  # Default to average
            return (facial_output + speech_output) / 2

# Function to train the ensemble model
def train_ensemble(ensemble_model, train_loader, val_loader, criterion, optimizer, 
                   device, num_epochs=50, early_stopping_patience=10):
    """
    Train the ensemble model
    
    Parameters:
    -----------
    ensemble_model : EnsembleModel
        The ensemble model to train
    train_loader : DataLoader
        DataLoader for training data
    val_loader : DataLoader
        DataLoader for validation data
    criterion : loss function
    optimizer : optimizer
    device : torch.device
        Device to train on (cuda/cpu)
    num_epochs : int
        Number of epochs to train
    early_stopping_patience : int
        Number of epochs to wait for improvement before stopping
    
    Returns:
    --------
    dict
        Training history
    """
    best_val_loss = float('inf')
    patience_counter = 0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'precision': [], 'recall': [], 'f1': []
    }
    
    for epoch in range(num_epochs):
        # Training phase
        ensemble_model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        for facial_batch, speech_batch, labels in train_loader:
            facial_batch, speech_batch, labels = facial_batch.to(device), speech_batch.to(device), labels.to(device)
            
            # Forward pass
            outputs = ensemble_model(facial_batch, speech_batch)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = train_loss / len(train_loader)
        train_acc = correct / total
        
        # Validation phase
        ensemble_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for facial_batch, speech_batch, labels in val_loader:
                facial_batch, speech_batch, labels = facial_batch.to(device), speech_batch.to(device), labels.to(device)
                
                outputs = ensemble_model(facial_batch, speech_batch)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        
        # Calculate precision, recall, and F1 score
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, all_preds, average='weighted', zero_division=1
        )
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['precision'].append(precision)
        history['recall'].append(recall)
        history['f1'].append(f1)
        
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")
        print("-" * 50)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save(ensemble_model.state_dict(), 'best_ensemble_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    return history

# Example usage
def create_ensemble_and_train(facial_model_path, speech_model_path, train_loader, val_loader):
    # Load pre-trained models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load the facial emotion recognition model
    facial_model = CNNModel(input_size=YOUR_FACIAL_INPUT_SIZE, num_classes=8).to(device)
    facial_model.load_state_dict(torch.load(facial_model_path))
    
    # Load the speech emotion recognition model 
    speech_model = CNNModel(input_size=YOUR_SPEECH_INPUT_SIZE, num_classes=8).to(device)
    speech_model.load_state_dict(torch.load(speech_model_path))
    
    # Create ensemble model
    ensemble_model = EnsembleModel(facial_model, speech_model, num_classes=8, fusion_method='attention').to(device)
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    
    # Only optimize the fusion parameters, not the base models
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, ensemble_model.parameters()),
        lr=0.001, 
        weight_decay=1e-5
    )
    
    # Train the ensemble
    history = train_ensemble(
        ensemble_model, 
        train_loader, 
        val_loader, 
        criterion, 
        optimizer, 
        device,
        num_epochs=100,
        early_stopping_patience=15
    )
    
    # Plot results
    import matplotlib.pyplot as plt
    
    # Plot accuracy
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Plot metrics
    plt.figure(figsize=(10, 5))
    plt.plot(history['precision'], label='Precision')
    plt.plot(history['recall'], label='Recall')
    plt.plot(history['f1'], label='F1-Score')
    plt.title('Model Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    return ensemble_model

# Data preprocessing and model evaluation utilities
class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, facial_features, speech_features, labels):
        """
        Dataset for multimodal input (facial and speech)
        
        Parameters:
        -----------
        facial_features : array-like
            Features from facial expressions
        speech_features : array-like
            Features from speech audio
        labels : array-like
            Emotion labels
        """
        self.facial_features = torch.FloatTensor(facial_features)
        self.speech_features = torch.FloatTensor(speech_features)
        self.labels = torch.LongTensor(labels)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.facial_features[idx], self.speech_features[idx], self.labels[idx]
    
def evaluate_ensemble(model, test_loader, device):
    """Evaluate the ensemble model on test data"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for facial_batch, speech_batch, labels in test_loader:
            facial_batch = facial_batch.to(device)
            speech_batch = speech_batch.to(device)
            
            outputs = model(facial_batch, speech_batch)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted', zero_division=1
    )
    
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    # Create confusion matrix
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='g', cmap='Blues',
                xticklabels=range(8), yticklabels=range(8))
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.title('Confusion Matrix')
    plt.show()
    
    return accuracy, precision, recall, f1