Skip to content

Commit

Permalink
[mlir][VectorOps] Implement strided_slice conversion
Browse files Browse the repository at this point in the history
Summary:
This diff implements the progressive lowering of strided_slice to either:
  1. extractelement + insertelement for the 1-D case
  2. extract + optional strided_slice + insert for the n-D case.

This combines properly with the other conversion patterns to lower all the way to LLVM.

Appropriate tests are added.

Reviewers: ftynse, rriddle, AlexEichenberger, andydavis1, tetuante

Reviewed By: andydavis1

Subscribers: merge_guards_bot, mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72310
  • Loading branch information
Nicolas Vasilache committed Jan 9, 2020
1 parent 24b326c commit 65678d9
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 3 deletions.
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/Attributes.h
Expand Up @@ -215,6 +215,25 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Array;
}

private:
/// Class for underlying value iterator support.
template <typename AttrTy>
class attr_value_iterator final
: public llvm::mapped_iterator<iterator, AttrTy (*)(Attribute)> {
public:
explicit attr_value_iterator(iterator it)
: llvm::mapped_iterator<iterator, AttrTy (*)(Attribute)>(
it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {}
AttrTy operator*() { return (*this->I).template cast<AttrTy>(); }
};

public:
template <typename AttrTy>
llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
attr_value_iterator<AttrTy>(end()));
}
};

//===----------------------------------------------------------------------===//
Expand Down
101 changes: 98 additions & 3 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Expand Up @@ -6,10 +6,11 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand All @@ -31,6 +32,7 @@
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;
using namespace mlir::vector;

template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
Expand Down Expand Up @@ -723,15 +725,108 @@ class VectorPrintOpConversion : public LLVMOpLowering {
}
};

// TODO(rriddle): Better support for attribute subtype forwarding + slicing.
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
unsigned dropBack = 0) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t, 4> res;
res.reserve(arrayAttr.size() - dropFront - dropBack);
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
it != eit; ++it)
res.push_back((*it).getValue().getSExtValue());
return res;
}

/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank
/// of `vector`.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = vector.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<ConstantIndexOp>(loc, offset));
}

/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank
/// of `vector`.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = into.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<ConstantIndexOp>(loc, offset));
}

/// Progressive lowering of StridedSliceOp to either:
/// 1. extractelement + insertelement for the 1-D case
/// 2. extract + optional strided_slice + insert for the n-D case.
class VectorStridedSliceOpRewritePattern
: public OpRewritePattern<StridedSliceOp> {
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;

PatternMatchResult matchAndRewrite(StridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();

assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");

int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();

auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isIntOrIndexOrFloat());
Value zero = rewriter.create<ConstantOp>(loc, elemType,
rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value extracted = extractOne(rewriter, loc, op.vector(), off);
if (op.offsets().getValue().size() > 1) {
StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.sizes(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
// Call matchAndRewrite recursively from within the pattern. This
// circumvents the current limitation that a given pattern cannot
// be called multiple times by the PatternRewrite infrastructure (to
// avoid infinite recursion, but in this case, infinite recursion
// cannot happen because the rank is strictly decreasing).
// TODO(rriddle, nicolasvasilache) Implement something like a hook for
// a potential function that must decrease and allow the same pattern
// multiple times.
auto success = matchAndRewrite(stridedSliceOp, rewriter);
(void)success;
assert(success && "Unexpected failure");
extracted = stridedSliceOp;
}
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, {res});
return matchSuccess();
}
};

/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.insert<VectorStridedSliceOpRewritePattern>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
VectorPrintOpConversion>(converter.getDialect()->getContext(),
converter);
VectorPrintOpConversion>(ctx, converter);
}

namespace {
Expand Down
61 changes: 61 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Expand Up @@ -423,3 +423,64 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_close() : () -> ()
// CHECK: llvm.call @print_newline() : () -> ()


func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) {
// CHECK-LABEL: llvm.func @strided_slice(

%0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">

%1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]">
// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]">

%2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]">
//
// Subvector vector<8xf32> @2
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <2 x float>]">
//
// Subvector vector<8xf32> @3
// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]">

return
}


0 comments on commit 65678d9

Please sign in to comment.