In [None]:
import pickle
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, accuracy_score
from torch.utils.data import Dataset, DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def calculate_accuracy(y_pred, y_true):
    class_weights = {cls: 1.0/count for cls, count in Counter(y_true).items()}
    wa = balanced_accuracy_score(y_true, y_pred, sample_weight=[class_weights[cls] for cls in y_true])
    ua = accuracy_score(y_true, y_pred)
    return wa, ua

In [None]:
def model_prediction(model, dataloader, accuracy_fn):
    eval_wa = 0.0
    eval_ua = 0.0
    y_true_ls = []
    y_pred_ls = []
    
    model.eval()
    with torch.no_grad():
        for batch, (text_embed, audio_embed, label) in enumerate(dataloader):
            text_embed = text_embed.to(device)
            audio_embed = audio_embed.to(device)
            label = label.to(device)
            output_logits, output_softmax = model(text_embed, audio_embed)
            output_logits, output_softmax = output_logits.to(device), output_softmax.to(device)
            y_preds = output_softmax.argmax(dim=1).to(device)
            
            wa, ua = calculate_accuracy(y_preds.cpu().numpy(), label.cpu().numpy())
            y_true_ls.append(label.cpu().numpy())
            y_pred_ls.append(y_preds.cpu().numpy())

            eval_wa += wa
            eval_ua += ua

            # if batch % 20 == 0:
            #     print(f"\tBatch {batch}: Test loss: {loss:.5f} | Test WA : {wa:.4f} | Test UA : {ua:.4f}")
            #     print("----------------------------------------")
        
        eval_wa /= len(dataloader)
        eval_ua /= len(dataloader)
        print(f"Total Test WA: {eval_wa:.4f} | Total Test UA: {eval_ua:4f}")
        
        return eval_wa, eval_ua, y_true_ls, y_pred_ls

In [None]:
BATCH_SIZE = 128
test_metadata = "features/ECESD_ENG_CMN_BERT_ECAPA_test.pkl"
test_dataset = Customized_Dataset(test_metadata)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class Customized_Dataset(Dataset):
    def __init__(self, metadata):
        super(Customized_Dataset, self).__init__()
        self.data = pickle.load(open(metadata, 'rb'))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        with torch.no_grad():
            text_embed = self.data[idx]['text_embed']
            audio_embed = self.data[idx]['audio_embed']
            label = self.data[idx]['label']

        return text_embed, audio_embed, label

In [None]:
# Create Multi-modal model
class MMSER(nn.Module):
    def __init__(self, num_classes=4):
        super(MMSER, self).__init__()
        self.num_classes = num_classes
        self.dropout = nn.Dropout(.2)
        self.linear1 = nn.Linear(960, 256)
        self.linear2 = nn.Linear(256, 64)
        self.linear3 = nn.Linear(64, num_classes)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, text_embed, audio_embed):
        concat_embed=torch.cat((text_embed,audio_embed), dim=1)
        x = self.dropout(concat_embed)
        x = self.linear1(x)
        x = self.linear2(x)
        y_logits = self.linear3(x)
        y_softmax = self.softmax(y_logits)
        return y_logits, y_softmax

In [None]:
loaded_model = MMSER(num_classes=4)
loaded_model.load_state_dict(torch.load("saved_models/ECESD_ENG_CMN_BERT_ECAPA.pt"))
loaded_model.to(device)

In [None]:
eces_test_wa, eces_test_ua, y_true_ls, y_pred_ls = model_prediction(loaded_model, test_dataloader, calculate_accuracy)

In [None]:
y_true_ls = np.concatenate(y_true_ls, axis=0)
y_pred_ls = np.concatenate(y_pred_ls, axis=0)

In [None]:
cm = confusion_matrix(np.array(y_true_ls), np.array(y_pred_ls))
print(cm)

In [None]:
cmn = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis])*100

ax = plt.subplots(figsize=(8, 5.5))[1]
sns.heatmap(cmn, cmap='flare', annot=True, square=True, linecolor='black', linewidths=0.75, ax = ax, fmt = '.2f', annot_kws={'size': 16})
ax.set_xlabel('Predicted', fontsize=18, fontweight='bold')
ax.xaxis.set_label_position('bottom')
ax.xaxis.set_ticklabels(["Anger", "Happiness", "Neutral", "Sadness"], fontsize=16)
ax.set_ylabel('Ground Truth', fontsize=18, fontweight='bold')
ax.yaxis.set_ticklabels(["Anger", "Happiness", "Neutral", "Sadness"], fontsize=16)
plt.tight_layout()
# plt.show()
# plt.savefig("Confusion_Matrices/ECESD_ENG_CMN_BERT_ECAPA.pdf", dpi=600, bbox_inches='tight') # Uncomment to save figures
# plt.savefig("Confusion_Matrices/ECESD_ENG_CMN_BERT_ECAPA.png", dpi=600, bbox_inches='tight') # Uncomment to save figures