In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GenerationConfig, GPT2Config
import torch
import torch.nn as nn
import os
from pathlib import Path

import transformers
from onnxruntime import InferenceSession
from transformers.onnx.features import FeaturesManager
from optimum.exporters.tasks import TasksManager

In [3]:
torch.manual_seed(2333)

<torch._C.Generator at 0x7f4488f031d0>

In [13]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
# model.config.return_dict = False
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
encoded_input

model.transformer.wte = nn.Embedding(50264, 768)
model.lm_head = nn.Linear(768, 50264, bias=False)

In [14]:
model.transformer.h = nn.ModuleList([model.transformer.h[0]])
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50264, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50264, bias=False)
)


In [15]:
# Stage 0
encoded_input_0 = {
    'input_ids': torch.randint(0, 50264, (64, 255)),
    'past_key_values': None
}

In [16]:
%%time
with torch.no_grad():
    output = model(**encoded_input_0)
print(output.logits.shape)
print(output.past_key_values[0][0].size())

torch.Size([64, 255, 50264])
torch.Size([64, 12, 255, 64])
CPU times: user 38.6 s, sys: 2.92 s, total: 41.5 s
Wall time: 2.69 s


In [19]:
print(len(output.past_key_values))
print(len(output.past_key_values[0]))

1
2


In [17]:
# prepare for stage 1
next_input = torch.argmax(output.logits[..., -1, :], dim=-1, keepdim=True)
next_input.size()

torch.Size([64, 1])

In [25]:
%%time

# Stage 1 without current past return
model.config.use_cache = False
with torch.no_grad():
    next_output = model(next_input, past_key_values=output.past_key_values)
print(next_output.logits.size())

torch.Size([64, 1, 50264])
CPU times: user 675 ms, sys: 1.89 ms, total: 677 ms
Wall time: 47 ms


In [26]:
# Export
import numpy as np

def save_to(x, fp):
    os.makedirs(os.path.dirname(fp), exist_ok=True)
    print(x.dtype)
    np_x = x.detach().numpy()
    with open(fp, 'wb') as f:
        np_x.tofile(f)

def save_past_key_values(x, n, d):
    for l in range(n):
        k, v = x[l]
        fp = os.path.join(d, "past_key_values", "{}.{}".format('past', l))
        save_to(k, fp + '.key')
        save_to(v, fp + '.value')

def past_key_values_names(n: int, prefix: str):
    for i in range(n):
        yield "{}.{}.key".format(prefix, i)
        yield "{}.{}.value".format(prefix, i)

def prepare_dimension(in_name, out_name):
    ds = {}
    for name in in_name + out_name:
        if name.startswith('past') or name.startswith('present'):
            ds[name] = {2: 'seq_length'}
    return ds

input_names = ["input_ids"] + list(past_key_values_names(1, "past"))
# output_names = ["logits"] + list(past_key_values_names(12, "present"))
output_names = ["logits"]
# dyn_dimensions = prepare_dimension(input_names, output_names)
print(input_names)
print(output_names)
# print(dyn_dimensions)

dummy_input = {
    'input_ids': next_input,
    'past_key_values': output.past_key_values
}

['input_ids', 'past.0.key', 'past.0.value']
['logits']


In [28]:
Path("./model_onnx").mkdir(parents=True, exist_ok=True)
torch.onnx.export(model,
                  dummy_input,
                  "./model_onnx/gpt2Block.onnx",
                  verbose=False,
                  input_names=input_names,
                  output_names=output_names,
                  dynamic_axes=None)

verbose: False, log level: Level.ERROR



In [29]:
save_to(dummy_input['input_ids'], './model_onnx/input_ids.dat')
save_past_key_values(dummy_input['past_key_values'], 1, './model_onnx')

torch.int64
torch.float32
torch.float32
