In [42]:
import numpy as np
import onnxruntime as rt
from scipy.special import softmax

sess = rt.InferenceSession("/Users/james/playground/mario-gpt/notebooks/custom_onnx_output/model.onnx")
encoder_sess = rt.InferenceSession("/Users/james/playground/mario-gpt/notebooks/bert_base/encoder_model.onnx")

In [39]:
a = np.array([[[1, 2, 3, 4, 5]]])
print(a)
print(a[0, 0, 2])
b = a[:, -1, :].flatten()
print(b[2])

[[[1 2 3 4 5]]]
3
3


In [40]:
for input in sess.get_inputs():
    print(input.name)

input_ids
attention_mask
position_ids
encoder_hidden_states


In [73]:
# Encode prompt
import json
vocab_path = "/Users/james/playground/mario-gpt/notebooks/bart_base/bart_vocab.json"
f = open(vocab_path, 'r')
vocab_data = json.load(f)
f.close()

def tokenize(text, vocab, max_seq_length):
    tokens = text.split()  # Simple whitespace tokenizer
    token_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]

    # Truncate if too long
    if len(token_ids) > max_seq_length - 2:
        token_ids = token_ids[:max_seq_length - 2]

    # Add <s> and </s>
    token_ids = [vocab['<s>']] + token_ids + [vocab['</s>']]
    attention_mask = [1] * len(token_ids)

    return np.array(token_ids), np.array(attention_mask)

prompt = "many pipes, many enemies, some blocks, high elevation"
token_ids, attention_mask = tokenize(prompt, vocab_data, 1024)
token_ids = token_ids.reshape(1, token_ids.shape[0])
attention_mask = attention_mask.reshape(1, attention_mask.shape[0])

encoded = encoder_sess.run(None, {'input_ids': token_ids, 'attention_mask': attention_mask})
hidden = encoded[0].mean(1).reshape(1, 1, encoded[0].shape[-1])
print(hidden.shape)

(1, 1, 768)


In [79]:
# Generate seed and run iterations
seed = np.array([[56]])
out = seed
num_steps = 700
context_len = 700-28
for i in range(num_steps):
    inp_ids = out * 1
    if len(out.shape) > 0 and out.shape[-1] > context_len:
        diff = inp_ids.shape[-1] % 14  # height of mario level
        ctx = context_len + diff
        inp_ids = inp_ids[:, -ctx:] * 1
    n_vals = inp_ids.shape[-1]
    position_ids = (np.arange(n_vals)).reshape(1, n_vals)
    attention_mask = np.ones((1, n_vals), dtype=np.int64)
    preds = sess.run(None, {"input_ids": inp_ids, "position_ids": position_ids, "attention_mask": attention_mask, 'encoder_hidden_states': hidden})
    logits = preds[0]
    logits = logits[:, -1, :].flatten()
    k = 16
    indices_of_top_k = np.argpartition(logits, -k)[-k:]
    indices_to_zero = np.setdiff1d(np.arange(logits.size), indices_of_top_k)
    logits[indices_to_zero] = -np.inf
    logits = logits / 2.0
    probs = softmax(logits)
    next = np.array([[np.random.choice(len(probs), p=probs)]])
    
    out = np.concatenate([out, next.reshape([1, 1])], axis=-1)

print(out)

[[56 56 88 13 13 79 79 13 13 13 13 13 13 13 13 88 88 13 13 13 79 13 13 13
  13 13 13 13 13 13 88 88 13 13 13 13 13 13 13 13 13 13 13 13 13 88 88 13
  13 13 13 13 13 13 13 13 13 13 13 13 88 79 79 13 13 13 13 13 13 13 13 13
  13 13 88 79 79 13 13 13 13 13 13 13 13 13 13 13 88 79 79 13 13 13 13 13
  13 13 13 13 13 13 88 13 13 13 13 13 13 13 13 13 13 13 13 88 13 13 13 13
  13 13 13 13 13 13 56 56 88 13 13 13 13 13 37 13 13 13 13 13 56 56 88 88
  13 13 13 13 13 13 13 13 13 13 56 56 56 88 88 13 13 13 13 13 13 13 13 13
  56 56 56 56 88 88 13 13 13 13 13 13 13 13 56 56 56 13 13 88 88 13 13 13
  13 13 13 13 56 59 59 59 28 88 13 13 13 13 13 13 13 13 56 61 61 61 30 88
  13 13 13 13 13 13 13 13 56 56 56 56 56 88 88 79 13 13 13 13 13 13 13 13
  13 13 13 13 88 56 13 13 13 13 13 13 13 13 13 13 13 13 88 79 13 13 13 13
  13 13 13 13 13 13 13 13 88 79 13 13 13 13 13 13 56 59 59 59 28 88 13 13
  13 13 13 13 13 13 56 61 61 61 30 88 13 13 13 13 13 13 13 13 56 56 56 56
  13 88 13 13 13 13 13 13 13 13 56 56 

In [80]:
# Parse the map
import json
tok_path = "/Users/james/playground/mario-gpt/notebooks/Mario-GPT2-700-context-length/onnx_output/tokenizer.json"
f = open(tok_path, 'r')
data = json.load(f)
f.close()
tokenizer = {}
for c, i in data['model']['vocab'].items():
    tokenizer[i] = c
    
str_list = []
for i in range(out.shape[1]):
    str_list.append(tokenizer[out[0, i]])

n_cols = len(str_list) // 14
cols = []
col = []
for i, c in enumerate(str_list):
    if i > 0 and i % 14 == 0:
        cols.append("".join(col))
        col = []
    col.append(c)

rows = []
for i in reversed(range(14)):
    row = []
    for j in range(n_cols):
        row.append(cols[j][i])
    rows.append("".join(row))
map = "\n".join(rows)

print(map)

--------------------------------------------------
--------------------------------------------------
--------------------------------------------------
--------------------------------------------------
--------------------------------------------------
---------E----------------------------------------
----------------oXoo------------------xxxxxx------
oo--ooo------x--xxxx---------------xxxxS-x--xoo---
o---ooo-----xxxxx---xxx-----------xxXSS-SQ---x----
---xxxxx---xx-<>X---<>-xx--------xxXXX--------x---
--xx----x-xxX-[]X---[]XX-x------xxXXXX---------x--
xxx------xxXXX[]X---[]XX--x----xxXXXXX----------x-
Xx-------XXXXX[]X---[]XX---xxxxxXXXXXX--SSSSSSSS-x
X--------XXXXXXXX---XXXXXX-XXXXXXXXXX-----------XX
