diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td index c169f48e573d0..c97bd04d32896 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td @@ -79,4 +79,14 @@ def Ptr_PtrAddFlags : I32Enum<"PtrAddFlags", "Pointer add flags", [ let cppNamespace = "::mlir::ptr"; } +//===----------------------------------------------------------------------===// +// Ptr diff flags enum properties. +//===----------------------------------------------------------------------===// + +def Ptr_PtrDiffFlags : I8BitEnum<"PtrDiffFlags", "Pointer difference flags", [ + I8BitEnumCase<"none", 0>, I8BitEnumCase<"nuw", 1>, I8BitEnumCase<"nsw", 2> + ]> { + let cppNamespace = "::mlir::ptr"; +} + #endif // PTR_ENUMS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 468a3004d5c62..e14f64330c294 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -415,6 +415,63 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ }]; } +//===----------------------------------------------------------------------===// +// PtrDiffOp +//===----------------------------------------------------------------------===// + +def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [ + Pure, AllTypesMatch<["lhs", "rhs"]>, SameOperandsAndResultShape + ]> { + let summary = "Pointer difference operation"; + let description = [{ + The `ptr_diff` operation computes the difference between two pointers, + returning an integer or index value representing the number of bytes + between them. + + The operation supports both scalar and shaped types with value semantics: + - When both operands are scalar: produces a single difference value + - When both are shaped: performs element-wise subtraction, + shapes must be the same + + The operation also supports the following flags: + - `none`: No flags are set. + - `nuw`: No Unsigned Wrap, if the subtraction causes an unsigned overflow + (that is: the result would be negative), the result is a poison value. + - `nsw`: No Signed Wrap, if the subtraction causes a signed overflow, the + result is a poison value. + + NOTE: The pointer difference is calculated using an integer type specified + by the data layout. The final result will be sign-extended or truncated to + fit the result type as necessary. + + Example: + + ```mlir + // Scalar pointers + %diff = ptr.ptr_diff %p1, %p2 : !ptr.ptr<#ptr.generic_space> -> i64 + + // Shaped pointers + %diffs = ptr.ptr_diff nsw %ptrs1, %ptrs2 : + vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64> + ``` + }]; + let arguments = (ins + Ptr_PtrLikeType:$lhs, Ptr_PtrLikeType:$rhs, + DefaultValuedProp, "PtrDiffFlags::none">:$flags + ); + let results = (outs Ptr_IntLikeType:$result); + let assemblyFormat = [{ + ($flags^)? $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result) + }]; + let extraClassDeclaration = [{ + /// Returns the operand's ptr type. + ptr::PtrType getPtrType(); + /// Returns the result's underlying int type. + Type getIntType(); + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index f0209af8a1ca3..51f25f755a8a6 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -391,6 +392,39 @@ LogicalResult PtrAddOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// PtrDiffOp +//===----------------------------------------------------------------------===// + +LogicalResult PtrDiffOp::verify() { + // If the operands are not shaped early exit. + if (!isa(getLhs().getType())) + return success(); + + // Just check the container type matches, `SameOperandsAndResultShape` handles + // the actual shape. + if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) { + return emitError() << "expected the result to have the same container " + "type as the operands when operands are shaped"; + } + + return success(); +} + +ptr::PtrType PtrDiffOp::getPtrType() { + Type lhsType = getLhs().getType(); + if (auto shapedType = dyn_cast(lhsType)) + return cast(shapedType.getElementType()); + return cast(lhsType); +} + +Type PtrDiffOp::getIntType() { + Type resultType = getResult().getType(); + if (auto shapedType = dyn_cast(resultType)) + return shapedType.getElementType(); + return resultType; +} + //===----------------------------------------------------------------------===// // ToPtrOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index 7e610cd42e931..8d6fffcca45f2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -351,6 +351,42 @@ translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder, return success(); } +/// Translate ptr.ptr_diff operation operation to LLVM IR. +static LogicalResult +translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs()); + llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs()); + + if (!lhs || !rhs) + return ptrDiffOp.emitError("Failed to lookup operands"); + + // Translate result type to LLVM type + llvm::Type *resultType = + moduleTranslation.convertType(ptrDiffOp.getResult().getType()); + if (!resultType) + return ptrDiffOp.emitError("Failed to translate result type"); + + PtrDiffFlags flags = ptrDiffOp.getFlags(); + + // Convert both pointers to integers using ptrtoaddr, and compute the + // difference: lhs - rhs + llvm::Value *llLhs = builder.CreatePtrToAddr(lhs); + llvm::Value *llRhs = builder.CreatePtrToAddr(rhs); + llvm::Value *result = builder.CreateSub( + llLhs, llRhs, /*Name=*/"", + /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw, + /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw); + + // Convert the difference to the expected result type by truncating or + // extending. + if (result->getType() != resultType) + result = builder.CreateIntCast(result, resultType, /*isSigned=*/true); + + moduleTranslation.mapValue(ptrDiffOp.getResult(), result); + return success(); +} + /// Implementation of the dialect interface that translates operations belonging /// to the `ptr` dialect to LLVM IR. class PtrDialectLLVMIRTranslationInterface @@ -371,6 +407,9 @@ class PtrDialectLLVMIRTranslationInterface .Case([&](PtrAddOp ptrAddOp) { return translatePtrAddOp(ptrAddOp, builder, moduleTranslation); }) + .Case([&](PtrDiffOp ptrDiffOp) { + return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation); + }) .Case([&](LoadOp loadOp) { return translateLoadOp(loadOp, builder, moduleTranslation); }) diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir index cc1eeb3cb5744..83e1c880650c5 100644 --- a/mlir/test/Dialect/Ptr/invalid.mlir +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -70,3 +70,11 @@ func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64> return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> } + +// ----- + +func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> vector<8xi64> { + // expected-error@+1 {{the result to have the same container type as the operands when operands are shaped}} + %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64> + return %res : vector<8xi64> +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index 7b2254185f57c..0a906ad559e21 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -211,3 +211,31 @@ func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.p %addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>> return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>> } + +/// Test ptr_diff operations with scalar pointers +func.func @ptr_diff_scalar_ops(%ptr1: !ptr.ptr<#ptr.generic_space>, %ptr2: !ptr.ptr<#ptr.generic_space>) -> (i64, index, i32) { + %diff_i64 = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i64 + %diff_index = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> index + %diff_i32 = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#ptr.generic_space> -> i32 + return %diff_i64, %diff_index, %diff_i32 : i64, index, i32 +} + +/// Test ptr_diff operations with vector pointers +func.func @ptr_diff_vector_ops(%ptrs1: vector<4x!ptr.ptr<#ptr.generic_space>>, %ptrs2: vector<4x!ptr.ptr<#ptr.generic_space>>) -> (vector<4xi64>, vector<4xindex>) { + %diff_i64 = ptr.ptr_diff none %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xi64> + %diff_index = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xindex> + return %diff_i64, %diff_index : vector<4xi64>, vector<4xindex> +} + +/// Test ptr_diff operations with tensor pointers +func.func @ptr_diff_tensor_ops(%ptrs1: tensor<8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<8x!ptr.ptr<#ptr.generic_space>>) -> (tensor<8xi64>, tensor<8xi32>) { + %diff_i64 = ptr.ptr_diff nsw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi64> + %diff_i32 = ptr.ptr_diff nsw | nuw %ptrs1, %ptrs2 : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xi32> + return %diff_i64, %diff_i32 : tensor<8xi64>, tensor<8xi32> +} + +/// Test ptr_diff operations with 2D tensor pointers +func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %ptrs2: tensor<4x8x!ptr.ptr<#ptr.generic_space>>) -> tensor<4x8xi64> { + %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64> + return %diff : tensor<4x8xi64> +} diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir index 2fa794130ec52..e2687e52ece57 100644 --- a/mlir/test/Target/LLVMIR/ptr.mlir +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -281,3 +281,99 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> { %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32 llvm.return %res : !ptr.ptr<#llvm.address_space<0>> } + +// CHECK-LABEL: define i64 @ptr_diff_scalar +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i32 @ptr_diff_scalar_i32 +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32 +// CHECK-NEXT: ret i32 %[[TRUNC]] +// CHECK-NEXT: } +llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i32 { + %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i32 + llvm.return %diff : i32 +} + +// CHECK-LABEL: define <4 x i64> @ptr_diff_vector +// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64> +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64> +// CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret <4 x i64> %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<4x!ptr.ptr<#llvm.address_space<0>>>) -> vector<4xi64> { + %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<4x!ptr.ptr<#llvm.address_space<0>>> -> vector<4xi64> + llvm.return %diffs : vector<4xi64> +} + +// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32 +// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64> +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64> +// CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32> +// CHECK-NEXT: ret <8 x i32> %[[TRUNC]] +// CHECK-NEXT: } +llvm.func @ptr_diff_vector_i32(%ptrs1: vector<8x!ptr.ptr<#llvm.address_space<0>>>, %ptrs2: vector<8x!ptr.ptr<#llvm.address_space<0>>>) -> vector<8xi32> { + %diffs = ptr.ptr_diff %ptrs1, %ptrs2 : vector<8x!ptr.ptr<#llvm.address_space<0>>> -> vector<8xi32> + llvm.return %diffs : vector<8xi32> +} + +// CHECK-LABEL: define i64 @ptr_diff_with_constants() { +// CHECK-NEXT: ret i64 4096 +// CHECK-NEXT: } +llvm.func @ptr_diff_with_constants() -> i64 { + %ptr1 = ptr.constant #ptr.address<0x2000> : !ptr.ptr<#llvm.address_space<0>> + %ptr2 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<0>> + %diff = ptr.ptr_diff %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff nsw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +} + +// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw +// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) { +// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64 +// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64 +// CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]] +// CHECK-NEXT: ret i64 %[[DIFF]] +// CHECK-NEXT: } +llvm.func @ptr_diff_with_flags_nsw_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.ptr<#llvm.address_space<0>>) -> i64 { + %diff = ptr.ptr_diff nsw | nuw %ptr1, %ptr2 : !ptr.ptr<#llvm.address_space<0>> -> i64 + llvm.return %diff : i64 +}