# Artificial Intelligence - Project

## Utils

### Import Utils

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import torchaudio
from torchsummary import summary


import cv2
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

objc[36059]: Class AVFFrameReceiver is implemented in both /Users/tt/miniforge3/envs/ds/lib/python3.8/site-packages/av/.dylibs/libavdevice.60.3.100.dylib (0x1465b8760) and /Users/tt/miniforge3/envs/ds/lib/libavdevice.59.7.100.dylib (0x16982c778). One of the two will be used. Which one is undefined.
objc[36059]: Class AVFAudioReceiver is implemented in both /Users/tt/miniforge3/envs/ds/lib/python3.8/site-packages/av/.dylibs/libavdevice.60.3.100.dylib (0x1465b87b0) and /Users/tt/miniforge3/envs/ds/lib/libavdevice.59.7.100.dylib (0x16982c7c8). One of the two will be used. Which one is undefined.


### Download Data

In [2]:
# Download Dataset
!wget https://web.eecs.umich.edu/~mihalcea/downloads/MELD.Raw.tar.gz --no-check-certificate

# Unzip Data
!tar -xf MELD.Raw.tar.gz

# Unzip Train data
!tar -xf ./MELD.Raw/train.tar.gz
# Unzip Valid data
!tar -xf ./MELD.Raw/dev.tar.gz
# Unzip Test data
!tar -xf ./MELD.Raw/test.tar.gz

--2024-06-10 07:31:05--  https://web.eecs.umich.edu/~mihalcea/downloads/MELD.Raw.tar.gz
Resolving web.eecs.umich.edu (web.eecs.umich.edu)... 141.212.113.214
Connecting to web.eecs.umich.edu (web.eecs.umich.edu)|141.212.113.214|:443... connected.
  Unable to locally verify the issuer's authority.
HTTP request sent, awaiting response... 200 OK
Length: 10878146150 (10G) [application/x-gzip]
Saving to: ‘MELD.Raw.tar.gz’


2024-06-10 07:42:54 (14.7 MB/s) - ‘MELD.Raw.tar.gz’ saved [10878146150/10878146150]



In [14]:
# Remove missing data
BASE_DIR = './MELD_Data/'
df_train = pd.read_csv(BASE_DIR + 'train.csv')
for idx in range(len(df_train)):
    try:
        filename = BASE_DIR + 'train/' + 'dia' + str(df_train.iloc[idx]['Dialogue_ID']) + '_utt' + str(df_train.iloc[idx]['Utterance_ID']) + '.mp4'
        cv2.VideoCapture(filename)
        torchaudio.load(filename)
    except:
        df_train.drop(idx, inplace=True)
df_train.to_csv('./train_filtered.csv')

[mov,mp4,m4a,3gp,3g2,mj2 @ 0x302d25df0] moov atom not found
OpenCV: Couldn't read video stream from file "./MELD_Data/train/dia125_utt3.mp4"


In [15]:
len(df_train)

6318

In [16]:
df_train_origin = pd.read_csv(BASE_DIR + 'train.csv')
len(df_train_origin)

9989

In [18]:
# Remove missing data
BASE_DIR = './MELD_Data/'
df_valid = pd.read_csv(BASE_DIR + 'valid.csv')
print(f"Before : {len(df_valid)}")
for idx in range(len(df_valid)):
    try:
        filename = BASE_DIR + 'valid/' + 'dia' + str(df_valid.iloc[idx]['Dialogue_ID']) + '_utt' + str(df_valid.iloc[idx]['Utterance_ID']) + '.mp4'
        cv2.VideoCapture(filename)
        torchaudio.load(filename)
    except:
        df_valid.drop(idx, inplace=True)
print(f"After : {len(df_valid)}")
df_valid.to_csv('./valid_filtered.csv')

Before : 1109


OpenCV: Couldn't read video stream from file "./MELD_Data/valid/dia110_utt7.mp4"


After : 1107


In [2]:
# Remove missing data
BASE_DIR = './MELD_Data/'
df_test = pd.read_csv(BASE_DIR + 'test.csv')
print(f"Before : {len(df_test)}")
for idx in range(len(df_test)):
    try:
        filename = BASE_DIR + 'test/' + 'dia' + str(df_test.iloc[idx]['Dialogue_ID']) + '_utt' + str(df_test.iloc[idx]['Utterance_ID']) + '.mp4'
        cv2.VideoCapture(filename)
        torchaudio.load(filename)
    except:
        df_test.drop(idx, inplace=True)
print(f"After : {len(df_test)}")
df_test.to_csv('./test_filtered.csv')

Before : 2610
After : 1266


### Preprocess and load data

In [3]:
# DataLoader
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchaudio
import torch.nn as nn
from transformers import BertTokenizer

import cv2
import os
import numpy as np
import pandas as pd

class MELDDataset(Dataset):
    def __init__(self, csv, path, transform=None, max_video_len=30, max_audio_len=16000, max_text_len=128):
        self.df = pd.read_csv(csv)
        self.label = {self.df['Emotion'].unique()[i] : i for i in range(len(self.df['Emotion'].unique()))}
        self.path = path
        self.max_video_len = max_video_len
        self.max_audio_len = max_audio_len
        self.max_text_len = max_text_len
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.video_transform = transform

    def __len__(self):
        return len(self.df)

    def load_video(self, video_path):
        frames = []
        try:
            cap = cv2.VideoCapture(video_path)
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = self.video_transform(frame)
                frames.append(frame)
                if len(frames) >= self.max_video_len:
                    break
            cap.release()
            frames = frames[:self.max_video_len]
        except:
            pass
        if len(frames)==0:
            frames.extend([torch.zeros((64, 64))] * (self.max_video_len - len(frames)))
        elif len(frames) < self.max_video_len:
            frames.extend([torch.zeros_like(frames[0])] * (self.max_video_len - len(frames)))
        return torch.stack(frames)

    def load_audio(self, audio_path, mel_bins=128, target_length=1024):
        try:
            waveform, sample_rate = torchaudio.load(audio_path)
            fbank = torchaudio.compliance.kaldi.fbank(
                waveform, htk_compat=True, sample_frequency=sample_rate, use_energy=False,
                window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
            n_frames = fbank.shape[0]
            p = target_length - n_frames
            if p > 0:
                m = torch.nn.ZeroPad2d((0, 0, 0, p))
                fbank = m(fbank)
            elif p < 0:
                fbank = fbank[0:target_length, :]

            fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
            return fbank
        except:
            return torch.zeros((target_length, mel_bins))

    def tokenize_text(self, text):
        encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_text_len, return_tensors='pt')
        return encoding['input_ids'].squeeze(), encoding['attention_mask'].squeeze()

    def __getitem__(self, idx):
        filename = 'dia' + str(self.df.iloc[idx]['Dialogue_ID']) + '_utt' + str(self.df.iloc[idx]['Utterance_ID']) + '.mp4'
        text = self.df.iloc[idx]['Utterance'].replace('\x92', "'")
        video = self.load_video(self.path + filename)
        audio = self.load_audio(self.path + filename)
        text, attention_mask = self.tokenize_text(text)
        label = self.label[self.df.iloc[idx]['Emotion']]
        return video, audio, text, attention_mask, label

def collate_fn(batch):
    videos, audios, texts, attention_masks, labels = zip(*batch)
    videos = torch.stack(videos)
    audios = torch.stack(audios)
    texts = torch.stack(texts)
    attention_masks = torch.stack(attention_masks)
    labels = torch.tensor(labels)
    return videos, audios, texts, attention_masks, labels

def MELD(datatype, transform=None, batch_size=2, collate=collate_fn):
    """DataLoader. \\
    Expected File structure is: \\
    ├── train\\
    ├── valid\\
    ├── test  \\
    ├── train.csv\\
    ├── valid.csv\\
    └── test.csv\\
    Change if you want. \\
    If transform is None, it just resizes data and returns Tensor.\\
    Video (Batch, Frame, Channel, Height, Width) \\
    Audio (Batch, Channel, Sample) \\
    Text  (Batch, tokenized Length)\\
    Label (Batch)
    """
    # Data to load
    if datatype == 'train':
        csv_file = '/content/train_sent_emo.csv'
        data_folder = '/content/train_splits/'
    elif datatype == 'valid':
        csv_file = '/content/MELD.Raw/dev_sent_emo.csv'
        data_folder = '/content/dev_splits_complete/'
    elif datatype == 'test':
        csv_file = '/content/MELD.Raw/test_sent_emo.csv'
        data_folder = '/content/output_repeated_splits_test/'
    # transform
    if transform is None:
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])
    # Load data
    dataset = MELDDataset(csv_file, data_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)
    return dataloader

TypeError: expected string or bytes-like object

In [30]:
# DataLoader
train_loader = MELD('train')
valid_loader = MELD('valid')
test_loader = MELD('test')

# Video

## Video processing model - Swin3D/B

In [20]:
## Select Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Now using {device} device")

Now using cuda device


In [27]:
from torchvision.models.video import swin3d_b
train_loader = MELD('train')
video_model = swin3d_b(weights='KINETICS400_IMAGENET22K_V1')
video_model.head = nn.Identity()
video_model = video_model.to(device)

  0%|          | 0/4995 [00:00<?, ?it/s]

565
torch.Size([1024, 128])


  0%|          | 0/4995 [00:00<?, ?it/s]

145
torch.Size([1024, 128])





## Extract Video Feature (Train)

In [None]:
features = torch.Tensor([]).to(device)
with torch.no_grad():
    video_model.eval()
    for video, _, _, _, _ in tqdm(train_loader):
        video = video.to(device)
        video = video.permute(0, 2, 1, 3, 4)
        feature = video_model(video)
        features = torch.concat([features, feature])
torch.save(features, 'video_feature_train.pt')

# Download feature tensor
files.download('video_feature_train.pt')

## Extract Video Feature (Valid)

In [None]:
features = torch.Tensor([]).to(device)
with torch.no_grad():
    video_model.eval()
    for video, _, _, _, _ in tqdm(valid_loader):
        video = video.to(device)
        video = video.permute(0, 2, 1, 3, 4)
        feature = video_model(video)
        features = torch.concat([features, feature])
torch.save(features, 'video_feature_valid.pt')

# Download feature tensor
files.download('video_feature_valid.pt')

## Extract Video Feature (Test)

In [None]:
features = torch.Tensor([]).to(device)
with torch.no_grad():
    video_model.eval()
    for video, _, _, _, _ in tqdm(test_loader):
        video = video.to(device)
        video = video.permute(0, 2, 1, 3, 4)
        feature = video_model(video)
        features = torch.concat([features, feature])
torch.save(features, 'video_feature_test.pt')

# Download feature tensor
files.download('video_feature_test.pt')

# Text

## Text processing Model - DeBERTa

In [None]:
import torch.nn.functional as F
from transformers import DebertaV2Model

class AngleSDETextEmbeddingModel(nn.Module):
    def __init__(self):
        super(AngleSDETextEmbeddingModel, self).__init__()
        self.deberta = DebertaV2Model.from_pretrained('microsoft/deberta-v3-base')
        self.sde_layer = nn.Linear(768, 768)
        self.angle_weight = nn.Parameter(torch.Tensor(768, 768))
        nn.init.uniform_(self.angle_weight)

    def angle_optimization(self, embeddings):
        norm = embeddings.norm(p=2, dim=1, keepdim=True)
        normalized_embeddings = embeddings / norm
        angles = torch.mm(normalized_embeddings, self.angle_weight)
        return angles

    def forward(self, input_ids, attention_mask):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        pooled_output = last_hidden_state[:, 0, :]

        # Apply SDE
        sde_output = self.sde_layer(pooled_output)
        sde_output = F.relu(sde_output)

        # Apply AnglE
        angle_output = self.angle_optimization(sde_output)

        return angle_output

text_model = AngleSDETextEmbeddingModel()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

AngleSDETextEmbeddingModel(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
 

## Extract Text Feature (Train)

In [None]:
text_model = text_model.to(device)
features = torch.Tensor([]).to(device)
with torch.no_grad():
    text_model.eval()
    for _, _, text, attn, _ in tqdm(train_loader):
        text, attn = text.to(device), attn.to(device)
        feature = text_model(text, attn)
        features = torch.concat([features, feature])
torch.save(features, 'text_feature_train.pt')
files.download('text_feature_train.pt')

## Extract Text Feature (Valid)

In [None]:
text_model = text_model.to(device)
features = torch.Tensor([]).to(device)
with torch.no_grad():
    text_model.eval()
    for _, _, text, attn, _ in tqdm(valid_loader):
        text, attn = text.to(device), attn.to(device)
        feature = text_model(text, attn)
        features = torch.concat([features, feature])
torch.save(features, 'text_feature_valid.pt')
files.download('text_feature_valid.pt')

## Extract Text Feature (Test)

In [None]:
text_model = text_model.to(device)
features = torch.Tensor([]).to(device)
with torch.no_grad():
    text_model.eval()
    for _, _, text, attn, _ in tqdm(test_loader):
        text, attn = text.to(device), attn.to(device)
        feature = text_model(text, attn)
        features = torch.concat([features, feature])
torch.save(features, 'text_feature_test.pt')
files.download('text_feature_test.pt')

# Audio

## Audio processing model: AST

In [None]:
# Initial Settings
import sys
!git clone https://github.com/YuanGongND/ast
sys.path.append('./ast')
%cd /content/ast/
!pip install timm==0.4.5
!pip install wget
import os, csv, argparse, wget
os.environ['TORCH_HOME'] = '/content/ast/pretrained_models'
if os.path.exists('/content/ast/pretrained_models') == False:
    os.mkdir('/content/ast/pretrained_models')
import torch, torchaudio, timm
import numpy as np
from torch.cuda.amp import autocast
import IPython
import torch.nn as nn

Cloning into 'ast'...
remote: Enumerating objects: 649, done.[K
remote: Counting objects: 100% (209/209), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 649 (delta 172), reused 159 (delta 159), pack-reused 440[K
Receiving objects: 100% (649/649), 2.41 MiB | 24.89 MiB/s, done.
Resolving deltas: 100% (360/360), done.
/content/ast
Collecting timm==0.4.5
  Downloading timm-0.4.5-py3-none-any.whl (287 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m287.4/287.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.4->timm==0.4.5)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.4->timm==0.4.5)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.4->timm==0.4.5)
  Using cached nvidia_cuda_cupti_cu12-12

In [None]:
# Model Definition
from src.models import ASTModel
class ASTModelVis(ASTModel):
    def get_att_map(self, block, x):
        qkv = block.attn.qkv
        num_heads = block.attn.num_heads
        scale = block.attn.scale
        B, N, C = x.shape
        qkv = qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)
        return attn

    def forward_visualization(self, x):
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)

        B = x.shape[0]
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)
        # save the attention map of each of 12 Transformer layer
        att_list = []
        for blk in self.v.blocks:
            cur_att = self.get_att_map(blk, x)
            att_list.append(cur_att)
            x = blk(x)
        return att_list

In [None]:
# Model Initiation and load pre-trained weights
audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
if os.path.exists('/content/ast/pretrained_models/audio_mdl.pth') == False:
    wget.download(audioset_mdl_url, out='/content/ast/pretrained_models/audio_mdl.pth')
input_tdim = 1024
ast_mdl = ASTModelVis(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
checkpoint_path = '/content/ast/pretrained_models/audio_mdl.pth'
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
audio_model.load_state_dict(checkpoint)
audio_model.module.mlp_head[1] = nn.Identity()

---------------AST Model Summary---------------
ImageNet pretraining: False, AudioSet pretraining: False
frequncey stride=10, time stride=10
number of patches=1212
DataParallel(
  (module): ASTModelVis(
    (v): DistilledVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): ModuleList(
        (0-11): 12 x Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out

## Extract Audio Feature (Train)

In [None]:
audio_model = audio_model.to(device)
features = torch.Tensor([]).to(device)
with torch.no_grad():
    audio_model.eval()
    for _, audio, _, _, _ in tqdm(train_loader):
        audio = audio.to(device)
        feature = audio_model(audio)
        features = torch.concat([features, feature])
torch.save(features, 'audio_feature_train.pt')
files.download('audio_feature_train.pt')

## Extract Audio Feature (Valid)

In [None]:
audio_model = audio_model.to(device)
features = torch.Tensor([]).to(device)
with torch.no_grad():
    audio_model.eval()
    for _, audio, _, _, _ in tqdm(valid_loader):
        audio = audio.to(device)
        feature = audio_model(audio)
        features = torch.concat([features, feature])
torch.save(features, 'audio_feature_valid.pt')
files.download('audio_feature_valid.pt')

## Extract Audio Feature (Test)

In [None]:
audio_model = audio_model.to(device)
features = torch.Tensor([]).to(device)
with torch.no_grad():
    audio_model.eval()
    for _, audio, _, _, _ in tqdm(test_loader):
        audio = audio.to(device)
        feature = audio_model(audio)
        features = torch.concat([features, feature])
torch.save(features, 'audio_feature_test.pt')
files.download('audio_feature_test.pt')

# Final Classifying model

In [None]:
# DataLoader
train_loader = MELD('train')
valid_loader = MELD('valid')

In [None]:
###################
# Vanilla Version #
###################
class MELDClassifier(nn.Module):
    def __init__(self, video_model, audio_model, text_model):
        # Feature Extractor
        super().__init__()
        self.video_model = video_model
        self.audio_model = audio_model
        self.text_model = text_model
        self.video_mapping = nn.Linear(1024, 768)
        self.audio_mapping = nn.Linear(768, 768)
        self.text_mapping = nn.Linear(768, 768)
        # Weight
        self.Wv = nn.Parameter(torch.Tensor([1]))
        self.Wa = nn.Parameter(torch.Tensor([1]))
        self.Wt = nn.Parameter(torch.Tensor([1]))
        # Classifier
        self.clf = nn.Linear(768, 7)

    def forward(self, v, a, t, attm):
        # Extract video feature and map
        fv = self.video_model(v)
        fv = self.video_mapping(fv)
        # Extract audio feature and map
        fa = self.audio_model(a)
        fa = self.audio_mapping(fa)
        # Extract text feature and map
        ft = self.text_model(t, attm)
        ft = self.text_mapping(ft)
        # Option 1: just weighting them
        feature = self.Wv * fv + self.Wa * fa + self.Wt * ft
        output = self.clf(feature)
        return output

In [None]:
# Model
model = MELDClassifier(video_model=video_model,
                       audio_model=audio_model,
                       text_model=text_model).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()
# Train
for epoch in range(10):
    loss_tmp = 0
    acc_tmp = 0
    for video, audio, text, attm, label in tqdm(train_loader):
        video, audio, text, attm, label = video.to(device), audio.to(device), text.to(device), attm.to(device), label.to(device)
        optimizer.zero_grad()
        video = video.permute(0, 2, 1, 3, 4)
        yhat = model(video, audio, text, attm)
        loss = loss_fn(yhat, label)
        loss.backward()
        optimizer.step()
        loss_tmp += loss.item()
        acc_tmp += (yhat.argmax(1) == label).type(torch.float).sum().item()
    print(f"Epoch {epoch} : Accuracy {acc_tmp:.2f}, Loss {loss_tmp:.2f}")
    torch.save(model.state_dict(), f'model_vanilla_epoch_{epoch}.pt')

In [None]:
# Validation
load_existing_pt = -1  # Indicate specific epoch
if load_existing_pt >= 1:
    model.load_state_dict(torch.load(f'/content/model_vanilla_epoch_{str(load_existing_pt)}.pt'))

with torch.no_grad():
    model.eval()
    loss_tmp = 0
    acc_tmp = 0
    for video, audio, text, attm, label in tqdm(valid_loader):
        video, audio, text, attm, label = video.to(device), audio.to(device), text.to(device), attm.to(device), label.to(device)
        video = video.permute(0, 2, 1, 3, 4)
        yhat = model(video, audio, text, attm)
        loss = loss_fn(yhat, label)
        loss_tmp += loss.item()
        acc_tmp += (yhat.argmax(1) == label).type(torch.float).sum().item()
    print(f"Accuracy {acc_tmp:.2f}, Loss {loss_tmp:.2f}")

In [None]:
###################
# DDM+CFM Version #
###################
class MELDClassifier(nn.Module):
    def __init__(self, video_model, audio_model, text_model):
        # Feature Extractor
        super().__init__()
        self.video_model = video_model
        self.audio_model = audio_model
        self.text_model = text_model
        # Modality and Utterance encoder
        self.video_projection = nn.Linear(1024, 768)
        self.video_me = nn.Linear(768, 768)
        self.video_ue = nn.Linear(768, 768)
        self.audio_me = nn.Linear(768, 768)
        self.audio_ue = nn.Linear(768, 768)
        self.text_me = nn.Linear(768, 768)
        self.text_ue = nn.Linear(768, 768)
        # Weight
        self.Wv = nn.Linear(768*3, 1)
        self.Wa = nn.Linear(768*3, 1)
        self.Wt = nn.Linear(768*3, 1)
        # Auxilary Loss
        self.mse = nn.MSELoss()
        # TCP
        self.TCPv = nn.Linear(768*3, 7)
        self.TCPa = nn.Linear(768*3, 7)
        self.TCPt = nn.Linear(768*3, 7)
        # Classifier
        self.clf = nn.Linear(768*3, 7)

    def forward(self, v, a, t, attm):
        loss = 0
        # Extract video feature, map and encode
        fv = self.video_model(v)
        fv = self.video_projection(fv)
        fv_m = self.video_me(fv)
        fv_u = self.video_ue(fv)
        # Extract audio feature and encode
        fa = self.audio_model(a)
        fa_m = self.audio_me(fa)
        fa_u = self.audio_ue(fa)
        # Extract text feature and encode
        ft = self.text_model(t, attm)
        ft_m = self.text_me(ft)
        ft_u = self.text_ue(ft)

        ############ DDM ############
        if self.training:
            # Contrastive Learning: Prepare
            B = ft.shape[0]
            f_modality = torch.empty(B, 768*3)
            f_modality[0::3] = fv_m
            f_modality[1::3] = fa_m
            f_modality[2::3] = ft_m
            f_utterance = torch.empty(B, 768*3)
            f_utterance[0::3] = fv_u
            f_utterance[1::3] = fa_u
            f_utterance[2::3] = ft_u
            # Contrastive Learning: Modality
            cos_sim = F.cosine_similarity(f_modality.unsqueeze(1), f_modality.unsqueeze(0), dim=-1)
            pos_indices = torch.arange(0, B*3).reshape(3, B).T
            pos_loss = 0
            for i in range(3):
                loss += (1 - cos_sim[pos_indices[:, i], pos_indices[:, i]]).mean()
            for i in range(3):
                for j in range(i + 1, 3):
                    loss += cos_sim[pos_indices[:, i], pos_indices[:, j]].mean()
            # Contrastive Learning: Utterance
            cos_sim = F.cosine_similarity(f_utterance.unsqueeze(1), f_utterance.unsqueeze(0), dim=-1)
            loss = torch.mean(torch.triu(cos_sim, diagonal=1))
        # Concat each feature vectors
        fv = torch.concat([fv, fv_m, fv_u])
        fa = torch.concat([fa, fa_m, fa_u])
        ft = torch.concat([ft, ft_m, ft_u])

        ############ CFM ############
        Wv = torch.sigmoid(self.Wv(fv))
        Wa = torch.sigmoid(self.Wa(fa))
        Wt = torch.sigmoid(self.Wt(ft))
        if self.training:
            # TCP, True Classification Probability
            logit_v = self.TCPv(fv)
            logit_a = self.TCPa(fa)
            logit_t = self.TCPt(ft)
            # TCP Loss
            Lv = F.softmax(logit_v).max(1)[0]
            La = F.softmax(logit_a).max(1)[0]
            Lt = F.softmax(logit_t).max(1)[0]
            loss += self.mse(Lv, Wv)
            loss += self.mse(La, Wa)
            loss += self.mse(Lt, Wt)

        # Weight
        h = Wv*fv + Wa*fa + Wt*ft

        ############ Classifier ############
        output = self.clf(h)
        if self.training:
            return output, loss, logit_v, logit_a, logit_t
        else:
            return output

In [None]:
# Model
model = MELDClassifier(video_model=video_model,
                       audio_model=audio_model,
                       text_model=text_model).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Train
model.train()
for epoch in range(10):
    loss_tmp = 0
    acc_tmp = 0
    for video, audio, text, attm, label in tqdm(train_loader):
        video, audio, text, attm, label = video.to(device), audio.to(device), text.to(device), attm.to(device), label.to(device)
        optimizer.zero_grad()
        video = video.permute(0, 2, 1, 3, 4)
        yhat, loss, lv, la, lt = model(video, audio, text, attm)
        loss += loss_fn(yhat, label)
        loss += loss_fn(yhat, lv)
        loss += loss_fn(yhat, la)
        loss += loss_fn(yhat, lt)
        loss.backward()
        optimizer.step()
        loss_tmp += loss.item()
        acc_tmp += (yhat.argmax(1) == label).type(torch.float).sum().item()
    print(f"Epoch {epoch} : Accuracy {acc_tmp:.2f}, Loss {loss_tmp:.2f}")
    torch.save(model.state_dict(), f'model_epoch_{epoch}.pt')

In [None]:
# Validation
load_existing_pt = -1  # Indicate specific epoch
if load_existing_pt >= 1:
    model.load_state_dict(torch.load(f'/content/model_epoch_{str(load_existing_pt)}.pt'))

with torch.no_grad():
    model.eval()
    loss_tmp = 0
    acc_tmp = 0
    for video, audio, text, attm, label in tqdm(valid_loader):
        video, audio, text, attm, label = video.to(device), audio.to(device), text.to(device), attm.to(device), label.to(device)
        video = video.permute(0, 2, 1, 3, 4)
        yhat = model(video, audio, text, attm)
        loss = loss_fn(yhat, label)
        loss_tmp += loss.item()
        acc_tmp += (yhat.argmax(1) == label).type(torch.float).sum().item()
    print(f"Accuracy {acc_tmp:.2f}, Loss {loss_tmp:.2f}")