In [1]:
import sys

sys.path.append("..")
import numpy as np
import torch
from tqdm import trange

from utils.video import read_video, transform_img
from utils.vqvae import CompressorConfig, Encoder

In [2]:
frames = read_video("../examples/sample_video_ecamera.hevc")
frames = np.array([transform_img(x) for x in frames])
frames = torch.from_numpy(frames).permute(0, 3, 1, 2).to(device="cuda").float()

In [3]:
frames.shape

torch.Size([1200, 3, 128, 256])

In [4]:
# load model
config = CompressorConfig()
with torch.device("meta"):
    encoder = Encoder(config)
encoder.load_state_dict_from_url(
    "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/encoder_pytorch_model.bin",
    assign=True,
)
encoder = encoder.eval().to(device="cuda")

In [5]:
# encoding loop
tokens = []
with torch.no_grad():
    for i in trange(len(frames)):
        encoding_indices = encoder(frames[i][None].float())
        tokens.append(encoding_indices)
tokens = torch.cat(tokens, dim=0)

100%|██████████| 1200/1200 [00:04<00:00, 296.80it/s]


In [6]:
tokens[0]

tensor([ 178,  424,  193,  312,  367,   99,  529,   11,  178,  376,  392,  213,
           2,  689,  876,  930,  531,  549,  264,  402,  409,  264,  228,  212,
         238,  549,  817,  788,  874,  644,  644,  688,  841,  567,  494,  938,
         688,  110,  171,  215,  284,  183,  553,  484,  597, 1007,  447,  946,
         169,  473,  790,  245,  694,  562,  385,  275,  427,  861, 1001,  460,
         443,   98,  256,   37,  140,  639,  463,  411,  856,  683,  335,   93,
         917,  697,  916,  788,  804,   36,  937,  335,  758,  417,  560,  722,
         658,   45,  166,  879,  364,  888,  671,  736,  742,    7,  274,  470,
         399,  556,  708,  175,  270,  737,  726,  972,  274,  640,   11,  322,
         722,  871,  525,  930,  872,  579,   34,  522,  436,  213,  726,  410,
         335,  616,  848,   11,  873, 1011,  367,  427], device='cuda:0')

In [7]:
# save the tokens! now head over to the decoding notebook
tokens = tokens.cpu().numpy()
np.save("../examples/tokens.npy", tokens)