In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GenerationConfig, GPT2Config
import torch
import torch.nn as nn
import os
from pathlib import Path

import transformers
from onnxruntime import InferenceSession
from transformers.onnx.features import FeaturesManager
from optimum.exporters.tasks import TasksManager

In [2]:
import numpy as np

def save_to(x, fp):
    os.makedirs(os.path.dirname(fp), exist_ok=True)
    print(x.dtype)
    np_x = x.detach().numpy()
    with open(fp, 'wb') as f:
        np_x.tofile(f)

def save_past_key_values(x, n, d):
    for l in range(n):
        k, v = x[l]
        fp = os.path.join(d, "past_key_values", "{}.{}".format('past', l))
        save_to(k, fp + '.key')
        save_to(v, fp + '.value')

In [3]:
torch.manual_seed(2333)
# torch.set_num_threads(1)

<torch._C.Generator at 0x7fa5d98b9110>

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
# model.config.return_dict = False
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
encoded_input

{'input_ids': tensor([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [5]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [6]:
# Stage 0
encoded_input_0 = {
    'input_ids': torch.randint(0, 50257, (64, 256)),
    'past_key_values': None
}


In [7]:
%%time
output = model(**encoded_input_0)
print(output.logits.shape)
print(output.past_key_values[0][0].size())

torch.Size([64, 256, 50257])
torch.Size([64, 12, 256, 64])
CPU times: user 2min 32s, sys: 33.7 s, total: 3min 6s
Wall time: 14.7 s


In [8]:
# prepare for stage 1
next_input = torch.argmax(output.logits[..., -1, :], dim=-1, keepdim=True)
next_input.size()

torch.Size([64, 1])

In [9]:
%%time

# Stage 1
next_output = model(next_input, past_key_values=output.past_key_values)
print(next_output.logits.shape)
print(next_output.past_key_values[0][0].size())

torch.Size([64, 1, 50257])
torch.Size([64, 12, 257, 64])
CPU times: user 4.14 s, sys: 5.4 s, total: 9.54 s
Wall time: 685 ms


In [16]:
# Export
def past_key_values_names(n: int, prefix: str):
    for i in range(n):
        yield "{}.{}.key".format(prefix, i)
        yield "{}.{}.value".format(prefix, i)

def prepare_dimension(in_name, out_name):
    ds = {}
    for name in in_name + out_name:
        if name.startswith('past') or name.startswith('present'):
            ds[name] = {2: 'seq_length'}
    return ds

input_names = ["input_ids"] + list(past_key_values_names(12, "past"))
output_names = ["logits"] + list(past_key_values_names(12, "present"))
dyn_dimensions = prepare_dimension(input_names, output_names)
print(input_names)
print(output_names)
print(dyn_dimensions)

dummy_input = {
    'input_ids': next_input,
    'past_key_values': output.past_key_values
}

['input_ids', 'past.0.key', 'past.0.value', 'past.1.key', 'past.1.value', 'past.2.key', 'past.2.value', 'past.3.key', 'past.3.value', 'past.4.key', 'past.4.value', 'past.5.key', 'past.5.value', 'past.6.key', 'past.6.value', 'past.7.key', 'past.7.value', 'past.8.key', 'past.8.value', 'past.9.key', 'past.9.value', 'past.10.key', 'past.10.value', 'past.11.key', 'past.11.value']
['logits', 'present.0.key', 'present.0.value', 'present.1.key', 'present.1.value', 'present.2.key', 'present.2.value', 'present.3.key', 'present.3.value', 'present.4.key', 'present.4.value', 'present.5.key', 'present.5.value', 'present.6.key', 'present.6.value', 'present.7.key', 'present.7.value', 'present.8.key', 'present.8.value', 'present.9.key', 'present.9.value', 'present.10.key', 'present.10.value', 'present.11.key', 'present.11.value']
{'past.0.key': {2: 'seq_length'}, 'past.0.value': {2: 'seq_length'}, 'past.1.key': {2: 'seq_length'}, 'past.1.value': {2: 'seq_length'}, 'past.2.key': {2: 'seq_length'}, 'past

In [17]:
torch.onnx.export(model,
                  dummy_input,
                  "./model_onnx/gpt2.onnx",
                  verbose=False,
                  input_names=input_names,
                  output_names=output_names,
                  dynamic_axes=None)

In [15]:
save_to(dummy_input['input_ids'], './model_onnx/input_ids.dat')
save_past_key_values(dummy_input['past_key_values'], 12, './model_onnx')

torch.int64
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32


In [19]:
gen_cfg = GenerationConfig(
    pad_token_id = tokenizer.eos_token_id, 
    use_cache = True,
    max_new_tokens = 128)

In [20]:
%%time
output = model.generate(**encoded_input, generation_config=gen_cfg)
print(tokenizer.decode(output[0]))

Replace me by any text you'd like.

I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you're aware of this, but I'm not sure if you
CPU times: user 5.14 s, sys: 264 ms, total: 5.41 s
Wall time: 5.46 s
