In [None]:
import os
os.chdir("/root/cvpr24_video_retrieval/")

In [None]:
import yaml
import numpy as np
import torch

from momaapi import MOMA
from utils.attr_dict import AttrDict
from utils.main_utils import seed_everything
from utils.clip_sampler import s3d_clip_sampler
from models.s3d import S3D
from tqdm import tqdm

In [None]:
with open("config/default.yaml", "r") as f:
    cfg = yaml.load(f, yaml.FullLoader)
    cfg = AttrDict(cfg)

if "SEED" in cfg:
    seed_everything(cfg.SEED)

torch.set_float32_matmul_precision("high")

In [None]:
model = S3D(cfg)

path = os.path.join(
    cfg.PATH.CKPT_PATH, "S3D", "S3D_kinetics400.pt"
)
assert os.path.isfile(path)

weight_dict = torch.load(path)
model_dict = model.state_dict()

for name, param in weight_dict.items():
    if "module" in name:
        name = '.'.join(name.split('.')[1:])
    if name in model_dict:
        assert param.size() == model_dict[name].size()
        model_dict[name].copy_(param)

model.eval()
model = model.to("cuda")

In [None]:
def transform(snippet):
    ''' stack & noralization '''
    snippet = np.concatenate(snippet, axis=-1)
    snippet = torch.from_numpy(snippet).permute(2, 0, 1).contiguous().float()
    snippet = snippet.mul_(2.).sub_(255).div(255)

    return snippet.view(1,-1,3,snippet.size(1),snippet.size(2)).permute(0,2,1,3,4)

In [None]:
s3d_args = cfg.MODEL.VIDEO.S3D
raw_path = "/data/dir_moma/videos/raw"
feat_path = "/data/dir_moma/feats/s3d"

for filename in tqdm(os.listdir(raw_path)):
    vid = filename[:-4] # remove .mp4
    clip_duration = s3d_args.clip_duration
    frames_per_clip = s3d_args.frames_per_clip
    stride = s3d_args.stride

    sampled_clips = s3d_clip_sampler(
        os.path.join(raw_path, filename), clip_duration, frames_per_clip, stride
    )

    embeddings = []
    for clip in sampled_clips:
        # clip: [h x w x 3, ... (x n_frames)]
        clip = transform(clip) # 1 x 3 x n_frames x h x w
        with torch.no_grad():
            emb = model(clip.cuda()) # 1 x 1024
        embeddings.append(emb.detach().cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0) # n_clips x 1024
    np.save(os.path.join(feat_path, f"{vid}.npy"), embeddings)