Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug TorchScript error from fastNLP_Bert #93742

Closed
jansel opened this issue Mar 23, 2022 · 5 comments
Closed

Debug TorchScript error from fastNLP_Bert #93742

jansel opened this issue Mar 23, 2022 · 5 comments
Assignees

Comments

@jansel
Copy link
Contributor

jansel commented Mar 23, 2022

$ ./torchbench.py --backend=ofi -n1 -k fastNLP_Bert
...
ERROR:torchdynamo.utils:jit error
Traceback (most recent call last):
  File "/home/jansel/torchdynamo/torchdynamo/utils.py", line 192, in torchscript
    return torch.jit.trace(model, example_inputs)
  File "/home/jansel/pytorch/torch/jit/_trace.py", line 741, in trace
    return trace_module(
  File "/home/jansel/pytorch/torch/jit/_trace.py", line 958, in trace_module
    module._c._create_method_from_trace(
RuntimeError: 0 INTERNAL ASSERT FAILED at "/home/jansel/pytorch/torch/csrc/jit/ir/alias_analysis.cpp":607, please report a bug to PyTorch. We don't have an op for aten::eq but it isn't a special case.  Argument types: Tensor, bool, 

Candidates:
	aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
	aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
	aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> (Tensor(a!))
	aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> (Tensor(a!))
	aten::eq.int_list(int[] a, int[] b) -> (bool)
	aten::eq.device(Device a, Device b) -> (bool)
	aten::eq.bool(bool a, bool b) -> (bool)
	aten::eq.enum(AnyEnumType a, AnyEnumType b) -> (bool)
	aten::eq.int(int a, int b) -> (bool)
	aten::eq.complex(complex a, complex b) -> (bool)
	aten::eq.float(float a, float b) -> (bool)
	aten::eq.int_float(int a, float b) -> (bool)
	aten::eq.float_int(float a, int b) -> (bool)
	aten::eq.float_complex(float a, complex b) -> (bool)
	aten::eq.complex_float(complex a, float b) -> (bool)
	aten::eq(Scalar a, Scalar b) -> (bool)
	aten::eq.str(str a, str b) -> (bool)
	aten::eq.float_list(float[] a, float[] b) -> (bool)
	aten::eq.Tensor_list(Tensor[] a, Tensor[] b) -> (bool)
	aten::eq.bool_list(bool[] a, bool[] b) -> (bool)
	aten::eq.str_list(str[] a, str[] b) -> (bool)

Later on there is the error (when jit.trace fails, it tries jit.script)

Traceback (most recent call last):
  File "/home/jansel/torchdynamo/torchdynamo/utils.py", line 197, in torchscript
    return torch.jit.script(model)
  File "/home/jansel/pytorch/torch/jit/_script.py", line 1266, in script
    return torch.jit._recursive.create_script_module(
  File "/home/jansel/pytorch/torch/jit/_recursive.py", line 454, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/jansel/pytorch/torch/jit/_recursive.py", line 520, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/home/jansel/pytorch/torch/jit/_recursive.py", line 371, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
Unknown builtin op: aten::LongTensor.
Could not find any similar ops to aten::LongTensor. This op may not exist or may not be currently supported in TorchScript.

These sounds like just missing ops? So perhaps easy to fix? They prevent all TorchScript-based backends from working on these models.

@eellison do these look familar to you?

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@eellison
Copy link
Contributor

You can assign this to me

eellison referenced this issue Mar 28, 2022
… in JIT"


Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter. 



[ghstack-poisoned]
eellison referenced this issue Mar 28, 2022
Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter. 



[ghstack-poisoned]
eellison referenced this issue Mar 28, 2022
… in JIT"


Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter. 



[ghstack-poisoned]
eellison referenced this issue Mar 28, 2022
Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter. 



[ghstack-poisoned]
eellison referenced this issue Mar 31, 2022
… in JIT"


Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter. 



[ghstack-poisoned]
eellison referenced this issue Mar 31, 2022
Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter. 



[ghstack-poisoned]
eellison referenced this issue Apr 5, 2022
… in JIT"


Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter.

Differential Revision: [D35362762](https://our.internmc.facebook.com/intern/diff/D35362762)

[ghstack-poisoned]
eellison referenced this issue Apr 5, 2022
Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter.

Differential Revision: [D35362762](https://our.internmc.facebook.com/intern/diff/D35362762)

[ghstack-poisoned]
facebook-github-bot referenced this issue Apr 6, 2022
Summary:
Pull Request resolved: #74785

Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D35362762

Pulled By: eellison

fbshipit-source-id: 960badf08ba2ab0818af5fd331aff3542051250f
pytorchmergebot referenced this issue Apr 6, 2022
Summary:
Pull Request resolved: #74785

Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D35362762

Pulled By: eellison

fbshipit-source-id: 960badf08ba2ab0818af5fd331aff3542051250f
(cherry picked from commit bd579de)
@eellison eellison closed this as completed Apr 6, 2022
@anijain2305
Copy link
Contributor

@eellison I still see this error after getting your latest PR.

The error signature looks like this. Reopening for now.


  eq(float a, Tensor b) -> (Tensor):
  Expected a value of type 'float' for argument 'a' but instead found type 'bool'.

  eq(int a, Tensor b) -> (Tensor):
  Expected a value of type 'int' for argument 'a' but instead found type 'bool'.

  eq(complex a, Tensor b) -> (Tensor):
  Expected a value of type 'complex' for argument 'a' but instead found type 'bool'.

The original call is:
  File "<eval_with_key>.1", line 9
    self_word_pieces_lengths = self.self_word_pieces_lengths
    getitem = self_word_pieces_lengths[words];  self_word_pieces_lengths = words = None
    eq = ne.eq(False)
         ~~~~~ <--- HERE
    masked_fill = getitem.masked_fill(eq, 0);  getitem = eq = None
    sum_2 = masked_fill.sum(dim = -1)

@anijain2305 anijain2305 reopened this Apr 7, 2022
@eellison
Copy link
Contributor

eellison commented Apr 7, 2022

@anijain2305 can you post the full error?

@anijain2305
Copy link
Contributor

@eellison This is the whole error and stack trace

Traceback (most recent call last):
  File "torchbench.py", line 999, in run_one_model
    new_result = model_iter_fn(model, example_inputs)
  File "torchbench.py", line 483, in forward_and_backward_pass
    def forward_and_backward_pass(mod, inputs, collect_outputs=True):
  File "torchbench.py", line 483, in forward_and_backward_pass
    def forward_and_backward_pass(mod, inputs, collect_outputs=True):
  File "torchbench.py", line 483, in forward_and_backward_pass
    def forward_and_backward_pass(mod, inputs, collect_outputs=True):
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/fastNLP/models/bert.py", line 256, in forward
    def forward(self, words):
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/fastNLP/embeddings/bert_embedding.py", line 125, in forward
    def forward(self, words):
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/fastNLP/embeddings/bert_embedding.py", line 432, in forward
    def forward(self, words):
  File "/fsx/users/anijain/torchdynamo/torchdynamo/eval_frame.py", line 58, in _fn
    return fn(*args, **kwargs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/users/anijain/functorch/functorch/_src/aot_autograd.py", line 603, in forward
    return compiled_f(
  File "/fsx/users/anijain/torchdynamo/torchdynamo/eval_frame.py", line 58, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/anijain/functorch/functorch/_src/aot_autograd.py", line 162, in forward
    compiled_fw = fw_compiler(fw_module, flat_tensor_args)
  File "/fsx/users/anijain/functorch/functorch/_src/compilers.py", line 70, in ts_compile
    f = torch.jit.script(fx_g)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/jit/_script.py", line 1266, in script
    return torch.jit._recursive.create_script_module(
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/jit/_recursive.py", line 454, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/jit/_recursive.py", line 520, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/jit/_recursive.py", line 371, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Arguments for call are not valid.
The following variants are available:

  aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'other' but instead found type 'bool'.

  aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor):
  Expected a value of type 'number' for argument 'other' but instead found type 'bool'.

  aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> (Tensor(a!)):
  Expected a value of type 'number' for argument 'other' but instead found type 'bool'.

  aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> (Tensor(a!)):
  Expected a value of type 'Tensor' for argument 'other' but instead found type 'bool'.

  aten::eq.int_list(int[] a, int[] b) -> (bool):
  Expected a value of type 'List[int]' for argument 'a' but instead found type 'Tensor'.

  aten::eq.device(Device a, Device b) -> (bool):
  Expected a value of type 'Device' for argument 'a' but instead found type 'Tensor'.

  aten::eq.bool(bool a, bool b) -> (bool):
  Expected a value of type 'bool' for argument 'a' but instead found type 'Tensor'.

  aten::eq.enum(AnyEnumType a, AnyEnumType b) -> (bool):
  Expected a value of type 'AnyEnumType' for argument 'a' but instead found type 'Tensor'.

  aten::eq.int(int a, int b) -> (bool):
  Expected a value of type 'int' for argument 'b' but instead found type 'bool'.

  aten::eq.complex(complex a, complex b) -> (bool):
  Expected a value of type 'complex' for argument 'b' but instead found type 'bool'.

  aten::eq.float(float a, float b) -> (bool):
  Expected a value of type 'float' for argument 'b' but instead found type 'bool'.

  aten::eq.int_float(int a, float b) -> (bool):
  Expected a value of type 'float' for argument 'b' but instead found type 'bool'.

  aten::eq.float_int(float a, int b) -> (bool):
  Expected a value of type 'int' for argument 'b' but instead found type 'bool'.

  aten::eq.float_complex(float a, complex b) -> (bool):
  Expected a value of type 'complex' for argument 'b' but instead found type 'bool'.

  aten::eq.complex_float(complex a, float b) -> (bool):
  Expected a value of type 'float' for argument 'b' but instead found type 'bool'.

  aten::eq(Scalar a, Scalar b) -> (bool):
  Expected a value of type 'number' for argument 'b' but instead found type 'bool'.

  aten::eq.str(str a, str b) -> (bool):
  Expected a value of type 'str' for argument 'a' but instead found type 'Tensor'.

  aten::eq.float_list(float[] a, float[] b) -> (bool):
  Expected a value of type 'List[float]' for argument 'a' but instead found type 'Tensor'.

  aten::eq.Tensor_list(Tensor[] a, Tensor[] b) -> (bool):
  Expected a value of type 'List[Tensor]' for argument 'a' but instead found type 'Tensor'.

  aten::eq.bool_list(bool[] a, bool[] b) -> (bool):
  Expected a value of type 'List[bool]' for argument 'a' but instead found type 'Tensor'.

  aten::eq.str_list(str[] a, str[] b) -> (bool):
  Expected a value of type 'List[str]' for argument 'a' but instead found type 'Tensor'.

  eq(float a, Tensor b) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'bool'.

  eq(int a, Tensor b) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'bool'.

  eq(complex a, Tensor b) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'bool'.

The original call is:
  File "<eval_with_key>.9", line 8
    sum_1 = torch.ops.aten.sum(ne, [-1])
    index = torch.ops.aten.index(primals_1, [primals_2]);  primals_1 = primals_2 = None
    eq = torch.ops.aten.eq(ne, False)
         ~~~~~~~~~~~~~~~~~ <--- HERE
    masked_fill = torch.ops.aten.masked_fill(index, eq, 0);  index = eq = None
    sum_2 = torch.ops.aten.sum(masked_fill, [-1])

ERROR

@eellison
Copy link
Contributor

eellison commented Apr 7, 2022

@anijain2305 filed #75465

@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@ngimel ngimel closed this as completed Feb 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants