diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td index 4542f57a62d79..78006d2dec40d 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td @@ -22,6 +22,34 @@ class Ptr_Attr + ]> { + let summary = "Address attribute"; + let description = [{ + The `address` attribute represents a raw memory address, expressed in bytes. + + Example: + + ```mlir + #ptr.address<0x1000> : !ptr.ptr<#ptr.generic_space> + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type, + APIntParameter<"">:$value); + let builders = [ + AttrBuilderWithInferredContext<(ins "PtrType":$type, + "const llvm::APInt &":$value), [{ + return $_get(type.getContext(), type, value); + }]> + ]; + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // GenericSpaceAttr //===----------------------------------------------------------------------===// @@ -37,16 +65,42 @@ def Ptr_GenericSpaceAttr : - Load and store operations are always valid, regardless of the type. - Atomic operations are always valid, regardless of the type. - Cast operations to `generic_space` are always valid. - + Example: ```mlir - #ptr.generic_space + #ptr.generic_space : !ptr.ptr<#ptr.generic_space> ``` }]; let assemblyFormat = ""; } +//===----------------------------------------------------------------------===// +// NullAttr +//===----------------------------------------------------------------------===// + +def Ptr_NullAttr : Ptr_Attr<"Null", "null", [ + DeclareAttrInterfaceMethods + ]> { + let summary = "Null pointer attribute"; + let description = [{ + The `null` attribute represents a null pointer. + + Example: + + ```mlir + #ptr.null + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type); + let builders = [ + AttrBuilderWithInferredContext<(ins "PtrType":$type), [{ + return $_get(type.getContext(), type); + }]> + ]; + let assemblyFormat = ""; +} + //===----------------------------------------------------------------------===// // SpecAttr //===----------------------------------------------------------------------===// @@ -62,7 +116,7 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> { - [Optional] index: bitwidth that should be used when performing index computations for the type. Setting the field to `kOptionalSpecValue`, means the field is optional. - + Furthermore, the attribute will verify that all present values are divisible by 8 (number of bits in a byte), and that `preferred` > `abi`. diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h index bb01ceaaeea54..c252f9efd0471 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h @@ -21,6 +21,12 @@ #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #include "mlir/Dialect/Ptr/IR/PtrEnums.h" +namespace mlir { +namespace ptr { +class PtrType; +} // namespace ptr +} // namespace mlir + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc" diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 3ac12978b947c..468a3004d5c62 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -36,7 +36,7 @@ class Ptr_ShapedValueType allowedTypes, list preds = []> : /*cppType=*/"::mlir::ShapedType">; // A ptr-like type, either scalar or shaped type with value semantics. -def Ptr_PtrLikeType : +def Ptr_PtrLikeType : AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>; // An int-like type, either scalar or shaped type with value semantics. @@ -57,6 +57,31 @@ def Ptr_Mask1DType : def Ptr_Ptr1DType : Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>; +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +def Ptr_ConstantOp : Pointer_Op<"constant", [ + ConstantLike, Pure, AllTypesMatch<["value", "result"]> + ]> { + let summary = "Pointer constant operation"; + let description = [{ + The `constant` operation produces a pointer constant. The attribute must be + a typed attribute of pointer type. + + Example: + + ```mlir + // Create a null pointer + %null = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space> + ``` + }]; + let arguments = (ins TypedAttrInterface:$value); + let results = (outs Ptr_PtrType:$result); + let assemblyFormat = "attr-dict $value"; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // FromPtrOp //===----------------------------------------------------------------------===// @@ -81,7 +106,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ ```mlir %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref - + // Cast the `%ptr` to a memref without utilizing metadata. %memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref ``` @@ -361,13 +386,13 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ // Scalar base and offset %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32 %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32 - + // Shaped base with scalar offset %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32 - + // Scalar base with shaped offset %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32> - + // Both base and offset are shaped %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32> ``` @@ -382,7 +407,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ }]; let hasFolder = 1; let extraClassDeclaration = [{ - /// `ViewLikeOp::getViewSource` method. + /// `ViewLikeOp::getViewSource` method. Value getViewSource() { return getBase(); } /// Returns the ptr type of the operation. @@ -418,7 +443,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [ // Scatter values to multiple memory locations ptr.scatter %value, %ptrs, %mask : vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>> - + // Scatter with alignment ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>> diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h index f45b88dc6deca..0b4f91cd750b8 100644 --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -103,10 +103,11 @@ struct FieldParser< /// Parse any integer. template -struct FieldParser::value, IntT>> { +struct FieldParser::value || + std::is_same_v), + IntT>> { static FailureOr parse(AsmParser &parser) { - IntT value = 0; + IntT value{}; if (parser.parseInteger(value)) return failure(); return value; diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 284c998690170..f0209af8a1ca3 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -56,6 +56,12 @@ verifyAlignment(std::optional alignment, return success(); } +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + //===----------------------------------------------------------------------===// // FromPtrOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp index d777667022a98..7e610cd42e931 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -29,7 +29,7 @@ namespace { /// Converts ptr::AtomicOrdering to llvm::AtomicOrdering static llvm::AtomicOrdering -convertAtomicOrdering(ptr::AtomicOrdering ordering) { +translateAtomicOrdering(ptr::AtomicOrdering ordering) { switch (ordering) { case ptr::AtomicOrdering::not_atomic: return llvm::AtomicOrdering::NotAtomic; @@ -49,10 +49,10 @@ convertAtomicOrdering(ptr::AtomicOrdering ordering) { llvm_unreachable("Unknown atomic ordering"); } -/// Convert ptr.ptr_add operation +/// Translate ptr.ptr_add operation to LLVM IR. static LogicalResult -convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase()); llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset()); @@ -83,18 +83,19 @@ convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.load operation -static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +/// Translate ptr.load operation to LLVM IR. +static LogicalResult +translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr()); if (!ptr) return loadOp.emitError("Failed to lookup pointer operand"); - // Convert result type to LLVM type + // Translate result type to LLVM type llvm::Type *resultType = moduleTranslation.convertType(loadOp.getValue().getType()); if (!resultType) - return loadOp.emitError("Failed to convert result type"); + return loadOp.emitError("Failed to translate result type"); // Create the load instruction. llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0)); @@ -102,7 +103,7 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder, resultType, ptr, alignment, loadOp.getVolatile_()); // Set op flags and metadata. - loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering())); + loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering())); // Set sync scope if specified if (loadOp.getSyncscope().has_value()) { llvm::LLVMContext &ctx = builder.getContext(); @@ -135,10 +136,10 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.store operation +/// Translate ptr.store operation to LLVM IR. static LogicalResult -convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue()); llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr()); @@ -151,7 +152,7 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder, builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_()); // Set op flags and metadata. - storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering())); + storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering())); // Set sync scope if specified if (storeOp.getSyncscope().has_value()) { llvm::LLVMContext &ctx = builder.getContext(); @@ -178,21 +179,21 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.type_offset operation +/// Translate ptr.type_offset operation to LLVM IR. static LogicalResult -convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - // Convert the element type to LLVM type +translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + // Translate the element type to LLVM type llvm::Type *elementType = moduleTranslation.convertType(typeOffsetOp.getElementType()); if (!elementType) - return typeOffsetOp.emitError("Failed to convert the element type"); + return typeOffsetOp.emitError("Failed to translate the element type"); - // Convert result type + // Translate result type llvm::Type *resultType = moduleTranslation.convertType(typeOffsetOp.getResult().getType()); if (!resultType) - return typeOffsetOp.emitError("Failed to convert the result type"); + return typeOffsetOp.emitError("Failed to translate the result type"); // Use GEP with null pointer to compute type size/offset. llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0)); @@ -204,10 +205,10 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.gather operation +/// Translate ptr.gather operation to LLVM IR. static LogicalResult -convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs()); llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask()); llvm::Value *passthrough = @@ -216,11 +217,11 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder, if (!ptrs || !mask || !passthrough) return gatherOp.emitError("Failed to lookup operands"); - // Convert result type to LLVM type. + // Translate result type to LLVM type. llvm::Type *resultType = moduleTranslation.convertType(gatherOp.getResult().getType()); if (!resultType) - return gatherOp.emitError("Failed to convert result type"); + return gatherOp.emitError("Failed to translate result type"); // Get the alignment. llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0)); @@ -233,10 +234,10 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.masked_load operation +/// Translate ptr.masked_load operation to LLVM IR. static LogicalResult -convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr()); llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask()); llvm::Value *passthrough = @@ -245,11 +246,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder, if (!ptr || !mask || !passthrough) return maskedLoadOp.emitError("Failed to lookup operands"); - // Convert result type to LLVM type. + // Translate result type to LLVM type. llvm::Type *resultType = moduleTranslation.convertType(maskedLoadOp.getResult().getType()); if (!resultType) - return maskedLoadOp.emitError("Failed to convert result type"); + return maskedLoadOp.emitError("Failed to translate result type"); // Get the alignment. llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0)); @@ -262,10 +263,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.masked_store operation +/// Translate ptr.masked_store operation to LLVM IR. static LogicalResult -convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translateMaskedStoreOp(MaskedStoreOp maskedStoreOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue()); llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr()); llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask()); @@ -281,10 +283,10 @@ convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder, return success(); } -/// Convert ptr.scatter operation +/// Translate ptr.scatter operation to LLVM IR. static LogicalResult -convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue()); llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs()); llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask()); @@ -300,7 +302,56 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder, return success(); } -/// Implementation of the dialect interface that converts operations belonging +/// Translate ptr.constant operation to LLVM IR. +static LogicalResult +translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + // Translate result type to LLVM type + llvm::PointerType *resultType = dyn_cast_or_null( + moduleTranslation.convertType(constantOp.getResult().getType())); + if (!resultType) + return constantOp.emitError("Expected a valid pointer type"); + + llvm::Value *result = nullptr; + + TypedAttr value = constantOp.getValue(); + if (auto nullAttr = dyn_cast(value)) { + // Create a null pointer constant + result = llvm::ConstantPointerNull::get(resultType); + } else if (auto addressAttr = dyn_cast(value)) { + // Create an integer constant and translate it to pointer + llvm::APInt addressValue = addressAttr.getValue(); + + // Determine the integer type width based on the target's pointer size + llvm::DataLayout dataLayout = + moduleTranslation.getLLVMModule()->getDataLayout(); + unsigned pointerSizeInBits = + dataLayout.getPointerSizeInBits(resultType->getAddressSpace()); + + // Extend or truncate the address value to match pointer size if needed + if (addressValue.getBitWidth() != pointerSizeInBits) { + if (addressValue.getBitWidth() > pointerSizeInBits) { + constantOp.emitWarning() + << "Truncating address value to fit pointer size"; + } + addressValue = addressValue.getBitWidth() < pointerSizeInBits + ? addressValue.zext(pointerSizeInBits) + : addressValue.trunc(pointerSizeInBits); + } + + // Create integer constant and translate to pointer + llvm::Type *intType = builder.getIntNTy(pointerSizeInBits); + llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue); + result = builder.CreateIntToPtr(intValue, resultType); + } else { + return constantOp.emitError("Unsupported constant attribute type"); + } + + moduleTranslation.mapValue(constantOp.getResult(), result); + return success(); +} + +/// Implementation of the dialect interface that translates operations belonging /// to the `ptr` dialect to LLVM IR. class PtrDialectLLVMIRTranslationInterface : public LLVMTranslationDialectInterface { @@ -314,30 +365,35 @@ class PtrDialectLLVMIRTranslationInterface LLVM::ModuleTranslation &moduleTranslation) const final { return llvm::TypeSwitch(op) + .Case([&](ConstantOp constantOp) { + return translateConstantOp(constantOp, builder, moduleTranslation); + }) .Case([&](PtrAddOp ptrAddOp) { - return convertPtrAddOp(ptrAddOp, builder, moduleTranslation); + return translatePtrAddOp(ptrAddOp, builder, moduleTranslation); }) .Case([&](LoadOp loadOp) { - return convertLoadOp(loadOp, builder, moduleTranslation); + return translateLoadOp(loadOp, builder, moduleTranslation); }) .Case([&](StoreOp storeOp) { - return convertStoreOp(storeOp, builder, moduleTranslation); + return translateStoreOp(storeOp, builder, moduleTranslation); }) .Case([&](TypeOffsetOp typeOffsetOp) { - return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation); + return translateTypeOffsetOp(typeOffsetOp, builder, + moduleTranslation); }) .Case([&](GatherOp gatherOp) { - return convertGatherOp(gatherOp, builder, moduleTranslation); + return translateGatherOp(gatherOp, builder, moduleTranslation); }) .Case([&](MaskedLoadOp maskedLoadOp) { - return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation); + return translateMaskedLoadOp(maskedLoadOp, builder, + moduleTranslation); }) .Case([&](MaskedStoreOp maskedStoreOp) { - return convertMaskedStoreOp(maskedStoreOp, builder, - moduleTranslation); + return translateMaskedStoreOp(maskedStoreOp, builder, + moduleTranslation); }) .Case([&](ScatterOp scatterOp) { - return convertScatterOp(scatterOp, builder, moduleTranslation); + return translateScatterOp(scatterOp, builder, moduleTranslation); }) .Default([&](Operation *op) { return op->emitError("Translation for operation '") diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index 51e5ac3ae691d..7b2254185f57c 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -114,7 +114,7 @@ func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.ge } /// Test operations with LLVM address space -func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, +func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, %mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> { // Gather from shared memory (address space 3) %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32> @@ -189,3 +189,25 @@ func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.gener %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64 return %res : tensor<8x!ptr.ptr<#ptr.generic_space>> } + +/// Test constant operations with null pointer +func.func @constant_null_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>) { + %null_generic = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space> + %null_as1 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>> + return %null_generic, %null_as1 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>> +} + +/// Test constant operations with address values +func.func @constant_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) { + %addr_0 = ptr.constant #ptr.address<0> : !ptr.ptr<#ptr.generic_space> + %addr_1000 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>> + %addr_deadbeef = ptr.constant #ptr.address<0xDEADBEEF> : !ptr.ptr<#llvm.address_space<3>> + return %addr_0, %addr_1000, %addr_deadbeef : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>> +} + +/// Test constant operations with large address values +func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>) { + %addr_max32 = ptr.constant #ptr.address<0xFFFFFFFF> : !ptr.ptr<#ptr.generic_space> + %addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>> + return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>> +} diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir index 9b99dd8e3a3eb..2fa794130ec52 100644 --- a/mlir/test/Target/LLVMIR/ptr.mlir +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -41,10 +41,10 @@ llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct< %2 = ptr.type_offset i16 : i32 %3 = ptr.type_offset i32 : i32 %4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)> - %5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)> - %6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)> - %7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)> - %8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)> + %5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)> + %6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)> + %7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)> + %8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)> llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)> } @@ -194,7 +194,7 @@ llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm // CHECK-NEXT: call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]]) // CHECK-NEXT: ret void // CHECK-NEXT: } -llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, +llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, %mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) { // Test with shared memory address space (3) and f64 elements %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64> @@ -255,3 +255,29 @@ llvm.func @llvm_ops_with_ptr_nvvm_values(%arg0: !llvm.ptr) { llvm.store %1, %arg0 : !ptr.ptr<#nvvm.memory_space>, !llvm.ptr llvm.return } + +// CHECK-LABEL: define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() { +// CHECK-NEXT: ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) } +llvm.func @constant_address_op() -> + !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, + !ptr.ptr<#llvm.address_space<1>>, + !ptr.ptr<#llvm.address_space<2>>)> { + %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>> + %1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>> + %2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>> + %3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> + %4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> + %5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> + %6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> + llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> +} + +// Test gep folders. +// CHECK-LABEL: define ptr @ptr_add_cst() { +// CHECK-NEXT: ret ptr inttoptr (i64 42 to ptr) +llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> { + %off = llvm.mlir.constant(42 : i32) : i32 + %ptr = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>> + %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32 + llvm.return %res : !ptr.ptr<#llvm.address_space<0>> +}