-
Notifications
You must be signed in to change notification settings - Fork 637
Description
Using a recent IREE (fd8715fd14eeb4b929c4b2d052377e60709b5c82) with its corresponding torch-mlir (7000187)
and llvm-project (a376df0140e67c86a1a48d4ab18ca8a3984b1b0c)
The following MLIR testcase is extracted from a source generated by IREE Turbine from a HuggingFace GPT2 model:
func.func @f(%in : !torch.vtensor<[1,10,2304],f32>) -> !torch.vtensor<[1,10,768],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int10 = torch.constant.int 10
%int768 = torch.constant.int 768
%int2304 = torch.constant.int 2304
%int23040 = torch.constant.int 23040
%lst0 = torch.prim.ListConstruct %int1, %int10, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%lst1 = torch.prim.ListConstruct %int23040, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%out = torch.aten.as_strided %in, %lst0, %lst1, %int0
: !torch.vtensor<[1,10,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,10,768],f32>
return %out : !torch.vtensor<[1,10,768],f32>
}
Compilation with iree-opt --torch-decompose-complex-ops --torch-scalarize-shapes --convert-torch-to-tmtensor --convert-torch-to-tensor --convert-torch-to-linalg repro-1.mlir
result in an assertion failure:
iree-opt: /work/iree/main/third_party/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::Value, From = mlir::OpFoldResult]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0. Program arguments: ./tools/iree-opt --torch-decompose-complex-ops --torch-scalarize-shapes --convert-torch-to-tmtensor --convert-torch-to-tensor --convert-torch-to-linalg repro-1.mlir
#0 0x0000fb2df12f9e64 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0x6929e64)
#1 0x0000fb2df12f7884 llvm::sys::RunSignalHandlers() (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0x6927884)
#2 0x0000fb2df12faa40 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
#3 0x0000fb2dfa579968 (linux-vdso.so.1+0x968)
#4 0x0000fb2dea697608 __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
#5 0x0000fb2dea64cb3c raise ./signal/../sysdeps/posix/raise.c:27:6
#6 0x0000fb2dea637e00 abort ./stdlib/abort.c:81:7
#7 0x0000fb2dea645cc0 __assert_fail_base ./assert/assert.c:93:7
#8 0x0000fb2dea645d30 __assert_perror_fail ./assert/assert-perr.c:31:1
#9 0x0000fb2df70043bc mlir::matchConstantIndex() (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc6343bc)
#10 0x0000fb2df6e92980 mlir::tensor::ExpandShapeOp::inferOutputShape(mlir::OpBuilder&, mlir::Location, mlir::RankedTensorType, llvm::ArrayRef<llvm::SmallVector<long, 2u>>, llvm::ArrayRef<mlir::OpFoldResult>) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc4c2980)
#11 0x0000fb2df6e92e10 mlir::tensor::ExpandShapeOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::Value, llvm::ArrayRef<llvm::SmallVector<long, 2u>>) (/work/iree/main/out/build/release-20/lib/libIREECompiler.so+0xc4c2e10)
#12 0x0000fb2df24a6138 mlir::tensor::ExpandShapeOp mlir::OpBuilder::create<mlir::tensor::ExpandShapeOp, mlir::Type&, mlir::Value, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&>(mlir::Location, mlir::Type&, mlir::Value&&, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&) DataMovement.cpp:0:0
#13 0x0000fb2df24982a4 (anonymous namespace)::ConvertAtenUnflattenIntOp::matchAndRewrite(mlir::torch::Torch::AtenUnflattenIntOp, mlir::torch::Torch::AtenUnflattenIntOpAdaptor, mlir::ConversionPatternRewriter&) const DataMovement.cpp:0:0
...
In the program above, the torch.aten.as_strided
is lowered by DecomposeAtenAsStridedOp
pattern (torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
)
into a sequence containing a few torch.aten.view
operations like
%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>
which in turn are transformed by ScalarizeShapes
pass into problematic torch.aten.unflatten.int
ops, like below:
func.func @f() -> !torch.vtensor<[1,?,1],si64> {
%none = torch.constant.none
%int-1 = torch.constant.int -1
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int10 = torch.constant.int 10
%steps = torch.aten.arange.start_step %int0, %int10, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si64>
%sizes = torch.prim.ListConstruct %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%ret = torch.aten.unflatten.int %steps, %int0, %sizes : !torch.vtensor<[10],si64>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>
return %ret : !torch.vtensor<[1,?,1],si64>
}
Compiling this via iree-opt --convert-torch-to-linalg repro-2.mlir
again yields the above assertion failure/stack trace.
Looking at some of the generated ops like
%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,?,1],si64>
there isn't really a need to generate a tensor type with a dynamic shape. While not incorrect, per se, it's a loss of
information. Fixing the decomposition to emit instead
%8 = torch.aten.view %6, %7 : !torch.vtensor<[10],si64>, !torch.list<int> -> !torch.vtensor<[1,10,1],si64>
helps avoid this issue.