In [1]:
import sys 
sys.path.append('..')
import numpy as np
import onnxruntime as ort
from tqdm import trange
from utils.sampling import softmax, multinomial
from utils.video import write_video, transpose_and_clip
from IPython.display import Video

In [2]:
TOKENS_PER_FRAME = 129
MAX_CONTEXT_SIZE = 20*129
BOS_TOKEN        = 1024

In [3]:
# load model session
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = ["CUDAExecutionProvider", "CPUExecutionProvider"]
session = ort.InferenceSession(f'../gpt2m/gpt2m.onnx', options, provider)
# print shapes
input_shapes = {i.name:  (i.shape, i.type) for i in session.get_inputs()}
output_shapes = {i.name: (i.shape, i.type) for i in session.get_outputs()}
print('input shapes : ', input_shapes)
print('output shapes: ', output_shapes)

input shapes :  {'input_ids': (['batch_size', 'seq_len'], 'tensor(int32)'), 'past_0': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_1': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_2': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_3': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_4': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_5': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_6': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_7': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_8': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_9': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_10': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_11': ([2, 'batch_size', 16, 'past_seq_len', 64], 'tensor(float16)'), 'past_12': ([2, 'batch_size', 16, 

In [4]:
def generate_frame_tokens(session, tokens):
  data = {'input_ids': tokens,
          **{f'past_{i}': np.zeros((2, 1, 16, 0, 64), dtype=np.float16) for i in range(24)}
          }

  data_ortvalue = {}
  for k in data:
      data_ortvalue[k] = ort.OrtValue.ortvalue_from_numpy(data[k], 'cuda', 0)

  io_binding = session.io_binding()
  for k in data:
      io_binding.bind_ortvalue_input(k, data_ortvalue[k])

  output_tokens = []
  for i in range(TOKENS_PER_FRAME):
    io_binding.bind_output('logits', 'cuda')
    
    for j in range(24):
      io_binding.bind_output(f'present_{j}', 'cuda')  

    session.run_with_iobinding(io_binding)
    ort_output = io_binding.get_outputs()  

    logits = ort_output[0].numpy()[:,-1,:]
    logits = logits.astype(np.float64)
    probs = softmax(logits, axis=1)
    tokens = multinomial(probs).astype(np.int32)
        
    output_tokens.append(tokens)
    data_ortvalue['input_ids'] = ort.OrtValue.ortvalue_from_numpy(tokens, 'cuda', 0)
    io_binding.bind_ortvalue_input('input_ids', data_ortvalue['input_ids'])

    for j in range(24):
      io_binding.bind_ortvalue_input(f'past_{j}', ort_output[1+j])
  
  return np.concatenate(output_tokens, axis=1)

In [5]:
# load tokens
tokens_condition = np.load("../examples/tokens.npy").astype(np.int32)
tokens_condition = np.c_[np.ones(len(tokens_condition), dtype=np.int32)*BOS_TOKEN, tokens_condition]
tokens_condition = tokens_condition[-(MAX_CONTEXT_SIZE//TOKENS_PER_FRAME - 1):].reshape(1,-1)

In [6]:
# generate! (slow...)
NEW_FRAMES = 20*5
for _ in trange(NEW_FRAMES):
  tokens = generate_frame_tokens(session, tokens_condition[:, -(MAX_CONTEXT_SIZE-TOKENS_PER_FRAME):])
  tokens_condition = np.concatenate([tokens_condition, tokens], axis=1)

100%|██████████| 100/100 [00:48<00:00,  2.07it/s]


In [7]:
# reshape and remove BOS token
tokens_condition = tokens_condition.reshape(-1,TOKENS_PER_FRAME)
tokens_condition = tokens_condition[:, 1:].astype(np.int64)

In [8]:
# load decoder
decoder_session = ort.InferenceSession(f'../gpt2m/decoder.onnx', options, provider)

In [9]:
# decode generated tokens to video (same as decode.ipynb)
decoded_video = []
for i in trange(len(tokens_condition)):
  outputs = decoder_session.run(None, {'encoding_indices': tokens_condition[i].reshape(1,8,16)})
  outputs = {o.name: x for o,x in zip(decoder_session.get_outputs(), outputs)}
  decoded_video.append(outputs['big_decoded_img'])

100%|██████████| 119/119 [00:06<00:00, 18.07it/s]


In [10]:
# transpose and format video
decoded_video = transpose_and_clip(decoded_video)

In [11]:
# save video
save_dst = '/tmp/generated.mp4'
write_video(decoded_video, save_dst, fps=20)
Video(save_dst, embed=True, width=700)