## library imports

In [31]:
import torch
import transformers
import onnx
import onnxruntime.training.onnxblock as onnxblock

## generating artifacts

In [4]:
model = transformers.AutoModel.from_pretrained('google/mobilebert-uncased')
model_name = 'mobilebert-uncased'

Some weights of the model checkpoint at google/mobilebert-uncased were not used when initializing MobileBertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing MobileBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MobileBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [25]:
# ... painstakingly create the random input

# expects
# input_ids = torch.LongTensor of shape (batch size, seq len)
# attention_mask = torch.FloatTensor of shape (batch size, seq len)
# token_type_ids = torch.LongTensor of shape (bs, seq len)

batch_size = 2
seq_len = 25
vocab = 20000
input_ids = torch.randint(vocab, (batch_size, seq_len))
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.float)
token_type_ids = torch.ones((batch_size, seq_len), dtype=torch.long)


In [28]:
torch.onnx.export(model, (input_ids, attention_mask, token_type_ids),
                  f"training_artifacts/{model_name}.onnx", 
                  input_names=["input_ids", "attention_mask", "token_type_ids"],
                  output_names=["output"])

  torch.tensor(1000),
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [32]:
class MobileBERTWithLoss(onnxblock.TrainingModel):
    def __init__(self):
        super().__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss()

    def build(self, loss_node_input_name):
        return self.loss(loss_node_input_name)

In [33]:
# Load the model from the exported inference ONNX file.
onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
eval_model = None
optimizer_model = None

training_block = MobileBERTWithLoss()

inference_model_output_name = "output"
with onnxblock.onnx_model(onnx_model) as model_accessor:
    loss_output_name = training_block(inference_model_output_name)
    eval_model = model_accessor.eval_model

optimizer_block = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as model_accessor:
    optimizer_outputs = optimizer_block(training_block.parameters())
    optimizer_model = model_accessor.model

2023-03-15 21:30:56.704342473 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/Slice_1_output_0'. It is no longer used by any node.
2023-03-15 21:30:56.704386371 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/ConstantOfShape_output_0'. It is no longer used by any node.
2023-03-15 21:30:56.704392871 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/Concat_1_output_0'. It is no longer used by any node.
2023-03-15 21:30:56.704405170 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/ConstantOfShape_1_output_0'. It is no longer used by any node.
2023-03-15 21:30:56.704425669 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/Transpose_1_output_0'. It is no longer used by any node.
2023-03-15 21:30:56.70

In [34]:

onnxblock.save_checkpoint(training_block.parameters(), f"training_artifacts/{model_name}.ckpt")
onnx.save(onnx_model, f"training_artifacts/{model_name}_training.onnx")
onnx.save(eval_model, f"training_artifacts/{model_name}_eval.onnx")
onnx.save(optimizer_model, f"training_artifacts/{model_name}_optimizer.onnx")

Bad pipe message: %s [b'\xbe\xce\xf1/j\x9e\xbf\xd0\xc5#\xafQR\x93\x7fga\x99 ]\xfe%\xabt\xad-jc\xb7-\xfb\xcd\xe4\xa3\rI\xcb\xac4\xfd\x0b\x0f\x88\x8a\x1b\x80\x9c\xf8\xd2|\xdd\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05', b'\x04\x01\x05\x01\x06\x01\x00']
Bad pipe message: %s [b"\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 <^\n\xc8'h\x8a\x04G\xab\x04\x8d\x1f\x92|\xdc\\\xe9\xb6\xab\t\x88"]
Bad pipe message: %s [b't\x96v\x8e\xa0`\xbd\xaa\xc6&\x17mR\xd7\x05\xc3\xbc,\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]']
Bad pipe message: %s [b"\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k