In [1]:
import sys
sys.path.append('..')
import numpy as np
from tqdm import trange
import torch
import torch._inductor.config
import torch._dynamo.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
from utils.video import write_video, transpose_and_clip
from utils.gpt import GPT, GPTConfig
from utils.vqvae import Decoder, CompressorConfig
from IPython.display import Video

In [2]:
# load model
config = GPTConfig()
with torch.device('meta'):
  model = GPT(config)
model.load_state_dict_from_url('https://huggingface.co/commaai/commavq-gpt2m/resolve/main/pytorch_model.bin', assign=True)
model = model.eval().to(device='cuda', dtype=torch.bfloat16)

In [3]:
# compile
model.decode_one_token = torch.compile(model.decode_one_token, mode="reduce-overhead", fullgraph=True)
idx = torch.randint(0, config.vocab_size, (config.block_size - config.tokens_per_frame, ), device='cuda')
y = model.generate(idx, config.tokens_per_frame)

In [4]:
# load tokens
tokens_condition = np.load("../examples/tokens.npy").astype(np.int32)
tokens_condition = np.c_[np.ones(len(tokens_condition), dtype=np.int32)*config.bos_token, tokens_condition]
tokens_condition = tokens_condition[-(config.block_size//config.tokens_per_frame - 1):].reshape(-1)
tokens_condition = torch.tensor(tokens_condition, device='cuda')

In [5]:
# generate! (slow...)
NEW_FRAMES = 20*5
for _ in trange(NEW_FRAMES):
  tokens = model.generate(tokens_condition[-(config.block_size-config.tokens_per_frame):], config.tokens_per_frame)
  tokens_condition = torch.cat([tokens_condition, tokens], axis=0)

100%|██████████| 100/100 [00:32<00:00,  3.09it/s]


In [6]:
# reshape and remove BOS token
tokens_condition = tokens_condition.reshape(-1,config.tokens_per_frame)
tokens_condition = tokens_condition[:, 1:].to(dtype=torch.int64)

In [7]:
# load model
config = CompressorConfig()
with torch.device('meta'):
  decoder = Decoder(config)
decoder.load_state_dict_from_url('https://huggingface.co/commaai/commavq-gpt2m/resolve/main/decoder_pytorch_model.bin', assign=True)
decoder = decoder.eval().to(device='cuda')

In [8]:
# decode generated tokens to video (same as decode.ipynb)
decoded_video = []
with torch.no_grad():
  for i in trange(len(tokens_condition)):
    decoded = decoder(tokens_condition[i][None])
    decoded_video.append(decoded)
decoded_video = torch.cat(decoded_video, dim=0).cpu().numpy()

100%|██████████| 119/119 [00:01<00:00, 78.14it/s]


In [9]:
# transpose and format video
decoded_video = transpose_and_clip(decoded_video)

In [10]:
# save video
save_dst = '/tmp/generated.mp4'
write_video(decoded_video, save_dst, fps=20)
Video(save_dst, embed=True, width=700)