Skip to content

[vit][compile]error: unknown: unsupported by backend contract: module initializers,How to lower vit_b_16 to linalg? #2469

@followtheart

Description

@followtheart

When I compile vit model into LINALG_ON_TENSORS,got this error;any help?

import numpy as np
import torch
import torchvision.models as models
import torch_mlir

model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
input_batch = torch.randn(1,3,224,224)
vit_mod = torch_mlir.compile(model, input_batch, 
                            output_type=torch_mlir.OutputType.LINALG_ON_TENSORS, 
                            use_tracing=False)

Error info:

Traceback (most recent call last):
  File "/home/jq/code/demo/torch_mlir-vit.py", line 9, in <module>
    vit_mod = torch_mlir.compile(model, input_batch, 
  File "/home/jq/code/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 451, in compile
    run_pipeline_with_repro_report(
  File "/home/jq/code/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline:
error: unknown: unsupported by backend contract: module initializers
note: unknown: see current operation: "torch.initialize.global_slots"(%20, %21, %22, %23, %32, %33, %34, %35, %44, %45, %46, %47, %56, %57, %58, %59, %68, %69, %70, %71, %80, %81, %82, %83, %92, %93, %94, %95, %104, %105, %106, %107, %116, %117, %118, %119, %128, %129, %130, %131, %140, %141, %142, %143, %152, %153, %154, %155) <{slotSymNames = [@encoder.layers.encoder_layer_0.self_attention.in_proj_weight, @encoder.layers.encoder_layer_0.self_attention.in_proj_bias, @encoder.layers.encoder_layer_0.self_attention.out_proj.weight, @encoder.layers.encoder_layer_0.self_attention.out_proj.bias, @encoder.layers.encoder_layer_1.self_attention.in_proj_weight, @encoder.layers.encoder_layer_1.self_attention.in_proj_bias, @encoder.layers.encoder_layer_1.self_attention.out_proj.weight, @encoder.layers.encoder_layer_1.self_attention.out_proj.bias, @encoder.layers.encoder_layer_2.self_attention.in_proj_weight, @encoder.layers.encoder_layer_2.self_attention.in_proj_bias, @encoder.layers.encoder_layer_2.self_attention.out_proj.weight, @encoder.layers.encoder_layer_2.self_attention.out_proj.bias, @encoder.layers.encoder_layer_3.self_attention.in_proj_weight, @encoder.layers.encoder_layer_3.self_attention.in_proj_bias, @encoder.layers.encoder_layer_3.self_attention.out_proj.weight, @encoder.layers.encoder_layer_3.self_attention.out_proj.bias, @encoder.layers.encoder_layer_4.self_attention.in_proj_weight, @encoder.layers.encoder_layer_4.self_attention.in_proj_bias, @encoder.layers.encoder_layer_4.self_attention.out_proj.weight, @encoder.layers.encoder_layer_4.self_attention.out_proj.bias, @encoder.layers.encoder_layer_5.self_attention.in_proj_weight, @encoder.layers.encoder_layer_5.self_attention.in_proj_bias, @encoder.layers.encoder_layer_5.self_attention.out_proj.weight, @encoder.layers.encoder_layer_5.self_attention.out_proj.bias, @encoder.layers.encoder_layer_6.self_attention.in_proj_weight, @encoder.layers.encoder_layer_6.self_attention.in_proj_bias, @encoder.layers.encoder_layer_6.self_attention.out_proj.weight, @encoder.layers.encoder_layer_6.self_attention.out_proj.bias, @encoder.layers.encoder_layer_7.self_attention.in_proj_weight, @encoder.layers.encoder_layer_7.self_attention.in_proj_bias, @encoder.layers.encoder_layer_7.self_attention.out_proj.weight, @encoder.layers.encoder_layer_7.self_attention.out_proj.bias, @encoder.layers.encoder_layer_8.self_attention.in_proj_weight, @encoder.layers.encoder_layer_8.self_attention.in_proj_bias, @encoder.layers.encoder_layer_8.self_attention.out_proj.weight, @encoder.layers.encoder_layer_8.self_attention.out_proj.bias, @encoder.layers.encoder_layer_9.self_attention.in_proj_weight, @encoder.layers.encoder_layer_9.self_attention.in_proj_bias, @encoder.layers.encoder_layer_9.self_attention.out_proj.weight, @encoder.layers.encoder_layer_9.self_attention.out_proj.bias, @encoder.layers.encoder_layer_10.self_attention.in_proj_weight, @encoder.layers.encoder_layer_10.self_attention.in_proj_bias, @encoder.layers.encoder_layer_10.self_attention.out_proj.weight, @encoder.layers.encoder_layer_10.self_attention.out_proj.bias, @encoder.layers.encoder_layer_11.self_attention.in_proj_weight, @encoder.layers.encoder_layer_11.self_attention.in_proj_bias, @encoder.layers.encoder_layer_11.self_attention.out_proj.weight, @encoder.layers.encoder_layer_11.self_attention.out_proj.bias]}> : (!torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>, !torch.tensor<[2304,768],f32>, !torch.tensor<[2304],f32>, !torch.tensor<[768,768],f32>, !torch.tensor<[768],f32>) -> ()
note: unknown: this is likely due to InlineGlobalSlots being unable to inline a global slot

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints extra-library=})' /tmp/VisionTransformer.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions