In [59]:
from torch.utils.data import DataLoader
# from timesformer_pytorch import TimeSformer
import pandas as pd
import torch

from utils.datasets import StronglyLabelledDataset

**Loading dataset**

In [57]:
dataset = StronglyLabelledDataset('train')

**Loading TimeSformer Model**

In [56]:
# Just to define TimeSformer's config -- I don't actually need this for now
df = pd.read_json('data/ontology.json')
df = df[['id','name']]

num_classes = len(df)

In [71]:
from transformers import TimesformerForVideoClassification

model = TimesformerForVideoClassification.from_pretrained(
    'facebook/timesformer-base-finetuned-k400',
    num_labels=num_classes,
    ignore_mismatched_sizes=True
)

model.config.output_hidden_states = True
model.config.num_frames = 128

Some weights of TimesformerForVideoClassification were not initialized from the model checkpoint at facebook/timesformer-base-finetuned-k400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([632, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([632]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
model.eval()

TimesformerForVideoClassification(
  (timesformer): TimesformerModel(
    (embeddings): TimesformerEmbeddings(
      (patch_embeddings): TimesformerPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (time_drop): Dropout(p=0.0, inplace=False)
    )
    (encoder): TimesformerEncoder(
      (layer): ModuleList(
        (0-11): 12 x TimesformerLayer(
          (drop_path): Identity()
          (attention): TimeSformerAttention(
            (attention): TimesformerSelfAttention(
              (qkv): Linear(in_features=768, out_features=2304, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
            )
            (output): TimesformerSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): TimesformerIntermediate(
            (dense

**Making inference with one example to get (second to) last layer embeddings (size = 768)** 

In [72]:
example_idx = 10

video, audio, metadata = dataset[example_idx]

video = video[:128] # Select 128 frames

## add batch dimension
video = video.unsqueeze(0)

# video = video.to(model.device)
# video = video.to(dtype=torch.float32)

print(f"Final video shape: {video.shape}")



torch.Size([301, 720, 1280, 3])
torch.Size([301, 3, 720, 1280])
Final video shape: torch.Size([1, 128, 3, 224, 224])


In [73]:
with torch.no_grad():
    outputs = model(video)

# All 12 layers' hidden states
hidden_states = outputs.hidden_states

# Second to last layer embeddings
last_layer_embeddings = hidden_states[-2]

In [74]:
last_layer_embeddings.shape

torch.Size([1, 25089, 768])

This is a matrix because it is the full sequence for each frame of the video. The size of the rows is indeed equivalent to number of patches * number of frames

In [75]:
128 * (224*224) / (16*16)

25088.0

We can pool it to get a 768 dimensional vector (representing the entire video)

In [76]:
pooled_embedding = last_layer_embeddings.mean(dim=1)

pooled_embedding.shape

torch.Size([1, 768])

In [69]:
model.config

TimesformerConfig {
  "_name_or_path": "facebook/timesformer-base-finetuned-k400",
  "architectures": [
    "TimesformerForVideoClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "attention_type": "divided_space_time",
  "drop_path_rate": 0,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6",
    "7": "LABEL_7",
    "8": "LABEL_8",
    "9": "LABEL_9",
    "10": "LABEL_10",
    "11": "LABEL_11",
    "12": "LABEL_12",
    "13": "LABEL_13",
    "14": "LABEL_14",
    "15": "LABEL_15",
    "16": "LABEL_16",
    "17": "LABEL_17",
    "18": "LABEL_18",
    "19": "LABEL_19",
    "20": "LABEL_20",
    "21": "LABEL_21",
    "22": "LABEL_22",
    "23": "LABEL_23",
    "24": "LABEL_24",
    "25": "LABEL_25",
    "26": "LABEL_26",
    "27": "LABEL_27",
    "28": "LABEL_28",
    "29": "LABEL_29",
    "30": "LABE