In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GenerationConfig, GPT2Config
import torch
import torch.nn as nn
import os
from pathlib import Path

import transformers
from onnxruntime import InferenceSession
from transformers.onnx.features import FeaturesManager
from optimum.exporters.tasks import TasksManager

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np

def save_to(x, fp):
    os.makedirs(os.path.dirname(fp), exist_ok=True)
    print(x.dtype)
    np_x = x.detach().numpy()
    with open(fp, 'wb') as f:
        np_x.tofile(f)

def save_past_key_values(x, n, d):
    for l in range(n):
        k, v = x[l]
        fp = os.path.join(d, "past_key_values", "{}.{}".format('past', l))
        save_to(k, fp + '.key')
        save_to(v, fp + '.value')

In [3]:
torch.manual_seed(2333)
torch.set_num_threads(1)

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model.eval()
# model.config.return_dict = False
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
encoded_input

{'input_ids': tensor([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [5]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout)

In [6]:
# Stage 0
encoded_input_0 = {
    'input_ids': torch.randint(0, 50257, (64, 256)),
    'past_key_values': None
}


In [7]:
%%time
output = model(**encoded_input_0)


CPU times: user 4min 50s, sys: 40.9 s, total: 5min 31s
Wall time: 5min 31s


In [8]:
print(output.logits.shape)
print(output.past_key_values[0][0].size())

torch.Size([64, 256, 50257])
torch.Size([64, 16, 256, 64])


In [9]:
# prepare for stage 1
next_input = torch.argmax(output.logits[..., -1, :], dim=-1, keepdim=True)
next_input.size()

torch.Size([64, 1])

In [10]:
%%time

# Stage 1
next_output = model(next_input, past_key_values=output.past_key_values)
print(next_output.logits.shape)
print(next_output.past_key_values[0][0].size())

torch.Size([64, 1, 50257])
torch.Size([64, 16, 257, 64])
CPU times: user 3.58 s, sys: 3.24 s, total: 6.83 s
Wall time: 7 s


In [11]:
# Export
def past_key_values_names(n: int, prefix: str):
    for i in range(n):
        yield "{}.{}.key".format(prefix, i)
        yield "{}.{}.value".format(prefix, i)
input_names = ["input_ids"] + list(past_key_values_names(24, "past"))
output_names = ["logits"] + list(past_key_values_names(24, "present"))
print(input_names)
print(output_names)

['input_ids', 'past.0.key', 'past.0.value', 'past.1.key', 'past.1.value', 'past.2.key', 'past.2.value', 'past.3.key', 'past.3.value', 'past.4.key', 'past.4.value', 'past.5.key', 'past.5.value', 'past.6.key', 'past.6.value', 'past.7.key', 'past.7.value', 'past.8.key', 'past.8.value', 'past.9.key', 'past.9.value', 'past.10.key', 'past.10.value', 'past.11.key', 'past.11.value', 'past.12.key', 'past.12.value', 'past.13.key', 'past.13.value', 'past.14.key', 'past.14.value', 'past.15.key', 'past.15.value', 'past.16.key', 'past.16.value', 'past.17.key', 'past.17.value', 'past.18.key', 'past.18.value', 'past.19.key', 'past.19.value', 'past.20.key', 'past.20.value', 'past.21.key', 'past.21.value', 'past.22.key', 'past.22.value', 'past.23.key', 'past.23.value']
['logits', 'present.0.key', 'present.0.value', 'present.1.key', 'present.1.value', 'present.2.key', 'present.2.value', 'present.3.key', 'present.3.value', 'present.4.key', 'present.4.value', 'present.5.key', 'present.5.value', 'present.6.

In [14]:
dummy_input = {
    'input_ids': next_input,
    'past_key_values': output.past_key_values
}
torch.onnx.export(model,
                  dummy_input,
                  "./model_onnx/gpt2.onnx",
                  verbose=False,
                  input_names=input_names,
                  output_names=output_names)

In [15]:
save_to(dummy_input['input_ids'], './model_onnx/input_ids.dat')
save_past_key_values(dummy_input['past_key_values'], 24, './model_onnx')

torch.int64
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
