Skip to content

Commit

Permalink
[mlir][linalg] Use affine apply in im2col gather index calculations
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D146816
  • Loading branch information
qedawkins committed Mar 24, 2023
1 parent 4e209a9 commit f5150ee
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 81 deletions.
56 changes: 19 additions & 37 deletions mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Expand Down Expand Up @@ -41,47 +42,28 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
return builder.create<arith::MulFOp>(loc, x, y);
}

// Unrolls the given composite `index` into a set of subindices with maximum
// iteration ranges specified by `factors` according to the following
// assumptions:
// 1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the
// product of the given list of factors
// 2. The iterators corresponding to the entries in `factors` are ordered from
// slowest to fastest varying
// Each subindex is then computed as:
// subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) )
static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
Value index,
ArrayRef<int64_t> factors) {
// Delinearizes the given composite `index` by the basis specified in `factors`.
static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
ArrayRef<int64_t> factors) {
assert(factors.size() >= 1 && "empty factor list");
SmallVector<Value, 3> indices(factors.size());
int64_t runningProd = 1;
for (int i = factors.size() - 1, end = 0; i >= end; i--) {
Value unrolledIndex = index;
if (i > 0) {
Value modBase = b.create<arith::ConstantOp>(
loc, b.getIndexAttr(runningProd * factors[i]));
unrolledIndex = b.create<arith::RemUIOp>(loc, unrolledIndex, modBase);
}
if (runningProd > 1) {
Value divDenom =
b.create<arith::ConstantOp>(loc, b.getIndexAttr(runningProd));
unrolledIndex = b.create<arith::DivUIOp>(loc, unrolledIndex, divDenom);
}
runningProd *= factors[i];
indices[i] = unrolledIndex;
}
return indices;
SmallVector<Value> basis;
for (int64_t f : factors)
basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
FailureOr<SmallVector<Value>> multiIndex =
delinearizeIndex(b, loc, index, basis);
assert(!failed(multiIndex) && "Failed to linearize img2col index");
return *multiIndex;
}

// Given indices corresponding to iterators in the output (oIndex) and filter
// (fIndex) for a convolution, compute the convolved index for the
// input as `oIndex * stride + fIndex`.
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
Value fIndex, int64_t stride) {
Value strideVal = b.create<arith::ConstantOp>(loc, b.getIndexAttr(stride));
Value convIndex = b.create<arith::MulIOp>(loc, oIndex, strideVal);
return b.create<arith::AddIOp>(loc, convIndex, fIndex);
AffineExpr oExpr, fExpr;
bindSymbols(b.getContext(), oExpr, fExpr);
AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
return makeComposedAffineApply(b, loc, convMap, ValueRange{oIndex, fIndex});
}

FailureOr<std::pair<Operation *, Operation *>>
Expand Down Expand Up @@ -159,12 +141,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);

// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value, 3> mIndices = unrollIndex(
SmallVector<Value> mIndices = unrollIndex(
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = mIndices[0];
auto owIndex = mIndices[1];

SmallVector<Value, 3> kIndices = unrollIndex(
SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
auto fhIndex = kIndices[0];
auto fwIndex = kIndices[1];
Expand Down Expand Up @@ -443,13 +425,13 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);

// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value, 3> kIndices = unrollIndex(
SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
auto icIndex = kIndices[0];
auto fhIndex = kIndices[1];
auto fwIndex = kIndices[2];

SmallVector<Value, 3> nIndices = unrollIndex(
SmallVector<Value> nIndices = unrollIndex(
nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = nIndices[0];
auto owIndex = nIndices[1];
Expand Down
61 changes: 17 additions & 44 deletions mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Expand Up @@ -37,29 +37,12 @@ transform.sequence failures(propagate) {
// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index

// Unrolled output shape indices.
// CHECK: %[[C14:.+]] = arith.constant 14 : index
// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index
// CHECK: %[[C14_1:.+]] = arith.constant 14 : index
// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index
// Compute input channel/convolved indices.
// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])

// Unrolled filter shape indices.
// CHECK: %[[C4:.+]] = arith.constant 4 : index
// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index
// CHECK: %[[C12:.+]] = arith.constant 12 : index
// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index
// CHECK: %[[C4_2:.+]] = arith.constant 4 : index
// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index
// CHECK: %[[C12_3:.+]] = arith.constant 12 : index
// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index

// Compute input indices.
// CHECK: %[[SH:.+]] = arith.constant 1 : index
// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
// CHECK: %[[SW:.+]] = arith.constant 1 : index
// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
// Extract from the input tensor.
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
Expand Down Expand Up @@ -234,6 +217,13 @@ transform.sequence failures(propagate) {
// -----

// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

// Im2col maps
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 9)>
// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1) -> (d0 floordiv 14 + (d1 mod 9) floordiv 3)>
// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 - (d0 floordiv 14) * 14 - (d1 floordiv 3) * 3)>


// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
Expand All @@ -252,29 +242,12 @@ transform.sequence failures(propagate) {
// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index
// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index

// Unrolled filter shape indices.
// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index
// CHECK: %[[C9:.+]] = arith.constant 9 : index
// CHECK: %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index
// CHECK: %[[C3_1:.+]] = arith.constant 3 : index
// CHECK: %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index
// CHECK: %[[C9_2:.+]] = arith.constant 9 : index
// CHECK: %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index

// Unrolled output shape indices.
// CHECK: %[[C14:.+]] = arith.constant 14 : index
// CHECK: %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index
// CHECK: %[[C14_3:.+]] = arith.constant 14 : index
// CHECK: %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index
// Compute input channel/convolved indices.
// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]](%[[KINDEX]])
// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]](%[[NINDEX]], %[[KINDEX]])
// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]](%[[NINDEX]], %[[KINDEX]])

// Compute input indices.
// CHECK: %[[SH:.+]] = arith.constant 1 : index
// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
// CHECK: %[[SW:.+]] = arith.constant 1 : index
// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
// Extract from the input tensor.
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
Expand Down

0 comments on commit f5150ee

Please sign in to comment.