diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h index d5055f023cdc8..8e86808cc424a 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -189,15 +189,13 @@ class UnrankedMemRefDescriptor : public StructBuilder { /// `unpack`. static unsigned getNumUnpackedValues() { return 2; } - /// Builds IR computing the sizes in bytes (suitable for opaque allocation) - /// and appends the corresponding values into `sizes`. `addressSpaces` - /// which must have the same length as `values`, is needed to handle layouts - /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)). - static void computeSizes(OpBuilder &builder, Location loc, + /// Builds and returns IR computing the size in bytes (suitable for opaque + /// allocation). `addressSpace` is needed to handle layouts where + /// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)). + static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, - ArrayRef values, - ArrayRef addressSpaces, - SmallVectorImpl &sizes); + UnrankedMemRefDescriptor desc, + unsigned addressSpace); /// TODO: The following accessors don't take alignment rules between elements /// of the descriptor struct into account. For some architectures, it might be diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index fce7a3f324b86..522e91421ff55 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, results.push_back(d.memRefDescPtr(builder, loc)); } -void UnrankedMemRefDescriptor::computeSizes( +Value UnrankedMemRefDescriptor::computeSize( OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, - ArrayRef values, ArrayRef addressSpaces, - SmallVectorImpl &sizes) { - if (values.empty()) - return; - assert(values.size() == addressSpaces.size() && - "must provide address space for each descriptor"); + UnrankedMemRefDescriptor desc, unsigned addressSpace) { // Cache the index type. Type indexType = typeConverter.getIndexType(); @@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes( builder, loc, indexType, llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8)); - sizes.reserve(sizes.size() + values.size()); - for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) { - // Emit IR computing the memory necessary to store the descriptor. This - // assumes the descriptor to be - // { type*, type*, index, index[rank], index[rank] } - // and densely packed, so the total size is - // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). - // TODO: consider including the actual size (including eventual padding due - // to data layout) into the unranked descriptor. - Value pointerSize = createIndexAttrConstant( - builder, loc, indexType, - llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); - Value doublePointerSize = - LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); - - // (1 + 2 * rank) * sizeof(index) - Value rank = desc.rank(builder, loc); - Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); - Value doubleRankIncremented = - LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); - Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, - doubleRankIncremented, indexSize); - - // Total allocation size. - Value allocationSize = LLVM::AddOp::create( - builder, loc, indexType, doublePointerSize, rankIndexSize); - sizes.push_back(allocationSize); - } + // Emit IR computing the memory necessary to store the descriptor. This + // assumes the descriptor to be + // { type*, type*, index, index[rank], index[rank] } + // and densely packed, so the total size is + // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). + // TODO: consider including the actual size (including eventual padding due + // to data layout) into the unranked descriptor. + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, + llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); + Value doublePointerSize = + LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); + + // (1 + 2 * rank) * sizeof(index) + Value rank = desc.rank(builder, loc); + Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); + Value doubleRankIncremented = + LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); + Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, + doubleRankIncremented, indexSize); + + // Total allocation size. + Value allocationSize = LLVM::AddOp::create(builder, loc, indexType, + doublePointerSize, rankIndexSize); + return allocationSize; } Value UnrankedMemRefDescriptor::allocatedPtr( diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 2568044f1fd32..72f41fd01fe7c 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( if (unrankedMemrefs.empty()) return success(); - // Compute allocation sizes. - SmallVector sizes; - UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), - unrankedMemrefs, unrankedAddressSpaces, - sizes); - // Get frequently used types. Type indexType = getTypeConverter()->getIndexType(); @@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( Type type = origTypes[i]; if (!isa(type)) continue; - Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + builder, loc, *getTypeConverter(), desc, + unrankedAddressSpaces[unrankedMemrefPos++]); // Allocate memory, copy, and free the source if necessary. Value memory = diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 9216e2a35a5ae..262e0e7a30c63 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering auto result = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(resultTypeU)); result.setRank(rewriter, loc, rank); - SmallVector sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - result, resultAddrSpace, sizes); - Value resultUnderlyingSize = sizes.front(); + Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize( + rewriter, loc, *getTypeConverter(), result, resultAddrSpace); Value resultUnderlyingDesc = LLVM::AllocaOp::create(rewriter, loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); @@ -1530,12 +1528,11 @@ struct MemRefReshapeOpLowering auto targetDesc = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(targetType)); targetDesc.setRank(rewriter, loc, resultRank); - SmallVector sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - targetDesc, addressSpace, sizes); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + rewriter, loc, *getTypeConverter(), targetDesc, addressSpace); Value underlyingDescPtr = LLVM::AllocaOp::create( rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), - sizes.front()); + allocationSize); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref.