# Fine-tuning for Video Classification with 🤗 Transformers

NOTE: This is heavily based on [this tutorial](https://github.com/huggingface/notebooks/blob/main/examples/video_classification.ipynb)

This notebook shows how to fine-tune a pre-trained Vision model for Video Classification on a custom dataset. The idea is to add a randomly initialized classification head on top of a pre-trained encoder and fine-tune the model altogether on a labeled dataset.


## Dataset

This notebook uses our own NBA shots video dataset. The dataset was prepared using [this notebook](./create_videos_bank.ipynb).

## Model

We'll fine-tune the [VideoMAE model](https://huggingface.co/docs/transformers/model_doc/videomae), which was pre-trained on the [Kinetics 400 dataset](https://www.deepmind.com/open-source/kinetics). You can find the other variants of VideoMAE available on 🤗 Hub [here](https://huggingface.co/models?search=videomae). You can also extend this notebook to use other video models such as [X-CLIP](https://huggingface.co/docs/transformers/model_doc/xclip#transformers.XCLIPVisionModel). 

**Note** that for models where there's no classification head already available you'll have to manually attach it (randomly initialized). But this is not the case for VideoMAE since we already have a [`VideoMAEForVideoClassification`](https://huggingface.co/docs/transformers/model_doc/xclip#transformers.XCLIPVisionModel) class.

## Data preprocessing

This notebook leverages [TorchVision's](https://pytorch.org/vision/stable/transforms.html) and [PyTorchVideo's](https://pytorchvideo.org/) transforms for applying data preprocessing transformations including data augmentation.

---

Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.

In [None]:
model_ckpt = "MCG-NJU/videomae-base-finetuned-kinetics" # pre-trained model from which to fine-tune
batch_size = 4 # batch size for training and evaluation

Before we start, let's install the `pytorchvideo`, `transformers`, and `evaluate` libraries.

In [None]:
!pip install pytorchvideo transformers evaluate accelerate -U

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your token:

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Then you need to install Git-LFS to upload your model checkpoints:

In [None]:
!git config --global credential.helper store

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

In [None]:
from transformers.utils import send_example_telemetry

send_example_telemetry("video_classification_notebook", framework="pytorch")

## Fine-tuning a model on a video classification task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) vision models on a Video Classification dataset.

Given a video, the goal is to predict an appropriate class for it, like "DUNK".

### Loading the dataset

Broadly, `dataset_root_path` is organized like so:

```bash
dataset/
    train/
        JUMP_SHOT/
            video_1.mp4
            video_2.mp4
            ...
        DUNK/
            video_1.mp4
            video_2.mp4
            ...
        ...
    val/
        JUMP_SHOT/
            video_1.mp4
            video_2.mp4
            ...
        DUNK/
            video_1.mp4
            video_2.mp4
            ...
        ...
    test/
        JUMP_SHOT/
            video_1.mp4
            video_2.mp4
            ...
        DUNK/
            video_1.mp4
            video_2.mp4
            ...
        ...
```

Let's now count the number of total videos we have. 

In [None]:
import shutil
import tempfile
import pathlib

dataset_name = "dataset"

try:
    # For running on colab
    from google.colab import drive
    drive.mount('/content/drive')
    root_path = pathlib.Path(f"./drive/MyDrive/")
    drive_dataset_root_path = root_path.joinpath(dataset_name)
    # Copy dataset from drive to local to avoid failure in loading videos 
    dataset_root_path = pathlib.Path(tempfile.mkdtemp()).joinpath(dataset_name)
    shutil.copytree(drive_dataset_root_path, dataset_root_path)
except ModuleNotFoundError:
    # For running on PC
    # dataset_root_path = pathlib.Path('UCF101_subset')
    root_path = pathlib.Path(".")
    dataset_root_path = pathlib.Path(r"C:\\Users\User\Google Drive").joinpath(dataset_name)
    is_colab = False
else:
    is_colab = True

In [None]:
video_extension = "avi"

video_count_train = len(list(dataset_root_path.glob(f"train/*/*.{video_extension}")))
video_count_val = len(list(dataset_root_path.glob(f"val/*/*.{video_extension}")))
video_count_test = len(list(dataset_root_path.glob(f"test/*/*.{video_extension}")))
video_total = video_count_train + video_count_val + video_count_test
print(f"Total videos: {video_total}")

In [None]:
all_video_file_paths = (
    list(dataset_root_path.glob(f"train/*/*.{video_extension}"))
    + list(dataset_root_path.glob(f"val/*/*.{video_extension}"))
    + list(dataset_root_path.glob(f"test/*/*.{video_extension}"))
)
all_video_file_paths[:5]

The video paths, when `sorted`, appear like so:

```py
...
'dataset/train/DUNK/0011900011_169.avi',
'dataset/train/DUNK/0011900035_719.avi',
'dataset/train/DUNK/0011900038_625.avi',
'dataset/train/DUNK/0011900042_580.avi',
'dataset/train/DUNK/0011900062_182.avi'
...
 ```

We notice that there are video clips belonging to the same game, where the game id denoted by the prefix of the video file name. `0022200892_607.avi` and `0022200892_563.avi`, for example.

Next up, we derive the set of labels we have in the dataset. Let's also create two dictionaries that'll be helpful when initializing the model:

* `label2id`: maps the class names to integers.
* `id2label`: maps the integers to class names. 

In [None]:

from collections import defaultdict

from typing import Iterable, List

def _update_before_and_after_word(mini_labels, before_word, after_word):
    n = len(mini_labels)

    before_word[mini_labels[0]].add(None)
    after_word[mini_labels[-1]].add(None)

    for i in range(n - 1):
        after_word[mini_labels[i]].add(mini_labels[i + 1])
        before_word[mini_labels[i + 1]].add(mini_labels[i])


def get_multi_labels_from_labels(original_labels: Iterable[str]):
    # Create an empty set to store unique words
    unique_mini_labels = set()
    before_word = defaultdict(set)
    after_word = defaultdict(set)

    # Loop through the labels
    for value in original_labels:
        # Split the value into words using "_" as the delimiter
        mini_labels = value.split('_')
        # Update dicts for building exclusive pairs later
        _update_before_and_after_word(mini_labels, before_word, after_word)
        # Add the new mini labels to the unique_mini_labels set
        unique_mini_labels.update(mini_labels)

    for first_word, possible_second_words in after_word.items():
        if len(possible_second_words) == 1:
            second_word = next(iter(possible_second_words))
            if second_word is not None and len(before_word[second_word]) == 1:
                new_mini_label = f"{first_word}_{second_word}"
                unique_mini_labels -= {first_word, second_word}
                unique_mini_labels.add(new_mini_label)

    return unique_mini_labels

In [None]:
class_labels = sorted({path.parent.stem for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

class_mini_labels = sorted(get_multi_labels_from_labels(class_labels))
multi_label2id = {label: i for i, label in enumerate(class_mini_labels)}
id2multi_label = {i: label for label, i in multi_label2id.items()}

original_id2multi_id = {id_:[1 if mini_label in label else 0 for mini_label in class_mini_labels] for id_, label in id2label.items()}

print(f"Unique classes: {list(label2id.keys())}.")
print(f"Unique mini classes: {list(multi_label2id.keys())}.")

We've got 5 unique classes. For each class we have 1600 videos in the training set. 

### Loading the model

In the next cell, we initialize a video classification model where the encoder is initialized with the pre-trained parameters and the classification head is randomly initialized. We also initialize the feature extractor associated to the model. This will come in handy during writing the preprocessing pipeline for our dataset.

In [None]:
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification


image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
model = VideoMAEForVideoClassification.from_pretrained(
    model_ckpt,
    label2id=multi_label2id,
    id2label=id2multi_label,
    problem_type="multi_label_classification",
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

The warning is telling us we are throwing away some weights (e.g. the weights and bias of the `classifier` layer) and randomly initializing some other (the weights and bias of a new `classifier` layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

**Note** that [this checkpoint](https://huggingface.co/MCG-NJU/videomae-base-finetuned-kinetics) leads to better performance on this task as the checkpoint was obtained fine-tuning on a similar downstream task having considerable domain overlap. You can check out [this checkpoint](https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset) which was obtained by fine-tuning `MCG-NJU/videomae-base-finetuned-kinetics` and it obtains much better performance.  

### Constructing the datasets for training

For preprocessing the videos, we'll leverage the [PyTorch Video library](https://pytorchvideo.org/). We start by importing the dependencies we need. 

In [None]:
from typing import Callable, List, Dict
import torch
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)


class ApplyTransformToListUnderKey:
    """
    Applies transform to key of dictionary input, where there is a list of values under it.

    Args:
        key (str): the dictionary key the transform is applied to
        transform (callable): the transform that is applied for each element
    """

    def __init__(self, key: str, transform: Callable):
        self._key = key
        self._transform = transform

    def __call__(self, x: Dict[str, List[torch.Tensor]]) -> Dict[str, List[torch.Tensor]]:
        for i in range(len(x[self._key])):
            x[self._key][i] = self._transform(x[self._key][i])
        return x

In [None]:
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
import torch
from typing import Type, Optional, Callable, Dict, Any
from pytorchvideo.data import ClipSampler, LabeledVideoDataset
from torch.utils.data import Dataset

class MultiLabelLabeledDataset(LabeledVideoDataset):
    def __init__(self, data_path: str, clip_sampler: ClipSampler,
                 video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
                 transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, video_path_prefix: str = "",
                 decode_audio: bool = True, decoder: str = "pyav"):
        labeled_video_paths = LabeledVideoPaths.from_path(data_path)
        labeled_video_paths.path_prefix = video_path_prefix
        super().__init__(labeled_video_paths, clip_sampler, video_sampler, transform, decode_audio, decoder)
        self.original_id2multi_id = original_id2multi_id

    def __next__(self):
        example = super().__next__()
        original_id = example["label"]
        multi_id = self.original_id2multi_id[original_id]
        example["label"] = multi_id
        return example

For the training dataset transformations, we use a combination of uniform temporal subsampling, pixel normalization, random cropping, and random horizontal flipping. For the validation and evaluation dataset transformations, we keep the transformation chain the same except for horizontal flipping. To learn more about the details of these transformations check out the [official documentation of PyTorch Video](https://pytorchvideo.org).  

We'll use the `image_processor` associated with the pre-trained model to obtain the following information:

* Image mean and standard deviation with which the video frame pixels will be normalized.
* Spatial resolution to which the video frames will be resized.

In [None]:
import os

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps


# Training dataset transformations.
train_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                    RandomHorizontalFlip(p=0.5),
                ]
            ),
        ),
    ]
)

# Training dataset.
train_dataset = MultiLabelLabeledDataset(
    data_path=os.path.join(dataset_root_path, "train"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
    decode_audio=False,
    transform=train_transform,
)

# Validation and Test datasets' transformations.
inference_transform = Compose(
    [
        ApplyTransformToListUnderKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                ]
            ),
        ),
    ]
)

CLIPS_FROM_SINGLE_VIDEO = 5

def build_evaluate_dataset(dataset_type: str):
    # Validation and evaluation datasets.
    dataset = MultiLabelLabeledDataset(
        data_path=os.path.join(dataset_root_path, dataset_type),
        clip_sampler=pytorchvideo.data.make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO),
        decode_audio=False,
        transform=inference_transform,
    )
    return dataset

val_dataset = build_evaluate_dataset("val")
test_dataset = build_evaluate_dataset("test")

**Note**: The above dataset pipelines are taken from the [official PyTorch Video example](https://pytorchvideo.org/docs/tutorial_classification#dataset). We're using the [`pytorchvideo.data.Ucf101()`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.Ucf101) function. If your dataset follows a similar structure (as shown above), then using the `pytorchvideo.data.Ucf101()` should work just fine. 

In [None]:
# We can access the `num_videos` argument to know the number of videos we have in the
# dataset.
train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos

Let's now take a preprocessed video from the dataset and investigate it. 

In [None]:
sample_video = next(iter(train_dataset))
display(sample_video['label'])

In [None]:
def investigate_video(sample_video):
    """Utility to investigate the keys present in a single video sample."""
    for k in sample_video:
        v = sample_video[k][0] if isinstance(sample_video[k],list) else sample_video[k] 
        if k == "video":
            print(k, v.shape)
        else:
            print(k, v)

    print(f"Video labels: {[id2multi_label[i] for i, label in enumerate(sample_video['label']) if label == 1]}")


investigate_video(sample_video)

We can also visualize the preprocessed videos for easier debugging. 

In [None]:
import imageio
import numpy as np
from IPython.display import Image


def unnormalize_img(img):
    """Un-normalizes the image pixels."""
    img = (img * std) + mean
    img = (img * 255).astype("uint8")
    return img.clip(0, 255)


def create_gif(video_tensor, filename="sample.gif"):
    """Prepares a GIF from a video tensor.
    
    The video tensor is expected to have the following shape:
    (num_frames, num_channels, height, width).
    """
    frames = []
    for video_frame in video_tensor:
        frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
        frames.append(frame_unnormalized)
    kargs = {"duration": 0.25, "loop": 0}
    imageio.mimsave(filename, frames, "GIF", **kargs)
    return filename


def display_gif(video_tensor, gif_name="sample.gif"):
    """Prepares and displays a GIF from a video tensor."""
    video_tensor = video_tensor.permute(1, 0, 2, 3)
    gif_filename = create_gif(video_tensor, gif_name)
    return Image(filename=gif_filename)

In [None]:
video_tensor = sample_video["video"]
display_gif(video_tensor)

### Training the model

We'll leverage [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer) from  🤗 Transformers for training the model. To instantiate a `Trainer`, we will need to define the training configuration and an evaluation metric. The most important is the [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments), which is a class that contains all the attributes to configure the training. It requires an output folder name, which will be used to save the checkpoints of the model. It also helps sync all the information in the model repository on 🤗 Hub.

Most of the training arguments are pretty self-explanatory, but one that is quite important here is `remove_unused_columns=False`. This one will drop any features not used by the model's call function. By default it's `True` because usually it's ideal to drop unused feature columns, making it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('video' in particular) in order to create `pixel_values` (which is a mandatory key our model expects in its inputs).

In [None]:
from transformers import TrainingArguments, Trainer, SchedulerType

model_name = "videomae-finetuned"
num_epochs = 25
model_type = "multilabel" if type(sample_video['label']) is list else "multiclass"
new_model_name = f"{model_name}-nba-{len(class_labels)}-class-{batch_size}-batch-{video_count_train}-vid-{model_type}"
steps_per_epoch = train_dataset.num_videos // batch_size
output_dir = new_model_name
remote_output_dir = root_path.joinpath(new_model_name)
resuming_from_checkpoint = remote_output_dir.exists()
if resuming_from_checkpoint and pathlib.Path(output_dir) != remote_output_dir:
    shutil.copytree(remote_output_dir, output_dir)

args = TrainingArguments(
    output_dir=output_dir,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit = 2, # Only 2 models are saved - The best one and the last one
    learning_rate=1.5e-5, # lr = base learning rate  batch size /(5e-4/128) * 4
    lr_scheduler_type=SchedulerType.COSINE,
    weight_decay=0.05,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    logging_steps=10,
    max_steps=steps_per_epoch * num_epochs, # Duplication of `num_train_epochs` because it throws otherwise.
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=True,
    num_train_epochs=num_epochs,
    ignore_data_skip=True
)

There's no need to define `max_steps` when instantiating `TrainingArguments`. Since the dataset returned by `pytorchvideo.data.Ucf101()` doesn't implement the `__len__()` method we had to specify `max_steps`.  

Next, we need to define a function for how to compute the metrics from the predictions, which will just use the `metric` we'll load now. The only preprocessing we have to do is to take the argmax of our predicted logits:

In [None]:
import statistics
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, classification_report
from transformers import EvalPrediction
import torch


# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    n = labels.shape[0] // CLIPS_FROM_SINGLE_VIDEO
    predictions = [
        np.average(batch, axis=0) for batch in np.array_split(predictions, n)
    ]
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = [batch[0] for batch in np.array_split(labels, n)]
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    # roc_auc = roc_auc_score(y_true, y_pred, average='micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               # 'roc_auc': roc_auc,
               'accuracy': accuracy}
    report = classification_report(y_true, y_pred, target_names=class_mini_labels)
    print(report)
    return metrics

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions."""
    return multi_label_metrics(predictions=eval_pred.predictions, labels=eval_pred.label_ids)

**A note on evaluation**:

In the [VideoMAE paper](https://arxiv.org/abs/2203.12602), the authors use the following evaluation strategy. They evaluate the model on several clips from test videos and apply different crops to those clips and report the aggregate score. We implement a version of this.

We also define a `collate_fn`, which will be used to batch examples together.
Each batch consists of 2 keys, namely `pixel_values` and `labels`.

In [None]:
import torch


def collate_fn(examples):
    """The collation function to be used by `Trainer` to prepare data batches."""
    # permute to (num_frames, num_channels, height, width)
    if isinstance(examples[0]["video"], torch.Tensor):
        # This is for training, where each training entry is a single video
        pixel_values = torch.stack([example["video"].permute(1, 0, 2, 3) for example in examples])
    elif isinstance(examples[0]["video"], list):
        # This is for evaluation, where each evaluation entry is multiple clips from a single video
        pixel_values = torch.cat(
            [torch.stack([single_example.permute(1, 0, 2, 3) for single_example in example["video"]]) for example in
             examples]
        )
    else:
        raise ValueError("Unrecognized input structure!")
        
    labels = np.array([example["label"] for example in examples], dtype=np.int64)
    # TODO - Maybe find a way to not unnecessarily duplicate those labels (They are duplicated just so the dimensions will fit with pixel_values, because torch tries to calculate loss for some reason
    if isinstance(examples[0]["video"], list):
        # This is for evaluation, where each evaluation entry is multiple clips from a single video
        labels = np.repeat(labels, CLIPS_FROM_SINGLE_VIDEO, axis=0)
    return {"pixel_values": pixel_values, "labels": torch.tensor(labels, dtype=torch.float)}

Then we just need to pass all of this along with our datasets to the `Trainer`:

In [None]:
from transformers import EarlyStoppingCallback

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=15, early_stopping_threshold=0.01)]
)

You might wonder why we pass along the `image_processor` as a tokenizer when we already preprocessed our data. This is only to make sure the feature extractor configuration file (stored as JSON) will also be uploaded to the repo on the hub.

Now we can finetune our model by calling the `train` method:

In [None]:
import traceback
import shutil

try:
    if not resuming_from_checkpoint:
        results = trainer.train()
    else:
        results = trainer.train(resume_from_checkpoint=True)
except Exception as e:
    print(traceback.format_exc())

We can check with the `evaluate` method that our `Trainer` did reload the best model properly (if it was not the last one):

In [None]:
val_results = trainer.evaluate(val_dataset)

In [None]:
display(val_results)

In [None]:
test_results = trainer.evaluate(test_dataset)

In [None]:
display(test_results)

In [None]:
trainer.save_model()
trainer.log_metrics("val", val_results)
trainer.save_metrics("val", val_results)
trainer.log_metrics("test", test_results)
trainer.save_metrics("test", test_results)
trainer.save_state()

You can now upload the result of the training to the Hub, just execute this instruction (note that the Trainer will automatically create a model card as well as Tensorboard logs - see the "Training metrics" tab - amazing isn't it?):

In [None]:
trainer.push_to_hub()

In [None]:
import time

shutil.copytree(output_dir, remote_output_dir.parent.joinpath(f"new_{new_model_name}"))
# We're sleeping so Google Drive has time to sync the files (It's about 2GB, so 2 minutes should be fine
time.sleep(2 * 60)

Now that our model is trained, let's use it to run inference on a video from `test_dataset`. 

## Inference

Let's load the trained model checkpoint and fetch a video from `test_dataset`. 

In [None]:
trained_model = model

In [None]:
sample_test_video = next(iter(test_dataset))
investigate_video(sample_test_video)

We then prepare the video as a `torch.Tensor` and run inference. 

In [None]:
def run_inference(model, sample_test_video):
    """Utility to run inference given a model and test video.
    
    The video is assumed to be preprocessed already.
    """
    video = sample_test_video['video']
    label = sample_test_video['label']
    # (num_frames, num_channels, height, width)
    
    perumuted_sample_test_videos = [clip.permute(1, 0, 2, 3) for clip in video]

    inputs = {
        "pixel_values": torch.stack(perumuted_sample_test_videos),
        # "labels": torch.tensor(np.repeat(label, CLIPS_FROM_SINGLE_VIDEO, axis=0))  # this can be skipped if you don't have labels available.
    }
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model = model.to(device)

    # forward pass
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    return logits

In [None]:
logits = run_inference(trained_model, sample_test_video)

In [None]:
display(logits)

We can now check if the model got the prediction right. 

In [None]:
for sample_test_clip in sample_test_video["video"]:
    display_gif(sample_test_clip)

In [None]:
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.5)] = 1
print("Predicted class:", [label for i, label in enumerate(model.config.id2label.values()) if predictions[i]])

And it looks like it got it right!

You can also use this model to bring in your own videos. Check out [this Space](https://huggingface.co/spaces/omermazig/videomae-finetuned-nba-5-class) to know more. The Space will also show you how to run inference for a single video file.

<br><div align=center>
    <img src="https://i.ibb.co/7nW4Rkn/sample-results.gif" width=700/>
</div>

## Next steps

Now that you've learned to train a well-performing video classification model on a custom dataset here is some homework for you:

* Increase the dataset size: include more classes and more samples per class. 
* Try out different hyperparameters to study how the model converges.
* Analyze the classes for which the model fails to perform well. 
* Try out a different video encoder.

Don't forget to share your models with the community =)

In [None]:
try:
    # For running on colab
    from google.colab import runtime
    runtime.unassign()
except ModuleNotFoundError:
    # I guess we're not on colab...
    pass