In [1]:
from pathlib import Path
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
from metadata import RawDataset, Metadata, Sentiment, Split
import matplotlib.pyplot as plt

In [50]:
dataset_name = 'CMU-MOSEI'
data_dir = Path("./data/preprocessed") / dataset_name
with open(f"./data/{dataset_name}.pkl", "rb") as f:
    metadata = pickle.load(f)

In [51]:
class EmbeddingDataset(Dataset):
    def __init__(self, data_dir: Path, metadata: RawDataset):
        self.data_dir = Path(data_dir)
        self.files = list(self.data_dir.glob("*.pt"))
        assert len(self.files) != 0
        self.metadata = metadata
        utterance_id_to_metadata_temp = {}
        self.utterance_id_to_metadata = {}
        self.utterance_id_to_sentiment = {}
        for m in metadata:
            utterance_id = f"{m.video_id}_{m.clip_id}"
            utterance_id_to_metadata_temp[utterance_id] = m
        for file in self.files:
            m = utterance_id_to_metadata_temp[file.stem]
            self.utterance_id_to_metadata[file.stem] = m
            if m.sentiment == Sentiment.NEGATIVE:
                self.utterance_id_to_sentiment[file.stem] = 0
            elif m.sentiment == Sentiment.NEUTRAL:
                self.utterance_id_to_sentiment[file.stem] = 1
            elif m.sentiment == Sentiment.POSITIVE:
                self.utterance_id_to_sentiment[file.stem] = 2
            else:
                raise ValueError(f"{m.sentiment=} not recognized")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        datum = torch.load(file)
        datum["sentiment"] = self.utterance_id_to_sentiment[file.stem]
        datum["a_length"] = datum["a"].shape[0]
        datum["t_length"] = datum["t"].shape[0]
        datum["v_length"] = datum["v"].shape[0]
        return datum

In [52]:
train_ds = EmbeddingDataset(data_dir / "TRAIN", metadata)
val_ds = EmbeddingDataset(data_dir / "VALIDATION", metadata)
test_ds = EmbeddingDataset(data_dir / "TEST", metadata)

In [53]:
def create_padding_mask(lengths, max_len):
    # Create a range tensor [0, 1, 2, ..., max_len-1]
    range_tensor = torch.arange(max_len).expand(len(lengths), max_len)
    # Compare the range tensor with lengths to generate the mask
    # The mask will have 'True' for valid positions and 'False' for padded positions
    mask = range_tensor < lengths.unsqueeze(1)
    return mask

In [54]:
def collate_fn_pad(batch: list[dict]):
    a_max, t_max, v_max = -1, -1, -1
    a, t, v, sentiment, a_length, t_length, v_length = [], [], [], [], [], [], []
    for datum in batch:
        a_max = max(a_max, datum["a_length"])
        a_length.append(datum["a_length"])
        t_max = max(t_max, datum["t_length"])
        t_length.append(datum["t_length"])
        v_max = max(v_max, datum["v_length"])
        v_length.append(datum["v_length"])
        a.append(datum["a"])
        t.append(datum["t"])
        v.append(datum["v"])
        sentiment.append(datum["sentiment"])
    a = nn.utils.rnn.pad_sequence(a)
    t = nn.utils.rnn.pad_sequence(t)
    v = nn.utils.rnn.pad_sequence(v)
    sentiment = torch.tensor(sentiment)
    a_length = torch.tensor(a_length)
    t_length = torch.tensor(t_length)
    v_length = torch.tensor(v_length)
    collated = {
        "a": a,
        "a_length": a_length,
        "a_padding": create_padding_mask(a_length, a_max),
        "t": t,
        "t_length": t_length,
        "t_padding": create_padding_mask(t_length, t_max),
        "v": v,
        "v_length": v_length,
        "v_padding": create_padding_mask(v_length, v_max),
        "sentiment": sentiment,
    }
    return collated

    # ## get sequence lengths
    # lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    # ## padd
    # batch = [ torch.Tensor(t).to(device) for t in batch ]
    # batch = torch.nn.utils.rnn.pad_sequence(batch)
    # ## compute mask
    # mask = (batch != 0).to(device)
    # return batch, lengths, mask

In [63]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn_pad, num_workers=0)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate_fn_pad, num_workers=0)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn_pad, num_workers=0)

In [65]:
for x in tqdm(test_dl):
    pass

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 121/121 [00:15<00:00,  7.88it/s]


In [None]:
## attn_mask = x["v_padding"].unsqueeze(-1) * x["t_padding"].unsqueeze(1)
attn_mask = ~attn_mask