Skip to content

Decomposing aten.as_strided causes assertion failure in MLIR #4315

@momchil-velikov

Description

@momchil-velikov

Using a recent IREE (fd8715fd14eeb4b929c4b2d052377e60709b5c82) with its corresponding torch-mlir (7000187)
and llvm-project (a376df0140e67c86a1a48d4ab18ca8a3984b1b0c)

The following MLIR testcase is extracted from a source generated by IREE Turbine from a HuggingFace GPT2 model:

  func.func @f(%in : !torch.vtensor<[1,10,2304],f32>) -> !torch.vtensor<[1,10,768],f32> {
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int10 = torch.constant.int 10
    %int768 = torch.constant.int 768
    %int2304 = torch.constant.int 2304
    %int23040 = torch.constant.int 23040

    %lst0 = torch.prim.ListConstruct %int1, %int10, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %lst1 = torch.prim.ListConstruct %int23040, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

    %out = torch.aten.as_strided %in, %lst0, %lst1, %int0
      : !torch.vtensor<[1,10,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,10,768],f32>

    return %out : !torch.vtensor<[1,10,768],f32>
  }

Compilation with iree-opt --torch-decompose-complex-ops --torch-scalarize-shapes --convert-torch-to-tmtensor --convert-torch-to-tensor --convert-torch-to-linalg repro-1.mlir
result in an assertion failure:

iree-opt: /work/iree/main/third_party/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::Value, From = mlir::OpFoldResult]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.      Program arguments: ./tools/iree-opt --torch-decompose-complex-ops --torch-scalarize-shapes --convert-torch-to-tmtensor --convert-torch-to-tensor --convert-torch-to-linalg repro-1.mlir
 #0 0x0000fb2df12f9e64 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0x6929e64)
 #1 0x0000fb2df12f7884 llvm::sys::RunSignalHandlers() (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0x6927884)
 #2 0x0000fb2df12faa40 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
 #3 0x0000fb2dfa579968 (linux-vdso.so.1+0x968)
 #4 0x0000fb2dea697608 __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x0000fb2dea64cb3c raise ./signal/../sysdeps/posix/raise.c:27:6
 #6 0x0000fb2dea637e00 abort ./stdlib/abort.c:81:7
 #7 0x0000fb2dea645cc0 __assert_fail_base ./assert/assert.c:93:7
 #8 0x0000fb2dea645d30 __assert_perror_fail ./assert/assert-perr.c:31:1
 #9 0x0000fb2df70043bc mlir::matchConstantIndex() (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc6343bc)
#10 0x0000fb2df6e92980 mlir::tensor::ExpandShapeOp::inferOutputShape(mlir::OpBuilder&, mlir::Location, mlir::RankedTensorType, llvm::ArrayRef<llvm::SmallVector<long, 2u>>, llvm::ArrayRef<mlir::OpFoldResult>) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc4c2980)
#11 0x0000fb2df6e92e10 mlir::tensor::ExpandShapeOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::Value, llvm::ArrayRef<llvm::SmallVector<long, 2u>>) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc4c2e10)
#12 0x0000fb2df24a6138 mlir::tensor::ExpandShapeOp mlir::OpBuilder::create<mlir::tensor::ExpandShapeOp, mlir::Type&, mlir::Value, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&>(mlir::Location, mlir::Type&, mlir::Value&&, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&) DataMovement.cpp:0:0
#13 0x0000fb2df24982a4 (anonymous namespace)::ConvertAtenUnflattenIntOp::matchAndRewrite(mlir::torch::Torch::AtenUnflattenIntOp, mlir::torch::Torch::AtenUnflattenIntOpAdaptor, mlir::ConversionPatternRewriter&) const DataMovement.cpp:0:0
...

In the program above, the torch.aten.as_strided is lowered by DecomposeAtenAsStridedOp pattern (torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp)
into a sequence containing a few torch.aten.view operations like

  %8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>

which in turn are transformed by ScalarizeShapes pass into problematic torch.aten.unflatten.int ops, like below:

func.func @f() -> !torch.vtensor<[1,?,1],si64> {
  %none = torch.constant.none
  %int-1 = torch.constant.int -1
  %int0 = torch.constant.int 0
  %int1 = torch.constant.int 1
  %int10 = torch.constant.int 10

  %steps = torch.aten.arange.start_step %int0, %int10, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si64>
  %sizes = torch.prim.ListConstruct %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

  %ret = torch.aten.unflatten.int %steps, %int0, %sizes : !torch.vtensor<[10],si64>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>
  return %ret : !torch.vtensor<[1,?,1],si64>
}

Compiling this via iree-opt --convert-torch-to-linalg repro-2.mlir again yields the above assertion failure/stack trace.

Looking at some of the generated ops like

%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>

there isn't really a need to generate a tensor type with a dynamic shape. While not incorrect, per se, it's a loss of
information. Fixing the decomposition to emit instead

%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,10,1],si64>

helps avoid this issue.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions