In [15]:
import os
import sys
import shutil
os.environ['TOKENIZERS_PARALLELISM'] = "False"

import torch
import pandas as pd
import numpy as np
import decord
import json
from torch.nn.functional import cosine_similarity
from utils.video import read_frames_decord
from IPython.display import display, Markdown, Latex
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"

import shared.utils as su
from notebooks.eval_care_retrieval import load_model
from utils.video import read_frames_decord
from utils.model import transform_pixel_values
from torchvision.transforms.v2 import (
    ToPILImage,
)

In [16]:
data_dir = "/scratch/shared/beegfs/piyush/datasets/WebVid-CoVR"
video_dir = '/datasets/WebVid/videos'

df = pd.read_csv(f"{data_dir}/webvid8m-covr_test-cleaned.csv")
print("Number of rows in CoVR-test: ", len(df))

df['video1_path'] = df['video1'].apply(lambda x: f"{video_dir}/{x}")
df['video2_path'] = df['video2'].apply(lambda x: f"{video_dir}/{x}")
df = df[df.video1_path.apply(os.path.exists) & df.video2_path.apply(os.path.exists)]
print("Number of rows with all videos available: ", df.shape)

d = set(df['video1_path']).intersection(set(df['video2_path']))

# Remove problematic videos
df = df[df.video1_path != '/datasets/WebVid/videos/108401_108450/6507308.mp4']

df.iloc[0].to_dict()

Number of rows in CoVR-test:  2556
Number of rows with all videos available:  (2556, 11)


{'txt1': 'Digital network, white lines and dots',
 'txt2': 'Digital network, green lines and dots',
 'sim_txt': 0.8944025,
 'pth1': '112/1016223889',
 'pth2': '112/1016223877',
 'edit': 'replace the white lines and dots with green',
 'scores': '[0.3524, 0.4087, 0.4136, 0.4111, 0.4137, 0.4084, 0.4062, 0.4054, 0.4025, 0.413, 0.4115, 0.4105, 0.4123, 0.4099, 0.4107]',
 'video1': '070651_070700/1016223889.mp4',
 'video2': '032701_032750/1016223877.mp4',
 'video1_path': '/datasets/WebVid/videos/070651_070700/1016223889.mp4',
 'video2_path': '/datasets/WebVid/videos/032701_032750/1016223877.mp4'}

In [3]:
query_embeddings_single_frame = su.io.load_pkl(
    f"{data_dir}/query_embeddings_single_frame.pkl"
)
len(query_embeddings_single_frame)

2555

In [17]:
query_embeddings_single_frame = su.io.load_pkl(
    f"{data_dir}/query_embeddings_single_frame_with_gen_caption.pkl"
)
len(query_embeddings_single_frame)

2555

In [18]:
candidate_embeddings = su.io.load_pkl(
    f"{data_dir}/gallery_embeddings-nframes_15.pkl"
)
len(candidate_embeddings)

2555

In [19]:
def gather_metrics(query_embeds, candidates):
    
    from utils.general_retrieval_metrics import itm_eval
    
    zq = []
    zc = []
    for i in range(len(df)):
        row = df.iloc[i].to_dict()
        query_key = f"{row['edit']}|{row['video1']}"
        candi_key = row['video2']
        if query_key not in query_embeds or candi_key not in candidates:
            print(f"Missing value for {i}. Skipped.")
            continue
        zq.append(query_embeds[query_key])
        zc.append(candidates[candi_key])
    zq = torch.stack(zq).numpy()
    zc = torch.stack(zc).numpy()
    print(zq.shape, zc.shape)
    
    # i:q and t:c; and we care about q2c metrics, i.e., i2t, i.e., text_*
    score_q2c = zq @ zc.T
    score_c2q = zc @ zq.T
    indices = {i:i for i in range(len(score_q2c))}
    metrics = itm_eval(scores_i2t=score_q2c, scores_t2i=score_c2q, txt2img=indices, img2txt=indices, add_50=True)

    metrics = {k: v for k, v in metrics.items() if 'txt' in k}
    return metrics


metrics_single_frame = gather_metrics(query_embeddings_single_frame, candidate_embeddings)
print(json.dumps(metrics_single_frame, indent=2))

(2555, 4096) (2555, 4096)
{
  "txt_r1": 49.19765166340509,
  "txt_r5": 74.12915851272015,
  "txt_r10": 82.15264187866927,
  "txt_r_mean": 68.4931506849315,
  "txt_r50": 95.22504892367905
}


#### BLIP2

In [10]:
def load_candidate_embeddings(feat_dir):
    embeds = {}
    video_ids = df.video2.tolist()
    paths = df.video2.apply(lambda x: f"{feat_dir}/{x.split('.mp4')[0]}.pth").tolist()
    for i in su.log.tqdm_iterator(range(len(video_ids))):
        embeds[video_ids[i]] = torch.load(paths[i])
    return embeds

candidate_embeddings_blip2 = load_candidate_embeddings(f"{data_dir}/blip2-vid-embs-large-all")
len(candidate_embeddings_blip2)

  0%|          | 0/2555 [00:00<?, ?it/s]

2555

In [12]:
candidate_embeddings_blip2['032701_032750/1016223877.mp4'].shape

torch.Size([15, 32, 256])