Skip to content

Commit

Permalink
Normalize lowering of MemRef types
Browse files Browse the repository at this point in the history
The RFC for unifying Linalg and Affine compilation passes into an end-to-end flow with a predictable ABI and linkage to external function calls raised the question of why we have variable sized descriptors for memrefs depending on whether they have static or dynamic dimensions  (https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio).

This CL standardizes the ABI on the rank of the memrefs.
The LLVM struct for a memref becomes equivalent to:
```
template <typename Elem, size_t Rank>
struct {
  Elem *ptr;
  int64_t sizes[Rank];
};
```

PiperOrigin-RevId: 270947276
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Sep 24, 2019
1 parent 74cdbf5 commit 42d8fa6
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 258 deletions.
10 changes: 4 additions & 6 deletions mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
Expand Up @@ -181,8 +181,8 @@ class ViewOpConversion : public ConversionPattern {
// Helper function to obtain the size of the given `memref` along the
// dimension `dim`. For static dimensions, emits a constant; for dynamic
// dimensions, extracts the size from the memref descriptor.
auto memrefSize = [int64Ty, pos, i64cst](MemRefType type, Value *memref,
int dim) -> Value * {
auto memrefSize = [&rewriter, int64Ty, i64cst](
MemRefType type, Value *memref, int dim) -> Value * {
assert(dim < type.getRank());
if (type.getShape()[dim] != -1) {
return i64cst(type.getShape()[dim]);
Expand All @@ -191,14 +191,12 @@ class ViewOpConversion : public ConversionPattern {
for (int i = 0; i < dim; ++i)
if (type.getShape()[i] == -1)
++dynamicDimPos;
return intrinsics::extractvalue(int64Ty, memref, pos(1 + dynamicDimPos));
return intrinsics::extractvalue(
int64Ty, memref, rewriter.getI64ArrayAttr({1, dynamicDimPos}));
};

// Helper function to obtain the data pointer of the given `memref`.
auto memrefPtr = [pos](MemRefType type, Value *memref) -> Value * {
if (type.hasStaticShape())
return memref;

auto elementTy = linalg::convertLinalgType(type.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo();
Expand Down
23 changes: 17 additions & 6 deletions mlir/examples/toy/Ch5/mlir/LateLowering.cpp
Expand Up @@ -149,6 +149,7 @@ class PrintOpConversion : public ConversionPattern {

// Create our loop nest now
using namespace edsc;
using extractvalue = intrinsics::ValueBuilder<LLVM::ExtractValueOp>;
using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
Expand All @@ -157,26 +158,36 @@ class PrintOpConversion : public ConversionPattern {
IndexedValue iOp(operand);
IndexHandle i, j, M(vOp.ub(0));

auto *dialect = op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto i8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();

ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
if (vOp.rank() == 1) {
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
llvmCall(retTy,
rewriter.getSymbolRefAttr(printfFunc),
{fmtCst, iOp(i)});
{extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)),
iOp(i)});
});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
// clang-format on
} else {
IndexHandle N(vOp.ub(1));
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
LoopBuilder(&j, zero, N, 1)([&]{
llvmCall(retTy,
rewriter.getSymbolRefAttr(printfFunc),
{fmtCst, iOp(i, j)});
llvmCall(
retTy,
rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)),
iOp(i, j)});
});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol});
llvmCall(
retTy,
rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
});
// clang-format on
}
Expand Down
39 changes: 15 additions & 24 deletions mlir/g3doc/ConversionToLLVMDialect.md
Expand Up @@ -52,36 +52,27 @@ For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and

Memref types in MLIR have both static and dynamic information associated with
them. The dynamic information comprises the buffer pointer as well as sizes of
any dynamically sized dimensions. Memref types are converted into either LLVM IR
pointer types if they are fully statically shaped; or to LLVM IR structure types
if they contain dynamic sizes. In the latter case, the first element of the
structure is a pointer to the converted (using these rules) memref element type,
followed by as many elements as the memref has dynamic sizes. The type of each
of these size arguments will be the LLVM type that results from converting the
MLIR `index` type. Zero-dimensional memrefs are treated as pointers to the
elemental type.
any dynamically sized dimensions. Memref types are normalized and converted to a
descriptor that is only dependent on the rank of the memref. The descriptor
contains the pointer to the data buffer followed by an array containing as many
64-bit integers as the rank of the memref. The array represents the size, in
number of elements, of the memref along the given dimension. For constant memref
dimensions, the corresponding size entry is a constant whose runtime value
matches the static value. This normalization serves as an ABI for the memref
type to interoperate with externally linked functions. In the particular case of
rank `0` memrefs, the size array is omitted, resulting in a wrapped pointer.

Examples:

```mlir {.mlir}
// All of the following are converted to just a pointer type because
// of fully static sizes.
memref<f32>
memref<1 x f32>
memref<10x42x42x43x123 x f32>
// resulting type
!llvm.type<"float*">
// All of the following are converted to a three-element structure
memref<?x? x f32>
memref<42x?x10x35x1x? x f32>
// resulting type assuming 64-bit pointers
!llvm.type<"{float*, i64, i64}">
memref<f32> -> !llvm.type<"{ float* }">
memref<1 x f32> -> !llvm.type<"{ float*, [1 x i64] }">
memref<? x f32> -> !llvm.type<"{ float*, [1 x i64] }">
memref<10x42x42x43x123 x f32> -> !llvm.type<"{ float*, [5 x i64] }">
memref<10x?x42x?x123 x f32> -> !llvm.type<"{ float*, [5 x i64] }">
// Memref types can have vectors as element types
memref<1x? x vector<4xf32>>
// which get converted as well
!llvm.type<"{<4 x float>*, i64}">
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, [1 x i64] }">
```

### Function Types
Expand Down
185 changes: 58 additions & 127 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -122,27 +122,38 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
.getPointerTo();
}

// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
// we return a pointer to the converted element type. Otherwise we return an
// LLVM structure type, where the first element of the structure type is a
// pointer to the elemental type of the MemRef and the following N elements are
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by
// 2. an array containing as many 64-bit integers as the rank of the MemRef:
// the array represents the size, in number of elements, of the memref along
// the given dimension. For constant MemRef dimensions, the corresponding size
// entry is a constant whose runtime value must match the static value.
// TODO(ntv, zinenko): add assertions for the static cases.
//
// template <typename Elem, size_t Rank>
// struct {
// Elem *ptr;
// int64_t sizes[Rank]; // omitted when rank == 0
// };
static unsigned kPtrPosInMemRefDescriptor = 0;
static unsigned kSizePosInMemRefDescriptor = 1;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
assert((type.getAffineMaps().empty() ||
(type.getAffineMaps().size() == 1 &&
type.getAffineMaps().back().isIdentity())) &&
"Non-identity layout maps must have been normalized away");
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrType = elementType.getPointerTo(type.getMemorySpace());

// Extra value for the memory space.
unsigned numDynamicSizes = type.getNumDynamicDims();
// If memref is statically-shaped we return the underlying pointer type.
if (numDynamicSizes == 0)
return ptrType;

SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
types.front() = ptrType;

return LLVM::LLVMType::getStructTy(llvmDialect, types);
auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
auto indexTy = getIndexType();
auto rank = type.getRank();
if (rank > 0) {
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
return LLVM::LLVMType::getStructTy(ptrTy, arrayTy);
}
return LLVM::LLVMType::getStructTy(ptrTy);
}

// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
Expand Down Expand Up @@ -600,25 +611,22 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
ArrayRef<Value *>(allocated));

// Deal with static memrefs
if (numOperands == 0)
return rewriter.replaceOp(op, allocated);

// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});

memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, allocated,
rewriter.getIndexArrayAttr(0));
rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));

// Store dynamically allocated sizes in the descriptor. Dynamic sizes are
// passed in as operands.
for (auto indexedSize : llvm::enumerate(operands)) {
// Store dynamically allocated sizes in the descriptor. Static and dynamic
// sizes are all passed in as operands.
for (auto indexedSize : llvm::enumerate(sizes)) {
int64_t index = indexedSize.index();
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
rewriter.getIndexArrayAttr(1 + indexedSize.index()));
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
}

// Return the final value of the descriptor.
Expand Down Expand Up @@ -679,60 +687,12 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
auto targetType = memRefCastOp.getType();
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();

// Copy the data buffer pointer.
auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
Value *buffer =
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
elementTypePtr, sourceType.hasStaticShape());
// Account for static memrefs as target types
if (targetType.hasStaticShape())
return rewriter.replaceOp(op, buffer);

// Create the new MemRef descriptor.
auto structType = lowering.convertType(targetType);
Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
// Otherwise target type is dynamic memref, so create a proper descriptor.
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, newDescriptor, buffer,
rewriter.getIndexArrayAttr(0));

// Fill in the dynamic sizes of the new descriptor. If the size was
// dynamic, copy it from the old descriptor. If the size was static, insert
// the constant. Note that the positions of dynamic sizes in the
// descriptors start from 1 (the buffer pointer is at position zero).
int64_t sourceDynamicDimIdx = 1;
int64_t targetDynamicDimIdx = 1;
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
// Ignore new static sizes (they will be known from the type). If the
// size was dynamic, update the index of dynamic types.
if (targetType.getShape()[i] != -1) {
if (sourceType.getShape()[i] == -1)
++sourceDynamicDimIdx;
continue;
}

auto sourceSize = sourceType.getShape()[i];
Value *size =
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
transformed.source(), // NB: dynamic memref
rewriter.getIndexArrayAttr(sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, newDescriptor, size,
rewriter.getIndexArrayAttr(targetDynamicDimIdx++));
}
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
"source dynamic dimensions were not processed");
assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
"target dynamic dimensions were not set up");

rewriter.replaceOp(op, newDescriptor);
// memref_cast is defined for source and destination memref types with the
// same element type, same mappings, same address space and same rank.
// Therefore a simple bitcast suffices. If not it is undefined behavior.
auto targetStructType = lowering.convertType(memRefCastOp.getType());
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
transformed.source());
}
};

Expand All @@ -754,25 +714,16 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();

auto shape = type.getShape();
uint64_t index = dimOp.getIndex();
int64_t index = dimOp.getIndex();
// Extract dynamic size from the memref descriptor and define static size
// as a constant.
if (shape[index] == -1) {
// Find the position of the dynamic dimension in the list of dynamic sizes
// by counting the number of preceding dynamic dimensions. Start from 1
// because the buffer pointer is at position zero.
int64_t position = 1;
for (uint64_t i = 0; i < index; ++i) {
if (shape[i] == -1)
++position;
}
if (ShapedType::isDynamic(shape[index]))
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), transformed.memrefOrTensor(),
rewriter.getIndexArrayAttr(position));
} else {
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
else
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
}
}
};

Expand Down Expand Up @@ -829,61 +780,41 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// Dynamic sizes are extracted from the MemRef descriptor, where they start
// from the position 1 (the buffer is at position 0).
SmallVector<Value *, 4> sizes;
unsigned dynamicSizeIdx = 1;
for (int64_t s : shape) {
for (auto en : llvm::enumerate(shape)) {
int64_t s = en.value();
int64_t index = en.index();
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
loc, this->getIndexType(), memRefDescriptor,
rewriter.getIndexArrayAttr(dynamicSizeIdx++));
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
sizes.push_back(size);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
// TODO(ntv, zinenko): assert dynamic descriptor size is constant.
}
}

// The second and subsequent operands are access subscripts. Obtain the
// linearized address in the buffer.
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
Value *subscript = indices.empty()
? nullptr
: linearizeSubscripts(rewriter, loc, indices, sizes);

Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0));
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
ArrayRef<Value *>{dataPtr, subscript},
loc, elementTypePtr, memRefDescriptor,
rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
SmallVector<Value *, 2> gepSubValues(1, dataPtr);
if (subscript)
gepSubValues.push_back(subscript);
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, gepSubValues,
ArrayRef<NamedAttribute>{});
}
// This is a getElementPtr variant, where the value is a direct raw pointer.
// If a shape is empty, we are dealing with a zero-dimensional memref. Return
// the pointer unmodified in this case. Otherwise, linearize subscripts to
// obtain the offset with respect to the base pointer. Use this offset to
// compute and return the element pointer.
Value *getRawElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *rawDataPtr,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
if (shape.empty())
return rawDataPtr;

SmallVector<Value *, 4> sizes;
for (int64_t s : shape) {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
}

Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
return rewriter.create<LLVM::GEPOp>(
loc, elementTypePtr, ArrayRef<Value *>{rawDataPtr, subscript},
ArrayRef<NamedAttribute>{});
}

Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
auto shape = type.getShape();
if (type.hasStaticShape()) {
// NB: If memref was statically-shaped, dataPtr is pointer to raw data.
return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
}
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
}
};
Expand Down

0 comments on commit 42d8fa6

Please sign in to comment.