In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd

In [2]:
class MELD(Dataset):
    def __init__(self, mode, csv, device='cpu', pretrained='/content/MELD.Raw/Pretrained/'):
        super().__init__()
        self.label = pd.read_csv(csv)['Emotion']
        self.video_feature = torch.load(pretrained+f"video_feature_{mode}.pt", map_location=device)
        self.audio_feature = torch.load(pretrained+f"audio_feature_{mode}.pt", map_location=device)
        self.text_feature = torch.load(pretrained+f"text_feature_{mode}.pt", map_location=device)
        self.label_to_index = {'neutral': 0,
                               'surprise': 1,
                               'fear': 2,
                               'sadness': 3,
                               'joy': 4,
                               'disgust': 5,
                               'anger': 6}

    def __getitem__(self, index):
        return self.video_feature[index], self.audio_feature[index], self.text_feature[index], self.label_to_index[self.label[index]]

    def __len__(self):
        return self.video_feature.shape[0]


In [4]:
trainset = MELD('train', './MELD_Data/train.csv', device='mps', pretrained='./Pretrained/')
trainloader = DataLoader(trainset, batch_size=8)

In [5]:
# Only Video
class MELDVideo(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.clf = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.clf(x)
        return x
    
device = 'mps'
EPOCHS = 20
model = MELDVideo(1024, 7).to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
model.train()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    acc = 0
    for video, _, _, label in tqdm(trainloader):
        label = label.to(device)
        optimizer.zero_grad()
        yhat = model(video)
        loss = loss_fn(yhat, label)
        loss.backward()
        optimizer.step()
        acc += (yhat.argmax(1)==label).float().sum()
    print(f"Accuracy : {100*acc/len(trainset):.2f}%")
print('\n')

Epoch 1


100%|██████████| 1249/1249 [00:03<00:00, 322.95it/s]


Accuracy : 42.41%
Epoch 2


100%|██████████| 1249/1249 [00:02<00:00, 437.23it/s]


Accuracy : 43.59%
Epoch 3


100%|██████████| 1249/1249 [00:02<00:00, 431.85it/s]


Accuracy : 44.22%
Epoch 4


100%|██████████| 1249/1249 [00:02<00:00, 439.50it/s]


Accuracy : 44.60%
Epoch 5


100%|██████████| 1249/1249 [00:02<00:00, 442.01it/s]


Accuracy : 44.68%
Epoch 6


100%|██████████| 1249/1249 [00:03<00:00, 384.76it/s]


Accuracy : 45.22%
Epoch 7


100%|██████████| 1249/1249 [00:02<00:00, 431.63it/s]


Accuracy : 45.54%
Epoch 8


100%|██████████| 1249/1249 [00:02<00:00, 433.79it/s]


Accuracy : 45.93%
Epoch 9


100%|██████████| 1249/1249 [00:02<00:00, 443.73it/s]


Accuracy : 46.19%
Epoch 10


100%|██████████| 1249/1249 [00:02<00:00, 437.69it/s]


Accuracy : 46.36%
Epoch 11


100%|██████████| 1249/1249 [00:03<00:00, 405.38it/s]


Accuracy : 46.66%
Epoch 12


100%|██████████| 1249/1249 [00:02<00:00, 439.55it/s]


Accuracy : 46.72%
Epoch 13


100%|██████████| 1249/1249 [00:02<00:00, 443.11it/s]


Accuracy : 46.78%
Epoch 14


100%|██████████| 1249/1249 [00:02<00:00, 441.81it/s]


Accuracy : 47.05%
Epoch 15


100%|██████████| 1249/1249 [00:02<00:00, 443.50it/s]


Accuracy : 46.97%
Epoch 16


100%|██████████| 1249/1249 [00:02<00:00, 440.67it/s]


Accuracy : 47.16%
Epoch 17


100%|██████████| 1249/1249 [00:02<00:00, 426.68it/s]


Accuracy : 47.17%
Epoch 18


100%|██████████| 1249/1249 [00:02<00:00, 443.13it/s]


Accuracy : 47.23%
Epoch 19


100%|██████████| 1249/1249 [00:02<00:00, 437.98it/s]


Accuracy : 47.24%
Epoch 20


100%|██████████| 1249/1249 [00:02<00:00, 436.34it/s]

Accuracy : 47.31%







In [6]:
# Only Audio
class MELDAudio(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.clf = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.clf(x)
        return x
    
EPOCHS = 10
model = MELDAudio(768, 7).to(device)
optimizer = torch.optim.AdamW(model.parameters())
loss_fn = nn.CrossEntropyLoss()
model.train()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    acc = 0
    for _, audio, _, label in tqdm(trainloader):
        label = label.to(device)
        optimizer.zero_grad()
        yhat = model(audio)
        loss = loss_fn(yhat, label)
        loss.backward()
        optimizer.step()
        acc += (yhat.argmax(1)==label).float().sum()
    print(f"Accuracy : {100*acc/len(trainset):.2f}%")
print('\n')

Epoch 1


100%|██████████| 1249/1249 [00:03<00:00, 326.49it/s]


Accuracy : 41.82%
Epoch 2


100%|██████████| 1249/1249 [00:03<00:00, 392.94it/s]


Accuracy : 43.22%
Epoch 3


100%|██████████| 1249/1249 [00:03<00:00, 372.25it/s]


Accuracy : 43.70%
Epoch 4


100%|██████████| 1249/1249 [00:03<00:00, 382.65it/s]


Accuracy : 44.04%
Epoch 5


100%|██████████| 1249/1249 [00:03<00:00, 368.48it/s]


Accuracy : 44.42%
Epoch 6


100%|██████████| 1249/1249 [00:03<00:00, 363.09it/s]


Accuracy : 44.65%
Epoch 7


100%|██████████| 1249/1249 [00:03<00:00, 349.71it/s]


Accuracy : 44.96%
Epoch 8


100%|██████████| 1249/1249 [00:03<00:00, 371.11it/s]


Accuracy : 45.13%
Epoch 9


100%|██████████| 1249/1249 [00:03<00:00, 343.75it/s]


Accuracy : 45.23%
Epoch 10


100%|██████████| 1249/1249 [00:03<00:00, 350.58it/s]

Accuracy : 45.35%







In [7]:
class MELDText(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.clf = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.clf(x)
        return x

EPOCHS = 10
model = MELDText(768, 7).to(device)
optimizer = torch.optim.AdamW(model.parameters())
loss_fn = nn.CrossEntropyLoss()
model.train()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    acc = 0
    for _, _, text, label in tqdm(trainloader):
        label = label.to(device)
        optimizer.zero_grad()
        yhat = model(text)
        loss = loss_fn(yhat, label)
        loss.backward()
        optimizer.step()
        acc += (yhat.argmax(1)==label).float().sum()
    print(f"Accuracy : {100*acc/len(trainset):.2f}%")
print('\n')

Epoch 1


100%|██████████| 1249/1249 [00:03<00:00, 331.12it/s]


Accuracy : 43.89%
Epoch 2


100%|██████████| 1249/1249 [00:02<00:00, 420.50it/s]


Accuracy : 44.86%
Epoch 3


100%|██████████| 1249/1249 [00:02<00:00, 438.06it/s]


Accuracy : 45.60%
Epoch 4


100%|██████████| 1249/1249 [00:02<00:00, 422.13it/s]


Accuracy : 46.30%
Epoch 5


100%|██████████| 1249/1249 [00:03<00:00, 348.90it/s]


Accuracy : 47.00%
Epoch 6


100%|██████████| 1249/1249 [00:03<00:00, 373.69it/s]


Accuracy : 47.52%
Epoch 7


100%|██████████| 1249/1249 [00:03<00:00, 391.44it/s]


Accuracy : 47.79%
Epoch 8


100%|██████████| 1249/1249 [00:03<00:00, 367.80it/s]


Accuracy : 48.22%
Epoch 9


100%|██████████| 1249/1249 [00:03<00:00, 406.10it/s]


Accuracy : 48.48%
Epoch 10


100%|██████████| 1249/1249 [00:02<00:00, 431.66it/s]

Accuracy : 48.87%





