-
Notifications
You must be signed in to change notification settings - Fork 639
Open
Description
When I compile vit model into LINALG_ON_TENSORS
,got this error;any help?
import numpy as np
import torch
import torchvision.models as models
import torch_mlir
model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
input_batch = torch.randn(1,3,224,224)
vit_mod = torch_mlir.compile(model, input_batch,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False)
Error info:
Traceback (most recent call last):
File "/home/jq/code/demo/torch_mlir-vit.py", line 9, in <module>
vit_mod = torch_mlir.compile(model, input_batch,
File "/home/jq/code/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 451, in compile
run_pipeline_with_repro_report(
File "/home/jq/code/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline:
error: unknown: unsupported by backend contract: module initializers
note: unknown: see current operation: "torch.initialize.global_slots"(%20, %21, %22, %23, %32, %33, %34, %35, %44, %45, %46, %47, %56, %57, %58, %59, %68, %69, %70, %71, %80, %81, %82, %83, %92, %93, %94, %95, %104, %105, %106, %107, %116, %117, %118, %119, %128, %129, %130, %131, %140, %141, %142, %143, %152, %153, %154, %155) <{slotSymNames = [@encoder.layers.encoder_layer_0.self_attention.in_proj_weight, @encoder.layers.encoder_layer_0.self_attention.in_proj_bias, @encoder.layers.encoder_layer_0.self_attention.out_proj.weight, @encoder.layers.encoder_layer_0.self_attention.out_proj.bias, @encoder.layers.encoder_layer_1.self_attention.in_proj_weight, @encoder.layers.encoder_layer_1.self_attention.in_proj_bias, @encoder.layers.encoder_layer_1.self_attention.out_proj.weight, @encoder.layers.encoder_layer_1.self_attention.out_proj.bias, @encoder.layers.encoder_layer_2.self_attention.in_proj_weight, @encoder.layers.encoder_layer_2.self_attention.in_proj_bias, @encoder.layers.encoder_layer_2.self_attention.out_proj.weight, @encoder.layers.encoder_layer_2.self_attention.out_proj.bias, @encoder.layers.encoder_layer_3.self_attention.in_proj_weight, @encoder.layers.encoder_layer_3.self_attention.in_proj_bias, @encoder.layers.encoder_layer_3.self_attention.out_proj.weight, @encoder.layers.encoder_layer_3.self_attention.out_proj.bias, @encoder.layers.encoder_layer_4.self_attention.in_proj_weight, @encoder.layers.encoder_layer_4.self_attention.in_proj_bias, @encoder.layers.encoder_layer_4.self_attention.out_proj.weight, @encoder.layers.encoder_layer_4.self_attention.out_proj.bias, @encoder.layers.encoder_layer_5.self_attention.in_proj_weight, @encoder.layers.encoder_layer_5.self_attention.in_proj_bias, @encoder.layers.encoder_layer_5.self_attention.out_proj.weight, @encoder.layers.encoder_layer_5.self_attention.out_proj.bias, @encoder.layers.encoder_layer_6.self_attention.in_proj_weight, @encoder.layers.encoder_layer_6.self_attention.in_proj_bias, @encoder.layers.encoder_layer_6.self_attention.out_proj.weight, @encoder.layers.encoder_layer_6.self_attention.out_proj.bias, @encoder.layers.encoder_layer_7.self_attention.in_proj_weight, @encoder.layers.encoder_layer_7.self_attention.in_proj_bias, @encoder.layers.encoder_layer_7.self_attention.out_proj.weight, @encoder.layers.encoder_layer_7.self_attention.out_proj.bias, @encoder.layers.encoder_layer_8.self_attention.in_proj_weight, @encoder.layers.encoder_layer_8.self_attention.in_proj_bias, @encoder.layers.encoder_layer_8.self_attention.out_proj.weight, @encoder.layers.encoder_layer_8.self_attention.out_proj.bias, @encoder.layers.encoder_layer_9.self_attention.in_proj_weight, @encoder.layers.encoder_layer_9.self_attention.in_proj_bias, @encoder.layers.encoder_layer_9.self_attention.out_proj.weight, @encoder.layers.encoder_layer_9.self_attention.out_proj.bias, @encoder.layers.encoder_layer_10.self_attention.in_proj_weight, @encoder.layers.encoder_layer_10.self_attention.in_proj_bias, @encoder.layers.encoder_layer_10.self_attention.out_proj.weight, @encoder.layers.encoder_layer_10.self_attention.out_proj.bias, @encoder.layers.encoder_layer_11.self_attention.in_proj_weight, @encoder.layers.encoder_layer_11.self_attention.in_proj_bias, @encoder.layers.encoder_layer_11.self_attention.out_proj.weight, @encoder.layers.encoder_layer_11.self_attention.out_proj.bias]}> : (!torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>) -> ()
note: unknown: this is likely due to InlineGlobalSlots being unable to inline a global slot
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints extra-library=})' /tmp/VisionTransformer.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
Metadata
Metadata
Assignees
Labels
No labels