Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART + ONNX torch.jit error iterabletree cannot be used as a value #14491

Closed
1 task
polly-morphism opened this issue Nov 22, 2021 · 11 comments
Closed
1 task

Comments

@polly-morphism
Copy link

Environment info

onnx 1.10.2
onnxruntime 1.9.0

  • transformers version: transformers 4.13.0.dev0
  • Platform: Ubuntu 18.4
  • Python version: 3.8
  • PyTorch version (GPU?): torch 1.8.0 gpu
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help

@fatcat-z @mfuntowicz @sgugger, @patil-suraj

Information

Model I am using: BartForConditionalGeneration

The problem arises when using:

To reproduce

Steps to reproduce the behavior:

python3.8 run_onnx_exporter.py --model_name_or_path facebook/bart-base

2021-11-22 17:34:47 | INFO | __main__ |  [run_onnx_exporter.py:224] Exporting model to ONNX
/home/pverzun/.local/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:217: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/pverzun/.local/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:223: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):
/home/pverzun/.local/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:254: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
/home/pverzun/.local/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:888: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_shape[-1] > 1:
/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_trace.py:934: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
  module._c._create_method_from_trace(
/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_trace.py:152: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  if a.grad is not None:

Traceback (most recent call last):
  File "run_onnx_exporter.py", line 229, in <module>
    main()
  File "run_onnx_exporter.py", line 225, in main
    export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)
  File "run_onnx_exporter.py", line 116, in export_and_validate_model
    **bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model))**
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_script.py", line 942, in script
    return torch.jit._recursive.create_script_module(
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/home/pverzun/.local/lib/python3.8/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
**RuntimeError: 
iterabletree cannot be used as a value:
  File "/home/pverzun/.local/lib/python3.8/site-packages/transformers/configuration_utils.py", line 387
        if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
            self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
            self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))**
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

Expected behavior

BART is converted to onnx with no issues

@LysandreJik
Copy link
Member

Pinging @michaelbenayoun on the issue :)

@fatcat-z
Copy link
Contributor

@LysandreJik I will take a look at this.

@LysandreJik
Copy link
Member

Thank you, @fatcat-z!

@polly-morphism
Copy link
Author

ok, so it seems like a problem with versions. Works with torch==1.10.0 numpy==1.21.4
onnx==1.10.2
onnxruntime==1.9.0

and latest transformers

@pranavpawar3
Copy link
Contributor

@patil-suraj @LysandreJik

There is another bug(?) in run_onnx_exporter.py script. Line, where dynamic axes are declared, the attention_mask isn't included in the set. Any reason why? 'Cause this hampers inputs of any other size than the onnx sample input.

However, adding the attention_mask object to the dyanmic_inputs set resolves the issue, able to convert+test the model.

Please let me know if this needs to be changed, I can open a PR, or somebody from the HF side can amend the changes instead.

@polly-morphism
Copy link
Author

Even after solving the attention mask issue I still wasn't able to get faster model after converting bart to onnx. Perhaps quantization could help, but like, on the same text I got 6sec on pytorch model GPU and 70sec on onnx optimized graph.

@fatcat-z
Copy link
Contributor

This is was designed as an example of showing how to export BART + Beam Search to ONNX successfully. It doesn't cover all of scenarios. Your PR is appreciated to make it better. Thanks!

@forglin
Copy link

forglin commented Nov 30, 2021

I tested the versions of the major packages. It is determined that upgrading pytorch from 1.8.0 to 1.9.1 can solve this bug.However in 1.9.1 pytorch does not support opset_version 14 and needs to be upgraded to 1.10.0.
I think the version of pytorch in requirement.txt can be modified.

@fatcat-z
Copy link
Contributor

fatcat-z commented Dec 5, 2021

Good catch. Fixed this in #14310

I tested the versions of the major packages. It is determined that upgrading pytorch from 1.8.0 to 1.9.1 can solve this bug.However in 1.9.1 pytorch does not support opset_version 14 and needs to be upgraded to 1.10.0. I think the version of pytorch in requirement.txt can be modified.

Good catch! Fixed in #14310

@lewtun
Copy link
Member

lewtun commented Dec 8, 2021

Hey @polly-morphism @diruoshui, given the PyTorch version fix in #14310 can we now close this issue?

@polly-morphism
Copy link
Author

Hey @polly-morphism @diruoshui, given the PyTorch version fix in #14310 can we now close this issue?

Yes, thank you!

@lewtun lewtun closed this as completed Dec 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants