In [1]:
import json
import os
import random
from operator import itemgetter
from collections import defaultdict

import torch
import numpy as np
import pandas as pd

from PIL import Image
from tqdm.auto import tqdm

from ego4d.research.readers import TorchAudioStreamReader, PyAvReader
VideoReader = TorchAudioStreamReader

In [2]:
def get_cam_stream_ids(take):
    for cam_id, vs in take["frame_aligned_videos"].items():
        if cam_id in ("best_exo", "collage"):
            continue
        stream_id = "0"
        if "aria" in cam_id.lower():
            stream_id = "rgb"
        
        yield cam_id, stream_id

In [3]:
RELEASE_DIR = "/checkpoint/miguelmartin/egoexo_data/dev"  # NOTE: changeme

egoexo = {
    "takes": os.path.join(RELEASE_DIR, "takes.json"),
    "captures": os.path.join(RELEASE_DIR, "captures.json"),
    "physical_setting": os.path.join(RELEASE_DIR, "physical_setting.json"),
    "participants": os.path.join(RELEASE_DIR, "participants.json"),
    "visual_objects": os.path.join(RELEASE_DIR, "visual_objects.json"),
    "splits": os.path.join(RELEASE_DIR, "annotations/splits.json"),
}

for k, v in egoexo.items():
    egoexo[k] = json.load(open(v))

takes = egoexo["takes"]
captures = egoexo["captures"]
takes_by_uid = {x["take_uid"]: x for x in takes}

In [4]:
splits = egoexo["splits"]["split_to_take_uids"]

# Utils

In [5]:
from ego4d.research.dataset import save_ego4d_features_to_hdf5

# TODO: do dropped
def get_data_for(split):
    return [
        (
            f"{x['take_uid']}_" + "_".join(cam_stream_id),
            {
                "parent_task_id": x["parent_task_id"] // 1000,
                "parent_task_name": x["parent_task_name"],
                "take_uid": x["take_uid"],
                "cam_id": cam_stream_id[0],
                "stream_id": cam_stream_id[1],
            },
        )
        for x in egoexo["takes"] if x["take_uid"] in split
        for cam_stream_id in get_cam_stream_ids(x)
    ]
from torch.utils.data import DataLoader, Dataset

import bisect
import math
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import h5py
import torch
from ego4d.research.readers import PyAvReader, StridedReader, TorchAudioStreamReader

from tqdm.auto import tqdm


class LabelledFeatureDset(torch.utils.data.Dataset):
    """
    A simple utility class to load features associated with labels. The input this
    method requires is as follows:
        1. `feature_hdf5_path`: the features transposed to a HDF5 file.
            See `save_ego4d_features_to_hdf5`
        2. `uid_label_pairs` a list of (uid, label). `label` can be anything
            `uid` is a unique id associated to the `feature_hdf5_path` file.
        3. `aggr_function` a function to aggregate based off given label
    """

    def __init__(
        self,
        feature_hdf5_path: str,
        uid_label_pairs: List[Tuple[str, Any]],
        aggr_function: Optional[Callable[[torch.Tensor, Any], torch.Tensor]] = None,
    ):
        self.features = h5py.File(feature_hdf5_path)
        self.aggr_function = (
            aggr_function
            if aggr_function is not None
            else lambda x, _: torch.tensor(x[0:]).squeeze()
        )
        self.uid_label_pairs = uid_label_pairs
        f_keys = set(self.features.keys())
        l_keys = set(uid for uid, _ in self.uid_label_pairs)
        if len(l_keys - f_keys) > 0:
            print(f"WARN: missing {len(l_keys - f_keys)} keys in feature hdf5 path: {feature_hdf5_path}")
            self.uid_label_pairs = [(uid, label) for uid, label in self.uid_label_pairs if uid in f_keys]
    
    def __len__(self):
        return len(self.uid_label_pairs)

    def __getitem__(self, idx: int):
        uid, label = self.uid_label_pairs[idx]
        feat = self.aggr_function(self.features[uid], label)
        return feat, label


# Zero-Shot

In [6]:
features_dir = "/checkpoint/miguelmartin/egoexo_features/maws_clip_2b_public"
features_paths = [x for x in os.listdir(features_dir) if x != "config.yaml"]

features_by_take_cam = {}
for x in features_paths:
    take_uid, cam_id, stream_id_pt = x.split("_")
    stream_id = stream_id_pt = stream_id_pt.split(".")[0]
    if take_uid not in features_by_take_cam:
        features_by_take_cam[take_uid] = {}
    key = (cam_id, stream_id)
    features_by_take_cam[take_uid][key] = os.path.join(features_dir, x)
    
    
features_paths[0]

'18664e6b-14c7-4a39-97e0-fd5af39268fb_aria01_rgb.pt'

In [7]:
features_by_take_cam[take_uid].keys()

dict_keys([('gp03', '0'), ('gp01', '0'), ('aria01', 'rgb'), ('gp04', '0'), ('gp05', '0'), ('gp02', '0'), ('gp06', '0')])

In [11]:
tqdm._instances.clear()


In [12]:
out_path = "/checkpoint/miguelmartin/egoexo_features/maws_clip_2b_public.hdf5"
video_uids = [x.split(".")[0] for x in features_paths]
feature_hdf5_path = out_path
save_ego4d_features_to_hdf5(video_uids, feature_dir=features_dir, out_path=out_path)

video_uid:   0%|          | 0/23929 [00:00<?, ?it/s]

In [13]:
train_takes = set(splits["train"]) & set(features_by_take_cam.keys())
val_takes = set(splits["val"]) & set(features_by_take_cam.keys())

In [14]:
list(get_cam_stream_ids(egoexo["takes"][0]))

[('aria01', 'rgb'),
 ('cam01', '0'),
 ('cam02', '0'),
 ('cam03', '0'),
 ('cam04', '0')]

In [15]:
val_data = get_data_for(val_takes)
train_data = get_data_for(train_takes)

In [18]:
# f_keys = set(val_dset.features.keys())
# l_keys = set(uid for uid, _ in val_dset.uid_label_pairs)

In [19]:
from ego4d.research.clep.val import accuracy

In [20]:
from maws.model_builder import build_model
model = build_model("vit_2b14_xlmr_l", "maws_clip")
model = model.eval().half()
model = model.to("cuda")

In [21]:
txt_labels = [
    "Phone with a QR code",
    "A person is cooking",
    "A person is performing Health related activities such as a COVID-19 test or CPR",
    "A person is at a campsite",
    "A person is performing repair on a bike",
    "A person is playing a musical instrument",
    "A person is playing basketball",
    "A person is rock climbing",
    "A person is playing soccer",
    "A person is dancing",
]

In [22]:
val_dset = LabelledFeatureDset(feature_hdf5_path, val_data)
val_dloader = DataLoader(val_dset, batch_size=1, shuffle=False)

train_dset = LabelledFeatureDset(feature_hdf5_path, train_data)
train_dloader = DataLoader(train_dset, batch_size=1, shuffle=False)

In [23]:
txt_emb = model.encode_texts(texts=txt_labels)

In [24]:
model_fn = lambda x: model.classify(text_features=txt_emb, image_features=x)
classifier = lambda x: x

In [25]:
def eval_classification(loader, topk, all_cams=False, cam_id_filter_fn=None):
    incorrect = []

    with torch.no_grad():
        cmps = []
        for x, y in tqdm(loader):
            vfs = x.cuda()
            assert vfs.shape[0] == 1
            
            pred = model_fn(vfs)
            cmps.append((pred, y))
        
        probs_by_take = defaultdict(list)
        for logits, y in cmps:
            assert len(y["take_uid"]) == 1
            probs_by_take[y["take_uid"][0]].append((logits.mean(1), y))
        
        accs = [0 for x in topk]
        n = 0
        for take_uid, prob_labels in probs_by_take.items():
            _, y = prob_labels[0]
            pred_targs = [
                (p, py["parent_task_id"])
                for (p, py) in prob_labels
                if cam_id_filter_fn is None or cam_id_filter_fn(py["take_uid"], py["cam_id"][0])
            ]
            if len(pred_targs) == 0:
                continue
            
            if all_cams:
                pred = torch.stack([x for x, _ in pred_targs]).mean(0) 
                target = y["parent_task_id"]
                pred_targs = [(pred, target)]
            
            for pred, target in pred_targs:
                for i, acc in enumerate(accuracy(pred, target.cuda(), topk=topk)):
                    if acc != 1:
                        incorrect.append((take_uid, pred.argmax().cpu().item(), target.cpu().item()))
                    accs[i] += acc
                n += pred.shape[0]

    accs = [x/n for x in accs]
    
    return {
        "accuracy_by_topk": {
            topk[i]: acc
            for i, acc in enumerate(accs)
        },
        "n": n,
        "incorrect": incorrect,
    }

In [26]:
def ego_cam_id_filter(take_uid, cam_id):
    return "aria" in cam_id.lower()

In [27]:
eval_classification(val_dloader, topk=(1, 2, 3, 5), all_cams=False, cam_id_filter_fn=None)["accuracy_by_topk"]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3797/3797 [04:18<00:00, 14.71it/s]


{1: 0.9399525941532789,
 2: 0.9723465894126943,
 3: 0.9855148801685542,
 5: 0.9971029760337108}

In [None]:
eval_classification(val_dloader, topk=(1, 2, 3, 5), all_cams=True, cam_id_filter_fn=None)["accuracy_by_topk"]

 67%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                        | 2557/3797 [00:45<00:04, 292.95it/s]

In [None]:
eval_classification(val_dloader, topk=(1, 2, 3, 5), all_cams=True, cam_id_filter_fn=ego_cam_id_filter)["accuracy_by_topk"]

# Fine-Tune

# Kepstep

In [None]:
keystep_train = json.load(open("/large_experiments/egoexo/dataset/annotations/keystep_train.json"))
keystep_val = json.load(open("/large_experiments/egoexo/dataset/annotations/keystep_val.json"))

In [None]:
keystep_train["taxonomy"].keys()