Skip to content

Commit

Permalink
Normalize MemRefType lowering to LLVM as strided MemRef descriptor
Browse files Browse the repository at this point in the history
This CL finishes the implementation of the lowering part of the [strided memref RFC](https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio).

Strided memrefs correspond conceptually to the following templated C++ struct:
```
template <typename Elem, size_t Rank>
struct {
  Elem *ptr;
  int64_t offset;
  int64_t sizes[Rank];
  int64_t strides[Rank];
};
```
The linearization procedure for address calculation for strided memrefs is the same as for linalg views:
`base_offset + SUM_i index_i * stride_i`.

The following CL will unify Linalg and Standard by removing !linalg.view in favor of strided memrefs.

PiperOrigin-RevId: 272033399
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Sep 30, 2019
1 parent 2713f36 commit 923b33e
Show file tree
Hide file tree
Showing 13 changed files with 389 additions and 254 deletions.
8 changes: 2 additions & 6 deletions mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
Expand Up @@ -187,12 +187,8 @@ class ViewOpConversion : public ConversionPattern {
if (type.getShape()[dim] != -1) {
return i64cst(type.getShape()[dim]);
}
int dynamicDimPos = 0;
for (int i = 0; i < dim; ++i)
if (type.getShape()[i] == -1)
++dynamicDimPos;
return intrinsics::extractvalue(
int64Ty, memref, rewriter.getI64ArrayAttr({1, dynamicDimPos}));
return intrinsics::extractvalue(int64Ty, memref,
rewriter.getI64ArrayAttr({2, dim}));
};

// Helper function to obtain the data pointer of the given `memref`.
Expand Down
19 changes: 12 additions & 7 deletions mlir/examples/Linalg/Linalg3/Execution.cpp
Expand Up @@ -68,30 +68,35 @@ FuncOp makeFunctionWithAMatmulOp(ModuleOp module, StringRef name) {
// This is equivalent to the structure that the conversion produces.
struct MemRefDescriptor2D {
float *ptr;
int64_t sz1;
int64_t sz2;
int64_t offset;
int64_t sizes[2];
int64_t strides[2];
};

// Alocate a 2D memref of the given size, store the sizes in the descriptor and
// initialize all values with 1.0f.
static MemRefDescriptor2D allocateInit2DMemref(int64_t sz1, int64_t sz2) {
MemRefDescriptor2D descriptor;
descriptor.ptr = static_cast<float *>(malloc(sizeof(float) * sz1 * sz2));
descriptor.sz1 = sz1;
descriptor.sz2 = sz2;
descriptor.offset = 0;
descriptor.sizes[0] = sz1;
descriptor.sizes[1] = sz2;
descriptor.strides[0] = sz2;
descriptor.strides[1] = 1;
for (int64_t i = 0, e = sz1 * sz2; i < e; ++i)
descriptor.ptr[i] = 1.0f;
return descriptor;
}

// Print the contents of the memref given its descriptor.
static void print2DMemref(const MemRefDescriptor2D &descriptor) {
for (int64_t i = 0; i < descriptor.sz1; ++i) {
for (int64_t i = 0; i < descriptor.sizes[0]; ++i) {
llvm::outs() << '[';
for (int64_t j = 0; j < descriptor.sz2; ++j) {
for (int64_t j = 0; j < descriptor.sizes[1]; ++j) {
if (j != 0)
llvm::outs() << ", ";
llvm::outs() << descriptor.ptr[i * descriptor.sz2 + j];
llvm::outs() << descriptor.ptr[i * descriptor.strides[0] +
j * descriptor.strides[1]];
}
llvm::outs() << "]\n";
}
Expand Down
45 changes: 29 additions & 16 deletions mlir/g3doc/ConversionToLLVMDialect.md
Expand Up @@ -51,28 +51,41 @@ For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
### Memref Types

Memref types in MLIR have both static and dynamic information associated with
them. The dynamic information comprises the buffer pointer as well as sizes of
any dynamically sized dimensions. Memref types are normalized and converted to a
descriptor that is only dependent on the rank of the memref. The descriptor
contains the pointer to the data buffer followed by an array containing as many
64-bit integers as the rank of the memref. The array represents the size, in
number of elements, of the memref along the given dimension. For constant memref
dimensions, the corresponding size entry is a constant whose runtime value
matches the static value. This normalization serves as an ABI for the memref
type to interoperate with externally linked functions. In the particular case of
rank `0` memrefs, the size array is omitted, resulting in a wrapped pointer.
them. The dynamic information comprises the buffer pointer as well as sizes and
strides of any dynamically sized dimensions. Memref types are normalized and
converted to a descriptor that is only dependent on the rank of the memref. The
descriptor contains:

1. the pointer to the data buffer, followed by
2. a lowered `index`-type integer containing the distance between the beginning
of the buffer and the first element to be accessed through the memref,
followed by
3. an array containing as many `index`-type integers as the rank of the memref:
the array represents the size, in number of elements, of the memref along
the given dimension. For constant MemRef dimensions, the corresponding size
entry is a constant whose runtime value must match the static value,
followed by
4. a second array containing as many 64-bit integers as the rank of the MemRef:
the second array represents the "stride" (in tensor abstraction sense), i.e.
the number of consecutive elements of the underlying buffer.

For constant memref dimensions, the corresponding size entry is a constant whose
runtime value matches the static value. This normalization serves as an ABI for
the memref type to interoperate with externally linked functions. In the
particular case of rank `0` memrefs, the size and stride arrays are omitted,
resulting in a struct containing a pointer + offset.

Examples:

```mlir {.mlir}
memref<f32> -> !llvm.type<"{ float* }">
memref<1 x f32> -> !llvm.type<"{ float*, [1 x i64] }">
memref<? x f32> -> !llvm.type<"{ float*, [1 x i64] }">
memref<10x42x42x43x123 x f32> -> !llvm.type<"{ float*, [5 x i64] }">
memref<10x?x42x?x123 x f32> -> !llvm.type<"{ float*, [5 x i64] }">
memref<f32> -> !llvm.type<"{ float*, i64 }">
memref<1 x f32> -> !llvm.type<"{ float*, i64, [1 x i64], [1 x i64] }">
memref<? x f32> -> !llvm.type<"{ float*, i64, [1 x i64], [1 x i64] }">
memref<10x42x42x43x123 x f32> -> !llvm.type<"{ float*, i64, [5 x i64], [5 x i64] }">
memref<10x?x42x?x123 x f32> -> !llvm.type<"{ float*, i64, [5 x i64], [5 x i64] }">
// Memref types can have vectors as element types
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, [1 x i64] }">
memref<1x? x vector<4xf32>> -> !llvm.type<"{ <4 x float>*, i64, [1 x i64], [1 x i64] }">
```

### Function Types
Expand Down
18 changes: 9 additions & 9 deletions mlir/include/mlir/IR/StandardTypes.h
Expand Up @@ -375,22 +375,22 @@ class MemRefType
/// where K and ki's are constants or symbols.
///
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with kDynamicStride). Strides encode the distance in
/// the number of elements between successive entries along a particular
/// dimension.
/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
/// non-contiguous memory region of `42` by `16` `f32` elements in which the
/// distance between two consecutive elements along the outer dimension is `1`
/// and the distance between two consecutive elements along the inner
/// dimension is `64`.
/// or dynamic (encoded with kDynamicStrideOrOffset). Strides encode the
/// distance in the number of elements between successive entries along a
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
/// elements in which the distance between two consecutive elements along the
/// outer dimension is `1` and the distance between two consecutive elements
/// along the inner dimension is `64`.
///
/// If a simple strided form cannot be extracted from the composition of the
/// layout map, returns llvm::None.
///
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order followed by the constant offset, to make indexing intuitive into the
/// result.
static constexpr int64_t kDynamicStride = std::numeric_limits<int64_t>::min();
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides) const;

static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
Expand Down
180 changes: 123 additions & 57 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -125,24 +125,36 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by
// 2. an array containing as many 64-bit integers as the rank of the MemRef:
// the array represents the size, in number of elements, of the memref along
// the given dimension. For constant MemRef dimensions, the corresponding size
// entry is a constant whose runtime value must match the static value.
// 2. a lowered `index`-type integer containing the distance between the
// beginning of the buffer and the first element to be accessed through the
// view, followed by
// 3. an array containing as many `index`-type integers as the rank of the
// MemRef: the array represents the size, in number of elements, of the memref
// along the given dimension. For constant MemRef dimensions, the
// corresponding size entry is a constant whose runtime value must match the
// static value, followed by
// 4. a second array containing as many `index`-type integers as the rank of
// the MemRef: the second array represents the "stride" (in tensor abstraction
// sense), i.e. the number of consecutive elements of the underlying buffer.
// TODO(ntv, zinenko): add assertions for the static cases.
//
// template <typename Elem, size_t Rank>
// struct {
// Elem *ptr;
// int64_t offset;
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
static unsigned kPtrPosInMemRefDescriptor = 0;
static unsigned kSizePosInMemRefDescriptor = 1;
static unsigned kOffsetPosInMemRefDescriptor = 1;
static unsigned kSizePosInMemRefDescriptor = 2;
static unsigned kStridePosInMemRefDescriptor = 3;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
assert((type.getAffineMaps().empty() ||
(type.getAffineMaps().size() == 1 &&
type.getAffineMaps().back().isIdentity())) &&
"Non-identity layout maps must have been normalized away");
SmallVector<int64_t, 4> strides;
bool strideSuccess = succeeded(type.getStridesAndOffset(strides));
assert(strideSuccess &&
"Non-strided layout maps must have been normalized away");
(void)strideSuccess;
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
Expand All @@ -151,9 +163,9 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
auto rank = type.getRank();
if (rank > 0) {
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
return LLVM::LLVMType::getStructTy(ptrTy, arrayTy);
return LLVM::LLVMType::getStructTy(ptrTy, indexTy, arrayTy, arrayTy);
}
return LLVM::LLVMType::getStructTy(ptrTy);
return LLVM::LLVMType::getStructTy(ptrTy, indexTy);
}

// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
Expand Down Expand Up @@ -538,7 +550,8 @@ struct ConstLLVMOpLowering
// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
static bool isSupportedMemRefType(MemRefType type) {
return llvm::all_of(type.getAffineMaps(),
return type.getAffineMaps().empty() ||
llvm::all_of(type.getAffineMaps(),
[](AffineMap map) { return map.isIdentity(); });
}

Expand All @@ -552,7 +565,21 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {

PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
if (isSupportedMemRefType(type))
return matchSuccess();

SmallVector<int64_t, 4> stridesAndOffset;
auto successStrides = type.getStridesAndOffset(stridesAndOffset);
if (failed(successStrides))
return matchFailure();

// Dynamic strides are ok if they can be deduced from dynamic sizes (which
// is guaranteed when succeeded(successStrides)).
// Dynamic offset however can never be alloc'ed.
if (stridesAndOffset.back() != MemRefType::kDynamicStrideOrOffset)
return matchFailure();

return matchSuccess();
}

void rewrite(Operation *op, ArrayRef<Value *> operands,
Expand Down Expand Up @@ -621,6 +648,16 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
ArrayRef<Value *>(allocated));

SmallVector<int64_t, 4> stridesAndOffset;
auto successStrides = type.getStridesAndOffset(stridesAndOffset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;

ArrayRef<int64_t> strides = ArrayRef<int64_t>(stridesAndOffset).drop_back();
// 0-D memref corner case: they have size 1 ...
assert((type.getRank() == 0 && strides.empty() && sizes.size() == 1) ||
(strides.size() == sizes.size()) && "unexpected number of stride");

// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
Expand All @@ -629,14 +666,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, allocated,
rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));

// Store dynamically allocated sizes in the descriptor. Static and dynamic
// sizes are all passed in as operands.
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor,
createIndexConstant(rewriter, op->getLoc(), stridesAndOffset.back()),
rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor));

if (type.getRank() == 0)
// No size/stride arrays in 0-D memref, use the descriptor value.
return rewriter.replaceOp(op, memRefDescriptor);

// Store all sizes in the descriptor.
Value *runningStride = nullptr;
// Iterate strides in reverse order, compute runningStride and strideValues.
auto nStrides = strides.size();
SmallVector<Value *, 4> strideValues(nStrides, nullptr);
for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
int64_t index = nStrides - 1 - indexedStride.index();
if (strides[index] == MemRefType::kDynamicStrideOrOffset)
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index]`
runningStride = runningStride
? rewriter.create<LLVM::MulOp>(
op->getLoc(), runningStride, sizes[index])
: createIndexConstant(rewriter, op->getLoc(), 1);
else
runningStride =
createIndexConstant(rewriter, op->getLoc(), strides[index]);
strideValues[index] = runningStride;
}
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(sizes)) {
int64_t index = indexedSize.index();
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, strideValues[index],
rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index}));
}

// Return the final value of the descriptor.
Expand Down Expand Up @@ -874,55 +940,55 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
return linearized;
}

// Given the MemRef type, a descriptor and a list of indices, extract the data
// buffer pointer from the descriptor, convert multi-dimensional subscripts
// into a linearized index (using dynamic size data from the descriptor if
// necessary) and get the pointer to the buffer element identified by the
// indices.
Value *getElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *memRefDescriptor,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
// Get the list of MemRef sizes. Static sizes are defined as constants.
// Dynamic sizes are extracted from the MemRef descriptor, where they start
// from the position 1 (the buffer is at position 0).
SmallVector<Value *, 4> sizes;
for (auto en : llvm::enumerate(shape)) {
int64_t s = en.value();
int64_t index = en.index();
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
loc, this->getIndexType(), memRefDescriptor,
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
sizes.push_back(size);
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value *getStridedElementPtr(Location loc, Type elementTypePtr,
Value *memRefDescriptor,
ArrayRef<Value *> indices,
ArrayRef<int64_t> stridesAndOffset,
ConversionPatternRewriter &rewriter) const {
auto indexTy = this->getIndexType();
Value *base = rewriter.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memRefDescriptor,
rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
Value *offset =
stridesAndOffset.back() == MemRefType::kDynamicStrideOrOffset
? rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor))
: this->createIndexConstant(rewriter, loc, stridesAndOffset.back());
auto strides = stridesAndOffset.drop_back();
for (int i = 0, e = indices.size(); i < e; ++i) {
Value *stride;
if (strides[i] != MemRefType::kDynamicStrideOrOffset) {
// Use static stride.
auto attr =
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]);
stride = rewriter.create<LLVM::ConstantOp>(loc, indexTy, attr);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
// TODO(ntv, zinenko): assert dynamic descriptor size is constant.
// Use dynamic stride.
stride = rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i}));
}
Value *additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offset = rewriter.create<LLVM::AddOp>(loc, offset, additionalOffset);
}

// The second and subsequent operands are access subscripts. Obtain the
// linearized address in the buffer.
Value *subscript = indices.empty()
? nullptr
: linearizeSubscripts(rewriter, loc, indices, sizes);

Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memRefDescriptor,
rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
SmallVector<Value *, 2> gepSubValues(1, dataPtr);
if (subscript)
gepSubValues.push_back(subscript);
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, gepSubValues,
ArrayRef<NamedAttribute>{});
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offset);
}
Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,

Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
auto shape = type.getShape();
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
SmallVector<int64_t, 4> stridesAndOffset;
auto res = type.getStridesAndOffset(stridesAndOffset);
assert(succeeded(res) && "expected strided MemRef");
(void)res;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices,
stridesAndOffset, rewriter);
}
};

Expand Down

0 comments on commit 923b33e

Please sign in to comment.