Skip to content

Commit

Permalink
[mlir][vector] Allow unroll of contraction in arbitrary order
Browse files Browse the repository at this point in the history
Adds supprot for vector unroll transformations to unroll in different
orders. For example, the `vector.contract` can be unrolled into a
smaller set of contractions.  There is a choice of how to unroll the
decomposition  based on the traversal order of (dim0, dim1, dim2).
The choice of traversal order can now be specified by a callback which
given by the caller of the transform. For now, only the
`vector.contract`, `vector.transfer_read/transfer_write` operations
support the callback.

Differential Revision: https://reviews.llvm.org/D127004
  • Loading branch information
christopherbate committed Jun 6, 2022
1 parent b79b2b6 commit 1469ebf
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ struct UnrollVectorOptions {
};
return *this;
}

/// Function that returns the traversal order (in terms of "for loop order",
/// i.e. slowest varying dimension to fastest varying dimension) that shoudl
/// be used when unrolling the given operation into units of the native vector
/// size.
using UnrollTraversalOrderFnType =
std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
UnrollVectorOptions &
setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
traversalOrderCallback = std::move(traversalOrderFn);
return *this;
}
};

//===----------------------------------------------------------------------===//
Expand Down
131 changes: 107 additions & 24 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include <numeric>

#define DEBUG_TYPE "vector-unrolling"

Expand All @@ -36,20 +39,78 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
return elementOffsets;
}

/// A functor that accomplishes the same thing as `getVectorOffset` but allows
/// for reordering the traversal of the dimensions. The order of traversal is
/// given in "for loop order" (outer to inner).
namespace {
class DecomposeShapeIterator {
private:
SmallVector<int64_t, 4> vectorShape;
SmallVector<int64_t> loopOrder;
SmallVector<int64_t> sliceStrides;
int64_t maxIndexVal{1};

public:
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> targetShape,
ArrayRef<int64_t> loopOrder)
: vectorShape(targetShape.begin(), targetShape.end()),
loopOrder(loopOrder.begin(), loopOrder.end()),
sliceStrides(originalShape.size()) {
// Compute the count for each dimension.
SmallVector<int64_t> sliceDimCounts(originalShape.size());
for (unsigned r = 0; r < originalShape.size(); ++r) {
sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
maxIndexVal *= sliceDimCounts[r];
}

// Reversing "loop order" gives dimensions from fastest varying to slowest
// varying (smallest stride to largest stride).
int64_t accum = 1;
for (auto idx : llvm::reverse(loopOrder)) {
sliceStrides[idx] = accum;
accum *= sliceDimCounts[idx];
}
}

// Turn the linear index into a d-tuple based on units of vectors of size
// `vectorShape`. The linear index is assumed to represent traversal of the
// dimensions based on `order`.
SmallVector<int64_t> delinearize(int64_t index) const {
// Traverse in for loop order (largest stride to smallest stride).
SmallVector<int64_t, 4> vectorOffsets(sliceStrides.size());
for (auto idx : loopOrder) {
vectorOffsets[idx] = index / sliceStrides[idx];
index %= sliceStrides[idx];
}
return vectorOffsets;
}

int64_t maxIndex() const { return maxIndexVal; }

/// Return the offset within d-tuple based on the ordering given by
/// `loopOrder`.
SmallVector<int64_t> getVectorOffset(int64_t index) const {
SmallVector<int64_t> vectorOffsets = delinearize(index);
SmallVector<int64_t> elementOffsets =
computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
return elementOffsets;
}
};
} // namespace

/// Compute the indices of the slice `index` for a tranfer op.
static SmallVector<Value>
sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
AffineMap permutationMap, Location loc,
OpBuilder &builder) {
static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
ArrayRef<Value> indices,
AffineMap permutationMap,
Location loc,
OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
auto isBroadcast = [](AffineExpr expr) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
return constExpr.getValue() == 0;
return false;
};
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalShape, targetShape, index);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value> slicedIndices(indices.begin(), indices.end());
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
Expand Down Expand Up @@ -99,6 +160,20 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
return targetShape;
}

static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops, Operation *op,
const vector::UnrollVectorOptions &options) {
SmallVector<int64_t> loopOrder =
llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
if (options.traversalOrderCallback != nullptr) {
Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
if (order.hasValue()) {
loopOrder = std::move(*order);
}
}
return loopOrder;
}

namespace {

struct UnrollTransferReadPattern
Expand All @@ -122,26 +197,30 @@ struct UnrollTransferReadPattern
Location loc = readOp.getLoc();
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
// Compute shape ratio of 'shape' and 'sizes'.
int64_t sliceCount = computeMaxLinearIndex(ratio);

// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
readOp.getIndices().end());
for (int64_t i = 0; i < sliceCount; i++) {

SmallVector<int64_t> loopOrder =
getUnrollOrder(ratio.size(), readOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
SmallVector<int64_t, 4> elementOffsets =
indexToOffsets.getVectorOffset(i);
SmallVector<Value, 4> indices =
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
sliceTransferIndices(elementOffsets, originalIndices,
readOp.getPermutationMap(), loc, rewriter);
auto slicedRead = rewriter.create<vector::TransferReadOp>(
loc, targetType, readOp.getSource(), indices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());

SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
Expand Down Expand Up @@ -174,20 +253,21 @@ struct UnrollTransferWritePattern
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
// Compute shape ratio of 'shape' and 'sizes'.
int64_t sliceCount = computeMaxLinearIndex(ratio);
SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());

SmallVector<int64_t> loopOrder =
getUnrollOrder(originalIndices.size(), writeOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
Value resultTensor;
for (int64_t i = 0; i < sliceCount; i++) {
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
indexToOffsets.getVectorOffset(i);
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);

SmallVector<Value, 4> indices =
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
sliceTransferIndices(elementOffsets, originalIndices,
writeOp.getPermutationMap(), loc, rewriter);
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
Expand Down Expand Up @@ -238,18 +318,21 @@ struct UnrollContractionPattern
SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);

// Compute shape ratio of 'shape' and 'sizes'.
int64_t sliceCount = computeMaxLinearIndex(ratio);
Location loc = contractOp.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
llvm::MapVector<
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
accCache;

SmallVector<int64_t> loopOrder = getUnrollOrder(
contractOp.getIndexingMaps().size(), contractOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
const int64_t sliceCount = indexToOffsets.maxIndex();
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> offsets =
getVectorOffset(originalSize, *targetShape, i);
SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());

// Helper to coompute the new shape of each operand and extract the slice.
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER

// CHECK-LABEL: func @transfer_read_unroll
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
Expand All @@ -13,6 +14,19 @@
// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32>

// ORDER-LABEL: func @transfer_read_unroll
// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// ORDER-NEXT: return %[[VEC3]] : vector<4x4xf32>

func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
Expand All @@ -33,6 +47,19 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
// CHECK-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: return

// ORDER-LABEL: func @transfer_write_unroll
// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
// ORDER: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// ORDER-NEXT: vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// ORDER-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// ORDER-NEXT: vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// ORDER-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// ORDER-NEXT: vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// ORDER-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// ORDER-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
// ORDER-NEXT: return

func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
Expand Down Expand Up @@ -222,6 +249,25 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) ->
// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32>

// ORDER-LABEL: func @transfer_read_unroll_different_rank
// ORDER-DAG: %[[C4:.*]] = arith.constant 4 : index
// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index
// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index
// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// ORDER-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// ORDER-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
// ORDER-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
// ORDER-NEXT: return %[[VEC5]] : vector<6x4xf32>

#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
func.func @transfer_read_unroll_different_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
%c0 = arith.constant 0 : index
Expand Down
Loading

0 comments on commit 1469ebf

Please sign in to comment.