In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [2]:
# Install videollama
!pip install /code/SemEvalParticipants/semeval/experiments/belikova/videollama/VideoLLaMA

# Download all weights to /code/SemEvalParticipants/semeval/experiments/belikova/videollama/ckpt
# https://huggingface.co/dim/SemEvalParticipants_models/blob/main/belikova/llama_embedding.pth
# https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned/tree/main

In [3]:
import random

import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaTokenizer
from datasets import load_dataset
from omegaconf import OmegaConf

from semeval.experiments.belikova.videollama.models.backbone import VideoLLAMABackbone
from video_llama.processors import AlproVideoTrainProcessor, AlproVideoEvalProcessor
from video_llama.processors.video_processor import load_video
from video_llama.models.ImageBind.data import load_and_transform_audio_data

In [3]:
ROOT = "/code/SemEvalParticipants/semeval/experiments/belikova/videollama"

## Model

In [4]:
class MultimodalClassifier(pl.LightningModule):
    def __init__(
            self,
            video_embedding,
            audio_embedding,
            text_embedding,
            input_dim=5120,
            hidden_dim=1024,
            num_classes=7,
        ):
        super().__init__()
        self.video_embedding = video_embedding
        self.audio_embedding = audio_embedding
        self.text_embedding = text_embedding
        self.projections = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(3)])

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
        
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.test_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        
        self.train_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average="macro")
        self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average="macro")
        self.test_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average="macro")

    def forward(self, text, video, audio):
        embeddings = [
            self.text_embedding(text),
            self.video_embedding(video),
            self.audio_embedding(audio),
        ]
        
        projections = [
            torch.mean(p(e.float()), dim=1) 
            for p, e in zip(self.projections, embeddings)
        ]
        concat_features = torch.cat(projections, dim=1)
        logits = self.classifier(concat_features)
        return logits

    def training_step(self, batch, batch_idx):
        video, audio, text = batch["video"], batch["audio"], batch["text"]
        labels = batch["label"]
        logits = self(text, video, audio)
        
        loss = nn.functional.cross_entropy(logits, labels)
        self.log("train_loss", loss)
        
        preds = torch.argmax(logits, dim=1)
        self.log("train_f1", self.train_f1(preds, labels), on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy(preds, labels), on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        video, audio, text = batch["video"], batch["audio"], batch["text"]
        labels = batch["label"]
        logits = self(text, video, audio)
        
        loss = nn.functional.cross_entropy(logits, labels)
        self.log("val_loss", loss)
        
        preds = torch.argmax(logits, dim=1)
        self.log('val_f1', self.val_f1(preds, labels), on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_accuracy(preds, labels), on_step=False, on_epoch=True, prog_bar=True)
        return loss
        
    def test_step(self, batch, batch_idx):
        video, audio, text = batch["video"], batch["audio"], batch["text"]
        labels = batch["label"]
        logits = self(text, video, audio)
        
        preds = torch.argmax(logits, dim=1)
        self.log("test_f1", self.test_f1(preds, labels), on_step=False, on_epoch=True)
        self.log('test_acc', self.test_accuracy(preds, labels), on_step=False, on_epoch=True, prog_bar=True)
        
    def on_train_epoch_start(self):
        self.train_f1.reset()

    def on_validation_epoch_start(self):
        self.val_f1.reset()

    def on_test_epoch_start(self):
        self.test_f1.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Data preprocessing

In [5]:
all_emotions = [
    "surprise",
    "fear",
    "sadness",
    "neutral",
    "joy",
    "anger",
    "disgust",
]

emotions2labels = {em: i for i, em in enumerate(all_emotions)}
labels2emotions = {i: em for i, em in enumerate(all_emotions)}

In [6]:
class SemEvalDataset(Dataset):
    
    def __init__(
            self,
            data_name="dim/SemEval_training_data_emotions",
            root="/code/data/video_with_audio",
            split="train",
            num_frames=8,
            resize_size=224,
            tokenizer_name=ROOT + "/ckpt/llama-2-13b-chat-hf",
        ):
        self.root = root
        self.annotation = load_dataset(data_name, split=split)
        self.num_frames = num_frames
        self.resize_size = resize_size
        if split == "train":
            self.transform = AlproVideoTrainProcessor(
                image_size=resize_size,
                n_frms=num_frames,
            ).transform
        else:
            self.transform = AlproVideoEvalProcessor(
                image_size=resize_size,
                n_frms=num_frames,
            ).transform
        self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
        self.tokenizer.pad_token = self.tokenizer.unk_token
        
    def __len__(self):
        return len(self.annotation)
        
    def __getitem__(self, index, num_retries=10, device="cpu"):
        result = {}
        for _ in range(num_retries):
            sample = self.annotation[index]
            video_path = "/".join([self.root, sample["video_name"]])
            try:
                result["video"] = self.transform(
                    load_video(
                        video_path=video_path,
                        n_frms=self.num_frames,
                        height=self.resize_size,
                        width=self.resize_size,
                        sampling ="uniform",
                        return_msg = False,
                    )
                )
                result["text"] = self.tokenizer(
                    sample["text"],
                    return_tensors="pt",
                    padding="longest",
                    max_length=512,
                    truncation=True,
                ).input_ids[0]
                result["audio"] = load_and_transform_audio_data(
                    [video_path],
                    device=device,
                    clips_per_video=self.num_frames,
                )[0]
                result["label"] = emotions2labels[sample["emotion"]]
                assert result["video"].shape[1] == self.num_frames == result["audio"].shape[0]
            except Exception as e:
                index = random.randint(0, len(self) - 1)
                continue
            break
        else:  
            raise RuntimeError(f"Failed to fetch sample after {num_retries} retries.")
        return result
    
    def collater(self, instances):
        text_ids = [instance["text"] for instance in instances]
        text_ids = torch.nn.utils.rnn.pad_sequence(
            text_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id,
        )
        
        batch = {
            "video": torch.stack([instance['video'] for instance in instances]),
            "text": text_ids,
            "audio": torch.stack([instance['audio'] for instance in instances]),
            "label": torch.tensor([instance['label'] for instance in instances]),
        }
        
        return batch

In [7]:
train_dataset = SemEvalDataset(split="train")
val_dataset = SemEvalDataset(split="test")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=train_dataset.collater)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=4, collate_fn=val_dataset.collater)

## Training

In [8]:
num_classes = 7
max_epochs = 20
output_path = ROOT + "/output/emo_classification_model_0.ckpt"

In [9]:
device = torch.device("cuda")

config = OmegaConf.load(ROOT + "/configs/backbone.yaml")
video_backbone = VideoLLAMABackbone.from_config(config)
video_backbone.to(device)
video_embedding = video_backbone.encode_videoQformer
audio_embedding = video_backbone.encode_audioQformer

text_embedding = nn.Embedding(32000, 5120, padding_idx=0, _freeze=True)
text_embedding.load_state_dict(torch.load(ROOT + "/ckpt/llama_embedding.pth"))
text_embedding.to(device);

In [10]:
wandb_logger = WandbLogger(
    name="multimodal_classification_0",
    project="emotion_analysis"
)

model = MultimodalClassifier(
    video_embedding,
    audio_embedding,
    text_embedding,
    num_classes=num_classes,
)
trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator="gpu",
    devices=-1,
    logger=wandb_logger)
trainer.fit(model, train_loader, val_loader)
trainer.save_checkpoint(output_path)

wandb_logger.experiment.save(output_path)
wandb_logger.experiment.finish()

In [None]:
# https://wandb.ai/julia-bel/emotion_analysis/runs/r04ozcch