Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Conversion for LongFormer predictions different #505

Open
2 of 4 tasks
adithya1111 opened this issue Nov 22, 2022 · 10 comments
Open
2 of 4 tasks

ONNX Conversion for LongFormer predictions different #505

adithya1111 opened this issue Nov 22, 2022 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@adithya1111
Copy link

System Info

Python 3.9.12
OS : MacOS
Optimum Version: 1.5.0

Who can help?

@JingyaHuang @lewtun @fxmarty @ydshieh : Hello guys. I have a longformer model finetuned on my dataset. I have the model exported to ONNX using Optimum. When I am trying to make predictions with the ONNX model, the results dont match with the original model's predictions. PFB the details. The final parts contain the original and ONNX model predictions and they are some what closer in some cases. But they seem to be off. So wanted to check if this is because of the global attention mask values I am passing. The model was trained using the default hyper parameters

{
  "_name_or_path": "/opt/ml/input/data/model-base",
  "architectures": [
    "LongformerForTokenClassification"
  ],
  "attention_mode": "longformer",
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3"
  },
  "ignore_attention_mask": false,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3
  },
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 4098,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "sep_token_id": 2,
  "torch_dtype": "float32",
  "transformers_version": "4.9.1",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

Export code to ONNX :

from optimum.onnxruntime import ORTModelForTokenClassification
import onnxruntime as ort
onnx_model = ORTModelForTokenClassification.from_pretrained(original_model_dir, from_transformers=True)
onnx_path = Path("onnx-longformer-model-optimum")
onnx_model.save_pretrained(onnx_path)

Original model predictions code :

ocr_text = 'test 123'
tokenized_values = tokenizer(ocr_text, max_length=int(model_config.max_position_embeddings - 2),
                          truncation=True, return_tensors='pt')

predictions = original_model(tokenized_values["input_ids"])

ONNX inference code:

import onnxruntime as ort
session = ort.InferenceSession("/Users/i849730/Desktop/onnx-testing/onnx-longformer-model-optimum/model.onnx", providers=["CPUExecutionProvider"])
import onnxruntime as ort
ocr_text = 'test 123'
tokenized_values = tokenizer(ocr_text, max_length=int(model_config.max_position_embeddings - 2),
                          truncation=True, return_tensors='pt')

predictions = longformer_model(tokenized_values["input_ids"])

session = ort.InferenceSession("/home/ec2-user/SageMaker/invoice-docsep/onnx-conversion-optimum/model.onnx", providers=["CPUExecutionProvider"])

#input_ids = torch.randint(tokenizer.vocab_size, (batch_size, seq_len))
input_ids = tokenized_values["input_ids"]
#attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64)
attention_mask = tokenized_values['attention_mask']

global_attention_mask = torch.zeros_like(input_ids)
# make every second token global
#global_attention_mask[:, 0] = 1
global_attention_mask[:, ::2] = 1
global_attention_mask = global_attention_mask.cpu().detach().numpy()

onnx_inputs = {
    "input_ids": input_ids.cpu().detach().numpy(),
    "attention_mask": attention_mask.cpu().detach().numpy(),
    "global_attention_mask" : global_attention_mask

}

res = session.run(None, onnx_inputs)

Original model predictions :

LongformerTokenClassifierOutput(loss=None, logits=tensor([[[ 0.6209,  0.0719,  0.1107, -0.5316],
         [ 3.0321, -0.2787, -0.6460, -2.5359],
         [ 2.6904,  0.1169, -0.7495, -2.8346],
         [ 0.6474,  0.0761,  0.1041, -0.5438]]], grad_fn=<AddBackward0>), hidden_states=None, attentions=None, global_attentions=None)

ONNX model predictions:

[array([[[ 0.49600145,  0.08062335,  0.12902021, -0.4010917 ],
         [ 3.0400352 , -0.34643874, -0.6276542 , -2.444679  ],
         [ 2.158992  ,  0.02124629, -0.5462518 , -2.094074  ],
         [ 0.6290194 ,  0.06919068,  0.10753635, -0.5197539 ]]],
       dtype=float32)]

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code snippets shared above

Expected behavior

The expected behavior is the outputs from ONNX model and original model should match. Here it is somewhat close. But still the results seem to be off

@fxmarty
Copy link
Contributor

fxmarty commented Nov 25, 2022

Hi @adithya1111 thank you for the detailed bug report following our discussion!

To provide you with some context, what magically happens when you pass from_transformers=True is that the original PyTorch model is exported to ONNX, and then loaded into ORTModel for inference with ONNX Runtime.

One limitation of the ONNX export (reference) as e.g. detailed in this issue is that an example of input needs to be provided to trace the execution. With longformer, an issue arises: some control flow (if/else, loops) depend on the input dimension! Hence, if you use very different input dimension (sequence length), the ONNX output will be garbage.

I will add an immediate (partial) solution: add an option to pass a target sequence length during the ONNX export, so that the registered control flow match with what will be actually used during inference with the ONNX model. However, the exported model will only work with certain sequence lengths. You can refer to #503 for tracking of the global dynamic control flow issue.

Alternatively, you can give a try to https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/longformer , I haven't but it's on my todo!


Easy reproduction of the issue:

python -m transformers.onnx --model fxmarty/tiny-random-longformer --feature default tiny_random_longformer_mine

and

from transformers import AutoTokenizer, AutoModel

import numpy as np
import onnxruntime as ort
import torch

torch.manual_seed(0)

def get_inp(tokenizer, batch_size, seq_len, device="cpu"):
    input_ids = torch.randint(tokenizer.vocab_size, (batch_size, seq_len))
    attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64)
    global_attention_mask = torch.randint(2, (batch_size, seq_len), dtype=torch.int64)
    
    if device == "cuda:0":
        input_ids = input_ids.to("cuda:0")
        attention_mask = attention_mask.to("cuda:0")
        global_attention_mask = global_attention_mask.to("cuda:0")

    return input_ids, attention_mask, global_attention_mask


model_name = "fxmarty/tiny-random-longformer"

tokenizer = AutoTokenizer.from_pretrained(model_name)
pt_model = AutoModel.from_pretrained(model_name)

onnx_path = "/home/fxmarty/tiny_random_longformer_mine/model.onnx"
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

for batch_size in range(1, 8):
    for seq_len in range(1, 200):
        input_ids, attention_mask, global_attention_mask = get_inp(tokenizer, batch_size, seq_len, device="cpu")
        
        onnx_inputs = {
            "input_ids": input_ids.cpu().detach().numpy(),
            "attention_mask": attention_mask.cpu().detach().numpy(),
            "global_attention_mask": global_attention_mask.cpu().detach().numpy()
        }
            
        res_pt = pt_model(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
        res_ort = session.run(None, onnx_inputs)
        
        pt_last_hidden_state = res_pt["last_hidden_state"].detach().numpy()
        ort_last_hidden_state = res_ort[0]

        if pt_last_hidden_state.shape != ort_last_hidden_state.shape:
            print(f"[x] bs={batch_size}, seq_len={seq_len}. Shape error")
        
        diff = np.max(np.abs(pt_last_hidden_state - ort_last_hidden_state))
        if diff > 1e-3:
            print(f"[x] bs={batch_size}, seq_len={seq_len}. Maxdiff: {diff}")

@adithya1111
Copy link
Author

adithya1111 commented Nov 29, 2022

Hey @fxmarty : Thanks so much for the detailed explanation and code. I was able to match the predictions with the old code itself. The difference in predictions was in the original model's predictions, I was not passing any global attention mask. So when I explicity passed it like you mentioned above
predictions = original_model(input_ids = tokenized_values["input_ids"], attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
that worked. Thanks for that.

I have couple of follow up questions

  • I know LongFormer's global attention can be provided at run time. The model I described above is a NER model (token-classification) one. Is there a good standard for the global attention values ?
  • For classification I read that, just set the the '<s>' token's global attention to be 1 and rest as 0. Is this correct ?
  • I tried experimenting with setting the global attention mask to 1 only for the <s> token. global_attention_mask = torch.tensor([[torch.tensor(1) if tokenizer.convert_ids_to_tokens(int(i))=='<s>' else torch.tensor(0) for i in input_ids[0]]])
tokenized_values = tokenizer(ocr_text, max_length=int(model_config.max_position_embeddings - 2),
                          truncation=True, return_tensors='pt')
input_ids = tokenized_values['input_ids']
tokenized_values_length = tokenized_values['input_ids'].shape[1]
attention_mask = torch.ones((1, tokenized_values_length), dtype=torch.int64)
global_attention_mask = torch.tensor([[torch.tensor(1) if tokenizer.convert_ids_to_tokens(int(i))=='<s>' else torch.tensor(0) for i in input_ids[0]]])
  • For my expected entity classes, the probability scores are ~0.01 higher if I go with global attention to be all 0's vs just set the "<s>" token's global attention to be 1 and rest as 0. So any suggestions on what can be a good standard ?
  • Another issue is that, if I provide the global attention all 0s , the ONNX inference fails (with a floating point error) , which I think is expected. I read about this in another open issue as well. So I cannot go with all 0s in global attention in ONNX for now.
  • For sequences greater than 512 tokens, the issue would be fixed in the upcoming release right ?

@fxmarty
Copy link
Contributor

fxmarty commented Nov 29, 2022

Right, you've got a pretty good understanding.

Something I forgot to mention is that longformer is currently not supported in the ORTModel, since the global_attention_mask argument is not supported. We should imo find a clean solution since it is not the only model that may have a custom input. The issue is tracked in #479

I know LongFormer's global attention can be provided at run time. The model I described above is a NER model (token-classification) one. Is there a good standard for the global attention values ?

I am not too familiar with the model, and info on this in lacking in the paper/issues I've read. But you can have a look.

For classification I read that, just set the the <s> token's global attention to be 1 and rest as 0. Is this correct ?

I think I've read this as well allenai/longformer#67 (comment)

Another issue is that, if I provide the global attention all 0s , the ONNX inference fails (with a floating point error) , which I think is expected. I read about this in another open issue as well. So I cannot go with all 0s in global attention in ONNX for now.

Yes, good catch! It's a limitation of the transformers.onnx export right now. I plan to extend the export in optimum.exporters to be more flexible.

For sequences greater than 512 tokens, the issue would be fixed in the upcoming release right ?

What do you mean here? Which issue?

@adithya1111
Copy link
Author

Thanks @fxmarty for your response : For the last question, I meant this one : #473

@fxmarty
Copy link
Contributor

fxmarty commented Nov 30, 2022

Oh I see! This issue is fixed. So the issue was not specific to sequence length >512, it was more about, depending on which example input you provide to torch.onnx.export, the export ONNX model can support some sequence length or not. For the model I was testing, the issue was arising with sequence length >512.

What (I hope) you can expect in the next release is a better support indicating the working and failing cases of the exported ONNX model, along with more options at export time to choose depending on the downstream use case (e.g. which sequence length you expect).

@adithya1111
Copy link
Author

Cool @fxmarty . I am using LongFormer and I see the issue for sequence length > 512 tokens, where I am getting the same error described by you here #473.

So the newer version would provide more options to handle > 512 tokens ? Is that a fair assumption ?

@fxmarty
Copy link
Contributor

fxmarty commented Dec 9, 2022

Hey sorry for my late reply. Debugging this afternoon I found where the issue is coming from: microsoft/onnxruntime#13920

Basically the model definition is broken for the ONNX export, and although the export silently passes the exported model will not work with arbitrary sequence lengths.

#473 should have been fixed though. Can you try with the latest release of transformers?

So the newer version would provide more options to handle > 512 tokens ? Is that a fair assumption ?

yes, but I've not worked on it yet. It could also be that fixing the model definition itself in transformers will solve the issue, and the exported ONNX model will be usable with arbitrary sequence length.

@adithya1111
Copy link
Author

adithya1111 commented Dec 9, 2022

@fxmarty - Thanks for the comment. I updated to the latest version and re-exported my model to ONNX

from optimum.onnxruntime import ORTModelForTokenClassification
import onnxruntime as ort
original_model_dir = '/Users/i849730/Desktop/onnx-testing/longformer_model/'
onnx_model = ORTModelForTokenClassification.from_pretrained(original_model_dir, from_transformers=True)
onnx_path = Path("onnx-longformer-model-optimum-v1")
onnx_model.save_pretrained(onnx_path)

And my prediction code :

input_ids, attention_mask, global_attention_mask = get_tokenized_values(ocr_text)
onnx_inputs = {
        "input_ids": input_ids.cpu().detach().numpy(),
        "attention_mask": attention_mask.cpu().detach().numpy(),
        "global_attention_mask": global_attention_mask.cpu().detach().numpy()}

res_ort = session.run(None, onnx_inputs)
scores_onnx = get_scores_from_logits(res_ort[0][0])
print(scores_onnx)

Results comparison :

ONNX Model:

[[0.4059719  0.21155939 0.28428534 0.09818341]
 [0.40278274 0.16127673 0.3333195  0.10262107]
 [0.40807104 0.15910122 0.33399487 0.09883281]
 ...

Original model

[[0.9999883  0.00000516 0.00000635 0.00000025]
 [0.9999865  0.0000102  0.0000031  0.00000012]
 [0.999992   0.00000503 0.0000029  0.00000013]
 ...
 [0.9999939  0.00000213 0.0000

For less than 512 tokens the outputs from the original model matches with the ONNX model , but for greater than 512 tokens, previously I was getting a shape error. With the latest release that error is gone. But the predictions dont match.

Questions ?

  1. Do I need to pass a dummy input when I am exporting ? My inputs are all variable length
  2. Do I need to add dynamic_axes in my ONNX export ? Would that solve the issue. If so, is there code on how to do it ?

@fxmarty
Copy link
Contributor

fxmarty commented Dec 12, 2022

Hey @adithya1111 Thanks again for testing!

With the latest release that error is gone. But the predictions dont match.

Yes, this is ""expected"". The issue is in the modeling of longformer, making the ONNX export silently fail, see: pytorch/pytorch#90607 This bug would need to be fixed, or the modeling modified, for the exported model to work for all sequence lengths.

To answer your questions:

  1. Unfortunately the ONNX export may record a path based on the sample example that is not generalizable to all shapes. Currently that's the issue in longformer (due to a bug in PyTorch) and longt5 (unsure, have to investigate). As a temporary solution (and for testing), I am proposing Support for custom input shapes in exporters onnx #575 .
  2. No, this is handled automatically in the OnnxConfig.

@fxmarty
Copy link
Contributor

fxmarty commented Dec 26, 2022

Hi @adithya1111 , hope you do well!

To follow up on my previous answer the support for overriding the default shapes at ONNX export has been added in the latest release! You can check

optimum-cli export onnx --help

to use it! Hint:

optimum-cli export onnx --model model-name-or-path --sequence_length 450 longformer_onnx/

This should allow to alleviate your issue while pytorch/pytorch#90607 is fixed on PyTorch's side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants