# Video-text retrieval tutorial

Video-text retrieval is the task of matching a video to the most similar text from a pool of texts, or vice versa. Ultimately, video-text retrieval can help us answer questions about videos, search for videos, etc.

`torchmultimodal`'s video-text retrieval model follows work by [Hayes et al. (2022)](https://arxiv.org/abs/2204.08058) and [Xu et al. (2021)](https://arxiv.org/abs/2109.14084). Both papers train a video encoder and text encoder with contrastive learning. The resulting architecture is called VideoCLIP, named after OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) model for image-text retrieval.

In this tutorial, we will learn how to:

*   Instantiate a VideoCLIP model
*   Calculate the model's outputs for a video-text dataset
*   Calculate a similarity matrix
*   Use the similarity matrix to peform multimodal retrieval

## Setup

First, follow the instructions in the [MUGEN dataset readme](https://github.com/facebookresearch/multimodal/blob/main/examples/mugen/data/README.md) to download the MUGEN dataset.

Next, we'll import the necessities.

In [1]:
import torch
from torchvision.io import write_video
from IPython.core.display import Video
from examples.mugen.retrieval.video_clip import videoclip
from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
from examples.mugen.data.mugen_datamodules import MUGENDataModule
from torchmultimodal.utils.common import load_module_from_url
from torchmultimodal.transforms.bert_text_transform import BertTextTransform
from torchmultimodal.transforms.video_transform import VideoTransform

To display videos, we'll also define a utility function.

In [2]:
def save_and_display_video(video_filename, video_tensor):
    write_video(video_filename, video_tensor, fps=10, video_codec="h264")
    return Video(video_filename, embed=True)

## Instantiate VideoCLIP model

Here we instantiate the VideoCLIP model and load weights from finetuning on the MUGEN dataset. We set `text_pretrained=False, video_pretrained=False` as those flags will load weights from pretraining the encoders on generic datasets and are designed to be used during finetuning (see [source](https://github.com/facebookresearch/multimodal/blob/main/examples/mugen/retrieval/video_clip.py) for more details).

In [3]:
# Instantiate the VideoCLIP model and load the weights from URL

model = videoclip(text_pretrained=False, video_pretrained=False)
load_module_from_url(model, "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_mugen.pt")

## Calculate model outputs on a single batch

Now we can apply the model to the MUGEN dataset. First we define a `MUGENDataset` object ([source](https://github.com/facebookresearch/multimodal/blob/main/examples/mugen/data/mugen_dataset.py)) and `MUGENDataModule` object ([source](https://github.com/facebookresearch/multimodal/blob/main/examples/mugen/data/mugen_datamodules.py)) for loading and batching data.

Note that usually we would pass in a text transform and video transform to `MUGENDataModule`, but for the sake of displaying the original data in this tutorial, we'll apply transforms later in the pipeline.

In [4]:
mugen_args = MUGENDatasetArgs(
    get_text_desc=True,
    get_game_frame=True,
    get_audio=False, 
    get_seg_map=False, 
    use_manual_annotation=True,
    use_auto_annotation=False,
)

datamodule = MUGENDataModule(
    mugen_args, 
    batch_size=8,
    shuffle=False,
)

Now we fetch a single batch from the dataset. This batch will be the pool of texts and videos to retrieve from.

In [5]:
sample_batch = next(iter(datamodule.test_dataloader()))

LOADING FROM JSON FROM datasets/coinrun/coinrun_dataset_jsons/release/test.json...
NUMBER OF FILES LOADED: 100


Finally, we apply the text and video transforms before inputting the transformed data to the VideoCLIP model. The model output contains the text embeddings and video embeddings that we'll use to calculate similarity.

In [6]:
text_transform = BertTextTransform()
video_transform = VideoTransform()

model.eval()
with torch.no_grad():
    output = model(
        text_transform(sample_batch['text']), 
        video_transform(sample_batch['video'])
    )

## Calculate similarity matrix

With the model output, we then calculate a similarity matrix, which measures the similarity between every pair of text and video in our batch. 

In contrastive learning, the typical way to calculate a similarity matrix is to matrix-multiply the batch of text embeddings by the batch of video embeddings. Let's define a function to get the similarity matrix from the VideoCLIP model output.

In [7]:
def get_similarity_matrix(model_output):
    return model_output.embeddings_a @ model_output.embeddings_b.T

similarity_matrix = get_similarity_matrix(output)
print(similarity_matrix)

tensor([[ 0.2365, -0.2073, -0.1788, -0.1114, -0.2014, -0.0773, -0.2289, -0.1723],
        [-0.1794,  0.2078, -0.1913, -0.1161, -0.1766,  0.0035, -0.1674, -0.1627],
        [-0.2955, -0.3056,  0.1341, -0.3823, -0.2497, -0.3678,  0.1187,  0.0533],
        [-0.1330, -0.1253, -0.2076,  0.1967, -0.0472, -0.0612, -0.2518, -0.1981],
        [-0.0666,  0.0130, -0.0811, -0.0896,  0.1996, -0.0211, -0.2113, -0.1254],
        [-0.0434, -0.1175, -0.1637, -0.0719, -0.0285,  0.1185, -0.1622, -0.1212],
        [-0.3154, -0.3133,  0.0642, -0.3520, -0.4654, -0.2851,  0.2109,  0.0142],
        [-0.1276, -0.2889,  0.0856, -0.3034, -0.2732, -0.2579,  0.0839,  0.1174]])


Take a look at the similarity matrix above for our batch of 8 texts and 8 videos. The item in the 0th row and 1st column, for example, is a measure of the similarity between the 0th text and the 1st video. If we're looking for the most similar video to the 0th text, we'd find the index of the largest number in the 0th row, which happens to be the 0th video.

Thus, we define some helper functions to retrieve the most similar item of one modality to an item of the other modality.

In [8]:
def retrieve_text_from_video(similarity_matrix, video_index):
    return torch.argmax(similarity_matrix[:,video_index]).item()

def retrieve_video_from_text(similarity_matrix, text_index):
    return torch.argmax(similarity_matrix[text_index, :]).item()

## Retrieve text from video

Let's apply what we know about the similarity matrix to retrieve the most similar text to the 0th video in our batch.

First, we find and display the 0th video.

In [9]:
video_idx = 0
video = sample_batch['video'][video_idx]

In [10]:
# Temporarily save the video to file so we can display it

video_filename = f"video_{video_idx}.mp4"
save_and_display_video(video_filename, video)

In [11]:
# Delete the temporary video file

!rm "{video_filename}"

To retrieve the text from the sample batch that is most similar to the 0th video:

In [12]:
text_idx = retrieve_text_from_video(similarity_matrix, video_idx)
text = sample_batch['text'][text_idx]
print(text)

Mugen moves left then right onto a ladder before dismounting the ladder onto the right of the top platform. It then jumps twice to the right onto a box then onto a coin before walking right into a gear being slain.


## Retrieve video from text

We can do the same process with modalities switched. Start with using the 6th text as a "query":

In [13]:
text_idx = 6
text = sample_batch['text'][text_idx]
print(text)

Mugen jumps onto a ledge and runs from right to left. It jumps on a snail, collects a coin, jumps up to another ledge, runs from left to right, and then collects four coins.


Then we retrieve the most similar video to the 6th text.

In [14]:
video_idx = retrieve_video_from_text(similarity_matrix, text_idx)
video = sample_batch['video'][video_idx]

In [15]:
# Temporarily save the video to file so we can display it

video_filename = f"video_{video_idx}.mp4"
save_and_display_video(video_filename, video)

In [16]:
# Delete the temporary video file

!rm "{video_filename}"

## Conclusion

Now you know the fundamentals of multimodal retrieval and how to use the VideoCLIP model! These ideas can be extended to other datasets and other modalities.