# Basic bilstm architecture

In [None]:
import torch
import torchmetrics
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


class CausalClassifier(pl.LightningModule):
    def __init__(
        self,
        embeddings,
        input_dim=4096,
        attention_dim=512,
        modality_embedding_dim=1024,
        emotion_embedding_dim=7,
    ):
        super().__init__()
        self.embeddings = embeddings
        self.modalities = embeddings.keys()

        self.projections = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(input_dim, modality_embedding_dim * 2),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(modality_embedding_dim * 2, modality_embedding_dim),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                )
                for _ in range(len(embeddings))
            ]
        )
        self.emotion_linear = nn.Sequential(
            nn.Linear(
                modality_embedding_dim * len(self.modalities), modality_embedding_dim
            ),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(modality_embedding_dim, attention_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        self.causal_bilstm = nn.LSTM(
            modality_embedding_dim * len(self.modalities),
            attention_dim // 2,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
        )
        self.emotion_bilstm = nn.LSTM(
            modality_embedding_dim * len(self.modalities),
            attention_dim // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.emotion_classifier = nn.Sequential(
            nn.Linear(attention_dim, attention_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(attention_dim // 2, emotion_embedding_dim),
        )

        self.causal_classifier = nn.Sequential(
            nn.Linear(attention_dim * 2, attention_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(attention_dim // 2, 1),
        )

        self.f1 = torchmetrics.F1Score(
            task="binary", multidim_average="global", ignore_index=-1
        )
        self.accuracy = torchmetrics.Accuracy(
            task="binary", multidim_average="global", ignore_index=-1
        )

    def forward(self, modality_embeddings):
        projections = [
            project(embedding.float())
            for project, embedding in zip(self.projections, modality_embeddings)
        ]

        utterance_embeddings = torch.cat(projections, dim=2)
        emotion_embeddings, _ = self.emotion_bilstm(utterance_embeddings)
        causal_embeddings, _ = self.causal_bilstm(utterance_embeddings)
        
        batch_size, seq_len, _ = emotion_embeddings.shape
        emotion_utterances = emotion_embeddings.unsqueeze(2).expand(-1, -1, seq_len, -1)
        causal_utterances = causal_embeddings.unsqueeze(1).expand(-1, seq_len, -1, -1)
        combined_utterances = torch.cat((emotion_utterances, causal_utterances), dim=-1)
        causal_logits = self.causal_classifier(combined_utterances).view(
            batch_size, seq_len, seq_len
        )
        emotion_logits = self.emotion_classifier(emotion_embeddings)

        return emotion_logits, causal_logits

    def training_step(self, batch, batch_idx):
        utterance_lengths = batch["utterance_length"]
        emotion_embeddings = batch["emotion_embedding"]
        labels = batch["label"]

        assert all([m in batch for m in self.modalities]), "incorrect modality input"
        batch_size, num_utterances = batch[self.modalities[0]].shape[:2]
        modality_embeddings = [
            self.embeddings[m](
                batch[m].reshape(batch_size * num_utterances, *batch[m].shape[2:])
            )
            for m in self.modalities
        ]
        modality_embeddings = [
            e.reshape(batch_size, num_utterances, *e.shape[1:])
            for e in modality_embeddings
        ]
        logits = self(modality_embeddings, utterance_lengths, emotion_embeddings)

        range_tensor = torch.arange(num_utterances).expand(batch_size, num_utterances)
        mask = (range_tensor < utterance_lengths.unsqueeze(1)).float()
        loss = F.binary_cross_entropy_with_logits(logits, labels.float(), weight=mask)
        self.log("train_loss", loss, on_epoch=True)

        # Calculate metrics
        preds = (torch.sigmoid(logits) >= 0.5).float()
        self.log("train_f1", self.f1(preds, labels))
        self.log("train_accuracy", self.accuracy(preds, labels))

        return loss

    def validation_step(self, batch, batch_idx):
        utterance_lengths = batch["utterance_length"]
        emotion_embeddings = batch["emotion_embedding"]
        labels = batch["label"]

        assert all([m in batch for m in self.modalities]), "incorrect modality input"
        batch_size, num_utterances = batch[self.modalities[0]].shape[:2]
        modality_embeddings = [
            self.embeddings[m](
                batch[m].reshape(batch_size * num_utterances, *batch[m].shape[2:])
            )
            for m in self.modalities
        ]
        modality_embeddings = [
            e.reshape(batch_size, num_utterances, *e.shape[1:])
            for e in modality_embeddings
        ]
        logits = self(modality_embeddings, utterance_lengths, emotion_embeddings)

        range_tensor = torch.arange(num_utterances).expand(batch_size, num_utterances)
        mask = (range_tensor < utterance_lengths.unsqueeze(1)).float()
        loss = F.binary_cross_entropy_with_logits(logits, labels.float(), weight=mask)
        self.log("val_loss", loss, on_epoch=True)

        # Calculate metrics
        preds = (torch.sigmoid(logits) >= 0.5).float()
        self.log("val_f1", self.f1(preds, labels))
        self.log("val_accuracy", self.accuracy(preds, labels))

        return loss

    def test_step(self, batch, batch_idx):
        utterance_lengths = batch["utterance_length"]
        emotion_embeddings = batch["emotion_embedding"]
        labels = batch["label"]

        assert all([m in batch for m in self.modalities]), "incorrect modality input"
        batch_size, num_utterances = batch[self.modalities[0]].shape[:2]
        modality_embeddings = [
            self.embeddings[m](
                batch[m].reshape(batch_size * num_utterances, *batch[m].shape[2:])
            )
            for m in self.modalities
        ]
        modality_embeddings = [
            e.reshape(batch_size, num_utterances, *e.shape[1:])
            for e in modality_embeddings
        ]
        logits = self(modality_embeddings, utterance_lengths, emotion_embeddings)

        # Calculate metrics
        preds = (torch.sigmoid(logits) >= 0.5).float()
        self.log("test_f1", self.f1(preds, labels))
        self.log("test_accuracy", self.accuracy(preds, labels))

    def on_train_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def on_validation_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def on_test_epoch_start(self):
        self.f1.reset()
        self.accuracy.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
