# Model

In [None]:
import torch
import torch.nn as nn
import torchmetrics
import torch.nn.functional as F
import pytorch_lightning as pl
import json
from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy, MulticlassConfusionMatrix
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay
import wandb
from pytorch_lightning.loggers import WandbLogger


In [None]:



device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")




In [None]:
class BimodalEmotionRecognitionModel(pl.LightningModule):
    def __init__(self, input_size_audio, input_size_video, hidden_size=64, num_classes=8, vocab=None, extra_hidden_size=64):
        super(BimodalEmotionRecognitionModel, self).__init__()
        self.validation_step_outputs = []
        self.test_step_outputs = []
        
        # Defining the layers
        self.audio_fc = nn.Linear(input_size_audio, hidden_size)
        self.audio_hidden = nn.Linear(hidden_size, extra_hidden_size)
        self.audio_hidden2 = nn.Linear(hidden_size, extra_hidden_size)        
        
        self.video_fc = nn.Linear(input_size_video, hidden_size)
        self.video_hidden = nn.Linear(hidden_size, extra_hidden_size)
        self.video_hidden2 = nn.Linear(hidden_size, extra_hidden_size)

        
        self.final_fc = nn.Linear(extra_hidden_size * 2, num_classes)
        
        self.relu = nn.ReLU()

        self.loss = nn.CrossEntropyLoss()
        self.f1_score = MulticlassF1Score(num_classes=num_classes)
        self.accuracy = MulticlassAccuracy(num_classes=num_classes)
        self.conf_matrix = MulticlassConfusionMatrix(num_classes=num_classes)

        # Store the vocabulary
        self.vocab = vocab
        # Calculate class weights to handle imbalance
        class_counts = torch.tensor([96, 192, 192, 192, 192, 192, 192, 192])  # Update with actual class counts
        total_samples = class_counts.sum().float()
        class_weights = total_samples / (num_classes * class_counts)

        # Convert class weights to a tensor and send to the device
        self.class_weights = class_weights.to(device)
        

    def forward(self, audio, video):
        audio = self.relu(self.audio_fc(audio.view(audio.size(0), -1)))
        audio = self.relu(self.audio_hidden(audio))
        audio = self.relu(self.audio_hidden2(audio))

        video = self.relu(self.video_fc(video.view(video.size(0), -1)))
        video = self.relu(self.video_hidden(video))
        video = self.relu(self.video_hidden2(video))

        
        combined = torch.cat([audio, video], dim=1)
        output = self.final_fc(combined)

        return output
    def training_step(self, batch, batch_idx):
        audio, video, target = batch
        output = self(audio, video)
        
        # Weighted loss computation
        loss = F.cross_entropy(output, target, weight=self.class_weights)
        f1 = self.f1_score(output.argmax(dim=1), target)
        acc = self.accuracy(output.argmax(dim=1), target)  
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_f1', f1, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_epoch=True, prog_bar=True)
        #log to wandb
        wandb.log({"train_loss":loss, "train_f1":f1, "train_acc":acc})
        return loss

    def validation_step(self, batch, batch_idx):
        audio, video, target = batch
        output = self(audio, video)
        loss = self.loss(output, target)
        f1 = self.f1_score(output.argmax(dim=1), target)
        acc = self.accuracy(output.argmax(dim=1), target)
        
        # Update confusion matrix
        self.conf_matrix.update(output.argmax(dim=1), target)

        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        
        #log to wandb
        wandb.log({"val_loss":loss, "f1_val":f1, "val_acc":acc})

        return loss

    def on_validation_epoch_end(self):
        # Move the confusion matrix to CPU and then convert to NumPy
        confusion_matrix = self.conf_matrix.compute().cpu().numpy()

        # Read class names from vocab.json
        with open("/mnt/c/users/admin/desktop/github/bimodal_emotion_recognition_with_ravdess_dataset/vocab.json", "r") as f:
            vocab_data = json.load(f)
        

        # Extract class names in order
        class_names = [vocab_data["idx2label"][str(i)] for i in range(len(vocab_data["idx2label"]))]

        # Display the confusion matrix
        plt.figure(figsize=(8, 8))
        sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
        plt.title("Confusion Matrix")
        plt.xlabel("Predicted Label")
        plt.ylabel("True Label")
        plt.show()
        
    def test_step(self, batch, batch_idx):
        audio, video, target = batch
        output = self(audio, video)
        loss = self.loss(output, target)
        f1 = self.f1_score(output.argmax(dim=1), target)
        acc = self.accuracy(output.argmax(dim=1), target)

        self.log('test_loss', loss, on_epoch=True, prog_bar=True)
        self.log('test_f1', f1, on_epoch=True, prog_bar=True)
        self.log('test_acc', acc, on_epoch=True, prog_bar=True)
        #log to wandb
        wandb.log({"test_loss":loss, "test_f1":f1, "test_acc":acc})
        return loss
    
    def configure_optimizers(self):
        criterion = nn.CrossEntropyLoss(weight=self.class_weights)
        return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)


# Datamodule/Dataset Class

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import json

class BimodalEmotionDataset(Dataset):
    def __init__(self, csv_file, vocab_file):
        self.data = pd.read_csv(csv_file)
        with open(vocab_file, 'r') as f:
            self.vocab = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        video_feature = np.load(self.data['video_feature_path'][idx])
        audio_feature = np.load(self.data['audio_feature_path'][idx])
        label = torch.tensor(self.vocab['label2idx'][self.data['label'][idx]])

        return audio_feature, video_feature, label

class BimodalEmotionRecognitionDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_workers=12):
        super(BimodalEmotionRecognitionDataModule, self).__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        dataset = BimodalEmotionDataset("/mnt/c/users/admin/desktop/github/bimodal_emotion_recognition_with_ravdess_dataset/example.csv", "/mnt/c/users/admin/desktop/github/bimodal_emotion_recognition_with_ravdess_dataset/vocab.json")
        train, test = train_test_split(dataset, test_size=0.2, random_state=42)
        train, val = train_test_split(train, test_size=0.1, random_state=42)
        self.dataset = dataset
        self.train_dataset = train
        self.val_dataset = val
        self.test_dataset = test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def whole_dataset_class_distribution(self):
        labels = [item[2].item() for item in self.train_dataset + self.val_dataset + self.test_dataset]
        class_counts = np.bincount(labels)
        return class_counts
    
    def class_distribution(self, dataset):
        labels = [item[2].item() for item in dataset]
        class_counts = np.bincount(labels)
        return class_counts

# Trainer

In [None]:
import pytorch_lightning as pl

if __name__ == "__main__":
    wandb.init(project="Emotion_Reg_Evalu", config={"input_size_audio": 512, "input_size_video": 1568 * 768, "hidden_size": 64, "num_classes": 8, "extra_hidden_size": 64})

    pl.seed_everything(42)

    data_module = BimodalEmotionRecognitionDataModule(batch_size=128, num_workers=12)
    data_module.setup()
    # Get class names from the vocabulary loaded during setup
    class_names = [data_module.dataset.vocab["idx2label"][str(i)] for i in range(len(data_module.dataset.vocab["idx2label"]))]

    # Plot class distribution for the whole dataset
    whole_dataset_distribution = data_module.whole_dataset_class_distribution()
    plt.bar(class_names, whole_dataset_distribution)
    plt.title('Whole Dataset Class Distribution')
    plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability
    plt.show()

    # Plot class distribution for the training dataset
    train_dataset_distribution = data_module.class_distribution(data_module.train_dataset)
    plt.bar(class_names, train_dataset_distribution)
    plt.title('Training Dataset Class Distribution')
    plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability
    plt.show()

    # Plot class distribution for the validation dataset
    val_dataset_distribution = data_module.class_distribution(data_module.val_dataset)
    plt.bar(class_names, val_dataset_distribution)
    plt.title('Validation Dataset Class Distribution')
    plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability
    plt.show()

    # Plot class distribution for the test dataset
    test_dataset_distribution = data_module.class_distribution(data_module.test_dataset)
    plt.bar(class_names, test_dataset_distribution)
    plt.title('Test Dataset Class Distribution')
    plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability
    plt.show()
    
    model = BimodalEmotionRecognitionModel(
        input_size_audio=512,
        input_size_video=1568 * 768,
        hidden_size=64,
        num_classes=8,
        extra_hidden_size=wandb.config.extra_hidden_size  
    ).to(device)

    
    wandb_logger = pl.loggers.WandbLogger()

    trainer = pl.Trainer(
        max_epochs=50,
        num_sanity_val_steps=0,
        logger=wandb_logger,
        log_every_n_steps=50

    )

    trainer.fit(model, data_module)
    trainer.test(model, datamodule=data_module)
