diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 11aae4de26c1e..f9ea8dba105a4 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -368,6 +368,14 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( return success(); } + // Memref index_cast is a no-op at the LLVM level since LLVM uses opaque + // pointers and memrefs of different integer/index element types all convert + // to the same LLVM struct type. + if (isa(op.getIn().getType())) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + bool isNonNeg = false; if constexpr (std::is_same_v) isNonNeg = op.getNonNeg(); diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 6a6016c4f5b16..75601e215744c 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -160,6 +160,30 @@ func.func @index_castui_nneg_not_set(%arg0: i1) { // ----- +// Memref index_cast is a no-op at the LLVM level since LLVM uses opaque +// pointers and all memrefs with integer or index element types convert to the +// same struct type. Verify that no sext/zext/trunc is generated. + +// CHECK-LABEL: @memref_index_cast +// CHECK-NOT: llvm.sext +// CHECK-NOT: llvm.trunc +func.func @memref_index_cast(%arg0: memref<3xi32>) -> memref<3xindex> { + %0 = arith.index_cast %arg0 : memref<3xi32> to memref<3xindex> + return %0 : memref<3xindex> +} + +// ----- + +// CHECK-LABEL: @memref_index_castui +// CHECK-NOT: llvm.zext +// CHECK-NOT: llvm.trunc +func.func @memref_index_castui(%arg0: memref<3xi32>) -> memref<3xindex> { + %0 = arith.index_castui %arg0 : memref<3xi32> to memref<3xindex> + return %0 : memref<3xindex> +} + +// ----- + // Checking conversion of signed integer types to floating point. // CHECK-LABEL: @sitofp func.func @sitofp(%arg0 : i32, %arg1 : i64) {