In [None]:
!pip install pytorchvideo transformers evaluate accelerate decord -U

In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline

model_ckpt = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multilabel-3"

In [None]:
import pathlib

dataset_name = "dataset"

try:
    # For running on colab
    from google.colab import drive
    drive.mount('/content/drive')
    root_path = pathlib.Path(f"./drive/MyDrive/")
    dataset_root_path = root_path.joinpath(dataset_name)
except ModuleNotFoundError:
    # For running on PC
    # dataset_root_path = pathlib.Path('UCF101_subset')
    root_path = pathlib.Path(".")
    dataset_root_path = pathlib.Path(r"C:\\Users\User\Google Drive").joinpath(dataset_name)
    is_colab = False
else:
    is_colab = True

In [None]:
pipe = pipeline("video-classification", model=model_ckpt)
trained_model = pipe.model
image_processor = pipe.image_processor

In [None]:
# from transformers import AutoImageProcessor, VideoMAEForVideoClassification
# 
# 
# image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
# trained_model = VideoMAEForVideoClassification.from_pretrained(model_ckpt)

In [None]:
video_extension = "avi"

all_video_file_paths = (
    list(dataset_root_path.glob(f"train/*/*.{video_extension}"))
    + list(dataset_root_path.glob(f"val/*/*.{video_extension}"))
    + list(dataset_root_path.glob(f"test/*/*.{video_extension}"))
)
all_video_file_paths[:5]

In [None]:
class_labels = sorted({path.parent.stem for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

id2multi_label = trained_model.config.id2label
multi_label2id = trained_model.config.label2id
class_mini_labels = list(id2multi_label.values())

original_id2multi_id = {id_:[1 if mini_label in label else 0 for mini_label in class_mini_labels] for id_, label in id2label.items()}

print(f"Unique classes: {list(label2id.keys())}.")
print(f"Unique mini classes: {list(multi_label2id.keys())}.")

In [None]:
from typing import Callable, List, Dict
import torch
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)


class ApplyTransformToListUnderKey:
    """
    Applies transform to key of dictionary input, where there is a list of values under it.

    Args:
        key (str): the dictionary key the transform is applied to
        transform (callable): the transform that is applied for each element
    """

    def __init__(self, key: str, transform: Callable):
        self._key = key
        self._transform = transform

    def __call__(self, x: Dict[str, List[torch.Tensor]]) -> Dict[str, List[torch.Tensor]]:
        for i in range(len(x[self._key])):
            x[self._key][i] = self._transform(x[self._key][i])
        return x

In [None]:
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
import torch
from typing import Type, Optional, Callable, Dict, Any
from pytorchvideo.data import ClipSampler, LabeledVideoDataset
from torch.utils.data import Dataset

class MultiLabelLabeledDataset(LabeledVideoDataset):
    def __init__(self, data_path: str, clip_sampler: ClipSampler,
                 video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
                 transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, video_path_prefix: str = "",
                 decode_audio: bool = True, decoder: str = "pyav"):
        labeled_video_paths = LabeledVideoPaths.from_path(data_path)
        labeled_video_paths.path_prefix = video_path_prefix
        super().__init__(labeled_video_paths, clip_sampler, video_sampler, transform, decode_audio, decoder)
        self.original_id2multi_id = original_id2multi_id

    def __next__(self):
        example = super().__next__()
        original_id = example["label"]
        multi_id = self.original_id2multi_id[original_id]
        example["label"] = multi_id
        return example

In [None]:
import os

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

# Validation and Test datasets' transformations.
inference_transform = Compose(
    [
        ApplyTransformToListUnderKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                ]
            ),
        ),
    ]
)

CLIPS_FROM_SINGLE_VIDEO = 5


def build_evaluate_dataset(dataset_type: str):
    # Validation and evaluation datasets.
    dataset = MultiLabelLabeledDataset(
        data_path=os.path.join(dataset_root_path, dataset_type),
        clip_sampler=pytorchvideo.data.make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO),
        decode_audio=False,
        transform=inference_transform,
    )
    return dataset

val_dataset = build_evaluate_dataset("val")
test_dataset = build_evaluate_dataset("test")

In [None]:
import torch


def collate_fn(examples):
    """The collation function to be used by `Trainer` to prepare data batches."""
    # permute to (num_frames, num_channels, height, width)
    if isinstance(examples[0]["video"], torch.Tensor):
        # This is for training, where each training entry is a single video
        pixel_values = torch.stack([example["video"].permute(1, 0, 2, 3) for example in examples])
    elif isinstance(examples[0]["video"], list):
        # This is for evaluation, where each evaluation entry is multiple clips from a single video
        pixel_values = torch.cat(
            [torch.stack([single_example.permute(1, 0, 2, 3) for single_example in example["video"]]) for example in
             examples]
        )
    else:
        raise ValueError("Unrecognized input structure!")
        
    labels = np.array([example["label"] for example in examples], dtype=np.int64)
    # TODO - Maybe find a way to not unnecessarily duplicate those labels (They are duplicated just so the dimensions will fit with pixel_values, because torch tries to calculate loss for some reason
    if isinstance(examples[0]["video"], list):
        # This is for evaluation, where each evaluation entry is multiple clips from a single video
        labels = np.repeat(labels, CLIPS_FROM_SINGLE_VIDEO, axis=0)
    return {"pixel_values": pixel_values, "labels": torch.tensor(labels, dtype=torch.float)}

In [None]:
from transformers import TrainingArguments, Trainer

batch_size = 4 if is_colab else 1

args = TrainingArguments(
    output_dir="kuku",
    per_device_eval_batch_size=batch_size,
    remove_unused_columns=False,
)

In [None]:
import statistics
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, classification_report
from transformers import EvalPrediction
import torch


# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    n = labels.shape[0] // CLIPS_FROM_SINGLE_VIDEO
    predictions = [
        np.average(batch, axis=0) for batch in np.array_split(predictions, n)
    ]
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.from_numpy(np.array(predictions)))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = [batch[0] for batch in np.array_split(labels, n)]
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average='micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    report = classification_report(y_true, y_pred, target_names=class_mini_labels)
    print(report)
    return metrics

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions."""
    return multi_label_metrics(predictions=eval_pred.predictions, labels=eval_pred.label_ids)

In [None]:
def build_trainer(dataset):
    trainer = Trainer(
        trained_model,
        args,
        eval_dataset=dataset,
        tokenizer=image_processor,
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
    )
    return trainer

In [None]:
for inference_dataset_type in ['val', 'test']:
    # build dataset.
    inference_dataset = build_evaluate_dataset(inference_dataset_type)
    # build trainer.
    inference_trainer = build_trainer(inference_dataset)
    # print results
    print(f"---------{inference_dataset_type}---------")
    results = inference_trainer.evaluate(inference_dataset)
    display(results)

In [None]:
try:
    # For running on colab
    from google.colab import runtime
    # Wait for printing to sync with browser
    import time;time.sleep(10)
    runtime.unassign()
except ModuleNotFoundError:
    # I guess we're not on colab...
    pass