In [1]:
import sys
sys.path.append('..')
import numpy as np
from tqdm import trange
import torch
import torch._inductor.config
import torch._dynamo.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
from utils.video import write_video, transpose_and_clip
from utils.gpt import GPT, Config, get_state_dict
from IPython.display import Video

In [2]:
# load model
config = Config()
with torch.device('meta'):
  model = GPT(config)
state_dict = get_state_dict('https://huggingface.co/commaai/commavq-gpt2m/resolve/main/pytorch_model.bin')
model.load_state_dict(state_dict, assign=True)
model = model.eval().to(device='cuda', dtype=torch.bfloat16)

In [3]:
# compile
model.decode_one_token = torch.compile(model.decode_one_token, mode="reduce-overhead", fullgraph=True)
idx = torch.randint(0, config.vocab_size, (config.block_size - config.tokens_per_frame, ), device='cuda')
y = model.generate(idx, config.tokens_per_frame)

In [4]:
# load tokens
tokens_condition = np.load("../examples/tokens.npy").astype(np.int32)
tokens_condition = np.c_[np.ones(len(tokens_condition), dtype=np.int32)*config.bos_token, tokens_condition]
tokens_condition = tokens_condition[-(config.block_size//config.tokens_per_frame - 1):].reshape(-1)
tokens_condition = torch.tensor(tokens_condition, device='cuda')

In [5]:
# generate! (slow...)
NEW_FRAMES = 20*5
for _ in trange(NEW_FRAMES):
  tokens = model.generate(tokens_condition[-(config.block_size-config.tokens_per_frame):], config.tokens_per_frame)
  tokens_condition = torch.cat([tokens_condition, tokens], axis=0)

100%|██████████| 100/100 [00:30<00:00,  3.25it/s]


In [6]:
# reshape and remove BOS token
tokens_condition = tokens_condition.reshape(-1,config.tokens_per_frame).cpu().numpy()
tokens_condition = tokens_condition[:, 1:].astype(np.int64)

In [7]:
# load decoder
import onnxruntime as ort
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = ["CUDAExecutionProvider"]
decoder_session = ort.InferenceSession(f'../gpt2m/decoder.onnx', options, provider)

[0;93m2024-05-27 16:15:43.721769027 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2024-05-27 16:15:43.721787292 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


In [8]:
# decode generated tokens to video (same as decode.ipynb)
decoded_video = []
for i in trange(len(tokens_condition)):
  outputs = decoder_session.run(None, {'encoding_indices': tokens_condition[i].reshape(1,8,16)})
  outputs = {o.name: x for o,x in zip(decoder_session.get_outputs(), outputs)}
  decoded_video.append(outputs['big_decoded_img'])

100%|██████████| 119/119 [00:06<00:00, 18.26it/s]


In [9]:
# transpose and format video
decoded_video = transpose_and_clip(decoded_video)

In [10]:
# save video
save_dst = '/tmp/generated.mp4'
write_video(decoded_video, save_dst, fps=20)
Video(save_dst, embed=True, width=700)