In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

Install videollama  
`!pip install semeval/experiments/belikova/videollama/VideoLLaMA`

Download all weights to `semeval/experiments/belikova/videollama/ckpt`
- https://huggingface.co/dim/SemEvalParticipants_models/blob/main/belikova/llama_embedding.pth  
- https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned/tree/main

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from omegaconf import OmegaConf

from semeval.experiments.belikova.videollama.models import (
    VideoLLAMABackbone,
    EmotionCausalClassifier,
)
from semeval.experiments.belikova.videollama.data import (
    EmotionCausalDataset,
)

## Architecture

In [None]:
import torchmetrics
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

In [None]:
class Attention2d(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super().__init__()
        self.attention_weights = nn.Linear(input_dim, attention_dim)
        self.context_vector = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, x):
        attention_scores = self.context_vector(torch.tanh(self.attention_weights(x)))
        attention_weights = F.softmax(attention_scores, dim=2)
        weighted_average = torch.sum(x * attention_weights, dim=2)
        return weighted_average


class CausalClassifier(pl.LightningModule):
    def __init__(
        self,
        embeddings,
        input_dim=5120,
        attention_dim=128,
        modality_embedding_dim=256,
        emotion_embedding_dim=7,
        lstm_hidden_dim=128,
    ):
        super().__init__()
        self.embeddings = embeddings
        self.modalities = embeddings.keys()

        self.projections = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(input_dim, modality_embedding_dim * 2),
                    nn.ReLU(),
                    nn.Linear(modality_embedding_dim * 2, modality_embedding_dim),
                    nn.ReLU(),
                )
                for _ in range(len(embeddings))
            ]
        )
        self.attentions = nn.ModuleList(
            [
                Attention2d(modality_embedding_dim, attention_dim)
                for _ in range(len(embeddings))
            ]
        )

        self.emotion_bilstm = nn.LSTM(
            modality_embedding_dim * len(self.modalities) + emotion_embedding_dim,
            lstm_hidden_dim,
            batch_first=True,
            bidirectional=True,
        )
        self.causal_bilstm = nn.LSTM(
            modality_embedding_dim * len(self.modalities),
            lstm_hidden_dim,
            batch_first=True,
            bidirectional=True,
        )

        self.f1 = torchmetrics.F1Score(
            task="binary", multidim_average="global", ignore_index=-1
        )
        self.accuracy = torchmetrics.Accuracy(
            task="binary", multidim_average="global", ignore_index=-1
        )

    def forward(self, modality_embeddings, utterance_lengths, emotion_embeddings):
        projections = [
            attention(project(embedding.float()))
            for attention, project, embedding in zip(
                self.attentions, self.projections, modality_embeddings
            )
        ]
        utterance_embeddings = torch.cat(projections, dim=2)

        causal_packed_utterances = rnn_utils.pack_padded_sequence(
            utterance_embeddings,
            utterance_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False,
        )

        causal_packed_output, _ = self.causal_bilstm(causal_packed_utterances)
        causal_utterances, _ = rnn_utils.pad_packed_sequence(
            causal_packed_output, batch_first=True
        )

        emotion_packed_utterances = rnn_utils.pack_padded_sequence(
            torch.cat((utterance_embeddings, emotion_embeddings), dim=2),
            utterance_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False,
        )
        emotion_packed_output, _ = self.emotion_bilstm(emotion_packed_utterances)
        emotion_utterances, _ = rnn_utils.pad_packed_sequence(
            emotion_packed_output, batch_first=True
        )

        logits = torch.bmm(emotion_utterances, causal_utterances.transpose(1, 2))

        return logits

    def training_step(self, batch, batch_idx):
        utterance_lengths = batch["utterance_length"]
        emotion_embeddings = batch["emotion_embedding"]
        labels = batch["label"]

        assert all([m in batch for m in self.modalities]), "incorrect modality input"
        batch_size, num_utterances = batch[self.modalities[0]].shape[:2]
        modality_embeddings = [
            self.embeddings[m](
                batch[m].reshape(batch_size * num_utterances, *batch[m].shape[2:])
            )
            for m in self.modalities
        ]
        modality_embeddings = [
            e.reshape(batch_size, num_utterances, *e.shape[1:])
            for e in modality_embeddings
        ]
        logits = self(modality_embeddings, utterance_lengths, emotion_embeddings)

        range_tensor = torch.arange(num_utterances).expand(batch_size, num_utterances)
        mask = (range_tensor < utterance_lengths.unsqueeze(1)).float()
        loss = F.binary_cross_entropy_with_logits(logits, labels.float(), weight=mask)
        self.log("train_loss", loss, on_epoch=True)

        # Calculate metrics
        preds = (torch.sigmoid(logits) >= 0.5).float()
        self.log("train_f1", self.f1(preds, labels))
        self.log("train_accuracy", self.accuracy(preds, labels))

        return loss

    def validation_step(self, batch, batch_idx):
        utterance_lengths = batch["utterance_length"]
        emotion_embeddings = batch["emotion_embedding"]
        labels = batch["label"]

        assert all([m in batch for m in self.modalities]), "incorrect modality input"
        batch_size, num_utterances = batch[self.modalities[0]].shape[:2]
        modality_embeddings = [
            self.embeddings[m](
                batch[m].reshape(batch_size * num_utterances, *batch[m].shape[2:])
            )
            for m in self.modalities
        ]
        modality_embeddings = [
            e.reshape(batch_size, num_utterances, *e.shape[1:])
            for e in modality_embeddings
        ]
        logits = self(modality_embeddings, utterance_lengths, emotion_embeddings)

        range_tensor = torch.arange(num_utterances).expand(batch_size, num_utterances)
        mask = (range_tensor < utterance_lengths.unsqueeze(1)).float()
        loss = F.binary_cross_entropy_with_logits(logits, labels.float(), weight=mask)
        self.log("val_loss", loss, on_epoch=True)

        # Calculate metrics
        preds = (torch.sigmoid(logits) >= 0.5).float()
        self.log("val_f1", self.f1(preds, labels))
        self.log("val_accuracy", self.accuracy(preds, labels))

        return loss

    def test_step(self, batch, batch_idx):
        utterance_lengths = batch["utterance_length"]
        emotion_embeddings = batch["emotion_embedding"]
        labels = batch["label"]

        assert all([m in batch for m in self.modalities]), "incorrect modality input"
        batch_size, num_utterances = batch[self.modalities[0]].shape[:2]
        modality_embeddings = [
            self.embeddings[m](
                batch[m].reshape(batch_size * num_utterances, *batch[m].shape[2:])
            )
            for m in self.modalities
        ]
        modality_embeddings = [
            e.reshape(batch_size, num_utterances, *e.shape[1:])
            for e in modality_embeddings
        ]
        logits = self(modality_embeddings, utterance_lengths, emotion_embeddings)

        # Calculate metrics
        preds = (torch.sigmoid(logits) >= 0.5).float()
        self.log("test_f1", self.f1(preds, labels))
        self.log("test_accuracy", self.accuracy(preds, labels))

    def on_train_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def on_validation_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def on_test_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
class Attention1d(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super().__init__()
        self.attention_weights = nn.Linear(input_dim, attention_dim)
        self.context_vector = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, x):
        attention_scores = self.context_vector(torch.tanh(self.attention_weights(x)))
        attention_weights = F.softmax(attention_scores, dim=1)
        weighted_average = torch.sum(x * attention_weights, dim=1)
        return weighted_average


class EmotionClassifier(pl.LightningModule):
    def __init__(
        self,
        embeddings,
        input_dim=5120,
        hidden_dim=512,
        attention_dim=128,
        num_classes=7,
    ):
        super().__init__()
        self.embeddings = embeddings
        self.modalities = embeddings.keys()

        self.projections = nn.ModuleList(
            [
                nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU())
                for _ in range(len(embeddings))
            ]
        )
        self.attentions = nn.ModuleList(
            [Attention1d(hidden_dim, attention_dim) for _ in range(len(embeddings))]
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes),
        )

        self.accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
        self.f1 = torchmetrics.F1Score(
            task="multiclass", num_classes=num_classes, average="macro"
        )

    def forward(self, embeddings):
        projections = [
            attention(project(embedding.float()))
            for attention, project, embedding in zip(
                self.attentions, self.projections, embeddings
            )
        ]
        concat_features = torch.cat(projections, dim=1)
        logits = self.classifier(concat_features)
        return logits

    def training_step(self, batch, batch_idx):
        embeddings = [self.embeddings[mod](batch[mod]) for mod in self.modalities]
        logits = self(embeddings)
        labels = batch["label"]

        loss = F.cross_entropy(logits, labels)
        self.log("train_loss", loss)

        preds = torch.argmax(logits, dim=1)
        self.log(
            "train_f1",
            self.f1(preds, labels),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "train_acc",
            self.accuracy(preds, labels),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        embeddings = [self.embeddings[mod](batch[mod]) for mod in self.modalities]
        logits = self(embeddings)
        labels = batch["label"]

        loss = F.cross_entropy(logits, labels)
        self.log("val_loss", loss)

        preds = torch.argmax(logits, dim=1)
        self.log(
            "val_f1",
            self.f1(preds, labels),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "val_acc",
            self.accuracy(preds, labels),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

        return loss

    def test_step(self, batch, batch_idx):
        embeddings = [self.embeddings[mod](batch[mod]) for mod in self.modalities]
        logits = self(embeddings)
        labels = batch["label"]

        preds = torch.argmax(logits, dim=1)
        self.log(
            "test_f1",
            self.f1(preds, labels),
            on_step=False,
            on_epoch=True,
        )
        self.log(
            "test_acc",
            self.accuracy(preds, labels),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

    def on_train_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def on_validation_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def on_test_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Data preprocessing

In [None]:
device = torch.device("cuda")
root = "/code/SemEvalParticipants/semeval/experiments/belikova/videollama"

In [None]:
modalities = ["text", "video"]
train_dataset = EmotionCausalDataset(split="train", modalities=modalities)
val_dataset = EmotionCausalDataset(split="test", modalities=modalities)
train_loader = DataLoader(
    train_dataset,
    batch_size=3,
    shuffle=True,
    num_workers=4,
    collate_fn=train_dataset.collater,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=3,
    num_workers=4,
    collate_fn=val_dataset.collater,
)

## Model setting

In [None]:
embeddings = {}
config = OmegaConf.load(root + "/configs/backbone.yaml")
video_backbone = VideoLLAMABackbone.from_config(config)
video_backbone.to(device)
embeddings["video"] = video_backbone.encode_videoQformer
# audio_embedding = video_backbone.encode_audioQformer

text_embedding = nn.Embedding(32000, 5120, padding_idx=0)
text_embedding.load_state_dict(
    torch.load(root + "/ckpt/llama_embedding.pth")
)
text_embedding.to(device)
embeddings["text"] = text_embedding

In [None]:
num_classes = 7
emotion_model = EmotionClassifier(
    {m: None for m in modalities},
    hidden_dim=512,
    attention_dim=128,
    num_classes=num_classes,
)
causal_model = CausalClassifier(
    {m: None for m in modalities},
    attention_dim=128,
    modality_embedding_dim=512,
    lstm_hidden_dim=128,
    emotion_embedding_dim=num_classes,
)
emotion_model.to(device)
causal_model.to(device)

## Training

In [None]:
max_epochs = 30
output_path = root + "/output/joint_classification_model_0.ckpt"

In [None]:
wandb_logger = WandbLogger(
    name="joint_classification_0", project="emotion_analysis"
)

model = EmotionCausalClassifier(
    embeddings,
    causal_classifier=causal_model,
    emotion_classifier=emotion_model,
    hidden_dim=512,
    num_classes=num_classes,
)
trainer = pl.Trainer(
    max_epochs=max_epochs, accelerator="gpu", devices=-1, logger=wandb_logger
)
trainer.fit(model, train_loader, val_loader)
trainer.save_checkpoint(output_path)

wandb_logger.experiment.save(output_path)
wandb_logger.experiment.finish()
