In [68]:
from typing import Set, Literal
from urllib.parse import urlsplit

import datasets
import torch
import torch.nn as nn
import torchvision.transforms.v2 as tv_transforms
import torchvision.transforms.v2.functional as F

In [69]:
# Auto-reload module to access .py files easily
%load_ext autoreload
%autoreload 2

import os
import sys

src_path = os.path.abspath("../src/")
if not src_path in sys.path:
    sys.path.append(src_path)

import transforms as custom_transforms

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [71]:
def init_hf_dataset(
    hf_dataset_name: str, progress: Set[str]
) -> datasets.IterableDataset:
    """
    Initialize HuggingFace dataset (both train and test splits) and filter out videos that have already been processed.
    Note: Currently only supports streaming huggingface datasets but not non-streaming huggingface dataset.
    """

    if hf_dataset_name == "jherng/xd-violence":

        def extract_relative_dir(full_filepath: str):
            data_url = "/datasets/jherng/xd-violence/resolve/main/data/video"
            return "/".join(
                urlsplit(full_filepath)
                .path.split(data_url)[-1]
                .lstrip("/")
                .split("/")[:-1]  # relative_dir
            )

        train_ds = datasets.load_dataset(
            hf_dataset_name, name="video", split="train", streaming=True
        ).map(
            remove_columns=[
                "binary_target",
                "multilabel_targets",
                "frame_annotations",
            ]
        )  # Remove unused columns for preprocessing

        test_ds = datasets.load_dataset(
            hf_dataset_name, name="video", split="test", streaming=True
        ).map(
            remove_columns=[
                "binary_target",
                "multilabel_targets",
                "frame_annotations",
            ]
        )

        # Concatenate train and test datasets
        combined_ds = datasets.concatenate_datasets([train_ds, test_ds])

        # Filter out videos that have already been processed
        # assume there's always a subdir in the path at 2nd last position,
        # e.g., 1-1004 from https://huggingface.co/datasets/.../1-1004/A.Beautiful.Mind.2001__%2300-01-45_00-02-50_label_A.mp4
        combined_ds = combined_ds.filter(
            lambda x: "/".join([extract_relative_dir(x["path"]), x["id"]])
            not in progress
        )

    else:
        raise ValueError(
            f"Dataset {hf_dataset_name} not supported. Currently only supports ['jherng/xd-violence']."
        )

    return combined_ds, extract_relative_dir

preprocessing_cfg = dict(
    io_backend=None,  # to be supplied by upstream
    id_key=None,  # to be supplied by upstream
    path_key=None,  # to be supplied by upstream
    num_clips=None,  # to be supplied by upstream
    crop_type=None,  # to be supplied by upstream
    clip_len=32,
    sampling_rate=2,
    resize_size=256,
    crop_size=224,
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
)

def build_preprocessing_pipeline(
    io_backend: Literal["http", "local"],
    id_key: str = "id",
    path_key: str = "path",
    num_clips: int = -1,
    crop_type: Literal["10-crop", "5-crop", "center"] = "5-crop",
) -> nn.Module:
    """
    Takes in a whole video and returns a tensor of shape (num_clips, num_crops, num_channels, clip_len, crop_h, crop_w) = (num_clips, num_crops, 3, 32, 224, 224).
    """

    preprocessing_cfg["io_backend"] = io_backend
    preprocessing_cfg["id_key"] = id_key
    preprocessing_cfg["path_key"] = path_key
    preprocessing_cfg["num_clips"] = num_clips
    preprocessing_cfg["crop_type"] = crop_type

    crop_type_config = {
        "5-crop": custom_transforms.FiveCrop,
        "10-crop": custom_transforms.TenCrop,
        "center": custom_transforms.CenterCrop,
    }

    pipeline = [
        custom_transforms.AdaptDataFormat(
            id_key=preprocessing_cfg["id_key"],
            path_key=preprocessing_cfg["path_key"],
        ),
        custom_transforms.VideoReaderInit(io_backend=preprocessing_cfg["io_backend"]),
        custom_transforms.TemporalClipSample(
            clip_len=preprocessing_cfg["clip_len"],
            sampling_rate=preprocessing_cfg["sampling_rate"],
            num_clips=preprocessing_cfg["num_clips"],
        ),
        custom_transforms.VideoDecode(),
        custom_transforms.Resize(size=preprocessing_cfg["resize_size"]),
        crop_type_config[preprocessing_cfg["crop_type"]](
            size=preprocessing_cfg["crop_size"]
        ),
        custom_transforms.ToDType(dtype=torch.float32, scale=True),
        custom_transforms.Normalize(
            mean=preprocessing_cfg["mean"], std=preprocessing_cfg["std"]
        ),
        custom_transforms.ConvertTCHWToCTHW(lead_dims=2),
        custom_transforms.PackInputs(preserved_meta=["id", "filename"]),
    ]

    return tv_transforms.Compose(pipeline)

In [74]:
hf_dataset, _ = init_hf_dataset("jherng/xd-violence", progress=set())
preprocessing = build_preprocessing_pipeline(io_backend="http", num_clips=-1)

hf_dataset, preprocessing

(<datasets.iterable_dataset.IterableDataset at 0x26d497ae230>,
 Compose(
       AdaptDataFormat(id_key=id, path_key=path)
       VideoReaderInit(io_backend=http)
       TemporalClipSample(clip_len=32, num_clips=-1, sampling_rate=2)
       VideoDecode()
       Resize(size=256)
       FiveCrop(size=224)
       ToDType(dtype=torch.float32, scale=True)
       Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
       ConvertTCHWToCTHW(lead_dims=2)
       PackInputs(preserved_meta=['id', 'filename'])
 ))

In [None]:
for i, video_ex in enumerate(hf_dataset):
    print(i)
    print(video_ex)

    video_ex = preprocessing(video_ex)
    
    print(video_ex["meta"])
    print(video_ex["inputs"].size())

    if i == 2:
        break

0
{'id': 'A.Beautiful.Mind.2001__#00-01-45_00-02-50_label_A', 'path': 'https://huggingface.co/datasets/jherng/xd-violence/resolve/main/data/video/1-1004/A.Beautiful.Mind.2001__%2300-01-45_00-02-50_label_A.mp4'}
