In [1]:
import os
import sys
import shutil
os.environ['TOKENIZERS_PARALLELISM'] = "False"
from glob import glob

import torch
import pandas as pd
import numpy as np
import json
from torch.nn.functional import cosine_similarity
from utils.video import read_frames_decord
from IPython.display import display, Markdown, Latex

import shared.utils as su
from notebooks.eval_care_retrieval import load_model

In [2]:
data_dir = "/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap/"

meta_dir = f"{data_dir}/egocvr"
video_dir = f"{meta_dir}/clips"

df_base = pd.read_csv(f"{meta_dir}/egocvr_data.csv")
df_anno = pd.read_csv(f"{meta_dir}/egocvr_annotations.csv")

len(df_base), len(df_anno)

(12526, 2295)

In [3]:
from tasks.eval_egocvr import *

USER: <video>
Edit instruction: <sent>
Imagine the given text edit instruction applied on the given video.
Summarize the resulting video in one word: ASSISTANT: 


In [4]:
# Load model
n_frames = 8
model_id = "/work/piyush/experiments/CaRe/Tarsier-7b/nli-9k+ego4d-1k/merged_checkpoint"
encoder = AutoEncoder.from_pretrained(model_id, device_map='auto')
su.misc.num_params(encoder.model)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading EncoderForTarsier from /work/piyush/experiments/CaRe/Tarsier-7b/nli-9k+ego4d-1k/merged_checkpoint
### do_image_padding is set as False, images will be resized directly!


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
TarsierForConditionalGeneration has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

::: Number of total parameters in TarsierForConditionalGeneration: 7063.427M


In [5]:
# Compute candidate embeddings
candidate_clip_ids = []
for x in df_anno.target_clip_ids.tolist():
    candidate_clip_ids.extend(eval(x))
candidate_clip_ids = np.unique(candidate_clip_ids)
print('Number of candidate videos: ', len(candidate_clip_ids))

candidates = {}
for c in su.log.tqdm_iterator(candidate_clip_ids, desc='Computing candidate embeddings'):
    video_path = f"{video_dir}/{c}.mp4"
    if not os.path.exists(video_path):
        print(f"Target video does not exist: {c}. Skipping.")
        continue
    else:
        try:
            video_tensor = read_frames_decord(video_path, n_frames)
        except:
            print(f"Error reading video: {c}. ")
            video_tensor = torch.randn(n_frames, 3, 270, 480)
        with torch.no_grad():
            zv = encoder.encode_vision(video_tensor.unsqueeze(0)).cpu().squeeze(0).float()
            zv = torch.nn.functional.normalize(zv, dim=-1)
        candidates[c] = zv
len(candidates)

Number of candidate videos:  1787


Computing candidate embeddings:   0%|          | 0/1787 [00:00<?, ?it/s]

Expanding inputs for image tokens in LLaVa should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Target video does not exist: 041b9423-9695-45cb-bc58-266af7f91039_90_99. Skipping.
Target video does not exist: 099f6f96-5aa7-4da8-a5e0-2e8bc03beee6_1308_1313. Skipping.
Target video does not exist: 27de8eef-27b5-431e-9c4b-668001a4d37c_212_220. Skipping.
Target video does not exist: 2dd04abb-859f-4121-9eeb-4db511b54c7b_171_177. Skipping.
Target video does not exist: 3d8f5230-0c22-4018-86b0-e6d851b74b11_1090_1098. Skipping.
Target video does not exist: 473d22ab-824e-4d9a-8bc0-924d499694d0_10_19. Skipping.
Target video does not exist: 48ca7acc-2b9c-46a9-ab11-be03bd1324d9_1063_1069. Skipping.
Target video does not exist: 6b7263dd-0354-468b-a485-0471ace84f15_348_356. Skipping.
Target video does not exist: 7d2060fa-501e-46db-9d7a-aa0c634889c4_96_104. Skipping.
Target video does not exist: 86315f9b-d779-4f7e-925a-bb6e9bd46f5a_67_73. Skipping.
Target video does not exist: 8bcd5b11-283d-41f9-9b34-8f70b75c2de9_24_30. Skipping.
Target video does not exist: 92f8142a-25aa-444a-ae37-43fae4f95f18_88

[h264 @ 0x702aea80] deblocking_filter_idc 3 out of range
[h264 @ 0x702aea80] decode_slice_header error
[h264 @ 0x702aea80] no frame!


Error reading video: e8283c8f-ecc0-448a-84d2-9f16160d7a4b_521_530. 


[h264 @ 0x71561000] Invalid NAL unit size (13234 > 78).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (-788362006 > 13347).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (-2006049028 > 2666).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (182849744 > 2245).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (-1969178081 > 13).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (-1301277176 > 13277).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (1302312752 > 3208).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264 @ 0x71561000] Invalid NAL unit size (-924193714 > 1433).
[h264 @ 0x71561000] Error splitting the input into NAL units.
[h264

Error reading video: ed432ff3-54f7-4fb9-9a67-80b4d43a076c_297_306. 


[h264 @ 0x49079d80] Invalid NAL unit size (-700343720 > 1536).
[h264 @ 0x49079d80] Error splitting the input into NAL units.


Error reading video: f12a83a9-77ee-495c-9f11-2d41137a0b1a_1706_1715. 


[h264 @ 0x702aea80] Invalid NAL unit size (2375 > 6).
[h264 @ 0x702aea80] Error splitting the input into NAL units.


Error reading video: f202143d-1055-497a-9eb5-bedd7c4ae361_2419_2428. 


[h264 @ 0x4923b980] Invalid NAL unit size (555421389 > 5).
[h264 @ 0x4923b980] Error splitting the input into NAL units.


Error reading video: f33dd5b7-ac69-4082-b152-a2db14b3a8cf_3853_3862. 


[h264 @ 0x7055ecc0] Invalid NAL unit size (14355 > 14310).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[h264 @ 0x7055ecc0] Invalid NAL unit size (-500573450 > 5459).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[aac @ 0x7019a240] Reserved bit set.
[aac @ 0x7019a240] Prediction is not allowed in AAC-LC.
[h264 @ 0x7055ecc0] Invalid NAL unit size (-409106368 > 1261).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[h264 @ 0x7055ecc0] Invalid NAL unit size (-836832399 > 43).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[h264 @ 0x7055ecc0] Invalid NAL unit size (-335300106 > 698).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[h264 @ 0x7055ecc0] Invalid NAL unit size (1795842696 > 5055).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[h264 @ 0x7055ecc0] Invalid NAL unit size (878910634 > 1459).
[h264 @ 0x7055ecc0] Error splitting the input into NAL units.
[h264 @ 0x7055ecc0] Invalid NAL unit siz

Error reading video: f4251612-a623-40a0-9fc4-2016c4aa607d_641_650. 
Target video does not exist: f79614b9-8c45-40f4-b1f3-a9f44428498e_81_89. Skipping.
Target video does not exist: fbbc72ed-e55f-4c01-998c-4c573136614d_696_703. Skipping.


1768

In [6]:
# Gather query embeddings
queries = {}
for i in su.log.tqdm_iterator(range(len(df_anno)), desc="Compute query embeddings"):
    row = df_anno.iloc[i].to_dict()
    video_path = f"{video_dir}/{row['video_clip_id']}.mp4"
    if not os.path.exists(video_path):
        print(f"Query video does not exist: {i}. Skipping.")
        continue
    edit_text = row['instruction']
    with torch.no_grad():
        zv = embed_video_text(encoder, video_path, edit_text, n_frames=8)
        zv = torch.nn.functional.normalize(zv, dim=-1)
        zv = zv.cpu().float()
    key = f"{edit_text}|{row['video_clip_id']}"
    queries[key] = zv
len(queries)

Compute query embeddings:   0%|          | 0/2295 [00:00<?, ?it/s]

Query video does not exist: 67. Skipping.
Query video does not exist: 136. Skipping.
Query video does not exist: 246. Skipping.
Query video does not exist: 810. Skipping.


[h264 @ 0x7528c5c0] deblocking_filter_idc 3 out of range
[h264 @ 0x7528c5c0] decode_slice_header error
[h264 @ 0x7528c5c0] no frame!


Error reading video: /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/clips/e8283c8f-ecc0-448a-84d2-9f16160d7a4b_521_530.mp4. Returning random noise.
Query video does not exist: 865. Skipping.


[h264 @ 0x7593bc00] Invalid NAL unit size (14355 > 14310).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-500573450 > 5459).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[aac @ 0x757179c0] Reserved bit set.
[aac @ 0x757179c0] Prediction is not allowed in AAC-LC.
[h264 @ 0x7593bc00] Invalid NAL unit size (-409106368 > 1261).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-836832399 > 43).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-335300106 > 698).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (1795842696 > 5055).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (878910634 > 1459).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit siz

Error reading video: /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/clips/f4251612-a623-40a0-9fc4-2016c4aa607d_641_650.mp4. Returning random noise.
Query video does not exist: 940. Skipping.
Query video does not exist: 1009. Skipping.
Query video does not exist: 1269. Skipping.
Query video does not exist: 1273. Skipping.
Query video does not exist: 1363. Skipping.
Query video does not exist: 1441. Skipping.


[h264 @ 0x76627dc0] Invalid NAL unit size (555421389 > 5).
[h264 @ 0x76627dc0] Error splitting the input into NAL units.


Error reading video: /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/clips/f33dd5b7-ac69-4082-b152-a2db14b3a8cf_3853_3862.mp4. Returning random noise.


[h264 @ 0x7593bc00] Invalid NAL unit size (13234 > 78).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-788362006 > 13347).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-2006049028 > 2666).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (182849744 > 2245).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-1969178081 > 13).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-1301277176 > 13277).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (1302312752 > 3208).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264 @ 0x7593bc00] Invalid NAL unit size (-924193714 > 1433).
[h264 @ 0x7593bc00] Error splitting the input into NAL units.
[h264

Error reading video: /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/clips/ed432ff3-54f7-4fb9-9a67-80b4d43a076c_297_306.mp4. Returning random noise.
Query video does not exist: 1574. Skipping.
Query video does not exist: 1748. Skipping.
Query video does not exist: 1777. Skipping.
Query video does not exist: 1781. Skipping.


[h264 @ 0x76cb8c00] Invalid NAL unit size (-700343720 > 1536).
[h264 @ 0x76cb8c00] Error splitting the input into NAL units.


Error reading video: /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/clips/f12a83a9-77ee-495c-9f11-2d41137a0b1a_1706_1715.mp4. Returning random noise.
Query video does not exist: 1870. Skipping.


[h264 @ 0x75821740] Invalid NAL unit size (2375 > 6).
[h264 @ 0x75821740] Error splitting the input into NAL units.


Error reading video: /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/clips/f202143d-1055-497a-9eb5-bedd7c4ae361_2419_2428.mp4. Returning random noise.
Query video does not exist: 2224. Skipping.


2273

In [8]:
len(candidates), len(queries)

(1768, 2273)

In [15]:
qid_to_cids = {}
df_anno[['video_clip_id', 'instruction', 'target_clip_ids']].apply(
    lambda x: qid_to_cids.update({f"{x[1]}|{x[0]}": eval(x[2])}), axis=1,
)
len(qid_to_cids)

2290

In [20]:
qid_to_cids['Shake it.|d1d1b6da-e7f8-48e7-9ee4-d8382582695a_971_980'], \
qid_to_cids['Pick up the pack.|fd0e10cd-80ea-4cd1-babe-e6570601b00b_112_121']

(['d1d1b6da-e7f8-48e7-9ee4-d8382582695a_897_906'],
 ['fd0e10cd-80ea-4cd1-babe-e6570601b00b_66_75'])

In [46]:
cid_to_qids = defaultdict(list)
df_anno[['video_clip_id', 'instruction', 'target_clip_ids']].apply(
    lambda x: [cid_to_qids[y].append(f"{x[1]}|{x[0]}") for y in eval(x[2])], axis=1
)
len(cid_to_qids)

1787

In [50]:
# for c in cid_to_qids:
#     if len(cid_to_qids[c]) > 1:
#         print(c, cid_to_qids[c])
#         break

In [62]:
from collections import defaultdict

qids = list(queries.keys())
cids = [x.item() for x in list(candidates.keys())]
q2c = defaultdict(list)
c2q = defaultdict(list)
zq = torch.stack([queries[x] for x in qids])
zc = torch.stack([candidates[x] for x in cids])
print(zq.shape, zc.shape)

scores_q2c = zq @ zc.T
scores_c2q = zc @ zq.T

idx_q2c = defaultdict(list)
for i in range(len(zq)):
    qid = qids[i]
    matched_cids = qid_to_cids[qid]
    matched_cids_indices = [np.where(np.array(cids) == x)[0] for x in matched_cids]
    idx_q2c[i].extend(list(matched_cids_indices))
idx_q2c = {i:list(np.concatenate(idx_q2c[i])) for i in idx_q2c}

idx_c2q = defaultdict(list)
for j in range(len(zc)):
    cid = cids[j]
    matched_qids = cid_to_qids[cid]
    matched_qids_indices = [np.where(np.array(qids) == x)[0] for x in matched_qids]
    idx_c2q[j].extend(list(matched_qids_indices))
idx_c2q = {j:list(np.concatenate(idx_c2q[j])) for j in idx_c2q}

len(idx_q2c), len(idx_c2q)

torch.Size([2273, 4096]) torch.Size([1768, 4096])


(2273, 1768)

In [63]:
index = 44
index = np.random.randint(len(zq))
print("Sample query index: ", index)
print("Sample query: ", qids[index])
print('-' * 60)
print("Matching candidate clip IDs (from index dict) \n", np.array(cids)[idx_q2c[index]])
print('-' * 60)
print("Matching candidate clip IDs (from ground truth) \n", np.array(qid_to_cids[qids[index]]))

print()

Sample query index:  912
Sample query:  Scoop paint with it.|ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_83_92
------------------------------------------------------------
Matching candidate clip IDs (from index dict) 
 ['ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_61_70'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_108_117'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_23_32'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_73_82'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_43_52']
------------------------------------------------------------
Matching candidate clip IDs (from ground truth) 
 ['ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_61_70'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_108_117'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_23_32'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_73_82'
 'ac4d8dd6-80b9-44e4-8d67-bcd2662bcd7b_43_52']



In [64]:
from utils.general_retrieval_metrics import itm_eval
metrics = itm_eval(
    scores_i2t=scores_q2c.numpy(),
    img2txt=idx_q2c,
    scores_t2i=scores_c2q.numpy(),
    txt2img=idx_c2q,
    add_50=True,
)
metrics

{'txt_r1': np.float64(5.807303123625165),
 'txt_r5': np.float64(77.12274527056753),
 'txt_r10': np.float64(87.59348878134624),
 'txt_r_mean': np.float64(56.84117905851298),
 'img_r1': np.float64(5.599547511312217),
 'img_r5': np.float64(75.1131221719457),
 'img_r10': np.float64(85.35067873303167),
 'img_r_mean': np.float64(55.35444947209654),
 'r_mean': np.float64(56.09781426530476),
 'txt_r50': np.float64(96.56841179058513),
 'img_r50': np.float64(95.6447963800905)}

In [25]:
cids[0]

'002c3b5c-ed86-4af3-99a1-4b497b7c8a86_2396_2405'

In [76]:
idx_q2c_local = defaultdict(list)
for i in range(len(zq)):
    qid = qids[i]
    video_id_query = qid.split('|')[-1].split('_')[0]
    matched_cids = qid_to_cids[qid]
    for cid in matched_cids:
        video_id_candi = cid.split("_")[0]
        if video_id_candi == video_id_query:
            idx = np.where(np.array(cids) == cid)[0]
            idx_q2c_local[i].append(idx)
idx_q2c_local = {j:list(np.concatenate(idx_q2c_local[j])) for j in idx_q2c_local}


idx_c2q_local = defaultdict(list)
for j in range(len(zc)):
    cid = cids[j]
    video_id_candi = cid.split("_")[0]
    
    matched_qids = cid_to_qids[cid]
    for qid in matched_qids:
        video_id_query = qid.split('|')[-1].split('_')[0]
        if video_id_query == video_id_candi:
            idx = np.where(np.array(qids) == qid)[0]
            idx_c2q_local[j].append(idx)
idx_c2q_local = {j:list(np.concatenate(idx_c2q_local[j])) for j in idx_c2q_local}

len(idx_q2c_local), len(idx_c2q_local)

(2273, 1768)

In [77]:
metrics = itm_eval(
    scores_i2t=scores_q2c.numpy(),
    img2txt=idx_q2c_local,
    scores_t2i=scores_c2q.numpy(),
    txt2img=idx_c2q_local,
    add_50=True,
)
metrics

{'txt_r1': np.float64(5.807303123625165),
 'txt_r5': np.float64(77.12274527056753),
 'txt_r10': np.float64(87.59348878134624),
 'txt_r_mean': np.float64(56.84117905851298),
 'img_r1': np.float64(5.599547511312217),
 'img_r5': np.float64(75.1131221719457),
 'img_r10': np.float64(85.35067873303167),
 'img_r_mean': np.float64(55.35444947209654),
 'r_mean': np.float64(56.09781426530476),
 'txt_r50': np.float64(96.56841179058513),
 'img_r50': np.float64(95.6447963800905)}

### Test code

In [5]:
i = 262
row = df_anno.iloc[i].to_dict()

# Source video file
src_file = f"{video_dir}/{row['video_clip_id']}.mp4"

# Just pick a random target clips from potential candidates
target_clip_ids = eval(row['target_clip_ids'])
target_clip_id = np.random.choice(target_clip_ids)
dst_file = f"{video_dir}/{target_clip_id}.mp4"

assert os.path.exists(src_file)
assert os.path.exists(dst_file)

display(
    su.visualize.show_single_image_sequence(src_file, label=row['video_clip_narration'])
)
display(Markdown(f"**Edit instruction**: {row['instruction']}"))
display(
    su.visualize.show_single_image_sequence(dst_file, label=row['target_clip_narration'])
)

VBox(children=(HTML(value='#C C drops the plate of food on the sink slap.'), Output()))

**Edit instruction**: Pick it up.

VBox(children=(HTML(value='#C C picks a plate from the sink slap.'), Output()))

### Load model

In [6]:
from models.modeling_encoders import AutoEncoder

model_id = "/work/piyush/experiments/CaRe/Tarsier-7b/nli-9k+ego4d-1k/merged_checkpoint"
encoder = AutoEncoder.from_pretrained(model_id, device_map='auto')
su.misc.num_params(encoder.model)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading EncoderForTarsier from /work/piyush/experiments/CaRe/Tarsier-7b/nli-9k+ego4d-1k/merged_checkpoint
### do_image_padding is set as False, images will be resized directly!


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

::: Number of total parameters in TarsierForConditionalGeneration: 7063.427M


**Compute embeddings for all candidate clips**

In [8]:
candidate_clip_ids = []
for x in df_anno.target_clip_ids.tolist():
    candidate_clip_ids.extend(eval(x))
candidate_clip_ids = np.unique(candidate_clip_ids)
len(candidate_clip_ids)

1787

In [9]:
n_frames = 8

candidates = {}
for c in su.log.tqdm_iterator(candidate_clip_ids, desc='Computing candidate embeddings'):
    video_path = f"{video_dir}/{c}.mp4"
    if not os.path.exists(video_path):
        print(f"Target video does not exist: {i}. Skipping.")
        continue
    else:
        video_tensor = read_frames_decord(video_path, n_frames)
        with torch.no_grad():
            zv = encoder.encode_vision(video_tensor.unsqueeze(0)).cpu().squeeze(0).float()
            zv = torch.nn.functional.normalize(zv, dim=-1)
        candidates[c] = zv
len(candidates)

Computing candidate embeddings:   0%|          | 0/1787 [00:00<?, ?it/s]

Expanding inputs for image tokens in LLaVa should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


**Compute embeddings for all composite queries**

In [13]:
from utils.video import read_frames_decord
from utils.model import transform_pixel_values
from torchvision.transforms.v2 import (
    ToPILImage,
)

PROMPT = "<video>\nEdit instruction: <sent>\n"\
         "Imagine the given text edit instruction applied on the given video.\n"\
         "Summarize the resulting video in one word:"
PROMPT = f"USER: {PROMPT} ASSISTANT: "
print(PROMPT)


def embed_video_text(encoder, video_path, edit_text, n_frames=8, verbose=False):
    generate_kwargs = {
        "max_new_tokens": 1,
        "output_hidden_states": True,
        "return_dict_in_generate": True,
    }

    # Prepare video
    pixel_values = read_frames_decord(video_path, n_frames).unsqueeze(0)
    pixel_values = transform_pixel_values(pixel_values)
    nframes = pixel_values.shape[1]
    to_image = ToPILImage()
    batched_frames = []
    for batch in pixel_values:
        frames = [to_image(v) for v in batch]
        batched_frames.append(frames)

    for frames in batched_frames:

        # Video
        input_prompt = PROMPT.replace("<video>", "<image>"*len(frames))

        # Text
        input_prompt = input_prompt.replace('<sent>', edit_text)

        if verbose:
            print(input_prompt)

        input_ids = encoder.processor.get_text_inputs(input_prompt)
        frames = encoder.processor.get_pixel_values(frames)
        inputs = {
            "input_ids": input_ids,
            "pixel_values": frames
        }
        inputs = {k:v.to(encoder.model.device) for k,v in inputs.items() if v is not None}
        outputs = encoder.model.generate(
            **inputs,
            **generate_kwargs,
        )
        zv = outputs.hidden_states[0][-1][:, -1, :]
        break # Safe to break since it is just one video

    if verbose:
        print(zv.shape)

    return zv.squeeze(0)

USER: <video>
Edit instruction: <sent>
Imagine the given text edit instruction applied on the given video.
Summarize the resulting video in one word: ASSISTANT: 


In [15]:
# Test on a sample
i = 262
row = df_anno.iloc[i].to_dict()
video_path = f"{video_dir}/{row['video_clip_id']}.mp4"
edit_text = row['instruction']
with torch.no_grad():
    zv = embed_video_text(encoder, video_path, edit_text, n_frames=8, verbose=True)
    zv = torch.nn.functional.normalize(zv, dim=-1)
    zv = zv.cpu().float()
zv.shape

USER: <image><image><image><image><image><image><image><image>
Edit instruction: Pick it up.
Imagine the given text edit instruction applied on the given video.
Summarize the resulting video in one word: ASSISTANT: 
torch.Size([1, 4096])


torch.Size([4096])

In [17]:
# Gather query embeddings
queries = {}
for i in su.log.tqdm_iterator(range(len(df_anno)), desc="Compute query embeddings"):
    row = df_anno.iloc[i].to_dict()
    video_path = f"{video_dir}/{row['video_clip_id']}.mp4"
    if not os.path.exists(video_path):
        print(f"Query video does not exist: {i}. Skipping.")
        continue
    edit_text = row['instruction']
    with torch.no_grad():
        zv = embed_video_text(encoder, video_path, edit_text, n_frames=8)
        zv = torch.nn.functional.normalize(zv, dim=-1)
        zv = zv.cpu().float()
    key = f"{edit_text}|{row['video_clip_id']}"
    queries[key] = zv
len(queries)

Compute query embeddings:   0%|          | 0/2295 [00:00<?, ?it/s]

Query video does not exist: 0. Skipping.
Query video does not exist: 1. Skipping.
Query video does not exist: 2. Skipping.
Query video does not exist: 3. Skipping.
Query video does not exist: 7. Skipping.
Query video does not exist: 8. Skipping.
Query video does not exist: 9. Skipping.
Query video does not exist: 10. Skipping.
Query video does not exist: 11. Skipping.
Query video does not exist: 12. Skipping.
Query video does not exist: 13. Skipping.
Query video does not exist: 14. Skipping.
Query video does not exist: 15. Skipping.
Query video does not exist: 16. Skipping.
Query video does not exist: 17. Skipping.
Query video does not exist: 18. Skipping.
Query video does not exist: 19. Skipping.
Query video does not exist: 20. Skipping.
Query video does not exist: 21. Skipping.
Query video does not exist: 22. Skipping.
Query video does not exist: 23. Skipping.
Query video does not exist: 24. Skipping.
Query video does not exist: 25. Skipping.
Query video does not exist: 26. Skipping.

KeyboardInterrupt: 

In [11]:
row

{'video_clip_id': '0c0ec306-2554-42fe-a744-cd1fec78f689_21_30',
 'target_clip_ids': "['0c0ec306-2554-42fe-a744-cd1fec78f689_0_5']",
 'video_clip_narration': '#C C drops the plate of food on the sink slap.',
 'target_clip_narration': '#C C picks a plate from the sink slap.',
 'instruction': 'Pick it up.',
 'modified_captions': ' #C C picks up the bowl of food.'}

### Processing Ego4D clips (one-time)

In [9]:
clips = glob(f"{data_dir}/cut_full_scale/*")
len(clips)

896666

In [10]:
clips[0]

'/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//cut_full_scale/4cb9c654-bab3-4861-b6bb-64c890f1af0d_5510.7_5513.7.mp4'

In [15]:
clip_ids = list(set([os.path.basename(x).split(".mp4")[0] for x in clips]))
len(clip_ids)

896663

In [16]:
clip_ids[:5]

['609d9772-daa2-45e9-b07a-1ffdadb942b4_467.2_468.2',
 '225a1ffe-3e7e-4ff8-a47e-2e80989077fe_413.0_413.9',
 '09bfaba7-450e-4141-9fe9-7db3a6b2c604_1235.3_1236.4',
 'c5851c24-4ae3-4430-b867-f43800f0f56e_3905.7_3906.9',
 '960b820b-1807-4219-8e74-94c4d00abc41_1861.3_1862.9']

In [17]:
from typing import List, Optional, Tuple
from collections import defaultdict

class ClipMatcher:
    """Efficiently find matching clip IDs for given file IDs."""
    
    def __init__(self, clip_ids: List[str]):
        """
        Initialize the matcher with a list of clip IDs.
        Builds an index for efficient lookup.
        
        Args:
            clip_ids: List of clip ID strings in format 'video_id_start_end'
        """
        self.clip_ids = clip_ids
        self.index = self._build_index(clip_ids)
    
    def _parse_id(self, id_string: str) -> Tuple[str, float, float]:
        """
        Parse an ID string into video_id, start_time, and end_time.
        
        Args:
            id_string: ID string in format 'video_id_start_end'
            
        Returns:
            Tuple of (video_id, start_time, end_time)
        """
        parts = id_string.rsplit('_', 2)
        video_id = parts[0]
        start_time = float(parts[1])
        end_time = float(parts[2])
        return video_id, start_time, end_time
    
    def _build_index(self, clip_ids: List[str]) -> dict:
        """
        Build an index mapping video_id to list of (clip_id, start, end).
        
        Args:
            clip_ids: List of clip ID strings
            
        Returns:
            Dictionary mapping video_id to list of tuples (clip_id, start, end)
        """
        index = defaultdict(list)
        for clip_id in clip_ids:
            video_id, start, end = self._parse_id(clip_id)
            index[video_id].append((clip_id, start, end))
        
        # Sort clips by start time for each video
        for video_id in index:
            index[video_id].sort(key=lambda x: x[1])
        
        return index
    
    def _calculate_overlap(self, start1: float, end1: float, 
                          start2: float, end2: float) -> float:
        """
        Calculate the overlap duration between two time ranges.
        
        Args:
            start1, end1: First time range
            start2, end2: Second time range
            
        Returns:
            Overlap duration in seconds
        """
        overlap_start = max(start1, start2)
        overlap_end = min(end1, end2)
        return max(0, overlap_end - overlap_start)
    
    def find_best_match(self, file_id: str) -> Optional[str]:
        """
        Find the best matching clip ID for a given file ID.
        
        The best match is the clip with maximum temporal overlap.
        
        Args:
            file_id: File ID string in format 'video_id_start_end'
            
        Returns:
            Best matching clip ID, or None if no matching video found
        """
        video_id, file_start, file_end = self._parse_id(file_id)
        
        # Check if video exists in index
        if video_id not in self.index:
            return None
        
        clips = self.index[video_id]
        
        best_clip = None
        max_overlap = 0
        
        for clip_id, clip_start, clip_end in clips:
            overlap = self._calculate_overlap(file_start, file_end, 
                                              clip_start, clip_end)
            
            if overlap > max_overlap:
                max_overlap = overlap
                best_clip = clip_id
        
        return best_clip if max_overlap > 0 else None
    
    def find_all_overlapping(self, file_id: str, 
                            min_overlap: float = 0) -> List[Tuple[str, float]]:
        """
        Find all clips that overlap with the given file ID.
        
        Args:
            file_id: File ID string in format 'video_id_start_end'
            min_overlap: Minimum overlap duration to include (default: 0)
            
        Returns:
            List of tuples (clip_id, overlap_duration) sorted by overlap desc
        """
        video_id, file_start, file_end = self._parse_id(file_id)
        
        if video_id not in self.index:
            return []
        
        clips = self.index[video_id]
        overlapping = []
        
        for clip_id, clip_start, clip_end in clips:
            overlap = self._calculate_overlap(file_start, file_end, 
                                              clip_start, clip_end)
            
            if overlap > min_overlap:
                overlapping.append((clip_id, overlap))
        
        # Sort by overlap duration (descending)
        overlapping.sort(key=lambda x: x[1], reverse=True)
        
        return overlapping


# Example usage
if __name__ == "__main__":
    # Sample clip IDs
    clip_ids = [
        '609d9772-daa2-45e9-b07a-1ffdadb942b4_467.2_468.2',
        '225a1ffe-3e7e-4ff8-a47e-2e80989077fe_413.0_413.9',
        '09bfaba7-450e-4141-9fe9-7db3a6b2c604_1235.3_1236.4',
        'c5851c24-4ae3-4430-b867-f43800f0f56e_3905.7_3906.9',
        '960b820b-1807-4219-8e74-94c4d00abc41_1861.3_1862.9',
        '609d9772-daa2-45e9-b07a-1ffdadb942b4_467.5_469.0',  # Another clip from same video
    ]
    
    # Initialize matcher
    matcher = ClipMatcher(clip_ids)
    
    # Test file ID
    file_id = '609d9772-daa2-45e9-b07a-1ffdadb942b4_467_469'
    
    # Find best match
    best_match = matcher.find_best_match(file_id)
    print(f"File ID: {file_id}")
    print(f"Best match: {best_match}")
    
    # Find all overlapping clips
    all_matches = matcher.find_all_overlapping(file_id)
    print(f"\nAll overlapping clips:")
    for clip_id, overlap in all_matches:
        print(f"  {clip_id}: {overlap:.2f}s overlap")

File ID: 609d9772-daa2-45e9-b07a-1ffdadb942b4_467_469
Best match: 609d9772-daa2-45e9-b07a-1ffdadb942b4_467.5_469.0

All overlapping clips:
  609d9772-daa2-45e9-b07a-1ffdadb942b4_467.5_469.0: 1.50s overlap
  609d9772-daa2-45e9-b07a-1ffdadb942b4_467.2_468.2: 1.00s overlap


In [18]:
clip_ids = list(set([os.path.basename(x).split(".mp4")[0] for x in clips]))
len(clip_ids)

896663

In [19]:
# Initialize matcher
matcher = ClipMatcher(clip_ids)

In [21]:
# Test on a sample
i = 0
row = df_anno.iloc[i].to_dict()
matcher.find_best_match(row['video_clip_id'])

'd1d1b6da-e7f8-48e7-9ee4-d8382582695a_978.8_980.0'

In [25]:
eval(row['target_clip_ids'])

['d1d1b6da-e7f8-48e7-9ee4-d8382582695a_897_906']

In [26]:
df_anno_clean = []
for i in su.log.tqdm_iterator(range(len(df_anno)), desc='Matching'):
    row = df_anno.iloc[i].to_dict()
    video_clip_id = matcher.find_best_match(row['video_clip_id'])
    target_clip_ids = [matcher.find_best_match(x) for x in eval(row['target_clip_ids'])]
    row['video_clip_id'] = video_clip_id
    row['target_clip_ids'] = target_clip_ids
    df_anno_clean.append(row)
df_anno_clean = pd.DataFrame(df_anno_clean)
df_anno_clean.shape

Matching:   0%|          | 0/2295 [00:00<?, ?it/s]

(2295, 6)

In [49]:
all_anno_clip_ids = []
for i in range(len(df_anno)):
    row = df_anno.iloc[i].to_dict()
    all_anno_clip_ids.append(row['video_clip_id'])
    all_anno_clip_ids.extend(eval(row['target_clip_ids']))
all_anno_clip_ids = np.unique(all_anno_clip_ids)
len(all_anno_clip_ids)

1975

In [71]:
all_anno_clip_ids[0]

np.str_('002c3b5c-ed86-4af3-99a1-4b497b7c8a86_2396_2405')

In [51]:
su.io.save_txt(list(all_anno_clip_ids), f'{meta_dir}/all_anno_clip_ids.txt')
!du -sh $meta_dir/*

87K	/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/all_anno_clip_ids.txt
552K	/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/egocvr_annotations.csv
1.6M	/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//egocvr/egocvr_data.csv


In [52]:
all_anno_video_ids = np.unique([x.split("_")[0] for x in all_anno_clip_ids])
len(all_anno_video_ids)

624

In [65]:
video_files = os.listdir('/scratch/shared/beegfs/shared-datasets/EGO4D/ego4d_data_v1/full_scale/')
video_files = [x.split(".mp4")[0] for x in video_files]
len(video_files)

9647

In [66]:
invalid_vids = []
for c in all_anno_video_ids:
    if c.item() in video_files:
        continue
    else:
        invalid_vids.append(c)
len(invalid_vids)

0

In [70]:
import decord
vp = f"/scratch/shared/beegfs/shared-datasets/EGO4D/ego4d_data_v1/full_scale/{video_files[0]}.mp4"
assert os.path.exists(vp)
vr = decord.VideoReader(vp)
len(vr), vr[0].shape

(216515, torch.Size([1080, 1440, 3]))

In [75]:
!ls /scratch/shared/beegfs/piyush/datasets/Ego4D-HCap/egocvr/all_anno_clip_ids.txt

/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap/egocvr/all_anno_clip_ids.txt


In [76]:
import PIL, PIL.Image
# PIL.Image.fromarray(np.asarray(vr[0]))

In [72]:
1080/360

3.0

In [64]:
invalid_vids[0]

np.str_('002c3b5c-ed86-4af3-99a1-4b497b7c8a86')

In [36]:
# Visualise an example
i = 0
row = df_anno_clean.iloc[i].to_dict()

# Source video file
src_file = f"{data_dir}/cut_full_scale/{row['video_clip_id']}.mp4"
assert os.path.exists(src_file)

# Just pick a random target clips from potential candidates
target_clip_ids = row['target_clip_ids']
target_clip_id = np.random.choice(target_clip_ids)
dst_file = f"{data_dir}/cut_full_scale/{target_clip_id}.mp4"
assert os.path.exists(dst_file)

AssertionError: 

In [37]:
dst_file

'/scratch/shared/beegfs/piyush/datasets/Ego4D-HCap//cut_full_scale/None.mp4'

In [11]:
video_ids = [os.path.basename(x).split('_')[0] for x in clips]

In [12]:
'd1d1b6da-e7f8-48e7-9ee4-d8382582695a' in set(video_ids)

True