In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


In [13]:
text_feature = torch.load('./Pretrained/text_feature_train.pt', map_location='mps')
text_feature

tensor([[7.8029, 7.7918, 8.0412,  ..., 7.3728, 7.9075, 7.8302],
        [7.6541, 7.6650, 7.8158,  ..., 7.7378, 7.8277, 7.9552],
        [7.7719, 7.8054, 7.7597,  ..., 7.4221, 7.9540, 7.6630],
        ...,
        [7.5831, 7.7927, 8.0133,  ..., 7.5580, 7.9469, 7.5431],
        [7.6394, 7.8788, 7.6675,  ..., 7.5630, 7.8884, 7.9090],
        [7.8469, 7.5502, 8.1130,  ..., 7.9290, 8.1559, 8.4261]],
       device='mps:0')

In [4]:
text_feature.shape

torch.Size([9989, 768])

In [14]:
import pandas as pd
labels = pd.read_csv('./MELD_Data/train.csv')['Emotion']#.unique()
labels

0        neutral
1        neutral
2        neutral
3        neutral
4       surprise
          ...   
9984     neutral
9985     neutral
9986    surprise
9987     neutral
9988         joy
Name: Emotion, Length: 9989, dtype: object

In [53]:
lbl_idx = {labels.unique()[i] : i for i in range(len(labels.unique()))}
lbl_idx

{'neutral': 0,
 'surprise': 1,
 'fear': 2,
 'sadness': 3,
 'joy': 4,
 'disgust': 5,
 'anger': 6}

In [56]:
lbl = []
for i in range(len(labels)):
    temp = [0] * len(lbl_idx)
    temp[lbl_idx[labels[i]]] = 1
    lbl.append(temp)
lbl = torch.Tensor(lbl).to('mps')

In [58]:
lbl.shape

torch.Size([9989, 7])

In [73]:
class MELD(Dataset):
    def __init__(self, feat_tensor, df):
        super().__init__()
        self.data = torch.load(feat_tensor, map_location='mps')
        self.labels = pd.read_csv(df)['Emotion']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx]
        return data, lbl_idx[label]

In [74]:
class Text(nn.Module):
    def __init__(self, f_dim, num_class):
        super().__init__()
        self.clf = nn.Linear(f_dim, num_class)
    
    def forward(self, x):
        x = self.clf(x)
        return x

In [78]:
model = Text(768, 7).to('mps')
dataset = MELD('./Pretrained/text_feature_train.pt', './MELD_Data/train.csv')
train_loader = DataLoader(dataset, batch_size=4)
optim = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(5):
    print(f"Epoch {epoch+1}")
    acc = 0
    loss_tmp = 0
    for data, label in train_loader:
        optim.zero_grad()
        data, label = data.to('mps'), label.to('mps')
        yhat = model(data)
        loss = loss_fn(yhat, label)
        loss_tmp += loss.item()
        loss.backward()
        optim.step()
        acc += (yhat.argmax(1) == label).type(torch.float).sum().item()
    print(f"Accuracy : {acc/len(dataset)*100:.2f}%")

Epoch 1
Accuracy : 32.68%
Epoch 2
Accuracy : 32.09%
Epoch 3
Accuracy : 32.31%
Epoch 4
Accuracy : 32.14%
Epoch 5
Accuracy : 31.84%
