In [1]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PegasusTokenizer
from transformers import PegasusForConditionalGeneration

import onnx
import onnxruntime as ort

In [2]:
ori_model_name="google/pegasus-xsum"

In [3]:
ori_pegasus_model = PegasusForConditionalGeneration.from_pretrained(ori_model_name)
ori_tokenizer = PegasusTokenizer.from_pretrained(ori_model_name)

In [4]:
def export_encoder(model, args, exported_model_path):
    model.eval()
    with torch.no_grad():
        _ = torch.onnx._export(model,
                           args,
                           exported_model_path,
                           export_params=True,
                           opset_version=14,
                           input_names=['input_ids'],
                           output_names=['hidden_states'],
                           dynamic_axes={
                               'input_ids': {0:'batch', 1: 'sequence'},
                               'hidden_states': {0:'batch', 1: 'sequence'},
                           })

In [10]:
class DecoderWithLMHead(torch.nn.Module):
    def __init__(self, decoder, lm_head, final_logits_bias):
        super().__init__()
        self.decoder = decoder
        self.lm_head = lm_head
        self.final_logits_bias = final_logits_bias
        
    def forward(self, input_ids, encoder_hidden_states):
        outputs = self.decoder(input_ids=input_ids,
                               attention_mask=None,
                               encoder_hidden_states=encoder_hidden_states)
        logits = self.lm_head(outputs[0]) + self.final_logits_bias
        next_token_logits = logits[:, -1, :]
        log_softmaxed = F.log_softmax(next_token_logits, 1)
        topk = torch.topk(log_softmaxed, 5, largest=True)
        return topk.values, topk.indices

In [11]:
def export_decoder(model, decoder_inputs, encoded, model_path):
    model.eval()
    with torch.no_grad():
        _ = torch.onnx.export(model,
                  (decoder_inputs, encoded),
                  output_decoder_path,
                  export_params=True,
                  opset_version=12,
                  input_names=['input_ids', 'encoder_hidden_states'],
                  output_names=['log_softmax', 'indices'],
                  dynamic_axes={
                          'input_ids': {0:'batch', 1: 'sequence'},
                          'encoder_hidden_states': {0:'batch', 1: 'sequence'},
                          'log_softmax': {0:'batch'},
                          'indices': {0:'batch'},
                 })

In [12]:
output_encoder_path = "./onnx_output/encoder_xsum_0129.onnx"
output_decoder_path = "./onnx_output/decoder_lm_xsum_0129.onnx"

In [13]:
export_text = """
I have been going over my folder.
"""

In [14]:
def export_encoder_and_decoder(tokenizer, model, export_text, output_encoder_path, output_decoder_path):
    export_input = tokenizer(export_text, return_tensors='pt')
    export_encoder(model.model.encoder, export_input['input_ids'], output_encoder_path)
    decoder_lm_head = DecoderWithLMHead(model.model.decoder, model.lm_head, model.final_logits_bias)
    export_decoder(decoder_lm_head, export_input['input_ids'], model.model.encoder(input_ids=export_input['input_ids']).last_hidden_state,output_decoder_path) 

In [15]:
export_encoder_and_decoder(ori_tokenizer, ori_pegasus_model, export_text, output_encoder_path, output_decoder_path)

  if input_shape[-1] > 1:
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):


## Validataion

In [2]:
###Setup

In [16]:
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
EP_list = ['CUDAExecutionProvider']
ori_encoder_session = ort.InferenceSession(output_encoder_path, sess_options, providers=EP_list)

In [17]:
text = """I have already taken classes in NLP, ML (both intro and grad level), and algorithms. 
I was even the teaching assistant for algorithms. 
I even was able to combine all of these in a self-project 
where I built a neural machine translation model capable of going from Shakespearean to modern English, 
which was able to make it all the way to the top of Hacker News. 
"""
test_input=ori_tokenizer([text], return_tensors='pt')
test_input["input_ids"].shape

torch.Size([1, 76])

### Encoder

In [18]:
# Pytorch encoder result
with torch.no_grad():
    pt_encoder_hidden_state = ori_pegasus_model.model.encoder(input_ids=test_input['input_ids']).last_hidden_state
# onnx encoder_result
encoder_output_ori_onnx = ori_encoder_session.run(None, {'input_ids':test_input["input_ids"].numpy()})

In [19]:
np.allclose(pt_encoder_hidden_state.numpy(), encoder_output_ori_onnx[0], atol=1e-4)

True

### Decoder

Validate decoder result using greedy search.

#### Pytorch results

In [20]:
def get_top_prob_index(decoder, lm_head, decoder_input_ids, encoder_hidden_states, final_logits_bias, output_topk):
    outputs = decoder(input_ids=decoder_input_ids,
                      attention_mask=None,
                      encoder_hidden_states=encoder_hidden_states)
    logits = lm_head(outputs[0]) + final_logits_bias
    next_token_logits = logits[:, -1, :]
    best_token_index = torch.argmax(next_token_logits, 1)
    if output_topk:
        log_softmaxed = F.log_softmax(next_token_logits, 1)
        topk = torch.topk(log_softmaxed, 5, largest=True)
        return (topk, best_token_index)
    else:
        return (next_token_logits, best_token_index)

In [21]:
def get_results_from_pytorch(decoder, lm_head, decoder_input_ids, encoder_hidden_states, final_logits_bias, output_topk):
    decoder_input_cur = decoder_input_ids
    next_token_logits_array = []
    while True:
        (next_token_logits, best_next) = get_top_prob_index(decoder, lm_head, decoder_input_cur, encoder_hidden_states, final_logits_bias, output_topk)
        next_token_logits_array.append(next_token_logits)    
        decoder_input_cur = torch.cat([decoder_input_cur, best_next.unsqueeze(1)], dim=-1)
        if best_next == 1:
            break
    return (decoder_input_cur, next_token_logits_array)

In [22]:
decoder_inputs = torch.tensor([[0]]).long()

In [23]:
with torch.no_grad():
    (summarization_id_pt, next_token_pt) = get_results_from_pytorch(ori_pegasus_model.model.decoder, ori_pegasus_model.lm_head, \
                                                                    decoder_inputs, pt_encoder_hidden_state,  ori_pegasus_model.final_logits_bias, True)

In [25]:
print(len(next_token_pt))

22


In [26]:
next_token_pt[1]

torch.return_types.topk(
values=tensor([[-1.0071, -1.7436, -2.3653, -3.0019, -3.2337]]),
indices=tensor([[346, 131, 133, 245, 123]]))

#### Onnx results

In [27]:
ori_decoder_session = ort.InferenceSession(output_decoder_path, sess_options, providers=EP_list)

In [30]:
test_decoder_output = ori_decoder_session.run(None, {'input_ids': decoder_inputs.numpy(), "encoder_hidden_states": encoder_output_ori_onnx[0]})

In [31]:
test_decoder_output

[array([[-0.65140724, -2.956626  , -3.3299437 , -3.4845095 , -4.494252  ]],
       dtype=float32),
 array([[ 125,  600, 8087,  240,  398]], dtype=int64)]

In [28]:
def get_summarization_ids_onnx(encoder_output, decoder_session, init_decoder_inputs, max_length):
    decoder_outputs = init_decoder_inputs
    next_token_info_array = []
    current_length = 1
    while current_length < max_length:
        onnx_decoder_outputs = decoder_session.run(None, {'input_ids': decoder_outputs, "encoder_hidden_states": encoder_output[0]})
        next_token_info_array.append(onnx_decoder_outputs)
        next_tokens = np.asarray([onnx_decoder_outputs[1][0][0]])
        decoder_outputs = np.concatenate([decoder_outputs, next_tokens[:, None]], axis=-1)
        if next_tokens[0] == 1:
            break
        current_length+=1
    return (decoder_outputs, next_token_info_array)

In [29]:
(summarization_id_onnx, next_token_onnx) = \
get_summarization_ids_onnx(encoder_output_ori_onnx, ori_decoder_session, decoder_inputs.numpy(), 1000)

In [30]:
print(len(summarization_id_onnx[0]))

23


In [31]:
assert(len(next_token_onnx) == len(next_token_pt))

In [32]:
not_close=False
for i in range(len(next_token_onnx)):
    if not np.allclose(next_token_pt[i][0].numpy(), next_token_onnx[i][0], atol=1e-4):
        not_close=True
if not_close:
    print("Not all decoder result close")
else:
    print("All decoder result close")

All decoder result close


In [37]:
def get_decoded_text(tokenizer, summarized_id):
    return [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summarized_id]

In [38]:
get_decoded_text(ori_tokenizer, summarization_id_pt) == get_decoded_text(ori_tokenizer, summarization_id_onnx)

True

In [39]:
get_decoded_text(ori_tokenizer, summarization_id_onnx)

['I am a junior at the University of California, Berkeley, and am interested in Natural Language Processing (NLP).']