diff --git a/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h new file mode 100644 index 0000000000000..0ff92bc85668c --- /dev/null +++ b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h @@ -0,0 +1,27 @@ +//===- PtrToLLVM.h - Ptr to LLVM dialect conversion -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H +#define MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +namespace ptr { +/// Populate the convert to LLVM patterns for the `ptr` dialect. +void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +/// Register the convert to LLVM interface for the `ptr` dialect. +void registerConvertPtrToLLVMInterface(DialectRegistry ®istry); +} // namespace ptr +} // namespace mlir + +#endif // MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 134fe8e14ca38..71986f83c4870 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -50,6 +50,7 @@ add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) +add_subdirectory(PtrToLLVM) add_subdirectory(ReconcileUnrealizedCasts) add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToEmitC) diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..2d416be13ee30 --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRPtrToLLVM + PtrToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRPtrDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + ) diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp new file mode 100644 index 0000000000000..a0758aa8b1369 --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp @@ -0,0 +1,440 @@ +//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/TypeUtilities.h" +#include + +using namespace mlir; + +namespace { +//===----------------------------------------------------------------------===// +// FromPtrOpConversion +//===----------------------------------------------------------------------===// +struct FromPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// GetMetadataOpConversion +//===----------------------------------------------------------------------===// +struct GetMetadataOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// PtrAddOpConversion +//===----------------------------------------------------------------------===// +struct PtrAddOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// ToPtrOpConversion +//===----------------------------------------------------------------------===// +struct ToPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// TypeOffsetOpConversion +//===----------------------------------------------------------------------===// +struct TypeOffsetOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Internal functions +//===----------------------------------------------------------------------===// + +// Function to create an LLVM struct type representing a memref metadata. +static FailureOr +createMemRefMetadataType(MemRefType type, + const LLVMTypeConverter &typeConverter) { + MLIRContext *context = type.getContext(); + // Get the address space. + FailureOr addressSpace = typeConverter.getMemRefAddressSpace(type); + if (failed(addressSpace)) + return failure(); + + // Get pointer type (using address space 0 by default) + auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace); + + // Get the strides offsets and shape. + SmallVector strides; + int64_t offset; + if (failed(type.getStridesAndOffset(strides, offset))) + return failure(); + ArrayRef shape = type.getShape(); + + // Use index type from the type converter for the descriptor elements + Type indexType = typeConverter.getIndexType(); + + // For a ranked memref, the descriptor contains: + // 1. The pointer to the allocated data + // 2. The pointer to the aligned data + // 3. The dynamic offset? + // 4. The dynamic sizes? + // 5. The dynamic strides? + SmallVector elements; + + // Allocated pointer. + elements.push_back(ptrType); + + // Potentially add the dynamic offset. + if (offset == ShapedType::kDynamic) + elements.push_back(indexType); + + // Potentially add the dynamic sizes. + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + elements.push_back(indexType); + } + + // Potentially add the dynamic strides. + for (int64_t stride : strides) { + if (stride == ShapedType::kDynamic) + elements.push_back(indexType); + } + return LLVM::LLVMStructType::getLiteral(context, elements); +} + +//===----------------------------------------------------------------------===// +// FromPtrOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult FromPtrOpConversion::matchAndRewrite( + ptr::FromPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the target memref type + auto mTy = dyn_cast(op.getResult().getType()); + if (!mTy) + return rewriter.notifyMatchFailure(op, "Expected memref result type"); + + if (!op.getMetadata() && op.getType().hasPtrMetadata()) { + return rewriter.notifyMatchFailure( + op, "Can convert only memrefs with metadata"); + } + + // Convert the result type + Type descriptorTy = getTypeConverter()->convertType(mTy); + if (!descriptorTy) + return rewriter.notifyMatchFailure(op, "Failed to convert result type"); + + // Get the strides, offsets and shape. + SmallVector strides; + int64_t offset; + if (failed(mTy.getStridesAndOffset(strides, offset))) { + return rewriter.notifyMatchFailure(op, + "Failed to get the strides and offset"); + } + ArrayRef shape = mTy.getShape(); + + // Create a new memref descriptor + Location loc = op.getLoc(); + auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy); + + // Set the allocated and aligned pointers. + desc.setAllocatedPtr( + rewriter, loc, + rewriter.create(loc, adaptor.getMetadata(), 0)); + desc.setAlignedPtr(rewriter, loc, adaptor.getPtr()); + + // Extract metadata from the passed struct. + unsigned fieldIdx = 1; + + // Set dynamic offset if needed. + if (offset == ShapedType::kDynamic) { + Value offsetValue = rewriter.create( + loc, adaptor.getMetadata(), fieldIdx++); + desc.setOffset(rewriter, loc, offsetValue); + } else { + desc.setConstantOffset(rewriter, loc, offset); + } + + // Set dynamic sizes if needed. + for (auto [i, dim] : llvm::enumerate(shape)) { + if (dim == ShapedType::kDynamic) { + Value sizeValue = rewriter.create( + loc, adaptor.getMetadata(), fieldIdx++); + desc.setSize(rewriter, loc, i, sizeValue); + } else { + desc.setConstantSize(rewriter, loc, i, dim); + } + } + + // Set dynamic strides if needed. + for (auto [i, stride] : llvm::enumerate(strides)) { + if (stride == ShapedType::kDynamic) { + Value strideValue = rewriter.create( + loc, adaptor.getMetadata(), fieldIdx++); + desc.setStride(rewriter, loc, i, strideValue); + } else { + desc.setConstantStride(rewriter, loc, i, stride); + } + } + + rewriter.replaceOp(op, static_cast(desc)); + return success(); +} + +//===----------------------------------------------------------------------===// +// GetMetadataOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult GetMetadataOpConversion::matchAndRewrite( + ptr::GetMetadataOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto mTy = dyn_cast(op.getPtr().getType()); + if (!mTy) + return rewriter.notifyMatchFailure(op, "Only memref metadata is supported"); + + // Get the metadata type. + FailureOr mdTy = + createMemRefMetadataType(mTy, *getTypeConverter()); + if (failed(mdTy)) { + return rewriter.notifyMatchFailure(op, + "Failed to create the metadata type"); + } + + // Get the memref descriptor. + MemRefDescriptor descriptor(adaptor.getPtr()); + + // Get the strides offsets and shape. + SmallVector strides; + int64_t offset; + if (failed(mTy.getStridesAndOffset(strides, offset))) { + return rewriter.notifyMatchFailure(op, + "Failed to get the strides and offset"); + } + ArrayRef shape = mTy.getShape(); + + // Create a new LLVM struct to hold the metadata + Location loc = op.getLoc(); + Value sV = rewriter.create(loc, *mdTy); + + // First element is the allocated pointer. + sV = rewriter.create( + loc, sV, descriptor.allocatedPtr(rewriter, loc), 0); + + // Track the current field index. + unsigned fieldIdx = 1; + + // Add dynamic offset if needed. + if (offset == ShapedType::kDynamic) { + sV = rewriter.create( + loc, sV, descriptor.offset(rewriter, loc), fieldIdx++); + } + + // Add dynamic sizes if needed. + for (auto [i, dim] : llvm::enumerate(shape)) { + if (dim != ShapedType::kDynamic) + continue; + sV = rewriter.create( + loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++); + } + + // Add dynamic strides if needed + for (auto [i, stride] : llvm::enumerate(strides)) { + if (stride != ShapedType::kDynamic) + continue; + sV = rewriter.create( + loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++); + } + rewriter.replaceOp(op, sV); + return success(); +} + +//===----------------------------------------------------------------------===// +// PtrAddOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult +PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get and check the base. + Value base = adaptor.getBase(); + if (!isa(base.getType())) + return rewriter.notifyMatchFailure(op, "Incompatible pointer type"); + + // Get the offset. + Value offset = adaptor.getOffset(); + + // Ptr assumes the offset is in bytes. + Type elementType = IntegerType::get(rewriter.getContext(), 8); + + // Convert the `ptradd` flags. + LLVM::GEPNoWrapFlags flags; + switch (op.getFlags()) { + case ptr::PtrAddFlags::none: + flags = LLVM::GEPNoWrapFlags::none; + break; + case ptr::PtrAddFlags::nusw: + flags = LLVM::GEPNoWrapFlags::nusw; + break; + case ptr::PtrAddFlags::nuw: + flags = LLVM::GEPNoWrapFlags::nuw; + break; + case ptr::PtrAddFlags::inbounds: + flags = LLVM::GEPNoWrapFlags::inbounds; + break; + } + + // Create the GEP operation with appropriate arguments + rewriter.replaceOpWithNewOp(op, base.getType(), elementType, + base, ValueRange{offset}, flags); + return success(); +} + +//===----------------------------------------------------------------------===// +// ToPtrOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult +ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Bail if it's not a memref. + if (!isa(op.getPtr().getType())) + return rewriter.notifyMatchFailure(op, "Expected a memref input"); + + // Extract the aligned pointer from the memref descriptor. + rewriter.replaceOp( + op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc())); + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeOffsetOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult TypeOffsetOpConversion::matchAndRewrite( + ptr::TypeOffsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Convert the type attribute. + Type type = getTypeConverter()->convertType(op.getElementType()); + if (!type) + return rewriter.notifyMatchFailure(op, "Couldn't convert the type"); + + // Convert the result type. + Type rTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!rTy) + return rewriter.notifyMatchFailure(op, "Couldn't convert the result type"); + + // TODO: Use MLIR's data layout. We don't use it because overall support is + // still flaky. + + // Create an LLVM pointer type for the GEP operation. + auto ptrTy = LLVM::LLVMPointerType::get(getContext()); + + // Create a GEP operation to compute the offset of the type. + auto offset = + LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type, + LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy), + ArrayRef({LLVM::GEPArg(1)})); + + // Replace the original op with a PtrToIntOp using the computed offset. + rewriter.replaceOpWithNewOp(op, rTy, offset.getRes()); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Ptr to LLVM. +struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &converter, + RewritePatternSet &patterns) const final { + ptr::populatePtrToLLVMConversionPatterns(converter, patterns); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// API +//===----------------------------------------------------------------------===// + +void mlir::ptr::populatePtrToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // Add address space conversions. + converter.addTypeAttributeConversion( + [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace) + -> TypeConverter::AttributeConversionResult { + if (type.getMemorySpace() != memorySpace) + return TypeConverter::AttributeConversionResult::na(); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0); + }); + + // Add type conversions. + converter.addConversion([&](ptr::PtrType type) -> Type { + std::optional maybeAttr = + converter.convertTypeAttribute(type, type.getMemorySpace()); + auto memSpace = + maybeAttr ? dyn_cast_or_null(*maybeAttr) : IntegerAttr(); + if (!memSpace) + return {}; + return LLVM::LLVMPointerType::get(type.getContext(), + memSpace.getValue().getSExtValue()); + }); + + // Convert ptr metadata of memref type. + converter.addConversion([&](ptr::PtrMetadataType type) -> Type { + auto mTy = dyn_cast(type.getType()); + if (!mTy) + return {}; + FailureOr res = + createMemRefMetadataType(mTy, converter); + return failed(res) ? Type() : res.value(); + }); + + // Add conversion patterns. + patterns.add(converter); +} + +void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 232ddaf6762c4..69a85dbe141ce 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -28,6 +28,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -81,6 +82,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { registerConvertMemRefToEmitCInterface(registry); registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); + ptr::registerConvertPtrToLLVMInterface(registry); registerConvertOpenMPToLLVMInterface(registry); registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir new file mode 100644 index 0000000000000..dc645fe0480fa --- /dev/null +++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir @@ -0,0 +1,318 @@ +// RUN: mlir-opt %s -convert-to-llvm | FileCheck %s + +// Tests different variants of ptr_add operation with various attributes +// (regular, nusw, nuw, inbounds) +// CHECK-LABEL: llvm.func @test_ptr_add( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: i64) -> !llvm.struct<(ptr, ptr, ptr, ptr)> { +// CHECK: %[[VAL_0:.*]] = llvm.getelementptr %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK: %[[VAL_1:.*]] = llvm.getelementptr nusw %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr nuw %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr inbounds %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK: %[[VAL_4:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, ptr, ptr)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_4]][0] : !llvm.struct<(ptr, ptr, ptr, ptr)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_5]][1] : !llvm.struct<(ptr, ptr, ptr, ptr)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_6]][2] : !llvm.struct<(ptr, ptr, ptr, ptr)> +// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_7]][3] : !llvm.struct<(ptr, ptr, ptr, ptr)> +// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)> +// CHECK: } +func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) { + %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index + %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index + %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index + %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index + return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space> +} + +// Tests type_offset operation which returns the size of different types +// CHECK-LABEL: llvm.func @test_type_offset() -> !llvm.struct<(i64, i64, i64)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.zero : !llvm.ptr +// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][1] : (!llvm.ptr) -> !llvm.ptr, f32 +// CHECK: %[[VAL_2:.*]] = llvm.ptrtoint %[[VAL_1]] : !llvm.ptr to i64 +// CHECK: %[[VAL_3:.*]] = llvm.mlir.zero : !llvm.ptr +// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK: %[[VAL_5:.*]] = llvm.ptrtoint %[[VAL_4]] : !llvm.ptr to i64 +// CHECK: %[[VAL_6:.*]] = llvm.mlir.zero : !llvm.ptr +// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)> +// CHECK: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK: %[[VAL_9:.*]] = llvm.mlir.poison : !llvm.struct<(i64, i64, i64)> +// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_9]][0] : !llvm.struct<(i64, i64, i64)> +// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_10]][1] : !llvm.struct<(i64, i64, i64)> +// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_11]][2] : !llvm.struct<(i64, i64, i64)> +// CHECK: llvm.return %[[VAL_12]] : !llvm.struct<(i64, i64, i64)> +// CHECK: } +func.func @test_type_offset() -> (index, index, index) { + %0 = ptr.type_offset f32 : index + %1 = ptr.type_offset i64 : index + %2 = ptr.type_offset !llvm.struct<(i32, f64)> : index + return %0, %1, %2 : index, index, index +} + +// Tests converting a memref to a pointer using to_ptr +// CHECK-LABEL: llvm.func @test_to_ptr( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64) -> !llvm.ptr { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: llvm.return %[[VAL_6]] : !llvm.ptr +// CHECK: } +func.func @test_to_ptr(%arg0: memref<10xf32, #ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> { + %0 = ptr.to_ptr %arg0 : memref<10xf32, #ptr.generic_space> -> <#ptr.generic_space> + return %0 : !ptr.ptr<#ptr.generic_space> +} + +// Tests extracting metadata from a static-sized memref +// CHECK-LABEL: llvm.func @test_get_metadata_static( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.mlir.undef : !llvm.struct<(ptr)> +// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.struct<(ptr)> +// CHECK: llvm.return %[[VAL_10]] : !llvm.struct<(ptr)> +// CHECK: } +func.func @test_get_metadata_static(%arg0: memref<10x20xf32, #ptr.generic_space>) -> !ptr.ptr_metadata> { + %0 = ptr.get_metadata %arg0 : memref<10x20xf32, #ptr.generic_space> + return %0 : !ptr.ptr_metadata> +} + +// Tests extracting metadata from a dynamically-sized memref +// CHECK-LABEL: llvm.func @test_get_metadata_dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr, i64, i64, i64)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_11:.*]] = llvm.extractvalue %[[VAL_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_10]][1] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_13:.*]] = llvm.extractvalue %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_14:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_12]][2] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_15:.*]] = llvm.extractvalue %[[VAL_7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_16:.*]] = llvm.insertvalue %[[VAL_15]], %[[VAL_14]][3] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: llvm.return %[[VAL_16]] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: } +func.func @test_get_metadata_dynamic(%arg0: memref) -> !ptr.ptr_metadata> { + %0 = ptr.get_metadata %arg0 : memref + return %0 : !ptr.ptr_metadata> +} + +// Tests reconstructing a static-sized memref from a pointer and metadata +// CHECK-LABEL: llvm.func @test_from_ptr_static( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.struct<(ptr)>) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(ptr)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_3]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.return %[[VAL_13]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: } +func.func @test_from_ptr_static(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: !ptr.ptr_metadata>) -> memref<10x20xf32, #ptr.generic_space> { + %0 = ptr.from_ptr %arg0 metadata %arg1 : <#ptr.generic_space> -> memref<10x20xf32, #ptr.generic_space> + return %0 : memref<10x20xf32, #ptr.generic_space> +} + +// Tests reconstructing a dynamically-sized memref from a pointer and metadata +// CHECK-LABEL: llvm.func @test_from_ptr_dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.struct<(ptr, i64, i64, i64)>) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_3]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[ARG1]][2] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[ARG1]][3] : !llvm.struct<(ptr, i64, i64, i64)> +// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.return %[[VAL_13]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: } +func.func @test_from_ptr_dynamic(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: !ptr.ptr_metadata>) -> memref { + %0 = ptr.from_ptr %arg0 metadata %arg1 : <#ptr.generic_space> -> memref + return %0 : memref +} + +// Tests a round-trip conversion of a memref with mixed static/dynamic dimensions +// CHECK-LABEL: llvm.func @test_memref_mixed( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG7]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[ARG8]], %[[VAL_8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_9]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_11:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i64)> +// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_9]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0] : !llvm.struct<(ptr, i64, i64)> +// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_9]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][1] : !llvm.struct<(ptr, i64, i64)> +// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][2] : !llvm.struct<(ptr, i64, i64)> +// CHECK: llvm.return %[[VAL_9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: } +func.func @test_memref_mixed(%arg0: memref<10x?x30xf32, #ptr.generic_space>) -> memref<10x?x30xf32, #ptr.generic_space> { + %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space> + %1 = ptr.get_metadata %arg0 : memref<10x?x30xf32, #ptr.generic_space> + %2 = ptr.from_ptr %0 metadata %1 : <#ptr.generic_space> -> memref<10x?x30xf32, #ptr.generic_space> + return %2 : memref<10x?x30xf32, #ptr.generic_space> +} + +// Tests a round-trip conversion of a strided memref with explicit offset +// CHECK-LABEL: llvm.func @test_memref_strided( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_9:.*]] = llvm.mlir.undef : !llvm.struct<(ptr)> +// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0] : !llvm.struct<(ptr)> +// CHECK: llvm.return %[[VAL_7]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: } +func.func @test_memref_strided(%arg0: memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space>) -> memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> { + %0 = ptr.to_ptr %arg0 : memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> -> <#ptr.generic_space> + %1 = ptr.get_metadata %arg0 : memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> + %2 = ptr.from_ptr %0 metadata %1 : <#ptr.generic_space> -> memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> + return %2 : memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> +} + +// Tests a comprehensive scenario with fully dynamic memref, including pointer arithmetic +// CHECK-LABEL: llvm.func @test_comprehensive_dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_9:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_7]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][1] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][2] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][3] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_18:.*]] = llvm.extractvalue %[[VAL_7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_19:.*]] = llvm.insertvalue %[[VAL_18]], %[[VAL_17]][4] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_20:.*]] = llvm.extractvalue %[[VAL_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_19]][5] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_22:.*]] = llvm.mlir.zero : !llvm.ptr +// CHECK: %[[VAL_23:.*]] = llvm.getelementptr %[[VAL_22]][1] : (!llvm.ptr) -> !llvm.ptr, f32 +// CHECK: %[[VAL_24:.*]] = llvm.ptrtoint %[[VAL_23]] : !llvm.ptr to i64 +// CHECK: %[[VAL_25:.*]] = llvm.getelementptr inbounds %[[VAL_8]]{{\[}}%[[VAL_24]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK: %[[VAL_26:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_27:.*]] = llvm.extractvalue %[[VAL_21]][0] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_28:.*]] = llvm.insertvalue %[[VAL_27]], %[[VAL_26]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_29:.*]] = llvm.insertvalue %[[VAL_25]], %[[VAL_28]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_30:.*]] = llvm.extractvalue %[[VAL_21]][1] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_31:.*]] = llvm.insertvalue %[[VAL_30]], %[[VAL_29]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_32:.*]] = llvm.extractvalue %[[VAL_21]][2] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_31]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_34:.*]] = llvm.extractvalue %[[VAL_21]][3] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_33]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_36:.*]] = llvm.extractvalue %[[VAL_21]][4] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_37:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_35]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[VAL_38:.*]] = llvm.extractvalue %[[VAL_21]][5] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)> +// CHECK: %[[VAL_39:.*]] = llvm.insertvalue %[[VAL_38]], %[[VAL_37]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.return %[[VAL_39]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: } +func.func @test_comprehensive_dynamic(%arg0: memref, #ptr.generic_space>) -> memref, #ptr.generic_space> { + %0 = ptr.to_ptr %arg0 : memref, #ptr.generic_space> -> <#ptr.generic_space> + %1 = ptr.get_metadata %arg0 : memref, #ptr.generic_space> + %2 = ptr.type_offset f32 : index + %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index + %4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref, #ptr.generic_space> + return %4 : memref, #ptr.generic_space> +} + +// Tests a round-trip conversion of a 0D (scalar) memref +// CHECK-LABEL: llvm.func @test_memref_0d( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64)> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_3]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[VAL_5:.*]] = llvm.mlir.undef : !llvm.struct<(ptr)> +// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][0] : !llvm.struct<(ptr)> +// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: } +func.func @test_memref_0d(%arg0: memref) -> memref { + %0 = ptr.to_ptr %arg0 : memref -> <#ptr.generic_space> + %1 = ptr.get_metadata %arg0 : memref + %2 = ptr.from_ptr %0 metadata %1 : <#ptr.generic_space> -> memref + return %2 : memref +} + +// Tests ptr indexing with a pointer coming from a memref. +// CHECK-LABEL: llvm.func @test_memref_ptradd_indexing( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64, %[[ARG9:.*]]: i64) -> !llvm.ptr { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG7]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[ARG8]], %[[VAL_8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_9]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[VAL_11:.*]] = llvm.mlir.zero : !llvm.ptr +// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_11]][1] : (!llvm.ptr) -> !llvm.ptr, f32 +// CHECK: %[[VAL_13:.*]] = llvm.ptrtoint %[[VAL_12]] : !llvm.ptr to i64 +// CHECK: %[[VAL_14:.*]] = llvm.mul %[[VAL_13]], %[[ARG9]] : i64 +// CHECK: %[[VAL_15:.*]] = llvm.getelementptr %[[VAL_10]]{{\[}}%[[VAL_14]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK: llvm.return %[[VAL_15]] : !llvm.ptr +// CHECK: } +func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_space>, %arg1: index) -> !ptr.ptr<#ptr.generic_space> { + %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space> + %1 = ptr.type_offset f32 : index + %2 = arith.muli %1, %arg1 : index + %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index + return %3 : !ptr.ptr<#ptr.generic_space> +}