In [76]:
from onnxruntime.training import artifacts
import torch
import onnx
import transformers

In [77]:
pipeline = transformers.pipeline(
    "text-generation",
    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
)

# transformers_model = transformers.LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", config=transformers.LlamaConfig(num_hidden_layers=2), ignore_mismatched_sizes=True)

In [78]:
tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")

In [79]:
class FlatModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *local_inputs):
        return self.model(inputs.input_ids, inputs.attention_mask)

model = FlatModel(pipeline.model)
# model = FlatModel(transformers_model)

In [81]:
input_names = ["input_ids", "attention_mask", "position_ids"]
output_names = ["loss", "logits"]

torch.onnx.export(model,
                  (inputs["input_ids"], inputs["attention_mask"]),
                  "tinyllama.onnx",
                  input_names = input_names, 
                  output_names = output_names,
                  export_params=True,
                  training=torch.onnx.TrainingMode.TRAINING,
                  do_constant_folding=False,
                  dynamic_axes={
                    "input_ids": {0: "batch_size", 1: "sequence_length"},
                    "attention_mask": {0: "batch_size", 1: "sequence_length"},
                    "logits": {0: "batch_size", 1: "sequence_length"}
                  }
                  )

In [82]:
requires_grad = []
frozen_params = []
for name, param in model.named_parameters():
    if param.requires_grad:
        requires_grad.append(name)
    else:
        frozen_params.append(name)

for name, param in model.named_buffers():
    frozen_params.append(name)

onnx_model = onnx.load("tinyllama.onnx")

artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    optimizer=artifacts.OptimType.AdamW,
)



ValueError: This protobuf of onnx model is too large (>2GB). Call check_model with model path instead.