In [10]:
import pickle
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset
from torch_geometric.data import Data
import numpy as np
from collections import Counter
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, accuracy_score
from torch.utils.data import Dataset, DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [11]:
class Customized_Dataset(Dataset):
    def __init__(self, metadata):
        super(Customized_Dataset, self).__init__()
        self.data = pickle.load(open(metadata, 'rb'))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        with torch.no_grad():
            text_embed = self.data[idx]['text_embed']
            audio_embed = self.data[idx]['audio_embed']
            label = self.data[idx]['label']

        return text_embed, audio_embed, label

In [12]:
BATCH_SIZE = 128
train_metadata = "C:/Users/admin/Documents/Speech-Emotion_Recognition-2/features/IEMOCAP_BERT_wav2vec_train.pkl"
val_metadata = "C:/Users/admin/Documents/Speech-Emotion_Recognition-2/features/IEMOCAP_BERT_wav2vec_val.pkl"
test_metadata = "C:/Users/admin/Documents/Speech-Emotion_Recognition-2/features/IEMOCAP_BERT_wav2vec_test.pkl"
train_dataset = Customized_Dataset(train_metadata)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = Customized_Dataset(val_metadata)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [13]:
class HybridAttentionMMSER(nn.Module):
    def __init__(self, num_classes=4, text_dim=768, audio_dim=768):
        super(HybridAttentionMMSER, self).__init__()
        self.num_classes = num_classes

        # Feature extraction layers
        self.text_fc = nn.Sequential(
            nn.Linear(text_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.BatchNorm1d(128)
        )
        self.audio_fc = nn.Sequential(
            nn.Linear(audio_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.BatchNorm1d(128)
        )

        # Cross-Modality Attention Mechanism
        self.text_to_audio_attention = nn.Sequential(
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 1),
            nn.Softmax(dim=1)
        )
        self.audio_to_text_attention = nn.Sequential(
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 1),
            nn.Softmax(dim=1)
        )

        # Fusion Layer (Hybrid: concatenation + element-wise product)
        self.fusion_fc = nn.Sequential(
            nn.Linear(128 * 3, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.BatchNorm1d(64)
        )

        # Classifier
        self.classifier = nn.Linear(64, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, text_embed, audio_embed):
        # Feature extraction
        text_features = self.text_fc(text_embed)
        audio_features = self.audio_fc(audio_embed)

        # Cross-modality attention
        text_att_weights = self.text_to_audio_attention(audio_features)  # Attention from audio to text
        audio_att_weights = self.audio_to_text_attention(text_features)  # Attention from text to audio

        text_att_features = text_features * audio_att_weights  # Apply attention weights
        audio_att_features = audio_features * text_att_weights  # Apply attention weights

        # Fusion
        concat_features = torch.cat((text_att_features, audio_att_features, text_features * audio_features), dim=1)
        fused_features = self.fusion_fc(concat_features)

        # Classification
        y_logits = self.classifier(fused_features)
        y_softmax = self.softmax(y_logits)

        return y_logits, y_softmax


In [14]:
def calculate_accuracy(y_pred, y_true):
    class_weights = {cls: 1.0/count for cls, count in Counter(y_true).items()}
    wa = balanced_accuracy_score(y_true, y_pred, sample_weight=[class_weights[cls] for cls in y_true])
    ua = accuracy_score(y_true, y_pred)
    return wa, ua

In [15]:
def train_step(model, dataloader, optim, loss_fn, accuracy_fn):
    train_loss = 0.0
    train_wa = 0.0
    train_ua = 0.0
    y_true_ls = []
    y_pred_ls = []
    
    model.train()
    for batch, (text_embed, audio_embed, label) in enumerate(dataloader):
        text_embed = text_embed.to(device)
        audio_embed = audio_embed.to(device)
        label = label.to(device)
        output_logits, output_softmax = model(text_embed, audio_embed)
        output_logits, output_softmax = output_logits.to(device), output_softmax.to(device)
        y_preds = output_softmax.argmax(dim=1).to(device)
        
        wa, ua = calculate_accuracy(y_preds.cpu().numpy(), label.cpu().numpy())
        y_true_ls.append(label.cpu().numpy())
        y_pred_ls.append(y_preds.cpu().numpy())
        loss = loss_fn(output_logits, label)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        train_loss += loss.item()
        train_wa += wa
        train_ua += ua
        
        # if batch % 20 == 0:
        #     print(f"\tBatch {batch}: Train loss: {loss:.5f} | Train WA : {wa:.4f} | Train UA : {ua:.4f}")
        #     print("----------------------------------------")
        
    train_loss /= len(dataloader)
    train_wa /= len(dataloader)
    train_ua /= len(dataloader)
    print(f"Total Train loss: {train_loss:.5f} | Total Train WA : {wa:.4f} | Total Train UA : {ua:.4f}")
    
    return train_loss, train_wa, train_ua, y_true_ls, y_pred_ls

In [16]:
def eval_step(model, dataloader, loss_fn, accuracy_fn):
    eval_loss = 0.0
    eval_wa = 0.0
    eval_ua = 0.0
    y_true_ls = []
    y_pred_ls = []
    
    model.eval()
    with torch.no_grad():
        for batch, (text_embed, audio_embed, label) in enumerate(dataloader):
            text_embed = text_embed.to(device)
            audio_embed = audio_embed.to(device)
            label = label.to(device)
            output_logits, output_softmax = model(text_embed, audio_embed)
            output_logits, output_softmax = output_logits.to(device), output_softmax.to(device)
            y_preds = output_softmax.argmax(dim=1).to(device)
            
            wa, ua = calculate_accuracy(y_preds.cpu().numpy(), label.cpu().numpy())
            y_true_ls.append(label.cpu().numpy())
            y_pred_ls.append(y_preds.cpu().numpy())
            loss = loss_fn(output_logits, label)

            eval_loss += loss.item()
            eval_wa += wa
            eval_ua += ua

            # if batch % 20 == 0:
            #     print(f"\tBatch {batch}: Test loss: {loss:.5f} | Test WA : {wa:.4f} | Test UA : {ua:.4f}")
            #     print("----------------------------------------")
        
        eval_loss /= len(dataloader)
        eval_wa /= len(dataloader)
        eval_ua /= len(dataloader)
        print(f"Total Test loss: {eval_loss:.5f} | Total Test WA: {eval_wa:.4f} | Total Test UA: {eval_ua:4f}")
        
        return eval_loss, eval_wa, eval_ua, y_true_ls, y_pred_ls

In [17]:
epochs = []
train_loss_hist, train_wa_hist, train_ua_hist, val_loss_hist, val_wa_hist, val_ua_hist = [], [], [], [], [], []

best_w = {}

best_train_loss, best_val_loss = 10000, 10000
best_wa, best_ua = 0.0, 0.0

NUM_EPOCHS = 200

model=HybridAttentionMMSER(num_classes=4)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [18]:
# Train Multi-modal model
for epoch in range(NUM_EPOCHS):
    print("Epoch", epoch)
    train_loss, train_wa, train_ua, y_true_ls, y_pred_ls = train_step(model, train_dataloader, optimizer, criterion, calculate_accuracy)
    val_loss, val_wa, val_ua, y_true_ls, y_pred_ls = eval_step(model, val_dataloader, criterion, calculate_accuracy)
    lr_scheduler.step()
    
    epochs.append(epoch)
    train_loss_hist.append(train_loss)
    val_loss_hist.append(val_loss)
    train_wa_hist.append(train_wa*100)
    val_wa_hist.append(val_wa*100)
    train_ua_hist.append(train_ua*100)
    val_ua_hist.append(val_ua*100)
    
    if train_loss < best_train_loss and val_loss < best_val_loss:
        best_train_loss, best_val_loss = train_loss, val_loss
        best_wa, best_ua = val_wa, val_ua
        torch.save(model.state_dict(), "C:/Users/admin/Documents/Speech-Emotion_Recognition-2/saved_models/IEMOCAP_ENG_CMN_BERT_wav2vec.pt")
        best_w = model.state_dict()
    
    print("\n==============================\n")
print("Best WA: ", best_wa)
print("Best UA: ", best_ua)


Epoch 0
Total Train loss: 1.14789 | Total Train WA : 0.6750 | Total Train UA : 0.6727
Total Test loss: 1.32858 | Total Test WA: 0.4211 | Total Test UA: 0.396354


Epoch 1
Total Train loss: 0.91570 | Total Train WA : 0.6766 | Total Train UA : 0.6545
Total Test loss: 1.33109 | Total Test WA: 0.4481 | Total Test UA: 0.444992


Epoch 2
Total Train loss: 0.80578 | Total Train WA : 0.7299 | Total Train UA : 0.7273
Total Test loss: 1.27438 | Total Test WA: 0.4566 | Total Test UA: 0.459736


Epoch 3
Total Train loss: 0.75156 | Total Train WA : 0.8332 | Total Train UA : 0.8182
Total Test loss: 1.27954 | Total Test WA: 0.4617 | Total Test UA: 0.495433


Epoch 4
Total Train loss: 0.69676 | Total Train WA : 0.7731 | Total Train UA : 0.7273
Total Test loss: 1.52345 | Total Test WA: 0.4461 | Total Test UA: 0.432732


Epoch 5
Total Train loss: 0.66139 | Total Train WA : 0.8006 | Total Train UA : 0.8000
Total Test loss: 1.43138 | Total Test WA: 0.4308 | Total Test UA: 0.456811


Epoch 6
Total Train lo