In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd

In [2]:
class MELD(Dataset):
    def __init__(self, mode, csv, device='cpu', pretrained='/content/MELD.Raw/Pretrained/'):
        super().__init__()
        self.label = pd.read_csv(csv)['Emotion']
        self.video_feature = torch.load(pretrained+f"video_feature_{mode}.pt", map_location=device)
        self.audio_feature = torch.load(pretrained+f"audio_feature_{mode}.pt", map_location=device)
        self.text_feature = torch.load(pretrained+f"text_feature_{mode}.pt", map_location=device)
        self.label_to_index = {'neutral': 0,
                               'surprise': 1,
                               'fear': 2,
                               'sadness': 3,
                               'joy': 4,
                               'disgust': 5,
                               'anger': 6}

    def __getitem__(self, index):
        return self.video_feature[index], self.audio_feature[index], self.text_feature[index], self.label_to_index[self.label[index]]

    def __len__(self):
        return self.video_feature.shape[0]
