In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline
from pathlib import Path

model_ckpt = "omermazig/videomae-base-finetuned-kinetics-finetuned-nba-data-2-batch-5-epochs-399-train-vids-multilabel"

dataset_root_path = Path(r"C:\Users\User\Google Drive\dataset_2_classes_small")

In [None]:
pipe = pipeline("video-classification", model=model_ckpt)
trained_model = pipe.model
image_processor = pipe.image_processor

In [None]:
# from transformers import AutoImageProcessor, VideoMAEForVideoClassification
# 
# 
# image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
# trained_model = VideoMAEForVideoClassification.from_pretrained(model_ckpt)

In [None]:
video_extension = "avi"

all_video_file_paths = (
    list(dataset_root_path.glob(f"train/*/*.{video_extension}"))
    + list(dataset_root_path.glob(f"val/*/*.{video_extension}"))
    + list(dataset_root_path.glob(f"test/*/*.{video_extension}"))
)
all_video_file_paths[:5]

In [None]:
class_labels = sorted({path.parent.stem for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

class_mini_labels = sorted({mini_label for label in class_labels for mini_label in label.split('_')})
multi_label2id = {label: i for i, label in enumerate(class_mini_labels)}
id2multi_label = {i: label for label, i in multi_label2id.items()}

print(f"Unique classes: {list(label2id.keys())}.")
print(f"Unique mini classes: {list(multi_label2id.keys())}.")

In [None]:
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)

In [None]:
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
import torch
from typing import Type, Optional, Callable, Dict, Any
from pytorchvideo.data import ClipSampler, LabeledVideoDataset
from torch.utils.data import Dataset

class MultiLabelLabeledDataset(LabeledVideoDataset):
    def __init__(self, data_path: str, clip_sampler: ClipSampler,
                 video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
                 transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, video_path_prefix: str = "",
                 decode_audio: bool = True, decoder: str = "pyav"):
        labeled_video_paths = LabeledVideoPaths.from_path(data_path)
        labeled_video_paths.path_prefix = video_path_prefix
        super().__init__(labeled_video_paths, clip_sampler, video_sampler, transform, decode_audio, decoder)
        self.label_components = class_mini_labels

    def convert_label_to_binary_vector(self, label):
        components = id2label[label].split("_")
        binary_vector = [1 if component in components else 0 for component in self.label_components]
        return binary_vector

    def __next__(self):
        example = super().__next__()
        original_label = example["label"]
        binary_vector = self.convert_label_to_binary_vector(original_label)
        example["label"] = binary_vector
        return example

In [None]:
import os

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = trained_model.config.num_frames

# Validation and Test datasets' transformations.
inference_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to),
                ]
            ),
        ),
    ]
)

MAX_VIDEO_DURATION = 6

def build_evaluate_dataset(dataset_type: str):
    # Validation and evaluation datasets.
    dataset = MultiLabelLabeledDataset(
        data_path=os.path.join(dataset_root_path, dataset_type),
        clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", MAX_VIDEO_DURATION),
        decode_audio=False,
        transform=inference_transform,
    )
    return dataset

In [None]:
from transformers import TrainingArguments, Trainer

batch_size = 4

args = TrainingArguments(
    output_dir="kuku",
    per_device_eval_batch_size=batch_size,
    remove_unused_columns=False,
)

In [None]:
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, classification_report
from transformers import EvalPrediction
import torch


# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average='micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    report = classification_report(y_true, y_pred, target_names=class_mini_labels)
    print(report)
    return metrics

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions,
                                           tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds,
        labels=p.label_ids)
    return result

In [None]:
import torch


def collate_fn(examples):
    """The collation function to be used by `Trainer` to prepare data batches."""
    # permute to (num_frames, num_channels, height, width)
    pixel_values = torch.stack(
        [example["video"].permute(1, 0, 2, 3) for example in examples]
    )
    labels = torch.tensor([example["label"] for example in examples], dtype=torch.float)
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
def build_trainer(dataset):
    trainer = Trainer(
        trained_model,
        args,
        eval_dataset=dataset,
        tokenizer=image_processor,
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
    )
    return trainer

In [None]:
for inference_dataset_type in ['val', 'test']:
    # build dataset.
    inference_dataset = build_evaluate_dataset(inference_dataset_type)
    # build trainer.
    inference_trainer = build_trainer(inference_dataset)
    # print results
    print(f"---------{inference_dataset_type}---------")
    results = inference_trainer.evaluate(inference_dataset)
    display(results)