In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import numpy as np
import math

In [4]:

dummy_visual_embeddings = torch.rand(
    64, 300, 256
)  # batch_size, timesteps, visual_embedding_dim
dummy_audio_embeddings = torch.rand(
    64, 600, 1024
)  # batch_size, timesteps, audio_embedding_dim
dummy_text_embeddings = torch.rand(
    64, 50, 768
)  # batch_size, timesteps, audio_embedding_dim

# with text, the 1st timestep is the CLS token, RoBERTa already adds the CLS token at the
# beginning of the sequence for each element in the batch

In [6]:
audio_embeddings = dummy_audio_embeddings
text_embeddings_cls = dummy_text_embeddings
visual_embeddings = dummy_visual_embeddings

In [7]:
cls_visual = torch.randn(768, requires_grad=True)
cls_audio = torch.randn(768, requires_grad=True)
visual_projection = nn.Linear(in_features=256, out_features=768)
audio_projection = nn.Linear(in_features=1024, out_features=768)

In [8]:
audio_embeddings_cls = torch.cat(
    (
        cls_audio.expand(audio_embeddings.shape[0], 1, cls_audio.shape[0]),
        audio_projection(audio_embeddings),
    ),
    dim=1,
)
visual_embeddings_cls = torch.cat(
    (
        cls_visual.expand(visual_embeddings.shape[0], 1, cls_visual.shape[0]),
        visual_projection(visual_embeddings),
    ),
    dim=1,
)

audio_embeddings_cls = audio_embeddings_cls.permute(1, 0, 2)
x_text = text_embeddings_cls.permute(1, 0, 2)
visual_embeddings_cls = visual_embeddings_cls.permute(1, 0, 2)

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim)
        )
        pe = torch.zeros(max_len, 1, embedding_dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[timesteps, batch_size, embedding_dim]``
        """
        x = x + self.pe[: x.size(0)]
        return x

In [10]:
# in the paper (https://ieeexplore.ieee.org/document/9206016), they encode these pretrained embeddings with
# sinusodial positional encodings, which is a bit strange. Wav2Vec2 achieves sort of a positional "encoding"
# via its convolutional layers
positional_encoder_visual = PositionalEncoding(embedding_dim=768, max_len=400)
positional_encoder_audio = PositionalEncoding(embedding_dim=768, max_len=800)

In [11]:
visual_embeddings_cls = positional_encoder_visual(visual_embeddings_cls)
audio_embeddings_cls = positional_encoder_audio(audio_embeddings_cls)

In [13]:
# The Self-Attention layers have residual connections and are followed by Layer Normalization as described
# in the paper "Attention is All You Need"
visual_self_attention = nn.MultiheadAttention(
    embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768
)
audio_self_attention = nn.MultiheadAttention(
    embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768
)

In [201]:
x_visual, _ = visual_self_attention(
    query=visual_embeddings_cls,
    value=visual_embeddings_cls,
    key=visual_embeddings_cls,
    need_weights=False,
    attn_mask=None
)
x_audio, _ = audio_self_attention(
    query=audio_embeddings_cls,
    value=audio_embeddings_cls,
    key=audio_embeddings_cls,
    need_weights=False,
    attn_mask=None,
)

In [202]:
x_visual = x_visual + visual_embeddings_cls
x_audio = x_audio + audio_embeddings_cls

In [203]:
visual_layernorm = nn.LayerNorm(normalized_shape=768)
audio_layernorm = nn.LayerNorm(normalized_shape=768)

In [204]:
x_visual = visual_layernorm(x_visual)
x_audio = audio_layernorm(x_audio)

In [205]:
# The IMA layers are between a host modality's CLS token and the entire embedding sequence of a target
# modality. So there are a total of 6 IMA Attention blocks, 1 for each possible pairwise permutation of 
# the three modalities. 

audio_visual_ima = nn.MultiheadAttention(embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768)
audio_visual_layernorm = nn.LayerNorm(normalized_shape=768)
audio_text_ima = nn.MultiheadAttention(embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768)
audio_text_layernorm = nn.LayerNorm(normalized_shape=768)

visual_audio_ima = nn.MultiheadAttention(embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768)
visual_audio_layernorm = nn.LayerNorm(normalized_shape=768)
visual_text_ima = nn.MultiheadAttention(embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768)
visual_text_layernorm = nn.LayerNorm(normalized_shape=768)

text_visual_ima = nn.MultiheadAttention(embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768)
text_visual_layernorm = nn.LayerNorm(normalized_shape=768)
text_audio_ima = nn.MultiheadAttention(embed_dim=768, num_heads=4, dropout=0.0, kdim=768, vdim=768)
text_audio_layernorm = nn.LayerNorm(normalized_shape=768)

In [206]:
x_audio_visual_cls, _ = audio_visual_ima(
    query=x_audio[0:1], # note the indexing here, using only the 1st token, which is the CLS token as the query vector.
    value=x_visual,
    key=x_visual,
    need_weights=False,
    attn_mask=None,
)
x_audio_visual_cls = audio_visual_layernorm(x_audio_visual_cls + x_audio[0:1])
x_audio_text_cls, _ = audio_text_ima(
    query=x_audio[0:1],
    value=x_text,
    key=x_text,
    need_weights=False,
    attn_mask=None,
)
x_audio_text_cls = audio_text_layernorm(x_audio_text_cls + x_audio[0:1])

x_visual_audio_cls, _ = visual_audio_ima(
    query=x_visual[0:1],
    value=x_audio,
    key=x_audio,
    need_weights=False,
    attn_mask=None,
)
x_visual_audio_cls = visual_audio_layernorm(x_visual_audio_cls + x_visual[0:1])
x_visual_text_cls, _ = visual_text_ima(
    query=x_visual[0:1],
    value=x_text,
    key=x_text,
    need_weights=False,
    attn_mask=None,
)
x_visual_text_cls = visual_text_layernorm(x_visual_text_cls + x_visual[0:1])

x_text_visual_cls, _ = text_visual_ima(
    query=x_text[0:1],
    value=x_visual,
    key=x_visual,
    need_weights=False,
    attn_mask=None,
)
x_text_visual_cls = text_visual_layernorm(x_text_visual_cls + x_text[0:1])
x_text_audio_cls, _ = text_audio_ima(
    query=x_text[0:1],
    value=x_audio,
    key=x_audio,
    need_weights=False,
    attn_mask=None,
)
x_text_audio_cls = text_audio_layernorm(x_text_audio_cls + x_text[0:1])

In [209]:
# The two fused intermodal representations are further combined via element-wise multiplication. In the paper of 
# the project, they saw a performance boost with this apprach compared to concatenation.
x_fused_audio = torch.squeeze(x_audio_text_cls * x_audio_visual_cls)
x_fused_visual = torch.squeeze(x_visual_audio_cls * x_visual_text_cls)
x_fused_text = torch.squeeze(x_text_audio_cls * x_text_visual_cls)

In [210]:
x_fused_text.shape

torch.Size([64, 768])

In [213]:
x_fused_multimodal = torch.cat((x_fused_audio, x_fused_visual, x_fused_text), dim=1)

In [214]:
x_fused_multimodal.shape

torch.Size([64, 2304])

In [216]:
output_projection = nn.Linear(in_features=768*3, out_features=3)

In [217]:
predictions = output_projection(x_fused_multimodal)

In [218]:
predictions.shape

torch.Size([64, 3])