Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-2 torch to linalg #337

Closed
AmosLewis opened this issue Sep 16, 2022 · 6 comments
Closed

GPT-2 torch to linalg #337

AmosLewis opened this issue Sep 16, 2022 · 6 comments
Assignees

Comments

@AmosLewis
Copy link
Contributor

No description provided.

@AmosLewis AmosLewis self-assigned this Sep 16, 2022
@AmosLewis AmosLewis changed the title GPT-2 torch to linalg model support GPT-2 torch to linalg Sep 16, 2022
@AmosLewis
Copy link
Contributor Author

AmosLewis commented Sep 16, 2022

Raised error when Import the TorchScript module to MLIR

  File "torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 233, in compile
        mb.import_module(scripted._c, class_annotator, import_options)
    raise Exception(f"""
Exception: 
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
Unsupported import tensor type: (1,1,.,.)
[ CPUByteType{1,1,1024,1024} ]

@AmosLewis
Copy link
Contributor Author

AmosLewis commented Sep 19, 2022

Raised error when Import the TorchScript module to MLIR

  File "torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 233, in compile
        mb.import_module(scripted._c, class_annotator, import_options)
    raise Exception(f"""
Exception: 
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
Unsupported import tensor type: (1,1,.,.)
[ CPUByteType{1,1,1024,1024} ]

This bug has been fixed by cherry-pick https://github.com/llvm/torch-mlir/issues/1383. But then it will generate a new bug for the pass lowering:

Traceback (most recent call last):
  File "/home/chi/src/ubuntu20/shark/SHARK/tank/gpt2_torch/gpt2.py", line 81, in <module>
    module = torch_mlir.compile(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 247, in compile
    run_pipeline_with_repro_report(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: 'func.call' op operand type mismatch: expected operand type '!torch.float', but provided '!torch.number' for operand number 1
note: see current operation: %1487 = "func.call"(%1482, %182, %1483, %1484, %1485, %1486) {callee = @__torch_mlir_shape_fn.aten.arange.start} : (!torch.float, !torch.number, !torch.optional<int>, !torch.optional<int>, !torch.optional<Device>, !torch.optional<bool>) -> !torch.list<int>


Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints}' /tmp/HfMaskedLM.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

@AmosLewis
Copy link
Contributor Author

AmosLewis commented Sep 21, 2022

The '!torch.float', but provided '!torch.number' bug could be fixed by use the make_fx. @Shukla-Gaurav has fixed this bug. Next step is to fix the view op. It might be fixed by llvm/torch-mlir#1353

torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
note: see current operation: %202 = "torch.aten.view"(%201, %184) : (!torch.vtensor<[1,?],si64>, !torch.list<int>) -> !torch.vtensor<[?,5],si64>


Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torch-backend-to-linalg-on-tensors-backend-pipeline' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

@AmosLewis
Copy link
Contributor Author

AmosLewis commented Sep 26, 2022

The '!torch.float', but provided '!torch.number' bug could be fixed by use the make_fx. @Shukla-Gaurav has fixed this bug. Next step is to fix the view op. It might be fixed by llvm/torch-mlir#1353

torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
note: see current operation: %202 = "torch.aten.view"(%201, %184) : (!torch.vtensor<[1,?],si64>, !torch.list<int>) -> !torch.vtensor<[?,5],si64>


Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torch-backend-to-linalg-on-tensors-backend-pipeline' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

The newest Alex Jakpin's patch should fix the view bug. llvm/torch-mlir#1353. If not, then try Shukla-Gaurav's old patch llvm/torch-mlir@18424b5

@AmosLewis
Copy link
Contributor Author

AmosLewis commented Sep 26, 2022

Now GPT Torch to Linalg should be done.
We just need to wait for Byte bug patch PR llvm/torch-mlir#1384 and View Bug Patch PR llvm/torch-mlir#1353 to be merged in torch-mlir. Then eveything for linalg should be good. Here is tmp py patch for test https://gist.github.com/AmosLewis/d69134cec46de50083d7e50c980ee258

Here is the results compare with shark_inference

GPT2 Torch Golden OUTPUT:
tensor([[[ -31.8388,  -30.9854,  -34.4231,  ...,  -39.7515,  -38.6848,
           -32.3074],
         [ -99.2055,  -98.8202, -104.2251,  ..., -112.2020, -109.0224,
          -100.2584],
         [-115.6919, -116.9150, -119.1486,  ..., -124.9616, -123.2126,
          -116.6671],
         [-123.0994, -123.1445, -128.7349,  ..., -130.6248, -130.6557,
          -125.1285],
         [ -80.2680,  -81.8277,  -89.0646,  ...,  -94.5047,  -96.1721,
           -83.7583]]], grad_fn=<UnsafeViewBackward0>)
GPT2 Torch to Linalg SharkInference OUTPUT:
tensor([[[ -31.8388,  -30.9854,  -34.4230,  ...,  -39.7514,  -38.6848,
           -32.3073],
         [ -99.2055,  -98.8202, -104.2252,  ..., -112.2021, -109.0225,
          -100.2584],
         [-115.6919, -116.9149, -119.1486,  ..., -124.9616, -123.2126,
          -116.6672],
         [-123.0994, -123.1446, -128.7350,  ..., -130.6248, -130.6558,
          -125.1285],
         [ -80.2680,  -81.8276,  -89.0646,  ...,  -94.5047,  -96.1721,
           -83.7583]]])

@nithinsubbiah
Copy link
Contributor

@AmosLewis Have we got all patches merged for GPT-2 lowering to Linalg. Can we close this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants