# MUSE

## Imports

In [1]:

import tarfile
import os
import os
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
import numpy as np
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report, accuracy_score, f1_score
import joblib
import nltk
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModel, AutoProcessor, DistilBertTokenizer

# Ensure NLTK data is downloaded
nltk.download('punkt')
nltk.download('stopwords')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

  Referenced from: <367D4265-B20F-34BD-94EB-4F3EE47C385B> /opt/anaconda3/envs/audio/lib/python3.12/site-packages/torchvision/image.so
  warn(
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/filiplandin/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/filiplandin/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


## Data

### Utils

In [2]:
def meld_collate_fn(batch):
    # batch is a list of conversation dicts (one per item in dataset)
    # We can combine them into a single batch,
    # but each conversation may have different # of utterances.

    dialog_ids = []
    fbank_lists = []
    text_lists = []
    emotion_lists = []
    sentiment_lists = []

    for conv in batch:
        dialog_ids.append(conv["dialog_id"])

        # Convert fbank_list (list of numpy arrays) to tensors and pad
        fbank_tensors = [torch.tensor(fbank) for fbank in conv["fbank_list"]]
        # Pad along the time dim (T)
        fbank_padded = pad_sequence(fbank_tensors, batch_first=True)
        fbank_lists.append(fbank_padded)

        text_lists.append(conv["text_list"])
        emotion_lists.append(torch.tensor(conv["emotion_list"]))
        sentiment_lists.append(torch.tensor(conv["sentiment_list"]))

    # Return them "as is", or do further padding if needed.
    return {
        "dialog_ids": dialog_ids,
        "fbank_lists": fbank_lists,
        "text_lists": text_lists,
        "emotion_lists": emotion_lists,
        "sentiment_lists": sentiment_lists
    }


def custom_text_preprocessor(text):
    return text

### fbanks extractor

In [3]:
def wav2fbank(filename):
    waveform, sr = torchaudio.load(filename)
    waveform = waveform - waveform.mean()

    try:
        fbank = torchaudio.compliance.kaldi.fbank(
            waveform,
            htk_compat=True,
            sample_frequency=sr,
            use_energy=False,
            window_type='hanning',
            num_mel_bins=128,
            dither=0.0,
            frame_shift=10
        )
    except:
        fbank = torch.zeros([512, 128]) + 0.01
        print('there is a loading error')

    target_length = 1024
    n_frames = fbank.shape[0]

    p = target_length - n_frames

    # cut and pad
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]

    return fbank

In [4]:
# convert audio to fbank from train, dev and test
for mode in ['train', 'dev', 'test']:
    df = pd.read_csv(f'data/{mode}/{mode}_sent_emo.csv')
    for i, row in df.iterrows():
        dia_id = row['Dialogue_ID']
        utt_id = row['Utterance_ID']
        audio_file = f'data/{mode}/audio/dia{dia_id}_utt{utt_id}.wav'
        fbank_file = f'data/{mode}/fbank/dia{dia_id}_utt{utt_id}.npy'
        
        # create the fbank directory if it does not exist
        if not os.path.exists(f'data/{mode}/fbank/'):
            os.makedirs(f'data/{mode}/fbank/')
        
        # check if the fbank file already exists
        if not os.path.exists(f'data/{mode}/fbank/dia{dia_id}_utt{utt_id}.npy'):
            fbank = wav2fbank(audio_file)
            fbank = fbank.numpy()
            np.save(f'data/{mode}/fbank/dia{dia_id}_utt{utt_id}.npy', fbank)
        else:
            print(f'{fbank_file} already exists')

FileNotFoundError: [Errno 2] No such file or directory: 'data/test/test_sent_emo.csv'

### Dataset and dataloader class

In [24]:
class MELDConversationDataset(Dataset):
    def __init__(self, csv_file, root_dir='./data', mode="train", text_processor=None):
        """
        We'll store a list of (dialog_id, [list_of_utterance_dicts]).
        Each utterance_dict might contain:
          {
            "fbank_path": str,
            "transcript": str,
            "emotion": int,
            "sentiment": int
          }
        """
        df = pd.read_csv(f'{root_dir}/{csv_file}')

        # order the df rows according to dialogueID and each dialogue according to utteranceID
        df = df.sort_values(by=['Dialogue_ID', 'Utterance_ID'])

        self.emotion_class_counts = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
        self.sentiment_class_counts = {0: 0, 1: 0, 2: 0}
        self.max_dialogue_size = 0

        self.dialogues = {}  # key: dialogue_id, value: list of utterance dicts
        prev_dia_id = None
        utt_count = 0

        for _, row in df.iterrows():

            dia_id = row["Dialogue_ID"]
            utt_id = row["Utterance_ID"]

            if prev_dia_id == dia_id:
                utt_count += 1
            else:
                if utt_count > self.max_dialogue_size:
                    self.max_dialogue_size = utt_count
                utt_count = 1

            fbank_path = f'data/{mode}/fbank/dia{dia_id}_utt{utt_id}.npy'

            emotion = row["Emotion"]
            emotion = emotion.lower()
            emotion_int = self.emotion_to_int(emotion)
            sentiment = row["Sentiment"]
            sentiment = sentiment.lower()
            sentiment_int = self.sentiment_to_int(sentiment)

            self.emotion_class_counts[emotion_int] += 1
            self.sentiment_class_counts[sentiment_int] += 1

            utter_dict = {
                "fbank_path": fbank_path,
                "transcript": row["Utterance"],
                "emotion": emotion_int,
                "sentiment": sentiment_int
            }

            if dia_id not in self.dialogues:
                self.dialogues[dia_id] = []
            self.dialogues[dia_id].append(utter_dict)

            prev_dia_id = dia_id

        # Convert to list of (dialog_id, list_of_utterances)
        self.dialogues = [(k, sorted(v, key=lambda x: x["fbank_path"]))
                          for k, v in self.dialogues.items()]
        # The sorting step ensures the utterances are in ascending order of utt_id if needed.

    def emotion_to_int(self, str):
        str_to_int = {"neutral": 0, "joy": 1, "surprise": 2,
                      "anger": 3, "sadness": 4, "fear": 5, "disgust": 6}
        return str_to_int[str]

    def emotion_to_str(self, int):
        int_to_str = {0: "neutral", 1: "joy", 2: "surprise",
                      3: "anger", 4: "sadness", 5: "fear", 6: "disgust"}
        return int_to_str[int]

    def sentiment_to_int(self, str):
        str_to_int = {"neutral": 0, "positive": 1, "negative": 2}
        return str_to_int[str]

    def sentiment_to_str(self, int):
        int_to_str = {0: "neutral", 1: "positive", 2: "negative"}
        return int_to_str[int]

    def __len__(self):
        return len(self.dialogues)

    def __getitem__(self, idx):
        dialog_id, utterances = self.dialogues[idx]

        # For each utterance, load filterbanks, transcript, emotion, sentiment
        fbank_list = []
        text_list = []
        emotion_list = []
        sentiment_list = []

        for utt in utterances:
            # shape e.g. (T, fbank_dim)
            fbank = np.load(utt["fbank_path"])
            fbank_list.append(fbank)

            text_list.append(utt["transcript"])
            emotion_list.append(utt["emotion"])         # or mapped to int
            sentiment_list.append(utt["sentiment"])     # or mapped to int

        return {
            "dialog_id": dialog_id,
            "fbank_list": fbank_list,
            "text_list": text_list,
            "emotion_list": emotion_list,
            "sentiment_list": sentiment_list
        }

In [25]:
train_set = MELDConversationDataset(csv_file="train_sent_emo.csv", root_dir="data/train", mode="train")
dev_set = MELDConversationDataset(csv_file="dev_sent_emo.csv", root_dir="data/dev", mode="dev")
test_set = MELDConversationDataset(csv_file='test_sent_emo.csv', root_dir='data/test', mode='test')

print("Data loaded.")

train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, collate_fn=meld_collate_fn)
dev_loader = DataLoader(dev_set, batch_size=4, shuffle=True, num_workers=4, collate_fn=meld_collate_fn)
test_loader = DataLoader(test_set, batch_size=4, shuffle=True, num_workers=4, collate_fn=meld_collate_fn)
print("Data loaders created.")

Data loaded.
Data loaders created.


### Miscs

In [26]:
print(len(train_set), len(dev_set), len(test_set))

1038 114 169


In [27]:
num_emotions = len(train_set.emotion_class_counts)
num_sentiments = len(train_set.sentiment_class_counts)

print("Number of emotions:", num_emotions)
print("Number of sentiments:", num_sentiments)

for i in range(7):
    print(f"Emotion {train_set.emotion_to_str(i)} count:", train_set.emotion_class_counts[i])
for i in range(3):
    print(f"Sentiment {train_set.sentiment_to_str(i)} count:", train_set.sentiment_class_counts[i])
    
print("\n")

Number of emotions: 7
Number of sentiments: 3
Emotion neutral count: 4709
Emotion joy count: 1743
Emotion surprise count: 1205
Emotion anger count: 1109
Emotion sadness count: 683
Emotion fear count: 268
Emotion disgust count: 271
Sentiment neutral count: 4709
Sentiment positive count: 2334
Sentiment negative count: 2945




In [28]:
#get weights for balancing classes
class_counts = train_set.emotion_class_counts
total_samples = 0

for key in class_counts:
    total_samples += class_counts[key]
    
print("Total samples:", total_samples)
class_weights = torch.zeros(len(class_counts))

for i in range(len(class_counts)):
    class_weights[i] = class_counts[i] / total_samples
    
class_weights = 1 / class_weights  # invert the weights
class_weights = class_weights / class_weights.sum()  # normalize the weights
class_weights = class_weights.to(device)
print("Class weights:", class_weights)

Total samples: 9988
Class weights: tensor([0.0186, 0.0503, 0.0728, 0.0791, 0.1284, 0.3272, 0.3236])


## Embeddings

## Model

In [38]:
# PARAMETERS
lr = 1e-4
criterions = {
    'emotion': nn.CrossEntropyLoss(weight=class_weights),
    'sentiment': nn.CrossEntropyLoss()
}

num_epochs = 50

#set max utterance size to be the max of all train dev and test sets
max_utt = max(train_set.max_dialogue_size, dev_set.max_dialogue_size, test_set.max_dialogue_size)
print("Max dialogue size:", max_utt)

Max dialogue size: 24


In [39]:
import torch.nn as nn


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, dropout_p=0.5):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")

        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout_p)  # Dropout layer added
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)  # apply dropout after first ReLU

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        out = self.dropout(out)  # apply dropout after second ReLU

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, modality, num_classes=1000, pool='avgpool', zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        self.modality = modality
        self.pool = pool
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        if modality == 'audio':
            self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
                                   bias=False)
        elif modality == 'visual':
            self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                                   bias=False)
        else:
            raise NotImplementedError(
                'Incorrect modality, should be audio or visual but got {}'.format(modality))

        self.out_conv = nn.Conv2d(
            512 * block.expansion, 768, kernel_size=1, stride=1, bias=False)

        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.normal_(m.weight, mean=1, std=0.02)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):

        if self.modality == 'visual':
            (B, C, T, H, W) = x.size()
            x = x.permute(0, 2, 1, 3, 4).contiguous()
            x = x.view(B * T, C, H, W)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.out_conv(x)  # Now x has shape [B, 768, H_out, W_out]

        out = x

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


def _resnet(arch, block, layers, modality, progress, **kwargs):
    model = ResNet(block, layers, modality, **kwargs)
    return model


def resnet18(modality, progress=True, **kwargs):
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], modality, progress,
                   **kwargs)

In [42]:
class AudioEncoder(nn.Module):
    def __init__(self, model_name="facebook/wav2vec2-base",
                 target_sr=16000,
                 fine_tune=False,
                 unfreeze_last_n=2,  # Number of last layers to unfreeze
                 device=device):
        super(AudioEncoder, self).__init__()
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.target_sr = target_sr

        # Freeze entire model
        for param in self.model.parameters():
            param.requires_grad = False

        if fine_tune:
            # Unfreeze only the last `unfreeze_last_n` encoder layers
            total_layers = len(self.model.encoder.layers)
            for layer_idx in range(total_layers - unfreeze_last_n, total_layers):
                for param in self.model.encoder.layers[layer_idx].parameters():
                    param.requires_grad = True

    def forward(self, waveforms):
        """
        :param waveforms: Tensor of shape [B, T] (already at self.target_sr)
        :return: A tensor of shape [B, hidden_dim] (audio embeddings for the batch)
        """
        # Ensure the waveforms are on the correct device
        waveforms = waveforms.to(self.model.device)

        # Prepare inputs for Wav2Vec2Processor
        inputs = self.processor(
            waveforms.cpu(),
            sampling_rate=self.target_sr,
            return_tensors="pt",
            padding=True
        ).input_values.squeeze((0, 1)).to(self.model.device)

        # Forward pass
        with torch.no_grad() if not self.training or not any(
            p.requires_grad for p in self.model.parameters()
        ) else torch.enable_grad():
            outputs = self.model(inputs)
            hidden_states = outputs.last_hidden_state  # shape [B, T, D]

        # Average pooling
        audio_emb = hidden_states.mean(dim=1)  # shape [B, D]

        return audio_emb


class TextEncoder(nn.Module):
    """
    Encodes text using a pretrained BERT model from Hugging Face.
    """

    def __init__(self,
                 model_name="distilbert-base-uncased",
                 fine_tune=False,
                 unfreeze_last_n_layers=2,
                 device=device):
        super(TextEncoder, self).__init__()
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)

        # If we're not fine-tuning at all, freeze everything.
        # Otherwise, freeze everything first, then selectively unfreeze layers.
        for param in self.model.parameters():
            param.requires_grad = False

        if fine_tune and unfreeze_last_n_layers > 0:
            # For BERT-base, there are 12 encoder layers: encoder.layer[0] ... encoder.layer[11].
            # Unfreeze the last N layers:
            total_layers = len(self.model.transformer.layer)
            for layer_idx in range(total_layers - unfreeze_last_n_layers, total_layers):
                for param in self.model.transformer.layer[layer_idx].parameters():
                    param.requires_grad = True

            # Optionally unfreeze the pooler layer if you use it
            # for param in self.model.pooler.parameters():
             #   param.requires_grad = True

    def forward(self, text_list):
        """
        :param text_list: A list of strings (or a single string) to encode.
        :return: A tensor of shape [batch_size, hidden_dim] with text embeddings
        """
        device = self.model.device

        # If a single string is passed, wrap it into a list
        if isinstance(text_list, str):
            text_list = [text_list]

        encodings = self.tokenizer(
            text_list,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(device)

        # If all parameters are frozen (no grad), then no_grad() is fine.
        # But if some layers are unfrozen, we want torch.enable_grad()
        # so that backprop can proceed for those layers.
        use_grad = any(p.requires_grad for p in self.model.parameters())
        with torch.enable_grad() if use_grad else torch.no_grad():
            outputs = self.model(
                input_ids=encodings.input_ids,
                attention_mask=encodings.attention_mask
            )

        # outputs.last_hidden_state -> shape [batch_size, seq_len, hidden_dim]
        # Typically we use the [CLS] token embedding as a single representation
        cls_emb = outputs.last_hidden_state[:, 0, :]  # shape [B, hidden_dim]
        return cls_emb


class FusionLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_layers=1, bidirectional=False, dropout=0.1):
        super(FusionLSTM, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.hidden_dim = hidden_dim * (2 if bidirectional else 1)

    def forward(self, x):
        """
        x: (batch_size, seq_len, input_dim)
        returns: (batch_size, seq_len, hidden_dim)
        """
        output, (hn, cn) = self.lstm(x)
        return output, (hn, cn)  # output is the sequence of hidden states


class MultimodalClassifierWithLSTM(nn.Module):
    def __init__(self, audio_encoder, text_encoder, fusion_lstm, hidden_dim=256, num_emotions=7, num_sentiments=3, max_utt=50):
        super(MultimodalClassifierWithLSTM, self).__init__()
        self.audio_encoder = audio_encoder
        self.text_encoder = text_encoder
        self.fusion_lstm = fusion_lstm  # e.g., an instance of FusionLSTM
        self.emotion_head = nn.Linear(hidden_dim*2, num_emotions)
        self.sentiment_head = nn.Linear(hidden_dim*2, num_sentiments)
        self.max_utt = max_utt

    def forward(self, audio_dialog_utts, text_dialog_utts):

        audio_emb = self.audio_encoder(audio_dialog_utts)
        audio_emb = F.adaptive_avg_pool2d(audio_emb, 1)
        audio_emb = torch.flatten(audio_emb, 1)

        text_emb = [self.text_encoder(utt) for utt in text_dialog_utts]
        text_emb = torch.stack(text_emb)  # shape (B, hidden_dim)

        text_emb = text_emb.squeeze(1)

        # print("Final Audio emb shape: ", audio_emb.shape)
        # print("Final Text emb shape: ", text_emb.shape)

        # Combine
        # (utts, audio_enc_dim + text_enc_dim)
        fused_emb = torch.cat([audio_emb, text_emb], dim=-1)

        # Reshape back to (B, S, fused_dim) for the LSTM
        # (1, utts, audio_enc_dim + text_enc_dim)
        fused_emb = fused_emb.unsqueeze(0)

        padded_fused_emb = F.pad(
            fused_emb, (0, 0, 0, self.max_utt - fused_emb.size(1)))

        # Pass through LSTM
        lstm_out, (hn, cn) = self.fusion_lstm(padded_fused_emb)

        # print("LSTM out shape: ", lstm_out.shape)

        lstm_out = lstm_out.squeeze(0)

        # Classification heads

        emotion_logits = self.emotion_head(lstm_out)
        sentiment_logits = self.sentiment_head(lstm_out)

        return emotion_logits, sentiment_logits

In [43]:
model = MultimodalClassifierWithLSTM(
    fusion_lstm=FusionLSTM(input_dim=1536, hidden_dim=256, num_layers=1, bidirectional=True, dropout=0.2),
    audio_encoder=resnet18(modality='audio'),
    text_encoder=TextEncoder(fine_tune=True, unfreeze_last_n_layers=1),
    hidden_dim=256,
    num_emotions=num_emotions,
    num_sentiments=num_sentiments,
    max_utt=max_utt
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0001)
experiment_name = 'TEST'

## Training

### utils functions

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, roc_curve, auc
from sklearn.preprocessing import label_binarize

def compute_metrics(true_labels, predictions):
    """
    Compute classification metrics including accuracy, per-class F1 scores, and weighted average F1 score.
    """
    # Compute overall accuracy
    b_accuracy = round(balanced_accuracy_score(true_labels, predictions), 3)   # Compute F1 scores
    report = classification_report(
        true_labels, predictions, output_dict=True, zero_division=0
    )
    
    per_class_f1 = {label: round(values["f1-score"], 3) for label, values in report.items() if label.isdigit()}
    macro_f1 = round(report["macro avg"]["f1-score"], 3)

    # Compile metrics into a dictionary
    metrics = {
        "balanced_acc": round(b_accuracy, 3),
        "macro_f1": macro_f1,
        "per_class_f1": per_class_f1
    }

    return metrics

def task_result_to_table(task_result):
    df = pd.DataFrame([
        {
            'epoch': res['epoch'],
            'train_loss': res['train_loss'],
            'val_loss': res['val_loss'],
            **flatten_metrics(res['train_metrics'], prefix='train'),
            **flatten_metrics(res['val_metrics'], prefix='val')
        }
        for res in task_result
    ])
    
    return df

def flatten_metrics(metrics, prefix=''):
    """
    Flatten metrics for easier storage in a tabular format.

    :param metrics: Dictionary of metrics (e.g., {'fused': {...}, 'audio': {...}, 'text': {...}})
    :param prefix: Prefix for column names (optional, e.g., 'train' or 'val').
    :return: Flattened dictionary.
    """
    flattened = {}
    for modality, modality_metrics in metrics.items():
        for metric_name, value in modality_metrics.items():
            flattened[f"{prefix}_{modality}_{metric_name}"] = value
    return flattened

def plot_metrics(results, task='emotions', metric='acc', modality='fused'):
    """
    Plots metrics (e.g., accuracy or F1 score) for a specific task and modality over epochs.
    
    :param results: Dictionary containing training and validation results.
    :param task: Task name ('emotion' or 'sentiment').
    :param metric: Metric to plot (e.g., 'acc', 'macro_f1', 'weighted_f1').
    :param modality: Modality to plot ('fused', 'audio', 'text').
    """
    train_values = [
        epoch_results['train_metrics'][modality][metric] 
        for epoch_results in results[f'results_{task}']
    ]
    val_values = [
        epoch_results['val_metrics'][modality][metric] 
        for epoch_results in results[f'results_{task}']
    ]
    epochs = list(range(1, len(train_values) + 1))

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_values, label=f'Train {metric.capitalize()}', marker='o')
    plt.plot(epochs, val_values, label=f'Val {metric.capitalize()}', marker='o')
    plt.title(f'{task.capitalize()} {metric.capitalize()} ({modality.capitalize()})')
    plt.xlabel('Epochs')
    plt.ylabel(metric.capitalize())
    plt.legend()
    plt.grid(True)
    plt.show()

def analyze_results_per_class(true_labels, predicted_labels, class_names, task_name="Sentiment", mode="confusion_matrix"):
    """
    Analyze results per class with confusion matrix, classification report, or ROC curves.

    Args:
        true_labels (list or np.ndarray): True labels for the task.
        predicted_labels (list or np.ndarray): Predicted labels from the model.
        class_names (list): List of class names.
        task_name (str): Name of the task (e.g., "Sentiment" or "Emotion").
        mode (str): The type of analysis. Options: "confusion_matrix", "classification_report", "roc_curve".
    """
    os.makedirs('images', exist_ok=True)
    
    if mode == "confusion_matrix":
        # Plot confusion matrix
        cm = confusion_matrix(true_labels, predicted_labels, labels=range(len(class_names)))
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
        plt.title(f"Confusion Matrix for {task_name}")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.xticks(rotation=45)
        plt.yticks(rotation=45)
        plt.savefig(f'images/alternating/{task_name}_confusion_matrix.png')
        plt.close()

    elif mode == "classification_report":
        # Print classification report
        report = classification_report(true_labels, predicted_labels, target_names=class_names, zero_division=0)
        with open(f'images/alternating/{task_name}_classification_report.txt', 'w') as f:
            f.write(f"Classification Report for {task_name}:\n\n")
            f.write(report)

    elif mode == "roc_curve":
        # Compute and plot ROC curves
        true_binarized = label_binarize(true_labels, classes=range(len(class_names)))
        predicted_binarized = label_binarize(predicted_labels, classes=range(len(class_names)))
        plt.figure(figsize=(8, 6))
        for i, class_name in enumerate(class_names):
            fpr, tpr, _ = roc_curve(true_binarized[:, i], predicted_binarized[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f"{class_name} (AUC = {roc_auc:.2f})")
        plt.plot([0, 1], [0, 1], "k--")  # Random baseline
        plt.title(f"ROC Curve for {task_name}")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend(loc="lower right")
        plt.savefig(f'images/alternating/{task_name}_roc_curve.png')
        plt.close()

def compute_emotion_class_weights(train_set, device, normalize=True):
    """
    Computes normalized inverse-frequency class weights for emotion labels 
    from the given training set and moves the result to the specified device.

    Args:
        train_set: Dataset object that contains samples and a method `get_emotions_dicts()`.
        device: The target device (e.g., 'cpu' or 'cuda') to move the tensor.
        normalize (bool): Whether to normalize the class weights to sum to 1. Default is True.

    Returns:
        torch.Tensor: Tensor containing class weights on the specified device.
    """
    # Gather all emotion labels in the train set
    emotion_labels = [sample[2] for sample in train_set.samples]

    # Count how many samples of each class
    unique_classes, counts = np.unique(emotion_labels, return_counts=True)
    print("Class labels:", unique_classes)
    print("Class counts:", counts)

    # Retrieve mapping dictionaries; assume str_to_int maps emotion string to an integer index
    _, str_to_int = train_set.get_emotions_dicts()

    num_classes = len(str_to_int)
    ordered_counts = [0] * num_classes

    # Map counts to the correct class index order
    for class_label, count in zip(unique_classes, counts):
        class_idx = str_to_int[class_label]   # e.g. 'neutral' → 0, 'joy' → 1, etc.
        ordered_counts[class_idx] = count

    ordered_counts = np.array(ordered_counts)
    print("Ordered counts:", ordered_counts)

    # Compute inverse frequency, avoiding division by zero
    inverse_freq = 1.0 / np.maximum(ordered_counts, 1)

    # Create a tensor of class weights
    emotions_class_weights = torch.tensor(inverse_freq, dtype=torch.float32)

    # Normalize the class weights if requested
    if normalize:
        emotions_class_weights = emotions_class_weights / emotions_class_weights.sum()

    # Move the tensor to the specified device
    emotions_class_weights = emotions_class_weights.to(device)

    return emotions_class_weights

In [None]:
def train_one_epoch(model, dataloader, optimizer, criterions, emotion_reg=0.5, sentiment_reg=0.5, device='cuda'):
    model.train()

    losses = {'emotion': 0.0, 'sentiment': 0.0}
    metrics = {
        'emotion': {'fused': [], 'labels': []},
        'sentiment': {'fused': [], 'labels': []}
    }

    loop = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in loop:
        batch_emotion_loss = 0.0
        batch_sentiment_loss = 0.0

        batch_emotion_preds = []
        batch_emotion_labels = []
        batch_sentiment_preds = []
        batch_sentiment_labels = []

        for b_idx in range(len(batch["dialog_ids"])):
            fbank_list      = batch["fbank_lists"][b_idx]       
            text_list       = batch["text_lists"][b_idx]        
            emotion_list    = batch["emotion_lists"][b_idx]
            sentiment_list  = batch["sentiment_lists"][b_idx]

            # Determine the actual number of utterances before padding
            actual_len = len(emotion_list)  # Number of real utterances in this conversation

            # Convert lists to tensors
            emotion_tensor = torch.as_tensor(emotion_list, dtype=torch.long, device=device)
            sentiment_tensor = torch.as_tensor(sentiment_list, dtype=torch.long, device=device)
            audio_array = torch.as_tensor(fbank_list, dtype=torch.float, device=device)

            audio_array = audio_array.unsqueeze(1)  # adjust dimensions as needed

            # Forward pass
            emotion_logits, sentiment_logits = model(audio_array, text_list)

            # If outputs have shape (1, seq_len, num_classes), squeeze the batch dimension
            if len(emotion_logits.shape) == 3:
                emotion_logits = emotion_logits.squeeze(0) 
                sentiment_logits = sentiment_logits.squeeze(0)

            # Slice the logits and tensors to ignore padded timesteps
            # emotion_logits shape assumed to be (max_utt, num_classes) at this point
            emotion_logits = emotion_logits[:actual_len]
            sentiment_logits = sentiment_logits[:actual_len]

            emotion_tensor = emotion_tensor[:actual_len]
            sentiment_tensor = sentiment_tensor[:actual_len]

            # Compute Loss
            e_loss = criterions['emotion'](emotion_logits, emotion_tensor)
            s_loss = criterions['sentiment'](sentiment_logits, sentiment_tensor)

            batch_emotion_loss += e_loss
            batch_sentiment_loss += s_loss

            # Predictions for non-padded time steps
            e_preds = torch.argmax(emotion_logits, dim=-1).detach().cpu().numpy()
            s_preds = torch.argmax(sentiment_logits, dim=-1).detach().cpu().numpy()

            batch_emotion_preds.extend(e_preds)
            batch_emotion_labels.extend(emotion_list)

            batch_sentiment_preds.extend(s_preds)
            batch_sentiment_labels.extend(sentiment_list)

        # Average and combine losses over conversations in batch
        batch_emotion_loss /= len(batch["dialog_ids"])
        batch_sentiment_loss /= len(batch["dialog_ids"])

        combined_loss = emotion_reg * batch_emotion_loss + sentiment_reg * batch_sentiment_loss
        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()

        losses['emotion'] += batch_emotion_loss.item()
        losses['sentiment'] += batch_sentiment_loss.item()

        metrics['emotion']['fused'].extend(batch_emotion_preds)
        metrics['emotion']['labels'].extend(batch_emotion_labels)

        metrics['sentiment']['fused'].extend(batch_sentiment_preds)
        metrics['sentiment']['labels'].extend(batch_sentiment_labels)

    losses['emotion'] /= len(dataloader)
    losses['sentiment'] /= len(dataloader)

    emotion_metrics = compute_metrics(metrics['emotion']['labels'], metrics['emotion']['fused'])
    sentiment_metrics = compute_metrics(metrics['sentiment']['labels'], metrics['sentiment']['fused'])

    return losses, {
        'emotion': {'fused': emotion_metrics},
        'sentiment': {'fused': sentiment_metrics}
    }