In [1]:
from transformers import pipeline, set_seed
from transformers import BertTokenizer, BertModel, BertConfig
import torch
import torch.nn as nn
import os

from optimum.exporters.tasks import TasksManager

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import numpy as np

def save_to(x, fp):
    print(x.dtype)
    np_x = x.detach().numpy()
    with open(fp, 'wb') as f:
        np_x.tofile(f)

In [3]:
torch.manual_seed(2333)
torch.set_num_threads(1)

In [4]:
bert_tasks = list(TasksManager.get_supported_tasks_for_model_type("bert", "onnx").keys())
bert_tasks

['default',
 'masked-lm',
 'sequence-classification',
 'multiple-choice',
 'token-classification',
 'question-answering']

In [12]:
from optimum.onnxruntime import ORTModelForMaskedLM
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./bert.onnx")
model = ORTModelForMaskedLM.from_pretrained("./bert.onnx")
encoded_input = { 
    'input_ids': torch.randint(0, 30522, (64, 20)),
    'token_type_ids': torch.zeros((64, 20), dtype=torch.long),
    'attention_mask': torch.randint(0, 2, (64, 20))
}
encoded_input

{'input_ids': tensor([[22143, 28113,  7618,  ..., 11057, 12451,  9348],
         [18854, 20765, 10706,  ..., 14001, 21047, 25356],
         [ 4659, 30503,  1001,  ...,  4444, 29327,  7125],
         ...,
         [ 1535, 10410, 26120,  ..., 27442, 12478,   239],
         [22372, 10083,  4459,  ..., 22625,  5628,  9219],
         [27550,  9595, 25063,  ...,  9420,  3319, 28344]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'attention_mask': tensor([[1, 0, 0,  ..., 0, 0, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 0, 1, 0],
         ...,
         [1, 1, 0,  ..., 1, 1, 1],
         [0, 1, 0,  ..., 1, 1, 0],
         [1, 0, 0,  ..., 0, 1, 1]])}

In [13]:
%%time
output = model(**encoded_input)
output.logits.size()

CPU times: user 10.6 s, sys: 328 ms, total: 10.9 s
Wall time: 738 ms


torch.Size([64, 20, 30522])

In [9]:
save_to(encoded_input['input_ids'], 'input_ids.dat')
save_to(encoded_input['token_type_ids'], 'token_type_ids.dat')
save_to(encoded_input['attention_mask'], 'attention_mask.dat')

torch.int64
torch.int64
torch.int64


In [4]:
bert = BertModel.from_pretrained('bert-base-uncased')
bert.eval()
print(bert)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [13]:
x = torch.randint(0, 30522, (64, 512))

In [14]:
%%time
y = bert(x)
y.last_hidden_state.size()

CPU times: user 2min 51s, sys: 30.4 s, total: 3min 22s
Wall time: 3min 22s


torch.Size([64, 512, 768])

In [7]:
input_names = ["text"]
output_names = ["prediction"]
dummy_input = torch.randint(0, 30522, (64, 512))

torch.onnx.export(bert,
                  dummy_input,
                  "./bert_64x512.onnx",
                  verbose=False,
                  input_names=input_names,
                  output_names=output_names)

In [15]:
save_to('.', x, y.last_hidden_state)

torch.int64 torch.float32
