-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Conversion for LongFormer predictions different #505
Comments
Hi @adithya1111 thank you for the detailed bug report following our discussion! To provide you with some context, what magically happens when you pass One limitation of the ONNX export (reference) as e.g. detailed in this issue is that an example of input needs to be provided to trace the execution. With longformer, an issue arises: some control flow (if/else, loops) depend on the input dimension! Hence, if you use very different input dimension (sequence length), the ONNX output will be garbage. I will add an immediate (partial) solution: add an option to pass a target sequence length during the ONNX export, so that the registered control flow match with what will be actually used during inference with the ONNX model. However, the exported model will only work with certain sequence lengths. You can refer to #503 for tracking of the global dynamic control flow issue. Alternatively, you can give a try to https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/longformer , I haven't but it's on my todo! Easy reproduction of the issue:
and from transformers import AutoTokenizer, AutoModel
import numpy as np
import onnxruntime as ort
import torch
torch.manual_seed(0)
def get_inp(tokenizer, batch_size, seq_len, device="cpu"):
input_ids = torch.randint(tokenizer.vocab_size, (batch_size, seq_len))
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64)
global_attention_mask = torch.randint(2, (batch_size, seq_len), dtype=torch.int64)
if device == "cuda:0":
input_ids = input_ids.to("cuda:0")
attention_mask = attention_mask.to("cuda:0")
global_attention_mask = global_attention_mask.to("cuda:0")
return input_ids, attention_mask, global_attention_mask
model_name = "fxmarty/tiny-random-longformer"
tokenizer = AutoTokenizer.from_pretrained(model_name)
pt_model = AutoModel.from_pretrained(model_name)
onnx_path = "/home/fxmarty/tiny_random_longformer_mine/model.onnx"
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
for batch_size in range(1, 8):
for seq_len in range(1, 200):
input_ids, attention_mask, global_attention_mask = get_inp(tokenizer, batch_size, seq_len, device="cpu")
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"attention_mask": attention_mask.cpu().detach().numpy(),
"global_attention_mask": global_attention_mask.cpu().detach().numpy()
}
res_pt = pt_model(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
res_ort = session.run(None, onnx_inputs)
pt_last_hidden_state = res_pt["last_hidden_state"].detach().numpy()
ort_last_hidden_state = res_ort[0]
if pt_last_hidden_state.shape != ort_last_hidden_state.shape:
print(f"[x] bs={batch_size}, seq_len={seq_len}. Shape error")
diff = np.max(np.abs(pt_last_hidden_state - ort_last_hidden_state))
if diff > 1e-3:
print(f"[x] bs={batch_size}, seq_len={seq_len}. Maxdiff: {diff}") |
Hey @fxmarty : Thanks so much for the detailed explanation and code. I was able to match the predictions with the old code itself. The difference in predictions was in the original model's predictions, I was not passing any global attention mask. So when I explicity passed it like you mentioned above I have couple of follow up questions
|
Right, you've got a pretty good understanding. Something I forgot to mention is that longformer is currently not supported in the
I am not too familiar with the model, and info on this in lacking in the paper/issues I've read. But you can have a look.
I think I've read this as well allenai/longformer#67 (comment)
Yes, good catch! It's a limitation of the
What do you mean here? Which issue? |
Oh I see! This issue is fixed. So the issue was not specific to sequence length >512, it was more about, depending on which example input you provide to What (I hope) you can expect in the next release is a better support indicating the working and failing cases of the exported ONNX model, along with more options at export time to choose depending on the downstream use case (e.g. which sequence length you expect). |
Hey sorry for my late reply. Debugging this afternoon I found where the issue is coming from: microsoft/onnxruntime#13920 Basically the model definition is broken for the ONNX export, and although the export silently passes the exported model will not work with arbitrary sequence lengths. #473 should have been fixed though. Can you try with the latest release of transformers?
yes, but I've not worked on it yet. It could also be that fixing the model definition itself in transformers will solve the issue, and the exported ONNX model will be usable with arbitrary sequence length. |
@fxmarty - Thanks for the comment. I updated to the latest version and re-exported my model to ONNX
And my prediction code :
Results comparison : ONNX Model:
Original model
For less than 512 tokens the outputs from the original model matches with the ONNX model , but for greater than 512 tokens, previously I was getting a shape error. With the latest release that error is gone. But the predictions dont match. Questions ?
|
Hey @adithya1111 Thanks again for testing!
Yes, this is ""expected"". The issue is in the modeling of longformer, making the ONNX export silently fail, see: pytorch/pytorch#90607 This bug would need to be fixed, or the modeling modified, for the exported model to work for all sequence lengths. To answer your questions:
|
Hi @adithya1111 , hope you do well! To follow up on my previous answer the support for overriding the default shapes at ONNX export has been added in the latest release! You can check
to use it! Hint:
This should allow to alleviate your issue while pytorch/pytorch#90607 is fixed on PyTorch's side. |
System Info
Python 3.9.12 OS : MacOS Optimum Version: 1.5.0
Who can help?
@JingyaHuang @lewtun @fxmarty @ydshieh : Hello guys. I have a longformer model finetuned on my dataset. I have the model exported to ONNX using Optimum. When I am trying to make predictions with the ONNX model, the results dont match with the original model's predictions. PFB the details. The final parts contain the original and ONNX model predictions and they are some what closer in some cases. But they seem to be off. So wanted to check if this is because of the global attention mask values I am passing. The model was trained using the default hyper parameters
Export code to ONNX :
Original model predictions code :
ONNX inference code:
Original model predictions :
ONNX model predictions:
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Code snippets shared above
Expected behavior
The expected behavior is the outputs from ONNX model and original model should match. Here it is somewhat close. But still the results seem to be off
The text was updated successfully, but these errors were encountered: