In [1]:
import os
os.chdir("/root/cvpr24_video_retrieval")

In [2]:
import sys
import yaml
import ndjson
import numpy as np
import torch

from momaapi import MOMA
from torchvision import transforms
from utils.attr_dict import AttrDict
from utils.main_utils import seed_everything
from models.frozen import FrozenInTime
from utils.clip_sampler import frozen_clip_sampler
from utils.frozen_utils import state_dict_data_parallel_fix
from tqdm import tqdm

In [3]:
with open("config/default.yaml", "r") as f:
    cfg = yaml.load(f, yaml.FullLoader)
    cfg = AttrDict(cfg)

if "SEED" in cfg:
    seed_everything(cfg.SEED)

torch.set_float32_matmul_precision("high")

In [4]:
model = FrozenInTime(cfg)

path = os.path.join(
    cfg.PATH.CKPT_PATH, "FROZEN", "cc-webvid2m-4f_stformer_b_16_224.pth.tar"
)
assert os.path.isfile(path)
sys.path.append("utils")
checkpoint = torch.load(path)
state_dict = checkpoint["state_dict"]
new_state_dict = state_dict_data_parallel_fix(
    state_dict, model.state_dict()
)
model.load_state_dict(new_state_dict, strict=False)

model.eval()
model = model.to("cuda")

In [5]:
def transform(snippet):
    snippet = torch.from_numpy(snippet)
    snippet = snippet.permute(0, 3, 1, 2) # n_frames x 3 x h x w
    snippet = snippet / 255.
    snippet = transforms.Resize(size=(224, 224))(snippet)
    return snippet

In [6]:
frozen_args = cfg.MODEL.FROZEN
raw_path = "/data/dir_moma/videos/raw"
feat_path = "/data/dir_moma/feats/frozen"
clip_duration = frozen_args.clip_duration
num_frames = frozen_args.num_frames

moma = MOMA("/data/dir_moma")
for split in ["train", "val", "test"]:
    ids_act = moma.get_ids_act(split=split)
    for act in tqdm(moma.get_anns_act(ids_act=ids_act), desc=f"SPLIT=({split})"):
        if act.id == "1YzGUyM3P2k":
            continue

        video = []
        sampled_clips = frozen_clip_sampler(
            os.path.join(raw_path, f"{act.id}.mp4"), 
            clip_duration=clip_duration, 
            num_frames=num_frames,
        )
    
        for clip in sampled_clips:
            clip = transform(clip)
            video.append(clip)

        embeddings = []
        for clip in video:
            # clip: n_frames(=4) x 3 x 224 x 224
            with torch.no_grad():
                emb = model(clip.cuda()) # 1 x n_patches x 768
            embeddings.append(emb.detach().cpu().numpy())

        embeddings = np.concatenate(embeddings, axis=0)
        np.save(os.path.join(feat_path, f"{act.id}.npy"), embeddings)

SPLIT=(train): 100%|██████████| 904/904 [22:49<00:00,  1.51s/it]
SPLIT=(val): 100%|██████████| 226/226 [05:47<00:00,  1.54s/it]
SPLIT=(test): 100%|██████████| 282/282 [06:51<00:00,  1.46s/it]


In [None]:
# SANIT CHECK
missing = []
raw_path = "/data/dir_moma/videos/raw"
feat_path = "/data/dir_moma/feats/frozen"
for filename in os.listdir(raw_path):
    vid = filename[:-4]
    if not os.path.exists(os.path.join(feat_path, f"{vid}.npy")):
        missing.append(vid)

print(f"missing: {missing} ({len(missing)})")