# Introduction

This notebook shows the capabilities of the [MAWS CLIP](https://github.com/facebookresearch/maws) features and how to use them in context with Ego-Exo4D. Each frame of Ego-Exo4D videos been fed into this model.

In this notebook it is shown that you can perform zero-shot classification with these features using the take's task labels: obtaining an accuracy of 97.8% top-1 and 99% top-2 accuracy. 

It is reccomended to read the paper: https://arxiv.org/abs/2303.13496

Zero-shot, N-shot and linear probe performance of the model are shown to be effective on a wide-variety of tasks.

In [None]:
import json
import os
import random
from operator import itemgetter
from collections import defaultdict

import torch
import numpy as np
import pandas as pd

from PIL import Image
from tqdm.auto import tqdm

from ego4d.research.readers import TorchAudioStreamReader, PyAvReader
VideoReader = TorchAudioStreamReader

In [None]:
release_dir = "/large_experiments/egoexo/v2/"  # NOTE: changeme

egoexo = {
    "takes": os.path.join(release_dir, "takes.json"),
    "captures": os.path.join(release_dir, "captures.json"),
    "physical_setting": os.path.join(release_dir, "physical_setting.json"),
    "participants": os.path.join(release_dir, "participants.json"),
    "visual_objects": os.path.join(release_dir, "visual_objects.json"),
    "splits": os.path.join(release_dir, "annotations/splits.json"),
}

for k, v in egoexo.items():
    egoexo[k] = json.load(open(v))

takes = egoexo["takes"]
captures = egoexo["captures"]
takes_by_uid = {x["take_uid"]: x for x in takes}

In [None]:
splits = egoexo["splits"]["split_to_take_uids"]

# Utils

In [None]:
from ego4d.research.clep.val import accuracy
from ego4d.research.dataset import LabelledFeatureDset
from torch.utils.data import DataLoader

In [None]:
def get_cam_stream_ids(take):
    for cam_id, vs in take["frame_aligned_videos"].items():
        if cam_id in ("best_exo", "collage"):
            continue
        stream_id = "0"
        if "aria" in cam_id.lower():
            stream_id = "rgb"
        
        yield cam_id, stream_id

def ego_cam_id_filter(take_uid, cam_id):
    return "aria" in cam_id.lower()

def get_data_for(split):
    return [
        (
            f"{x['take_uid']}_" + "_".join(cam_stream_id),
            {
                "parent_task_id": x["parent_task_id"] // 1000,
                "parent_task_name": x["parent_task_name"],
                "take_uid": x["take_uid"],
                "cam_id": cam_stream_id[0],
                "stream_id": cam_stream_id[1],
            },
        )
        for x in egoexo["takes"] if x["take_uid"] in split
        for cam_stream_id in get_cam_stream_ids(x)
    ]

In [None]:
def eval_classification(loader, topk, all_cams=False, cam_id_filter_fn=None):
    incorrect = []

    with torch.no_grad():
        cmps = []
        for x, y in tqdm(loader):
            vfs = x.cuda()
            assert vfs.shape[0] == 1
            
            pred = model_fn(vfs)
            cmps.append((pred, y))
        
        probs_by_take = defaultdict(list)
        for logits, y in cmps:
            assert len(y["take_uid"]) == 1
            probs_by_take[y["take_uid"][0]].append((logits.mean(1), y))
        
        accs = [0 for x in topk]
        n = 0
        for take_uid, prob_labels in probs_by_take.items():
            _, y = prob_labels[0]
            pred_targs = [
                (p, py["parent_task_id"])
                for (p, py) in prob_labels
                if cam_id_filter_fn is None or cam_id_filter_fn(py["take_uid"], py["cam_id"][0])
            ]
            if len(pred_targs) == 0:
                continue
            
            if all_cams:
                pred = torch.stack([x for x, _ in pred_targs]).mean(0) 
                target = y["parent_task_id"]
                pred_targs = [(pred, target)]
            
            for pred, target in pred_targs:
                for i, acc in enumerate(accuracy(pred, target.cuda(), topk=topk)):
                    if acc != 1:
                        incorrect.append((take_uid, pred.argmax().cpu().item(), target.cpu().item()))
                    accs[i] += acc
                n += pred.shape[0]

    accs = [x/n for x in accs]
    
    return {
        "accuracy_by_topk": {
            topk[i]: acc
            for i, acc in enumerate(accs)
        },
        "n": n,
        "incorrect": incorrect,
    }

# Feature Pre-Processing (for perf)

In [None]:
from ego4d.research.dataset import save_features_to_hdf5

In [None]:
features_dir = os.path.join(release_dir, "features/maws_clip_2b")
# features_dir = "/checkpoint/miguelmartin/egoexo_features/maws_clip_2b_public"
features_paths = [x for x in os.listdir(features_dir) if x != "config.yaml"]

features_by_take_cam = {}
for x in features_paths:
    take_uid, cam_id, stream_id_pt = x.split("_")
    stream_id = stream_id_pt = stream_id_pt.split(".")[0]
    if take_uid not in features_by_take_cam:
        features_by_take_cam[take_uid] = {}
    key = (cam_id, stream_id)
    features_by_take_cam[take_uid][key] = os.path.join(features_dir, x)

features_paths[0]

In [None]:
out_path = "/checkpoint/miguelmartin/egoexo_features/maws_clip_2b_public.hdf5"
video_uids = [x.split(".")[0] for x in features_paths]
feature_hdf5_path = out_path
# NOTE: this will take ~50minutes
# save_features_to_hdf5(video_uids, feature_dir=features_dir, out_path=out_path)

In [None]:
model = torch.hub.load("facebookresearch/maws", model="vit_2b14_xlmr_l_maws_clip")
model = model.eval().half()
model = model.to("cuda")

# Zero-Shot on Task Labels

In [None]:
train_takes = set(splits["train"]) & set(features_by_take_cam.keys())
val_takes = set(splits["val"]) & set(features_by_take_cam.keys())

In [None]:
list(get_cam_stream_ids(egoexo["takes"][0]))

In [None]:
val_data = get_data_for(val_takes)
train_data = get_data_for(train_takes)

In [None]:
txt_labels = [
    "Phone with a QR code",
    "A person is cooking",
    "A person is performing Health related activities such as a COVID-19 test or CPR",
    "A person is at a campsite",
    "A person is performing repair on a bike",
    "A person is playing a musical instrument",
    "A person is playing basketball",
    "A person is rock climbing",
    "A person is playing soccer",
    "A person is dancing",
]
txt_emb = model.encode_texts(texts=txt_labels)

In [None]:
# example zero-shot inference
feature_path = "000a19fe-776e-4c88-b0c3-2fad016a6025_aria01_rgb.pt"
take_uid = feature_path.split("_")[0]
gt_label = takes_by_uid[take_uid]["parent_task_name"]
feature_path = os.path.join(features_dir, feature_path)
xs = torch.load(feature_path)
txt_emb = model.encode_texts(texts=txt_labels)
logits = model.classify(text_features=txt_emb, image_features=xs.cuda())
pred_class = logits.mean(0).argmax()
pred_txt = txt_labels[pred_class]

print(f"""
Take: {take_uid}
Prediction: {pred_txt}
GT: {gt_label}
""")

In [None]:
val_dset = LabelledFeatureDset(feature_hdf5_path, val_data)
val_dloader = DataLoader(val_dset, batch_size=1, shuffle=False)

train_dset = LabelledFeatureDset(feature_hdf5_path, train_data)
train_dloader = DataLoader(train_dset, batch_size=1, shuffle=False)

In [None]:
model_fn = lambda x: model.classify(text_features=txt_emb, image_features=x)
classifier = lambda x: x

In [None]:
eval_classification(val_dloader, topk=(1, 2, 3, 5), all_cams=False, cam_id_filter_fn=None)["accuracy_by_topk"]

In [None]:
eval_classification(val_dloader, topk=(1, 2, 3, 5), all_cams=True, cam_id_filter_fn=None)["accuracy_by_topk"]

In [None]:
eval_classification(val_dloader, topk=(1, 2, 3, 5), all_cams=True, cam_id_filter_fn=ego_cam_id_filter)["accuracy_by_topk"]