## library imports

In [35]:
import torch
import transformers
import onnx
import onnxruntime.training.onnxblock as onnxblock

## generating artifacts

In [36]:
model = transformers.AutoModel.from_pretrained('google/mobilebert-uncased')
model_name = 'mobilebert-uncased'

Some weights of the model checkpoint at google/mobilebert-uncased were not used when initializing MobileBertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing MobileBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MobileBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [37]:
# ... painstakingly create the random input

# expects
# input_ids = torch.LongTensor of shape (batch size, seq len)
# attention_mask = torch.FloatTensor of shape (batch size, seq len)
# token_type_ids = torch.LongTensor of shape (bs, seq len)

batch_size = 2
seq_len = 25
vocab = 20000
input_ids = torch.randint(vocab, (batch_size, seq_len))
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.float)
token_type_ids = torch.ones((batch_size, seq_len), dtype=torch.long)


In [40]:
torch.onnx.export(model, (input_ids, attention_mask, token_type_ids),
                  f"training_artifacts/{model_name}.onnx", 
                  input_names=["input_ids", "attention_mask", "token_type_ids"],
                  output_names=["output"])

In [42]:
class MobileBERTWithLoss(onnxblock.TrainingModel):
    def __init__(self):
        super().__init__()
        self.loss = onnxblock.loss.CrossEntropyLoss()

    def build(self, loss_node_input_name):
        return self.loss(loss_node_input_name)


# Load the model from the exported inference ONNX file.
onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
eval_model = None
optimizer_model = None

training_block = MobileBERTWithLoss()

inference_model_output_name = "output"
with onnxblock.onnx_model(onnx_model) as model_accessor:
    loss_output_name = training_block(inference_model_output_name)
    eval_model = model_accessor.eval_model

optimizer_block = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as model_accessor:
    optimizer_outputs = optimizer_block(training_block.parameters())
    optimizer_model = model_accessor.model

2023-03-16 00:02:58.228826786 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/Slice_1_output_0'. It is no longer used by any node.
2023-03-16 00:02:58.228864285 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/ConstantOfShape_output_0'. It is no longer used by any node.
2023-03-16 00:02:58.228949181 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/Concat_1_output_0'. It is no longer used by any node.
2023-03-16 00:02:58.228955281 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/ConstantOfShape_1_output_0'. It is no longer used by any node.
2023-03-16 00:02:58.228967580 [I:onnxruntime:Default, graph.cc:3493 CleanUnusedInitializersAndNodeArgs] Removing initializer '/embeddings/Transpose_1_output_0'. It is no longer used by any node.
2023-03-16 00:02:58.22

In [43]:

onnxblock.save_checkpoint(training_block.parameters(), f"training_artifacts/{model_name}.ckpt")
onnx.save(onnx_model, f"training_artifacts/{model_name}_training.onnx")
onnx.save(eval_model, f"training_artifacts/{model_name}_eval.onnx")
onnx.save(optimizer_model, f"training_artifacts/{model_name}_optimizer.onnx")