In [1]:
# Copyright (c) Meta, Inc. and its affiliates. All Rights Reserved.

# Initial Setup

* Clone this repo: `git clone https://github.com/facebookresearch/Ego4d.git`
* Download checkpoint/models binaries, place in this directory: [https://dl.fbaipublicfiles.com/ego4d/moments_cvpr_binaries.tgz](https://dl.fbaipublicfiles.com/ego4d/moments_cvpr_binaries.tgz)
* Download annoations via the [CLI](cli.ego4d-data.org): `ego4d --output_directory="~/ego4d_data" --datasets annotations video_540ss --metadata`
    * For the miniset (i.e. 1 machine), you can choose to download only the minisets via moments_mini_train_uids.csv/moments_mini_val_uids.csv files:
        `ego4d -y --output_directory ~/ego4d_data/ --datasets video_540ss --video_uid_file ~/path/to/moments_mini_val_uids.csv`
* conda create environment: `conda env create -n moments_cvpr -–file conda-env.yaml`
    * Alternatively, using your own environment, install the latest pytorchvideo: `pip install git+https://github.com/facebookresearch/pytorchvideo.git`
* Fire up jupyter/vscode/etc here: `jupyter notebook MomentsWorkshop.ipynb`

### Questions:
* Docs: [docs.ego4d-data.org](docs.ego4d-data.org)
* Forum: [discuss.ego4d-data.org](discuss.ego4d-data.org)
* Email Me Directly: eugenebyrne@fb.com


# Setup Code

In [None]:
# Imports

import logging
import json
import os
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import pytorch_lightning as pl
import torch

from iopath.common.file_io import g_pathmgr
from pytorchvideo.data.clip_sampling import ClipInfo, ClipSampler

# Either use directly from PTV or use the example below
# from pytorchvideo.data.ego4d import Ego4dMomentsDataset

In [3]:
# Inputs - Update first section as relevant

# Repository path where the notebook (and downloaded binaries) are loaded
notebook_path = os.path.expanduser("~/source/ego4d/notebooks/moments_cvpr/")
# Path to the downloaded ego4d data via the CLI above, local or network
data_path = os.path.expanduser("~/ego4d_data/")
# Place checkpoints at data_path or otherwise?
checkpoint_path = os.path.join(data_path, "checkpoints/")

# Confirming we've updated paths accordingly
assert os.path.isfile(os.path.join(notebook_path, "MomentsWorkshop.ipynb")), f"Inputs paths improperly configured - not found: {notebook_path}"

@dataclass
class Ego4dMomentsTrainerParams:
    annotation_path_train: str = os.path.join(data_path, "v1/annotations/moments_train.json")
    annotation_path_val: str = os.path.join(data_path, "v1/annotations/moments_val.json")
    # Primary Ego4D JSON
    metadata_path: str = os.path.join(data_path, "ego4d.json")
    # Assuming downsampled videos
    video_path: str = os.path.join(data_path, "v1/video_540ss/")
    num_nodes: int = 1
    num_gpus_per_node: int = 0
    num_workers: int = 0
    batch_size: int = 8
    use_video: bool = True
    use_audio: bool = False
    use_imu: bool = False
    # window_size: int
    # n_frames: int
    learning_rate: float = 0.0001
    learning_rate_milestones: List[int] = field(default_factory=lambda: [5, 10, 15])
    checkpoint_dir: str = checkpoint_path
    every_n_train_steps: int = 117
    resume_from_checkpoint: Optional[str] = None
    # Testing Purposes Only! Set to -1
    max_epochs: int = 1
    # Testing Purposes Only! Set to -1
    max_steps: int = 100
    image_model_type: str = "resnet18"
    accumulate_grad_batches: int = 1
    label_mapping_file: Optional[str] = "moments_label_ids.json"
    # Recommended! Use the video_540ss downsampled videos
    downsampled: bool = True
    # A miniset (20 train, 10 val) provided for test training
    miniset: bool = True

inputs = Ego4dMomentsTrainerParams()

In [18]:
# Utils

log: logging.Logger = logging.getLogger("Ego4dMoments")
# To see more verbose logging, uncomment below (i.e. print all INFO -> console)
# log.setLevel(logging.INFO)
# sh = logging.StreamHandler(sys.stdout)
# sh.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(message)s \t[%(filename)s.%(funcName)s:%(lineno)d]", datefmt="%y%m%d %H:%M:%S"))
# log.addHandler(sh)

def check_window_len(
    s_time: float, e_time: float, w_len: float, video_dur: float
) -> Tuple[float, float]:
    """
    Constrain/slide the give time window to `w_len` size and the video/clip length.
    """
    # adjust to match w_len
    interval = e_time - s_time
    if abs(interval - w_len) > 0.001:
        delta = w_len - (e_time - s_time)
        s_time = s_time - (delta / 2)
        e_time = e_time + (delta / 2)
        if s_time < 0:
            e_time += -s_time
            s_time = 0
    if video_dur:
        if e_time > video_dur:
            overlap = e_time - video_dur
            assert s_time >= overlap, "Incompatible w_len / video_dur"
            s_time -= overlap
            e_time -= overlap
            log.info(
                f"check_window_len: video overlap ({overlap}) adjusted -> ({s_time:.2f}, {e_time:.2f}) video: {video_dur}"  # noqa
            )
    if abs((e_time - s_time) - w_len) > 0.01:
        log.error(
            f"check_window_len: invalid time interval: {s_time}, {e_time}",
            stack_info=True,
        )
    return s_time, e_time

class MomentsClipSampler(ClipSampler):
    """
    ClipSampler for Ego4d moments. Will return a fixed `window_sec` window
    around the given annotation, shifting where relevant to account for the end
    of the clip/video.

    clip_start/clip_end is added to the annotation dict to facilitate future lookups.
    """

    def __init__(self, window_sec: float = 0) -> None:
        self.window_sec = window_sec

    def __call__(
        self,
        last_clip_end_time: float,
        video_duration: float,
        annotation: Dict[str, Any],
    ) -> ClipInfo:
        assert (
            last_clip_end_time is None or last_clip_end_time <= video_duration
        ), f"last_clip_end_time ({last_clip_end_time}) > video_duration ({video_duration})"
        start = annotation["label_video_start_sec"]
        end = annotation["label_video_end_sec"]
        if video_duration is not None and end > video_duration:
            log.error(f"Invalid video_duration/end_sec: {video_duration} / {end}")
            # If it's small, proceed anyway
            if end > video_duration + 0.1:
                raise Exception(
                    f"Invalid video_duration/end_sec: {video_duration} / {end} ({annotation['video_name']})"  # noqa
                )
        assert end >= start, f"end < start: {end:.2f} / {start:.2f}"
        if self.window_sec > 0:
            s, e = check_window_len(start, end, self.window_sec, video_duration)
            if s != start or e != end:
                start = s
                end = e
        annotation["clip_start"] = start
        annotation["clip_end"] = end
        return ClipInfo(start, end, 0, 0, True)

def get_label_id_map(label_id_map_path: str) -> Dict[str, int]:
    label_name_id_map: Dict[str, int]

    try:
        with g_pathmgr.open(label_id_map_path, "r") as f:
            label_json = json.load(f)
            # Verify?
            return label_json
    except Exception:
        raise FileNotFoundError(f"{label_id_map_path} must be a valid label id json")

class Ego4dImuDataBase(ABC):
    """
    Base class placeholder for Ego4d IMU data.
    """

    def __init__(self, basepath: str):
        self.basepath = basepath

    @abstractmethod
    def has_imu(self, video_uid: str) -> bool:
        pass

    @abstractmethod
    def get_imu_sample(
        self, video_uid: str, video_start: float, video_end: float
    ) -> Dict[str, Any]:
        pass

def get_video_uids(path):
    with g_pathmgr.open(path, "r") as f:
        return set([x for x in f.read().split('\n') if x])

In [5]:
# LabeledVideoDataset (Provided only for context)

from __future__ import annotations

import gc
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import torch.utils.data
from pytorchvideo.data.clip_sampling import ClipSampler
from pytorchvideo.data.video import VideoPathHandler

from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
from pytorchvideo.data.utils import MultiProcessSampler


logger = log


class LabeledVideoDataset(torch.utils.data.IterableDataset):
    """
    LabeledVideoDataset handles the storage, loading, decoding and clip sampling for a
    video dataset. It assumes each video is stored as either an encoded video
    (e.g. mp4, avi) or a frame video (e.g. a folder of jpg, or png)
    """

    _MAX_CONSECUTIVE_FAILURES = 10

    def __init__(
        self,
        labeled_video_paths: List[Tuple[str, Optional[dict]]],
        clip_sampler: ClipSampler,
        video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
        transform: Optional[Callable[[dict], Any]] = None,
        decode_audio: bool = True,
        decode_video: bool = True,
        decoder: str = "pyav",
    ) -> None:
        """
        Args:
            labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
                    video file paths and associated labels. If video paths are a folder
                    it's interpreted as a frame video, otherwise it must be an encoded
                    video.

            clip_sampler (ClipSampler): Defines how clips should be sampled from each
                video. See the clip sampling documentation for more information.

            video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
                video container. This defines the order videos are decoded and,
                if necessary, the distributed split.

            transform (Callable): This callable is evaluated on the clip output before
                the clip is returned. It can be used for user defined preprocessing and
                augmentations on the clips. The clip output format is described in __next__().

            decode_audio (bool): If True, decode audio from video.

            decode_video (bool): If True, decode video frames from a video container.

            decoder (str): Defines what type of decoder used to decode a video. Not used for
                frame videos.
        """
        self._decode_audio = decode_audio
        self._decode_video = decode_video
        self._transform = transform
        self._clip_sampler = clip_sampler
        self._labeled_videos = labeled_video_paths
        self._decoder = decoder

        # If a RandomSampler is used we need to pass in a custom random generator that
        # ensures all PyTorch multiprocess workers have the same random seed.
        self._video_random_generator = None
        if video_sampler == torch.utils.data.RandomSampler:
            self._video_random_generator = torch.Generator()
            self._video_sampler = video_sampler(
                self._labeled_videos, generator=self._video_random_generator
            )
        else:
            self._video_sampler = video_sampler(self._labeled_videos)

        self._video_sampler_iter = None  # Initialized on first call to self.__next__()

        # Depending on the clip sampler type, we may want to sample multiple clips
        # from one video. In that case, we keep the store video, label and previous sampled
        # clip time in these variables.
        self._loaded_video_label = None
        self._loaded_clip = None
        self._last_clip_end_time = None
        self.video_path_handler = VideoPathHandler()

    @property
    def video_sampler(self):
        """
        Returns:
            The video sampler that defines video sample order. Note that you'll need to
            use this property to set the epoch for a torch.utils.data.DistributedSampler.
        """
        return self._video_sampler

    @property
    def num_videos(self):
        """
        Returns:
            Number of videos in dataset.
        """
        return len(self.video_sampler)

    def __next__(self) -> dict:
        """
        Retrieves the next clip based on the clip sampling strategy and video sampler.

        Returns:
            A dictionary with the following format.

            .. code-block:: text

                {
                    'video': <video_tensor>,
                    'label': <index_label>,
                    'video_label': <index_label>
                    'video_index': <video_index>,
                    'clip_index': <clip_index>,
                    'aug_index': <aug_index>,
                }
        """
        if not self._video_sampler_iter:
            # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
            self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler))

        video_id = None
        for i_try in range(self._MAX_CONSECUTIVE_FAILURES):
            # Reuse previously stored video if there are still clips to be sampled from
            # the last loaded video.
            if self._loaded_video_label:
                video, info_dict, video_index = self._loaded_video_label
            else:
                video_index = next(self._video_sampler_iter)
                try:
                    video_path, info_dict = self._labeled_videos[video_index]
                    video_id = info_dict.get("video_name")
                    # TODO: Repeatedly called?
                    # print(f"Video: {video_path}")
                    video = self.video_path_handler.video_from_path(
                        video_path,
                        decode_audio=self._decode_audio,
                        # decode_video=self._decode_video,
                        decoder=self._decoder,
                    )
                    self._loaded_video_label = (video, info_dict, video_index)
                except Exception as e:
                    logger.warn(
                        "Failed to load video with error: {}; trial {}; id: {}".format(
                            e,
                            i_try,
                            video_id,
                        )
                    )
                    logger.exception("Video load exception")
                    continue

            video_id = video.name

            (
                clip_start,
                clip_end,
                clip_index,
                aug_index,
                is_last_clip,
            ) = self._clip_sampler(self._last_clip_end_time, video.duration, info_dict)

            if isinstance(clip_start, list):  # multi-clip in each sample

                # Only load the clips once and reuse previously stored clips if there are multiple
                # views for augmentations to perform on the same clips.
                if aug_index[0] == 0:
                    self._loaded_clip = {}
                    loaded_clip_list = []
                    for i in range(len(clip_start)):
                        clip_dict = video.get_clip(clip_start[i], clip_end[i])
                        if clip_dict is None or clip_dict["video"] is None:
                            self._loaded_clip = None
                            break
                        loaded_clip_list.append(clip_dict)

                    if self._loaded_clip is not None:
                        for key in loaded_clip_list[0].keys():
                            self._loaded_clip[key] = [x[key] for x in loaded_clip_list]

            else:  # single clip case

                # Only load the clip once and reuse previously stored clip if there are multiple
                # views for augmentations to perform on the same clip.
                if aug_index == 0:
                    self._loaded_clip = video.get_clip(clip_start, clip_end)

            self._last_clip_end_time = clip_end

            video_is_null = (
                self._loaded_clip is None or self._loaded_clip["video"] is None
            )
            if (
                is_last_clip[-1] if isinstance(is_last_clip, list) else is_last_clip
            ) or video_is_null:
                # Close the loaded encoded video and reset the last sampled clip time ready
                # to sample a new video on the next iteration.
                self._loaded_video_label[0].close()
                self._loaded_video_label = None
                self._last_clip_end_time = None
                self._clip_sampler.reset()

                # Force garbage collection to release video container immediately
                # otherwise memory can spike.
                gc.collect()

                if video_is_null:
                    logger.warn(
                        "Failed to load clip {}; trial {}".format(video.name, i_try)
                    )
                    continue

            frames = self._loaded_clip["video"]
            audio_samples = self._loaded_clip["audio"]
            sample_dict = {
                "video": frames,
                "video_name": video.name,
                "video_index": video_index,
                "clip_index": clip_index,
                "aug_index": aug_index,
                **info_dict,
                **({"audio": audio_samples} if audio_samples is not None else {}),
            }
            if self._transform is not None:
                sample_dict = self._transform(sample_dict)

                # User can force dataset to continue by returning None in transform.
                if sample_dict is None:
                    logger.info("LVD: Sample Bypass: {video.name}")
                    continue

            return sample_dict
        else:
            raise RuntimeError(
                f"a Failed to load video after {self._MAX_CONSECUTIVE_FAILURES} retries. id: {video_id}"  # noqa
            )

    def __iter__(self):
        self._video_sampler_iter = None  # Reset video sampler

        # If we're in a PyTorch DataLoader multiprocessing context, we need to use the
        # same seed for each worker's RandomSampler generator. The workers at each
        # __iter__ call are created from the unique value: worker_info.seed - worker_info.id,
        # which we can use for this seed.
        worker_info = torch.utils.data.get_worker_info()
        if self._video_random_generator is not None and worker_info is not None:
            base_seed = worker_info.seed - worker_info.id
            self._video_random_generator.manual_seed(base_seed)

        return self


def labeled_video_dataset(
    data_path: str,
    clip_sampler: ClipSampler,
    video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
    transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
    video_path_prefix: str = "",
    decode_audio: bool = True,
    decoder: str = "pyav",
) -> LabeledVideoDataset:
    """
    A helper function to create ``LabeledVideoDataset`` object for Ucf101 and Kinetics datasets.

    Args:
        data_path (str): Path to the data. The path type defines how the data
            should be read:

            * For a file path, the file is read and each line is parsed into a
              video path and label.
            * For a directory, the directory structure defines the classes
              (i.e. each subdirectory is a class).

        clip_sampler (ClipSampler): Defines how clips should be sampled from each
                video. See the clip sampling documentation for more information.

        video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
                video container. This defines the order videos are decoded and,
                if necessary, the distributed split.

        transform (Callable): This callable is evaluated on the clip output before
                the clip is returned. It can be used for user defined preprocessing and
                augmentations to the clips. See the ``LabeledVideoDataset`` class for clip
                output format.

        video_path_prefix (str): Path to root directory with the videos that are
                loaded in ``LabeledVideoDataset``. All the video paths before loading
                are prefixed with this path.

        decode_audio (bool): If True, also decode audio from video.

        decoder (str): Defines what type of decoder used to decode a video.

    """
    labeled_video_paths = LabeledVideoPaths.from_path(data_path)
    labeled_video_paths.path_prefix = video_path_prefix
    dataset = LabeledVideoDataset(
        labeled_video_paths,
        clip_sampler,
        video_sampler,
        transform,
        decode_audio=decode_audio,
        decoder=decoder,
    )
    return dataset


In [6]:
# Dataset (Provided for context)

import json
from bisect import bisect_left
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import numpy as np

import torch
import torch.autograd.profiler as profiler
import torch.utils.data
import torchaudio

from iopath.common.file_io import g_pathmgr
# from pytorchvideo.data import LabeledVideoDataset
from pytorchvideo.data.clip_sampling import ClipSampler
# from pytorchvideo.data.ego4d.utils import (
#     Ego4dImuDataBase,
#     get_label_id_map,
#     MomentsClipSampler,
# )
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Div255,
    Normalize,
    RandomShortSideScale,
    ShortSideScale,
)
from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip

class Ego4dMomentsDataset(LabeledVideoDataset):
    """
    Ego4d video/audio/imu dataset for the moments benchmark:
    `<https://ego4d-data.org/docs/benchmarks/episodic-memory/>`

    This dataset handles the parsing of frames, loading and clip sampling for the
    videos.

    IO utilizing :code:`iopath.common.file_io.PathManager` to support
    non-local storage uri's.
    """

    VIDEO_FPS = 30
    AUDIO_FPS = 48000

    def __init__(
        self,
        annotation_path: str,
        metadata_path: str,
        split: Optional[str] = None,
        decode_audio: bool = True,
        imu: bool = False,
        clip_sampler: Optional[ClipSampler] = None,
        video_sampler: Type[
            torch.utils.data.Sampler
        ] = torch.utils.data.SequentialSampler,
        transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
        decoder: str = "pyav",
        filtered_labels: Optional[List[str]] = None,
        window_sec: int = 10,
        audio_transform_type: str = "melspectrogram",
        imu_path: str = None,
        label_id_map: Optional[Dict[str, int]] = None,
        label_id_map_path: Optional[str] = None,
        video_path_override: Optional[Callable[[str], str]] = None,
        video_path_handler: Optional[VideoPathHandler] = None,
        eligible_video_uids: Optional[Set[str]] = None,
    ) -> None:
        """
        Args:
            annotation_path (str):
                Path or URI to Ego4d moments annotations json (ego4d.json). Download via:
                `<https://github.com/facebookresearch/Ego4d/blob/main/ego4d/cli/README.md>`

            metadata_path (str):
                Path or URI to primary Ego4d metadata json (moments.json). Download via:
                `<https://github.com/facebookresearch/Ego4d/blob/main/ego4d/cli/README.md>`

            split (Optional[str]): train/val/test

            decode_audio (bool): If True, decode audio from video.

            imu (bool): If True, load IMU data.

            clip_sampler (ClipSampler):
                A standard PTV ClipSampler. By default, if not specified, `MomentsClipSampler`

            video_sampler (VideoSampler):
                A standard PTV VideoSampler.

            transform (Optional[Callable[[Dict[str, Any]], Any]]):
                This callable is evaluated on the clip output before the clip is returned.
                It can be used for user-defined preprocessing and augmentations to the clips.

                    The clip input is a dictionary with the following format:
                        {{
                            'video': <video_tensor>,
                            'audio': <audio_tensor>,
                            'imu': <imu_tensor>,
                            'start_time': <float>,
                            'stop_time': <float>
                        }}

                If transform is None, the raw clip output in the above format is
                returned unmodified.

            decoder (str): Defines what type of decoder used to decode a video within
                `LabeledVideoDataset`.

            filtered_labels (List[str]):
                Optional list of moments labels to filter samples for training.

            window_sec (int): minimum window size in s

            audio_transform_type: melspectrogram / spectrogram / mfcc

            imu_path (Optional[str]):
                Path to the ego4d IMU csv file.  Required if imu=True.

            label_id_map / label_id_map_path:
                A map of moments labels to consistent integer ids.  If specified as a path
                we expect a vanilla .json dict[str, int].  Exactly one must be specified.

            video_path_override ((str) -> str):
                An override for video paths, given the video_uid, to support downsampled/etc
                videos.

            video_path_handler (VideoPathHandler):
                Primarily provided as an override for `CachedVideoPathHandler`

        Example Usage:
            Ego4dMomentsDataset(
                annotation_path="~/ego4d_data/v1/annotations/moments.json",
                metadata_path="~/ego4d_data/v1/ego4d.json",
                split="train",
                decode_audio=True,
                imu=False,
            )
        """

        assert annotation_path
        assert metadata_path
        assert split in [
            "train",
            "val",
            "test",
        ], f"Split '{split}' not supported for ego4d"
        self.split: str = split
        self.decode_audio = decode_audio
        self.training: bool = split == "train"
        self.window_sec = window_sec
        self._transform_source = transform
        self.audio_transform_type = audio_transform_type
        assert (label_id_map is not None) ^ (
            label_id_map_path is not None
        ), f"Either label_id_map or label_id_map_path required ({label_id_map_path} / {label_id_map})"  # noqa

        self.video_means = (0.45, 0.45, 0.45)
        self.video_stds = (0.225, 0.225, 0.225)
        self.video_crop_size = 224
        self.video_min_short_side_scale = 256
        self.video_max_short_side_scale = 320

        try:
            with g_pathmgr.open(metadata_path, "r") as f:
                metadata = json.load(f)
        except Exception:
            raise FileNotFoundError(
                f"{metadata_path} must be a valid metadata json for Ego4D"
            )

        self.video_metadata_map: Dict[str, Any] = {
            x["video_uid"]: x for x in metadata["videos"]
        }

        if not g_pathmgr.isfile(annotation_path):
            raise FileNotFoundError(f"{annotation_path} not found.")

        try:
            with g_pathmgr.open(annotation_path, "r") as f:
                moments_annotations = json.load(f)
        except Exception:
            raise FileNotFoundError(f"{annotation_path} must be json for Ego4D dataset")

        self.label_name_id_map: Dict[str, int]
        if label_id_map:
            self.label_name_id_map = label_id_map
        else:
            self.label_name_id_map = get_label_id_map(label_id_map_path)
            assert self.label_name_id_map

        self.num_classes: int = len(self.label_name_id_map)
        log.info(f"Label Classes: {self.num_classes}")

        self.imu_data: Optional[Ego4dImuDataBase] = None
        if imu:
            assert imu_path, "imu_path not provided"
            self.imu_data = Ego4dImuData(imu_path)

        video_uids = set()
        clip_uids = set()
        clip_video_map = {}
        labels = set()
        labels_bypassed = set()
        cnt_samples_bypassed = 0
        cnt_samples_bypassed_labels = 0
        samples = []

        for vid in moments_annotations["videos"]:
            video_uid = vid["video_uid"]
            video_uids.add(video_uid)
            vsplit = vid["split"]
            if split and vsplit != split:
                continue
            # If IMU, filter videos without IMU
            if self.imu_data and not self.imu_data.has_imu(video_uid):
                continue
            if eligible_video_uids and video_uid not in eligible_video_uids:
                continue
            for clip in vid["clips"]:
                clip_uid = clip["clip_uid"]
                clip_uids.add(clip_uid)
                clip_video_map[clip_uid] = video_uid
                clip_start_sec = clip["video_start_sec"]
                clip_end_sec = clip["video_end_sec"]
                for vann in clip["annotations"]:
                    for lann in vann["labels"]:
                        label = lann["label"]
                        labels.add(label)
                        start = lann["start_time"]
                        end = lann["end_time"]
                        # remove sample with same timestamp
                        if start == end:
                            continue
                        start_video = lann["video_start_time"]
                        end_video = lann["video_end_time"]
                        assert end_video >= start_video

                        if abs(start_video - (clip_start_sec + start)) > 0.5:
                            log.debug(
                                f"Suspect clip/video start mismatch: clip: {clip_start_sec:.2f} + {start:.2f} video: {start_video:.2f}"  # noqa
                            )

                        # filter annotation base on the existing label map
                        if filtered_labels and label not in filtered_labels:
                            cnt_samples_bypassed += 1
                            labels_bypassed.add(label)
                            continue
                        metadata = self.video_metadata_map[video_uid]

                        if metadata.get("is_stereo"):
                            cnt_samples_bypassed += 1
                            continue

                        if video_path_override:
                            video_path = video_path_override(video_uid)
                        else:
                            video_path = metadata["manifold_path"]
                        if not video_path:
                            cnt_samples_bypassed += 1
                            log.error("Bypassing invalid video_path: {video_uid}")
                            continue

                        sample = {
                            "clip_uid": clip_uid,
                            "video_uid": video_uid,
                            "duration": metadata["duration_sec"],
                            "clip_video_start_sec": clip_start_sec,
                            "clip_video_end_sec": clip_end_sec,
                            "labels": [label],
                            "label_video_start_sec": start_video,
                            "label_video_end_sec": end_video,
                            "video_path": video_path,
                        }
                        assert (
                            sample["label_video_end_sec"]
                            > sample["label_video_start_sec"]
                        )

                        if self.label_name_id_map:
                            if label in self.label_name_id_map:
                                sample["labels_id"] = self.label_name_id_map[label]
                            else:
                                cnt_samples_bypassed_labels += 1
                                continue
                        else:
                            log.error("Missing label_name_id_map")
                        samples.append(sample)

        self.cnt_samples: int = len(samples)

        log.info(
            f"Loaded {self.cnt_samples} samples. Bypass: {cnt_samples_bypassed} Label Lookup Bypass: {cnt_samples_bypassed_labels}"  # noqa
        )
        print(
            f"Loaded {self.cnt_samples} samples. Bypass: {cnt_samples_bypassed} Label Lookup Bypass: {cnt_samples_bypassed_labels}"  # noqa
        )

        for sample in samples:
            assert "labels_id" in sample, f"init: Sample missing labels_id: {sample}"

        if not clip_sampler:
            clip_sampler = MomentsClipSampler(self.window_sec)

        super().__init__(
            [(x["video_path"], x) for x in samples],
            clip_sampler,
            video_sampler,
            transform=self._transform_mm,
            decode_audio=decode_audio,
            decoder=decoder,
        )

        if video_path_handler:
            self.video_path_handler = video_path_handler

    def check_IMU(self, input_dict: Dict[str, Any]) -> bool:
        if (
            len(input_dict["imu"]["signal"].shape) != 2
            or input_dict["imu"]["signal"].shape[0] == 0
            or input_dict["imu"]["signal"].shape[0] < 200
            or input_dict["imu"]["signal"].shape[1] != 6
        ):
            log.warning(f"Problematic Sample: {input_dict}")
            return True
        else:
            return False

    def _transform_mm(self, sample_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        log.debug("_transform_mm")
        with profiler.record_function("_transform_mm"):
            video_uid = sample_dict["video_uid"]
            assert video_uid

            assert sample_dict["video"] is not None
            assert (
                "labels_id" in sample_dict
            ), f"Sample missing labels_id: {sample_dict}"

            video = sample_dict["video"]

            expected = int(self.VIDEO_FPS * self.window_sec)
            actual = video.size(1)
            if expected != actual:
                log.error(
                    f"video size mismatch: actual: {actual} expected: {expected} video: {video.size()} uid: {video_uid}",  # noqa
                    stack_info=True,
                )
                return None

            start = sample_dict["clip_start"]
            end = sample_dict["clip_end"]
            assert start >= 0 and end >= start

            if abs((end - start) - self.window_sec) > 0.01:
                log.warning(f"Invalid IMU time window: ({start}, {end})")

            if self.imu_data:
                sample_dict["imu"] = self.imu_data.get_imu_sample(
                    video_uid,
                    start,
                    end,
                )
                if self.check_IMU(sample_dict):
                    log.warning(f"Bad IMU sample: ignoring: {video_uid}")
                    return None

            sample_dict = self._video_transform()(sample_dict)

            if self.decode_audio:
                audio_fps = self.AUDIO_FPS
                sample_dict["audio"] = self._preproc_audio(
                    sample_dict["audio"], audio_fps
                )

            labels = sample_dict["labels"]
            one_hot = self.convert_one_hot(labels)
            sample_dict["labels_onehot"] = one_hot

            if self._transform_source:
                sample_dict = self._transform_source(sample_dict)

            lcnt = sum(one_hot)

            log.info(
                f"Sample ({sample_dict['video_name']}): "
                f"({sample_dict['clip_start']:.2f}, {sample_dict['clip_end']:.2f}) "
                f" {sample_dict['labels_id']} | {sample_dict['labels']} | {lcnt}"
            )

            return sample_dict

    # pyre-ignore
    def _video_transform(self):
        """
        This function contains example transforms using both PyTorchVideo and
        TorchVision in the same callable. For 'train' model, we use augmentations (prepended
        with 'Random'), for 'val' we use the respective deterministic function
        """

        assert (
            self.video_means
            and self.video_stds
            and self.video_min_short_side_scale > 0
            and self.video_crop_size > 0
        )

        video_transforms = ApplyTransformToKey(
            key="video",
            transform=Compose(
                # pyre-fixme
                [Div255(), Normalize(self.video_means, self.video_stds)]
                + [  # pyre-fixme
                    RandomShortSideScale(
                        min_size=self.video_min_short_side_scale,
                        max_size=self.video_max_short_side_scale,
                    ),
                    RandomCrop(self.video_crop_size),
                    RandomHorizontalFlip(p=0.5),
                ]
                if self.training
                else [
                    ShortSideScale(self.video_min_short_side_scale),
                    CenterCrop(self.video_crop_size),
                ]
            ),
        )
        return Compose([video_transforms])

    def signal_transform(self, type: str = "spectrogram", sample_rate: int = 48000):
        if type == "spectrogram":
            n_fft = 1024
            win_length = None
            hop_length = 512

            transform = torchaudio.transforms.Spectrogram(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                center=True,
                pad_mode="reflect",
                power=2.0,
            )
        elif type == "melspectrogram":
            n_fft = 1024
            win_length = None
            hop_length = 512
            n_mels = 64

            transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                center=True,
                pad_mode="reflect",
                power=2.0,
                norm="slaney",
                onesided=True,
                n_mels=n_mels,
                mel_scale="htk",
            )
        elif type == "mfcc":
            n_fft = 2048
            win_length = None
            hop_length = 512
            n_mels = 256
            n_mfcc = 256

            transform = torchaudio.transforms.MFCC(
                sample_rate=sample_rate,
                n_mfcc=n_mfcc,
                melkwargs={
                    "n_fft": n_fft,
                    "n_mels": n_mels,
                    "hop_length": hop_length,
                    "mel_scale": "htk",
                },
            )
        else:
            raise ValueError(type)

        return transform

    def _preproc_audio(self, audio, audio_fps) -> Dict[str, Any]:
        # convert stero to mono
        # https://github.com/pytorch/audio/issues/363
        waveform_mono = torch.mean(audio, dim=0, keepdim=True)
        return {
            "signal": waveform_mono,
            "spectrogram": self.signal_transform(
                type=self.audio_transform_type,
                sample_rate=audio_fps,
            )(waveform_mono),
            "sampling_rate": audio_fps,
        }

    def convert_one_hot(self, label_list: List[str]) -> List[int]:
        f_label_list = [x for x in label_list if x in self.label_name_id_map.keys()]
        assert len(f_label_list) == len(label_list), f"invalid filter {len(label_list)} -> {len(f_label_list)}: {label_list}"
        label_list = f_label_list
        one_hot = [0 for _ in range(self.num_classes)]
        for lab in label_list:
            one_hot[self.label_name_id_map[lab]] = 1
        assert sum(one_hot) == len(label_list)
        return one_hot

In [7]:
# Dataloader

def create_ego4d_moments_dataset(
    annotation_path: str,
    metadata_path: str = inputs.metadata_path,
    imu_path: str = "~/ego4d_data/imu/",  # Not publicly available
    label_id_map: Dict[str, int] = None,
    split: Optional[str] = None,
    decode_audio: bool = True,
    imu: bool = False,
    eligible_video_uids: Set[str] = None,
    **kwargs,
):
    return Ego4dMomentsDataset(
        annotation_path=annotation_path,
        metadata_path=metadata_path,
        imu_path=imu_path,
        label_id_map=label_id_map,
        split=split,
        decode_audio=decode_audio,
        imu=imu,
        eligible_video_uids=eligible_video_uids,
        **kwargs,
    )

def collate_fn_mm_moments(data: List[Dict[str, Any]]) -> Dict[str, Any]:
    log.info("collate_fn_mm_moments")
    input_tensor_IMU = []
    input_tensor_VIDEO = []
    input_tensor_AUDIO = []
    input_tensor_SPECTRO = []
    len_list_imu = []
    len_list_spectrogram = []
    len_list_audio = []
    labels = []  # labels as class names
    labels_onehot = []  # labels as one hot vectors
    for d in data:
        if "video" in d:
            input_tensor_VIDEO.append(d["video"])
        if "spectrogram" in d:
            input_tensor_SPECTRO.append(d["spectrogram"])
            len_list_spectrogram.append(d["spectrogram"].size(2))
        if "audio" in d:
            input_tensor_AUDIO.append(d["audio"])
            len_list_audio.append(d["audio"].size(1))
        if "imu" in d:
            input_tensor_IMU.append(d["imu"])
            len_list_imu.append(d["imu"].size(1))
        labels.append(d["labels"])
        labels_onehot.append(d["labels_onehot"])

    dict_output = {}
    dict_output["labels"] = labels
    dict_output["labels_onehot"] = torch.tensor(labels_onehot).float()
    if input_tensor_IMU:
        min_len = min(len_list_imu)
        input_tensor_IMU = [t[:, :min_len] for t in input_tensor_IMU]
        dict_output["imu"] = torch.stack(input_tensor_IMU)
    if input_tensor_AUDIO:
        min_len = min(len_list_audio)
        input_tensor_AUDIO = [t[:, :min_len] for t in input_tensor_AUDIO]
        dict_output["audio"] = torch.stack(input_tensor_AUDIO)
    if input_tensor_VIDEO:
        dict_output["video"] = torch.stack(input_tensor_VIDEO)
    if input_tensor_SPECTRO:
        min_len = min(len_list_spectrogram)
        input_tensor_SPECTRO = [t[:, :, :min_len] for t in input_tensor_SPECTRO]
        dict_output["spectrogram"] = torch.stack(input_tensor_SPECTRO)
    return dict_output

class Ego4dMomentsDataModule(pl.LightningDataModule):
    """
    LightningDataModule for the Ego4d Moments dataset.  Practically a wrapper around
    `Ego4dMomentsDataset`
    """

    def __init__(
        self,
        label_id_map: Dict[str, int] = None,
        batch_size: int = 8,
        num_workers: int = 8,
        train_transforms=None,
        val_transforms=None,
        test_transforms=None,
        dims=None,
        video_path_override: Optional[Callable[[str], str]] = None,
        eligible_video_uids: Set[str] = None,
    ) -> None:
        super().__init__(train_transforms, val_transforms, test_transforms, dims)

        self.batch_size: int = batch_size
        self.num_workers: int = num_workers
        self.video_path_override = video_path_override
        self.eligible_video_uids = eligible_video_uids

        self.label_name_id_map = label_id_map
        assert self.label_name_id_map, "Failed to load label_name_id_map"
        self.num_classes: int = len(self.label_name_id_map)
        assert self.num_classes > 0

        log.info(f"Ego4dMomentsDataModule: num_classes: {self.num_classes}")

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        dataset_train = create_ego4d_moments_dataset(
            annotation_path=inputs.annotation_path_train,
            split="train",
            decode_audio=False,
            imu=False,
            label_id_map=self.label_name_id_map,
            label_id_map_path=None,
            video_path_override=self.video_path_override,
            eligible_video_uids=self.eligible_video_uids,
        )

        return torch.utils.data.DataLoader(
            dataset_train,
            collate_fn=collate_fn_mm_moments,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        dataset_train = create_ego4d_moments_dataset(
            annotation_path=inputs.annotation_path_val,
            split="val",
            decode_audio=False,
            imu=False,
            label_id_map=self.label_name_id_map,
            label_id_map_path=None,
            video_path_override=self.video_path_override,
            eligible_video_uids=self.eligible_video_uids,
        )

        return torch.utils.data.DataLoader(
            dataset_train,
            collate_fn=collate_fn_mm_moments,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

In [8]:
# LFMM Model

MODEL_PATH_RESNET = os.path.join(notebook_path, "resnet18-f37072fd.pth")
MODEL_PATH_R2p1d = os.path.join(notebook_path, "r2plus1d_18-91a641e6.pth")

import csv
import json
import logging
import os
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from iopath.common.file_io import PathManager
from torchmetrics import AveragePrecision

logger = log


def get_image_encoder(image_model_type: str) -> Tuple[nn.Module, int]:
    """
    Helper function to return a given image encoder.  For each, we strip out
    the last classification layer, where necessary, so that the final
    embedding layer is return.  To that end, we return a tuple of the image
    encoder (e.g., `nn.Module`) along with the corresponding embedding dims.
    """
    image_encoder = None
    embedding_dims = 0

    if image_model_type == "r2+1d":
        # Loading checkpoint from manifold b/c outside internet connection
        # usually not available.
        image_encoder = torchvision.models.video.r2plus1d_18(pretrained=False)
        with g_pathmgr.open(
            MODEL_PATH_R2p1d,
            "rb",
        ) as f:
            previous_state = torch.load(f, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(previous_state)

        # Cut out the classifier head to expose the 512-D embeddings
        image_encoder.fc = nn.Identity()
        embedding_dims = 512
    elif image_model_type == "resnet18":
        # Loading checkpoint from manifold b/c outside internet connection
        # usually not available.
        image_encoder = torchvision.models.resnet18(pretrained=False)
        with g_pathmgr.open(
            MODEL_PATH_RESNET,
            "rb",
        ) as f:
            previous_state = torch.load(f, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(previous_state)

        # Cut out the classifier head to expose the 512-D embeddings
        image_encoder.fc = nn.Identity()
        embedding_dims = 512

    return (image_encoder, embedding_dims)


class LFMM(nn.Module):
    def __init__(
        self,
        num_classes: int,
        image_model_type: str = "resnet18",
        IMU: bool = False,
        VIDEO: bool = True,
        AUDIO: bool = False,
    ) -> None:
        """
        Late-fusion based multimodal model.  Given a sequence of image frames and a
        corresponding IMU buffer, take only the last image frame and encode and combine
        with the IMU model.  We combine at the logits layer (averaging).

        Choices for the image encoder are ARNet (not pre-trained) and
        Resnet18 (pre-trained on ImageNet).  Revisit this for better pre-trained
        image encoder options.

        :TODO: fuse logits or last embedding layers (via concat)?
        :TODO: add audio stream
        """
        super().__init__()

        self.IMU = IMU
        self.VIDEO = VIDEO
        self.AUDIO = AUDIO
        assert self.IMU or self.VIDEO or self.AUDIO

        self.embedding_dims = 0

        # Setup the image encoder
        if self.VIDEO:
            self.image_encoder, image_embedding_dims = get_image_encoder(
                image_model_type
            )
            self.embedding_dims += image_embedding_dims

        if self.IMU:
            # Setup the IMU encoder.  Add a 32-D embedding output:
            # TODO: self.imu_encoder = MCNN1DAttentionPooling(32, None).net
            self.embedding_dims += 32

        if self.AUDIO:
            # Setup the Audio encoder
            # TODO: Replace
            # self.audio_encoder = AudioClassifierModifiedVGG(
            #     None,
            #     num_classes,
            #     pretrained_checkpoint="manifold://fai4ar_supar/tree/checkpoints/mobilenet_v3_small-047dcff4.pth",
            # ).net

            # Cut out the classifier head to expose the 1024-D embeddings
            self.audio_encoder.classifier[3] = nn.Identity()
            self.embedding_dims += 1024

        # Then the final classifier head on top of the concatenated embeddings
        self.clf = nn.Linear(self.embedding_dims, num_classes)

    def forward(
        self,
        images: torch.Tensor,
        imu: torch.Tensor,
        audio: torch.Tensor,
    ) -> torch.Tensor:
        """
        Inputs -
            images: [bsz, channels=3, <num_frames?>, height, width] tensor of images
                --> num_frames optional.  If not given then its just the last frame of the sequence.
            imu: [bsz, channels=6, length]
            audio: [bsz, channels=3, height, width] Tensor of spectrograms
        """
        embeddings = []

        if self.VIDEO:
            # Run the image encoder
            embeddings.append(self.image_encoder(images))

        if self.IMU:
            # Run the IMU encoder
            embeddings.append(self.imu_encoder(imu))

        if self.AUDIO:
            # Run the audio encoder
            embeddings.append(self.audio_encoder(audio))

        embeddings = torch.cat(embeddings, 1)
        logits = self.clf(embeddings)

        return logits


class LFMMModule(pl.LightningModule):
    def __init__(
        self,
        num_classes: int,
        lr: float,
        lr_milestones: List[int],
        image_model_type: str = "resnet18",
        save_checkpoint_dir: str = None,
        IMU: bool = False,
        VIDEO: bool = True,
        AUDIO: bool = False,
        label_id_map: Optional[Dict[str, int]] = None,
        label_mapping_file: Optional[str] = None,
    ):
        super().__init__()

        self.IMU = IMU
        self.VIDEO = VIDEO
        self.AUDIO = AUDIO
        self.num_classes = num_classes
        self.image_model_type = image_model_type

        self.label_mapping = None
        try:
            ids = set()
            if label_id_map is not None:
                self.label_mapping = {}
                for class_name, id in label_id_map.items():
                    self.label_mapping[id] = class_name
                    ids.add(id)
                assert len(label_id_map) == num_classes, f"label_id_map invalid: {len(label_id_map)} != {num_classes}"
            elif label_mapping_file is not None:
                with g_pathmgr.open(label_mapping_file, "r") as f:
                    class_mapping = json.load(f)
                self.label_mapping = {}
                for class_name, metadata in class_mapping.items():
                    self.label_mapping[metadata["index"]] = class_name
                    ids.add(metadata["index"])
            if not self.label_mapping:
                logger.error("LFMM: No Label Mapping Provided!")
            else:
                max_id = max(ids)
                if max_id != len(ids) - 1 or len(ids) != len(self.label_mapping) or num_classes != len(self.label_mapping):
                    logger.error(f"Error: LFMM: Label->Id Inconsistency: Labels: {len(self.label_mapping)} num_classes: {num_classes} Ids: {len(ids)} Max: {max_id} Min: {min(ids)}")  # noqa
                else:
                    logger.info(f"LFMM: Valid Labels: {len(self.label_mapping)} num_classes: {num_classes} Max: {max_id} Min: {min(ids)} ({self.VIDEO}|{self.AUDIO}|{self.IMU})")
        except Exception as exc:
            logger.error(
                f"Error opening label mapping file ({label_mapping_file}): {exc}"
            )
            self.label_mapping = None
            raise Exception(f"Error opening label mapping file ({label_mapping_file}): {exc}")

        self.model = LFMM(
            num_classes,
            image_model_type=image_model_type,
            IMU=self.IMU,
            VIDEO=self.VIDEO,
            AUDIO=self.AUDIO,
        )

        self.lr = lr
        self.lr_milestones = lr_milestones
        self.save_checkpoint_dir = save_checkpoint_dir

        self.save_hyperparameters()

    def forward(
        self,
        images: torch.Tensor,
        imu: torch.Tensor,
        audio: torch.Tensor,
    ) -> torch.Tensor:
        return self.model(images, imu, audio)

    def one_step(self, batch, batch_idx):
        if self.VIDEO:
            # Size: [bsz, 3, num_frames, height, width] - e.g. [12, 3, 20, 112, 112]
            if self.image_model_type == "r2+1d":
                # Use the entire video clip tensor
                video = batch["video"]  # [bsz, 3, num_frames, height, width]
            else:
                # Strip out the last image frames from the video clips
                video = batch["video"]
                video = video[:, :, -1]  # [bsz, 3, height, width]
        else:
            video = torch.tensor(0)

        if self.IMU:
            # Size [bsz, 6, length] - e.g. [12, 6, 966]
            imu = batch["imu"]  # [bsz, 6, length]
        else:
            imu = torch.tensor(0)

        if self.AUDIO:
            # Size: [bsz, 3, height, width] - e.g. [12, 3, 201, 401]
            audio = batch["spectrogram"]  # [bsz, 3, height, width]
        else:
            audio = torch.tensor(0)

        output = self.forward(video, imu, audio)
        target = batch["labels_onehot"]
        loss = F.binary_cross_entropy_with_logits(output, target)
        return loss, output

    def training_step(self, batch, batch_idx):
        """
        video shape: [bsz, 3, num_frames, height, width]
        imu shape: [bsz, 6, length]
        audio shape: [bsz, 3, num_frames, height, width]
        """

        loss, _ = self.one_step(batch, batch_idx)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        video shape: [bsz, 3, num_frames, height, width]
        imu shape: [bsz, 6, length]
        audio shape: [bsz, 3, num_frames, height, width]
        """

        loss, output = self.one_step(batch, batch_idx)
        self.log("val/loss", loss, prog_bar=True)
        return {"loss": loss, "output": output, "label": batch["labels_onehot"]}

    def test_step(self, batch, batch_idx):
        """
        video shape: [bsz, 3, num_frames, height, width]
        imu shape: [bsz, 6, length]
        audio shape: [bsz, 3, num_frames, height, width]
        """

        loss, output = self.one_step(batch, batch_idx)
        self.log("test/loss", loss, prog_bar=True)
        return {"loss": loss, "output": output, "label": batch["labels_onehot"]}

    def validation_step_end(self, batch_parts):
        """
        Accumulate the outputs across the devices, for a single mini-batch step.
        """
        losses = batch_parts["loss"]
        outputs = batch_parts["output"]
        label = batch_parts["label"]

        return {
            "loss": torch.mean(torch.Tensor(losses)),
            "output": outputs,
            "label": label,
        }

    def test_step_end(self, batch_parts):
        return self.validation_step_end(batch_parts)

    def evaluation_epoch_end(self, step_outputs, prefix):
        """
        Accumulate the AP and mAP metrics across all samples and devices.
        """

        mean_loss = torch.mean(torch.Tensor([v["loss"] for v in step_outputs]))
        self.log(
            """{prefix}/total_loss""".format(prefix=prefix),
            mean_loss,
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )

        outputs = torch.cat([v["output"] for v in step_outputs], 0)
        label = torch.cat([v["label"] for v in step_outputs], 0).long()

        outputs = torch.nn.functional.sigmoid(outputs)
        average_precision = AveragePrecision(pos_label=1)  # do this per class

        per_class_ap = torch.Tensor(
            [
                average_precision(outputs[:, class_idx], label[:, class_idx])
                for class_idx in range(self.num_classes)
            ]
        )

        nonzeo_per_class_ap = torch.Tensor(
            [
                average_precision(outputs[:, class_idx], label[:, class_idx])
                for class_idx in range(self.num_classes)
                if torch.sum(label[:, class_idx]) > 0
            ]
        )

        # Replace NaN with 0
        # per_class_ap[per_class_ap != per_class_ap] = 0

        mAP = torch.mean(nonzeo_per_class_ap)
        self.log(
            """{prefix}/mAP""".format(prefix=prefix),
            mAP,
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )

        per_class_ap_map = {}
        if self.label_mapping is not None:
            # Log out per-class AP metrics
            for label_index, ap in enumerate(per_class_ap):
                per_class_ap_map[self.label_mapping[label_index]] = ap

                self.log(
                    """{prefix}/AP [{class_name}]""".format(
                        prefix=prefix, class_name=self.label_mapping[label_index]
                    ),
                    ap,
                    prog_bar=False,
                )

        return {
            """{prefix}/total_loss""".format(prefix=prefix): mean_loss,
            """{prefix}/mAP""".format(prefix=prefix): mAP,
            """{prefix}/per_class_ap""".format(prefix=prefix): per_class_ap_map,
        }

    def validation_epoch_end(self, validation_step_outputs):
        return self.evaluation_epoch_end(validation_step_outputs, "val")

    def test_epoch_end(self, test_step_outputs):
        final_results = self.evaluation_epoch_end(test_step_outputs, "test")

        per_class_ap_map = final_results["test/per_class_ap"]

        # Save final results to CSV file before returning
        out_file = os.path.join(self.save_checkpoint_dir, "test_per_class_ap.csv")
        with g_pathmgr.open(out_file, "w") as csvfile:
            csvwriter = csv.writer(csvfile, delimiter=",")

            class_names = sorted(per_class_ap_map.keys())
            csvwriter.writerow(class_names)

            scores = []
            for class_name in class_names:
                scores.append(per_class_ap_map[class_name].item())
            csvwriter.writerow(scores)
        logger.info(f"Saved final per-class AP results to: {out_file}")

        return final_results

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1
        )

        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]


In [None]:
def video_path_override(x):
    path = os.path.join(inputs.video_path, f"{x}.mp4")
    return path

moments_label_id_map = get_label_id_map(inputs.label_mapping_file)
eligible_uids = None
if inputs.miniset:
    train_uids = get_video_uids("moments_mini_train_uids.csv")
    val_uids = get_video_uids("moments_mini_val_uids.csv")
    eligible_uids = train_uids | val_uids

data = Ego4dMomentsDataModule(
    moments_label_id_map,
    batch_size=inputs.batch_size,
    num_workers=inputs.num_workers,
    video_path_override=video_path_override,
    eligible_video_uids=eligible_uids,
)

In [10]:
# Model

model = LFMMModule(
    num_classes=data.num_classes,
    lr=inputs.learning_rate,
    lr_milestones=inputs.learning_rate_milestones,
    image_model_type=inputs.image_model_type,
    save_checkpoint_dir=inputs.checkpoint_dir,
    IMU=inputs.use_imu,
    VIDEO=inputs.use_video,
    AUDIO=inputs.use_audio,
    label_id_map=moments_label_id_map,
)

In [11]:
# Lightning Setup

checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath=inputs.checkpoint_dir,
    # every_n_train_steps=inputs.every_n_train_steps,
    verbose=True,
    monitor="val/mAP",
    mode="max",
    save_top_k=5,
)

In [12]:
# Trainer

trainer = pl.Trainer(
    num_nodes=inputs.num_nodes,
    gpus=inputs.num_gpus_per_node,
    # If using multiple GPU/node, you'll want to use DDP (but not in a notebook)
    # strategy=pl.strategies.DDPStrategy(find_unused_parameters=False),
    # logger=pl_logger,
    callbacks=[checkpoint],
    max_epochs=inputs.max_epochs,
    max_steps=inputs.max_steps,
    resume_from_checkpoint=inputs.resume_from_checkpoint,
    accumulate_grad_batches=inputs.accumulate_grad_batches,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(model, datamodule=data)


  | Name  | Type | Params
-------------------------------
0 | model | LFMM | 11.2 M
-------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.860    Total estimated model params size (MB)


Epoch 0: : 19it [19:15, 60.80s/it, loss=0.545, v_num=1]
Loaded 275 samples. Bypass: 0 Label Lookup Bypass: 61
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:11<00:00,  6.53s/it]Val Net: out: -10287.2421875 label: 16
Loaded 891 samples. Bypass: 0 Label Lookup Bypass: 173                     
Epoch 0: : 81it [51:39, 50.00s/it, loss=0.088, v_num=1] 

Epoch 0, global step 100: 'val/mAP' was not in top 5


Epoch 0: : 81it [51:39, 50.00s/it, loss=0.088, v_num=1]


In [20]:
metrics = trainer.validate(model, data, None, verbose=True)
metrics

Loaded 275 samples. Bypass: 0 Label Lookup Bypass: 61
Validation DataLoader 0: : 16it [04:08, 15.55s/it]
Validation: 35it [09:34,  9.51s/it]Val Net: out: -700311.0 label: 275
Validation: 35it [09:34, 16.42s/it]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                      Validate metric                                               DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 val/AP ["cut_/_chop_/_slice_a_vegetable,_fruit,_or_meat"]                      0.06919191777706146
val/AP ["put_on_safety_equipment_(e.g._gloves,_helmet,_safet                            nan
                        y_goggles)"]
   val/AP [arrange_/_organize_clothes_in_closet/dresser]                       0.0036496350076049566
        val/AP [arrange_/_organize_items_in_fridge]                                     nan
          val/AP [arrange_/_



[{'val/loss': 0.5302677750587463,
  'val/total_loss': 0.5298373103141785,
  'val/mAP': 0.05019664019346237,
  'val/AP [serve_food_onto_a_plate]': 0.0074136704206466675,
  'val/AP [converse_/_interact_with_someone]': 0.21875594556331635,
  'val/AP [use_phone]': 0.0538051538169384,
  'val/AP [clean_/_wipe_a_table_or_kitchen_counter]': 0.02984566241502762,
  'val/AP [plant_seeds_/_plants_/_flowers_into_ground]': nan,
  'val/AP [tie_up_branches_/_plants_with_string]': nan,
  'val/AP [cut_tree_branch]': nan,
  'val/AP [harvest_vegetables_/_fruits_/_crops_from_trees]': nan,
  'val/AP [remove_weeds_from_ground]': nan,
  'val/AP [cut_other_item_using_tool]': nan,
  'val/AP [throw_away_trash_/_put_trash_in_trash_can]': 0.018948446959257126,
  'val/AP [water_soil_/_plants_/_crops]': 0.017617180943489075,
  'val/AP [wash_hands]': 0.08985372632741928,
  'val/AP [turn-on_/_light_the_stove_burner]': 0.4682539701461792,
  'val/AP [trim_hedges_or_branches]': nan,
  'val/AP [harvest_vegetables_/_fruits

In [None]:
print(f"mAP: {metrics[0]['val/mAP']}")

mAP: 0.04237818717956543


# Validate From Checkpoint

In [None]:
# CLI download for val downsampled videos only: 
# !python -m ego4d.cli.cli -y --output_directory ~/ego4d_data/ --datasets video_540ss --video_uid_file ~/path/to/moments_mini_val_uids.csv

# Using the checkpoint from the binaries file
model = LFMMModule.load_from_checkpoint("moments.ckpt", label_mapping_file=None, label_id_map=moments_label_id_map)

metrics = trainer.validate(model, data, None, verbose=True)
print(f"mAP: {metrics[0]['val/mAP']}")

# Moments Benchmark Comparison

In [None]:
# https://github.com/EGO4D/episodic-memory/blob/main/MQ/README.md