In [33]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PegasusTokenizer
from transformers import PegasusForConditionalGeneration

import onnx
import onnxruntime as ort

In [34]:
ori_model_name="google/pegasus-xsum"

In [35]:
ori_pegasus_model = PegasusForConditionalGeneration.from_pretrained(ori_model_name)
ori_tokenizer = PegasusTokenizer.from_pretrained(ori_model_name)

In [36]:
def export_encoder(model, args, exported_model_path):
    model.eval()
    with torch.no_grad():
        _ = torch.onnx._export(model,
                           args,
                           exported_model_path,
                           export_params=True,
                           opset_version=12,
                           input_names=['input_ids'],
                           output_names=['hidden_states'],
                           dynamic_axes={
                               'input_ids': {0:'batch', 1: 'sequence'},
                               'hidden_states': {0:'batch', 1: 'sequence'},
                           })

In [37]:
class DecoderWithLMHead(torch.nn.Module):
    def __init__(self, decoder, lm_head, final_logits_bias):
        super().__init__()
        self.decoder = decoder
        self.lm_head = lm_head
        self.final_logits_bias = final_logits_bias
        
    def forward(self, input_ids, encoder_hidden_states):
        outputs = self.decoder(input_ids=input_ids,
                               attention_mask=None,
                               encoder_hidden_states=encoder_hidden_states)
        logits = self.lm_head(outputs[0]) + self.final_logits_bias
        next_token_logits = logits[:, -1, :]
        log_softmaxed = F.log_softmax(next_token_logits, 1)
        topk = torch.topk(log_softmaxed, 5, largest=True)
        return topk.values, topk.indices

In [38]:
def export_decoder(model, decoder_inputs, encoded, model_path):
    model.eval()
    with torch.no_grad():
        _ = torch.onnx.export(model,
                  (decoder_inputs, encoded),
                  output_decoder_path,
                  export_params=True,
                  opset_version=12,
                  input_names=['input_ids', 'encoder_hidden_states'],
                  output_names=['log_softmax', 'indices'],
                  dynamic_axes={
                          'input_ids': {0:'batch', 1: 'sequence'},
                          'encoder_hidden_states': {0:'batch', 1: 'sequence_encoder_length'},
                          'log_softmax': {0:'batch'},
                          'indices': {0:'batch'},
                 })

In [39]:
output_encoder_path = "./onnx_output/encoder_xsum_0129.onnx"
output_decoder_path = "./onnx_output/decoder_lm_xsum_0129.onnx"

In [40]:
export_text = """
I have been going over my folder.
"""

In [41]:
def export_encoder_and_decoder(tokenizer, model, export_text, output_encoder_path, output_decoder_path):
    export_input = tokenizer(export_text, return_tensors='pt')
    export_encoder(model.model.encoder, export_input['input_ids'], output_encoder_path)
    decoder_lm_head = DecoderWithLMHead(model.model.decoder, model.lm_head, model.final_logits_bias)
    export_decoder(decoder_lm_head, export_input['input_ids'], model.model.encoder(input_ids=export_input['input_ids']).last_hidden_state,output_decoder_path) 

In [42]:
export_encoder_and_decoder(ori_tokenizer, ori_pegasus_model, export_text, output_encoder_path, output_decoder_path)