diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index e22fc1d478e4f..a15e19b24e54b 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -288,12 +288,6 @@ def SelectI1ToNot : // IndexCastOp //===----------------------------------------------------------------------===// -// index_cast(index_cast(x)) -> x, if dstType == srcType. -def IndexCastOfIndexCast : - Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)), - (replaceWithValue $x), - [(Constraint> $res, $x)]>; - // index_cast(extsi(x)) -> index_cast(x) def IndexCastOfExtSI : Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>; @@ -302,12 +296,6 @@ def IndexCastOfExtSI : // IndexCastUIOp //===----------------------------------------------------------------------===// -// index_castui(index_castui(x)) -> x, if dstType == srcType. -def IndexCastUIOfIndexCastUI : - Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2), - (replaceWithValue $x), - [(Constraint> $res, $x)]>; - // index_castui(extui(x)) -> index_castui(x) def IndexCastUIOfExtUI : Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2), diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 5f10a94522350..569d1869a5abe 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1909,6 +1909,15 @@ OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { // IndexCastOp //===----------------------------------------------------------------------===// +/// Return the bit-width of \p t for the purpose of index_cast width checks. +/// For vector types use the element type; index maps to its internal storage +/// width (64 on all current targets). +static unsigned getIndexCastWidth(Type t) { + if (auto intTy = dyn_cast(getElementTypeOrSelf(t))) + return intTy.getWidth(); + return IndexType::kInternalStorageBitWidth; +} + static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) { if (!areValidCastInputsAndOutputs(inputs, outputs)) return false; @@ -1933,16 +1942,29 @@ OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { if (auto intTy = dyn_cast(getElementTypeOrSelf(getType()))) resultBitwidth = intTy.getWidth(); - return constFoldCastOp( - adaptor.getOperands(), getType(), - [resultBitwidth](const APInt &a, bool & /*castStatus*/) { - return a.sextOrTrunc(resultBitwidth); - }); + if (auto foldResult = constFoldCastOp( + adaptor.getOperands(), getType(), + [resultBitwidth](const APInt &a, bool & /*castStatus*/) { + return a.sextOrTrunc(resultBitwidth); + })) + return foldResult; + + // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least + // as wide as A. If B is narrower, the inner cast truncates and the outer + // cast sign-extends, so the round-trip is lossy. + if (auto inner = getOperand().getDefiningOp()) { + Value x = inner.getOperand(); + if (x.getType() == getType()) { + if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType())) + return x; + } + } + return {}; } void arith::IndexCastOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// @@ -1960,16 +1982,29 @@ OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { if (auto intTy = dyn_cast(getElementTypeOrSelf(getType()))) resultBitwidth = intTy.getWidth(); - return constFoldCastOp( - adaptor.getOperands(), getType(), - [resultBitwidth](const APInt &a, bool & /*castStatus*/) { - return a.zextOrTrunc(resultBitwidth); - }); + if (auto foldResult = constFoldCastOp( + adaptor.getOperands(), getType(), + [resultBitwidth](const APInt &a, bool & /*castStatus*/) { + return a.zextOrTrunc(resultBitwidth); + })) + return foldResult; + + // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at + // least as wide as A. If B is narrower, the inner cast truncates and the + // outer cast zero-extends, so the round-trip is lossy. + if (auto inner = getOperand().getDefiningOp()) { + Value x = inner.getOperand(); + if (x.getType() == getType()) { + if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType())) + return x; + } + } + return {}; } void arith::IndexCastUIOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 319dfc31ab637..e45adb7287ac4 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -693,8 +693,6 @@ func.func @arith_index_cast(%arg0: i32) -> i32 { // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to !emitc.ptrdiff_t // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : !emitc.ptrdiff_t to !emitc.size_t %idx = arith.index_cast %arg0 : i32 to index - // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to !emitc.ptrdiff_t - // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : !emitc.ptrdiff_t to i32 %int = arith.index_cast %idx : index to i32 // CHECK: %[[Const:.*]] = "emitc.constant" @@ -704,6 +702,7 @@ func.func @arith_index_cast(%arg0: i32) -> i32 { // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1 %bool = arith.index_cast %idx : index to i1 + // CHECK: return %[[Arg0]] : i32 return %int : i32 } @@ -715,8 +714,6 @@ func.func @arith_index_castui(%arg0: i32) -> i32 { // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to !emitc.size_t %idx = arith.index_castui %arg0 : i32 to index - // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to ui32 - // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : ui32 to i32 %int = arith.index_castui %idx : index to i32 // CHECK: %[[Const:.*]] = "emitc.constant" @@ -726,6 +723,7 @@ func.func @arith_index_castui(%arg0: i32) -> i32 { // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1 %bool = arith.index_castui %idx : index to i1 + // CHECK: return %[[Arg0]] : i32 return %int : i32 } diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir index bf1e8580a5b76..497574af2a2d8 100644 --- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir +++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir @@ -237,12 +237,9 @@ func.func @index_cast_2d(%arg0: vector<1x2x3xi1>) { // CHECK: %[[SEXT2:.*]] = llvm.sext %[[EXTRACT2]] : vector<3xi1> to vector<3xi{{.*}}> // CHECK: %[[INSERT2:.*]] = llvm.insertvalue %[[SEXT2]], %[[INSERT1]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> %0 = arith.index_cast %arg0: vector<1x2x3xi1> to vector<1x2x3xindex> - // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[INSERT2]][0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> - // CHECK: %[[TRUNC1:.*]] = llvm.trunc %[[EXTRACT3]] : vector<3xi{{.*}}> to vector<3xi1> - // CHECK: %[[INSERT3:.*]] = llvm.insertvalue %[[TRUNC1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>> - // CHECK: %[[EXTRACT4:.*]] = llvm.extractvalue %[[INSERT2]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> - // CHECK: %[[TRUNC2:.*]] = llvm.trunc %[[EXTRACT4]] : vector<3xi{{.*}}> to vector<3xi1> - // CHECK: %[[INSERT4:.*]] = llvm.insertvalue %[[TRUNC2]], %[[INSERT3]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>> + // The back-cast folds away: index_cast(index_cast(x:i1):index):i1 -> x + // because index (64-bit) is wider than i1, so the round-trip is lossless. + // CHECK-NOT: llvm.trunc %1 = arith.index_cast %0: vector<1x2x3xindex> to vector<1x2x3xi1> return } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 18665e2eb6f4a..ee3e713f8481e 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -724,6 +724,113 @@ func.func @indexCastUIFoldVectorIndexToInt() -> vector<3xi32> { return %int : vector<3xi32> } +// CHECK-LABEL: @indexCastOfIndexCast_lossless +// The intermediate index type (64 bits) is at least as wide as i64 (64 bits), +// so the round-trip is lossless and the chain folds away. +// CHECK: return %arg0 +func.func @indexCastOfIndexCast_lossless(%arg0: i64) -> i64 { + %0 = arith.index_cast %arg0 : i64 to index + %1 = arith.index_cast %0 : index to i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @indexCastOfIndexCast_lossy +// The intermediate i8 type (8 bits) is narrower than index (64 bits), so +// folding would drop the truncation — must be preserved. +// CHECK: %[[a:.+]] = arith.index_cast %arg0 : index to i8 +// CHECK: %[[b:.+]] = arith.index_cast %[[a]] : i8 to index +// CHECK: return %[[b]] +func.func @indexCastOfIndexCast_lossy(%arg0: index) -> index { + %0 = arith.index_cast %arg0 : index to i8 + %1 = arith.index_cast %0 : i8 to index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @indexCastUIOfIndexCastUI_lossless +// The intermediate index type is at least as wide as i64, so the chain folds. +// CHECK: return %arg0 +func.func @indexCastUIOfIndexCastUI_lossless(%arg0: i64) -> i64 { + %0 = arith.index_castui %arg0 : i64 to index + %1 = arith.index_castui %0 : index to i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @indexCastUIOfIndexCastUI_lossy +// The intermediate i8 is narrower than index, so the truncation must be kept. +// CHECK: %[[a:.+]] = arith.index_castui %arg0 : index to i8 +// CHECK: %[[b:.+]] = arith.index_castui %[[a]] : i8 to index +// CHECK: return %[[b]] +func.func @indexCastUIOfIndexCastUI_lossy(%arg0: index) -> index { + %0 = arith.index_castui %arg0 : index to i8 + %1 = arith.index_castui %0 : i8 to index + return %1 : index +} + +// ----- + +// CHECK-LABEL: @indexCastUIOfIndexCastUI_3way_lossy +// Regression test for the original bug: a 3-element chain where the outermost +// cast pair would be incorrectly folded away, dropping the i8 truncation. +// CHECK: %[[a:.*]] = arith.index_castui %arg0 : i64 to index +// CHECK: %[[b:.*]] = arith.index_castui %[[a]] : index to i8 +// CHECK: %[[c:.*]] = arith.index_castui %[[b]] : i8 to index +// CHECK: return %[[c]] +func.func @indexCastUIOfIndexCastUI_3way_lossy(%arg0: i64) -> index { + %0 = arith.index_castui %arg0 : i64 to index + %1 = arith.index_castui %0 : index to i8 + %2 = arith.index_castui %1 : i8 to index + return %2 : index +} + +// ----- + +// CHECK-LABEL: @indexCastOfIndexCast_3way_lossy +// Signed 3-way chain where the outermost pair folds (i64->index is lossless +// since 64 >= 64) but the inner i8 truncation is preserved. The net result +// is that %2 becomes %0 directly, collapsing the last two casts. +// CHECK: %[[a:.*]] = arith.index_cast %arg0 : i8 to index +// CHECK: return %[[a]] +func.func @indexCastOfIndexCast_3way_lossy(%arg0: i8) -> index { + %0 = arith.index_cast %arg0 : i8 to index + %1 = arith.index_cast %0 : index to i64 + %2 = arith.index_cast %1 : i64 to index + return %2 : index +} + +// ----- + +// CHECK-LABEL: @indexCastOfIndexCast_i8_roundtrip +// i8 -> index -> i8: the intermediate index is at least as wide as i8 (64 >= 8), +// so the round-trip is lossless and the chain folds away. +// CHECK: return %arg0 +func.func @indexCastOfIndexCast_i8_roundtrip(%arg0: i8) -> i8 { + %0 = arith.index_cast %arg0 : i8 to index + %1 = arith.index_cast %0 : index to i8 + return %1 : i8 +} + +// ----- + +// CHECK-LABEL: @indexCastOfIndexCast_vector_lossy +// vector<3xi128> -> vector<3xindex> -> vector<3xi128>: i128 (128 bits) is wider +// than the 64-bit index, so the cast is lossy and must NOT fold. +// CHECK: %[[a:.+]] = arith.index_cast %arg0 : vector<3xi128> to vector<3xindex> +// CHECK: %[[b:.+]] = arith.index_cast %[[a]] : vector<3xindex> to vector<3xi128> +// CHECK: return %[[b]] +func.func @indexCastOfIndexCast_vector_lossy(%arg0: vector<3xi128>) -> vector<3xi128> { + %0 = arith.index_cast %arg0 : vector<3xi128> to vector<3xindex> + %1 = arith.index_cast %0 : vector<3xindex> to vector<3xi128> + return %1 : vector<3xi128> +} + +// ----- + // CHECK-LABEL: @signExtendConstant // CHECK: %[[cres:.+]] = arith.constant -2 : i16 // CHECK: return %[[cres]] diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 6f1a422324e08..4f0d4bb0d8f5d 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -853,8 +853,7 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi // CHECK-SAME: %[[B1:.+]]: i32 // CHECK-DAG: %[[T0:.+]] = linalg.index 0 // CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64 -// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index -// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[T0]]] // CHECK: linalg.yield %[[CAST1]], %[[EXTRACT]] // CHECK: return %[[RESULT]]#1