In [1]:
import sys 
sys.path.append('..')
import numpy as np
import onnxruntime as ort
from tqdm import trange
from utils.sampling import softmax, multinomial
from utils.video import write_video, transpose_and_clip
from IPython.display import Video

In [2]:
TOKENS_PER_FRAME = 129
MAX_CONTEXT_SIZE = 20*129
BOS_TOKEN        = 1024

In [3]:
# load tokens
tokens_condition = np.load("../examples/tokens.npy").astype(np.int64)
tokens_condition = np.c_[np.ones(len(tokens_condition), dtype=np.int64)*BOS_TOKEN, tokens_condition]
tokens_condition = tokens_condition[-(MAX_CONTEXT_SIZE//TOKENS_PER_FRAME - 1):].reshape(1,-1)

In [4]:
# load model session
options = ort.SessionOptions()

provider = 'CUDAExecutionProvider'
session = ort.InferenceSession(f'../models/gpt2m.onnx', options, [provider])
# print shapes
input_shapes = {i.name: i.shape for i in session.get_inputs()}
output_shapes = {i.name: i.shape for i in session.get_outputs()}
print('input shapes : ', input_shapes)
print('output shapes: ', output_shapes)

input shapes :  {'tokens': ['b', 't_present'], 'kvcache': ['b', 24, 't_past', 2048], 'use_cache_branch': [1]}
output shapes:  {'logits': ['b', 1, 1025], 'kvcache_out': ['b', 24, 't_past_out', 2048]}


In [5]:
def generate_frame_tokens(session, tokens):
  data = {'tokens': tokens,
          'kvcache':  np.zeros((1, 24, 0*TOKENS_PER_FRAME, 2048), dtype=np.float16),
          'use_cache_branch': np.array([0], dtype=np.int32)}

  data_ortvalue = {}
  for k in ['tokens', 'kvcache', 'use_cache_branch']:
      data_ortvalue[k] = ort.OrtValue.ortvalue_from_numpy(data[k], 'cuda', 0)

  io_binding = session.io_binding()
  for k in ['tokens', 'kvcache', 'use_cache_branch']:
      io_binding.bind_ortvalue_input(k, data_ortvalue[k])

  io_binding.bind_output('logits', 'cuda')
  for i in range(TOKENS_PER_FRAME):
    io_binding.bind_output('kvcache_out', 'cuda')  

    if i > 0:
      data['use_cache_branch'] = np.array([1]).astype(np.int32)
      data_ortvalue['use_cache_branch'] = ort.OrtValue.ortvalue_from_numpy(data['use_cache_branch'], 'cuda', 0)
      io_binding.bind_ortvalue_input('use_cache_branch', data_ortvalue['use_cache_branch'])


    session.run_with_iobinding(io_binding)
    ort_output = io_binding.get_outputs()  

    logits = ort_output[0].numpy()[:,-1,:]
    logits = logits.astype(np.float64)
    probs = softmax(logits, axis=1)
    tokens = multinomial(probs)  

    data['tokens'] = np.concatenate([data['tokens'], tokens], axis=1)
    data_ortvalue['tokens'] = ort.OrtValue.ortvalue_from_numpy(data['tokens'], 'cuda', 0)
    io_binding.bind_ortvalue_input('tokens', data_ortvalue['tokens'])

    io_binding.bind_ortvalue_input('kvcache', ort_output[1])
  
  return data['tokens'][:, -TOKENS_PER_FRAME:]

In [6]:
# generate! (slow...)
NEW_FRAMES = 20*5
for _ in trange(NEW_FRAMES):
  tokens = generate_frame_tokens(session, tokens_condition[:, -(MAX_CONTEXT_SIZE-TOKENS_PER_FRAME):])
  tokens_condition = np.concatenate([tokens_condition, tokens], axis=1)

100%|██████████| 100/100 [02:29<00:00,  1.49s/it]


In [7]:
# reshape and remove BOS token
tokens_condition = tokens_condition.reshape(-1,TOKENS_PER_FRAME)
tokens_condition = tokens_condition[:, 1:]

In [8]:
# load decoder
decoder_session = ort.InferenceSession(f'../models/decoder.onnx', options, [provider])

In [9]:
# decode generated tokens to video (same as decode.ipynb)
decoded_video = []
for i in trange(len(tokens_condition)):
  outputs = decoder_session.run(None, {'encoding_indices': tokens_condition[i].reshape(1,8,16)})
  outputs = {o.name: x for o,x in zip(decoder_session.get_outputs(), outputs)}
  decoded_video.append(outputs['big_decoded_img'])

100%|██████████| 119/119 [00:01<00:00, 59.95it/s]


In [10]:
# transpose and format video
decoded_video = transpose_and_clip(decoded_video)

In [11]:
# save video
save_dst = '/tmp/generated.mp4'
write_video(decoded_video, save_dst, fps=20)
Video(save_dst, embed=True, width=700)