Skip to content

Commit

Permalink
[mlir][Vector] Add lowering of 1-D vector transfer_read/write to mask…
Browse files Browse the repository at this point in the history
…ed load/store

Summary:
This revision adds support to lower 1-D vector transfers to LLVM.
A mask of the vector length is created that compares the base offset + linear index to the dim of the vector.
In each position where this does not overflow (i.e. offset + vector index < dim), the mask is set to 1.

A notable fact is that the lowering uses llvm.dialect_cast to allow writing code in the simplest form by targeting the simplest mix of vector and LLVM dialects and
letting other conversions kick in.

Differential Revision: https://reviews.llvm.org/D77703
  • Loading branch information
Nicolas Vasilache committed Apr 9, 2020
1 parent 413467f commit 8345b86
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 70 deletions.
Expand Up @@ -398,6 +398,29 @@ class ConvertToLLVMPattern : public ConversionPattern {
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const;

// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value> indices,
ArrayRef<Value> allocSizes) const;

// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(Location loc, Type elementTypePtr,
Value descriptor, ArrayRef<Value> indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const;

Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const;

protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Expand Up @@ -73,6 +73,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,

/// Vector type utilities.
LLVMType getVectorElementType();
unsigned getVectorNumElements();
bool isVectorTy();

/// Function type utilities.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Expand Up @@ -111,6 +111,7 @@ class Builder {
IntegerAttr getI16IntegerAttr(int16_t value);
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
IntegerAttr getIndexAttr(int64_t value);

/// Signed and unsigned integer attribute getters.
IntegerAttr getSI32IntegerAttr(int32_t value);
Expand Down
119 changes: 55 additions & 64 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Expand Up @@ -735,6 +735,61 @@ Value ConvertToLLVMPattern::createIndexConstant(
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}

Value ConvertToLLVMPattern::linearizeSubscripts(
ConversionPatternRewriter &builder, Location loc, ArrayRef<Value> indices,
ArrayRef<Value> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");

Value linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(), ArrayRef<Value>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
}
return linearized;
}

Value ConvertToLLVMPattern::getStridedElementPtr(
Location loc, Type elementTypePtr, Value descriptor,
ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);

Value base = memRefDescriptor.alignedPtr(rewriter, loc);
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.offset(rewriter, loc)
: this->createIndexConstant(rewriter, loc, offset);

for (int i = 0, e = indices.size(); i < e; ++i) {
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.stride(rewriter, loc, i)
: this->createIndexConstant(rewriter, loc, strides[i]);
Value additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
}
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}

Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
Value memRefDesc,
ArrayRef<Value> indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
offset, rewriter);
}

/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
Expand Down Expand Up @@ -1913,70 +1968,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? success() : failure();
}

// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value> indices,
ArrayRef<Value> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");

Value linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(),
ArrayRef<Value>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
}
return linearized;
}

// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(Location loc, Type elementTypePtr,
Value descriptor, ArrayRef<Value> indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);

Value base = memRefDescriptor.alignedPtr(rewriter, loc);
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.offset(rewriter, loc)
: this->createIndexConstant(rewriter, loc, offset);

for (int i = 0, e = indices.size(); i < e; ++i) {
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.stride(rewriter, loc, i)
: this->createIndexConstant(rewriter, loc, strides[i]);
Value additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
}
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}

Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
offset, rewriter);
}
};

// Load operation is lowered to obtaining a pointer to the indexed element
Expand Down
145 changes: 139 additions & 6 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -894,6 +895,129 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
}
};

template <typename ConcreteOp>
void replaceTransferOp(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Operation *op, ArrayRef<Value> operands, Value dataPtr,
Value mask);

template <>
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
Location loc, Operation *op,
ArrayRef<Value> operands, Value dataPtr,
Value mask) {
auto xferOp = cast<TransferReadOp>(op);
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
VectorType fillType = xferOp.getVectorType();
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);

auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
op, vecTy, dataPtr, mask, ValueRange{fill},
rewriter.getI32IntegerAttr(1));
}

template <>
void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
Location loc, Operation *op,
ArrayRef<Value> operands, Value dataPtr,
Value mask) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
}

static TransferReadOpOperandAdaptor
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
return TransferReadOpOperandAdaptor(operands);
}

static TransferWriteOpOperandAdaptor
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
return TransferWriteOpOperandAdaptor(operands);
}

/// Conversion pattern that converts a 1-D vector transfer read/write op in a
/// sequence of:
/// 1. Bitcast to vector form.
/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
/// 3. Create a mask where offsetVector is compared against memref upper bound.
/// 4. Rewrite op as a masked read or write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertToLLVMPattern {
public:
explicit VectorTransferConversion(MLIRContext *context,
LLVMTypeConverter &typeConv)
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
typeConv) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto xferOp = cast<ConcreteOp>(op);
auto adaptor = getTransferOpAdapter(xferOp, operands);
if (xferOp.getMemRefType().getRank() != 1)
return failure();
if (!xferOp.permutation_map().isIdentity())
return failure();

auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };

Location loc = op->getLoc();
Type i64Type = rewriter.getIntegerType(64);
MemRefType memRefType = xferOp.getMemRefType();

// 1. Get the source/dst address as an LLVM vector pointer.
// TODO: support alignment when possible.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
auto vectorDataPtr =
rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);

// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
unsigned vecWidth = vecTy.getVectorNumElements();
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
SmallVector<int64_t, 8> indices;
indices.reserve(vecWidth);
for (unsigned i = 0; i < vecWidth; ++i)
indices.push_back(i);
Value linearIndices = rewriter.create<ConstantOp>(
loc, vectorCmpType,
DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
linearIndices = rewriter.create<LLVM::DialectCastOp>(
loc, toLLVMTy(vectorCmpType), linearIndices);

// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
Value offsetIndex = *(xferOp.indices().begin());
offsetIndex = rewriter.create<IndexCastOp>(
loc, vectorCmpType.getElementType(), offsetIndex);
Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);

// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
dim =
rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
Value mask =
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
mask);

// 5. Rewrite as a masked read / write.
replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
vectorDataPtr, mask);

return success();
}
};

class VectorPrintOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorPrintOpConversion(MLIRContext *context,
Expand Down Expand Up @@ -1079,16 +1203,25 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
// clang-format off
patterns.insert<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
VectorShuffleOpConversion, VectorExtractElementOpConversion,
VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorTypeCastOpConversion, VectorPrintOpConversion>(
ctx, converter);
patterns
.insert<VectorBroadcastOpConversion,
VectorReductionOpConversion,
VectorShuffleOpConversion,
VectorExtractElementOpConversion,
VectorExtractOpConversion,
VectorFMAOp1DConversion,
VectorInsertElementOpConversion,
VectorInsertOpConversion,
VectorPrintOpConversion,
VectorTransferConversion<TransferReadOp>,
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion>(ctx, converter);
// clang-format on
}

void mlir::populateVectorToLLVMMatrixConversionPatterns(
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Expand Up @@ -1774,6 +1774,9 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
LLVMType LLVMType::getVectorElementType() {
return get(getContext(), getUnderlyingType()->getVectorElementType());
}
unsigned LLVMType::getVectorNumElements() {
return getUnderlyingType()->getVectorNumElements();
}
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }

/// Function type utilities.
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/IR/Builders.cpp
Expand Up @@ -93,6 +93,10 @@ DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
return DictionaryAttr::get(value, context);
}

IntegerAttr Builder::getIndexAttr(int64_t value) {
return IntegerAttr::get(getIndexType(), APInt(64, value));
}

IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}
Expand Down

0 comments on commit 8345b86

Please sign in to comment.