diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index e46e1bd08dee9..0c7a832414074 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1061,4 +1061,47 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// AffineDelinearizeIndexOp +//===----------------------------------------------------------------------===// + +def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", + [NoSideEffect]> { + let summary = "delinearize an index"; + let description = [{ + The `affine.delinearize_index` operation takes a single index value and + calculates the multi-index according to the given basis. + + Example: + + ``` + %indices:3 = affine.delinearize_index %linear_index into (%c16, %c224, %c224) : index, index, index + ``` + + In the above example, `%indices:3` conceptually holds the following: + + ``` + #map0 = affine_map<()[s0] -> (s0 floordiv 50176)> + #map1 = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)> + #map2 = affine_map<()[s0] -> (s0 mod 224)> + %indices_0 = affine.apply #map0()[%linear_index] + %indices_1 = affine.apply #map1()[%linear_index] + %indices_2 = affine.apply #map2()[%linear_index] + ``` + }]; + + let arguments = (ins Index:$linear_index, Variadic:$basis); + let results = (outs Variadic:$multi_index); + + let assemblyFormat = [{ + $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index) + }]; + + let builders = [ + OpBuilder<(ins "Value":$linear_index, "ArrayRef":$basis)> + ]; + + let hasVerifier = 1; +} + #endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 2e18a6fb7f3a1..bab315ecffde4 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -110,6 +110,14 @@ createSuperVectorizePass(ArrayRef virtualVectorSize); /// Overload relying on pass options for initialization. std::unique_ptr> createSuperVectorizePass(); +/// Populate patterns that expand affine index operations into more fundamental +/// operations (not necessarily restricted to Affine dialect). +void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns); + +/// Creates a pass to expand affine index operations into more fundamental +/// operations (not necessarily restricted to Affine dialect). +std::unique_ptr createAffineExpandIndexOpsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index d50c22569d56e..1f31bccf5f861 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -397,4 +397,9 @@ def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp" let constructor = "mlir::createSimplifyAffineStructuresPass()"; } +def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> { + let summary = "Lower affine operations operating on indices into more fundamental operations"; + let constructor = "mlir::createAffineExpandIndexOpsPass()"; +} + #endif // MLIR_DIALECT_AFFINE_PASSES diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index 345f955e2061c..006c61ced2125 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -304,6 +304,21 @@ Optional> expandAffineMap(OpBuilder &builder, AffineMap affineMap, ValueRange operands); +/// Holds the result of (div a, b) and (mod a, b). +struct DivModValue { + Value quotient; + Value remainder; +}; + +/// Create IR to calculate (div lhs, rhs) and (mod lhs, rhs). +DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs); + +/// Generate the IR to delinearize `linearIndex` given the `basis` and return +/// the multi-index. +FailureOr> delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef basis); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 766f41dc0771d..ae1a4a320a14d 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4036,6 +4036,34 @@ LogicalResult AffineVectorStoreOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// DelinearizeIndexOp +//===----------------------------------------------------------------------===// + +void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, + Value linear_index, + ArrayRef basis) { + result.addTypes(SmallVector(basis.size(), builder.getIndexType())); + result.addOperands(linear_index); + SmallVector basisValues = + llvm::to_vector(llvm::map_range(basis, [&](OpFoldResult ofr) -> Value { + Optional staticDim = getConstantIntValue(ofr); + if (staticDim.has_value()) + return builder.create(result.location, + *staticDim); + return ofr.dyn_cast(); + })); + result.addOperands(basisValues); +} + +LogicalResult AffineDelinearizeIndexOp::verify() { + if (getBasis().empty()) + return emitOpError("basis should not be empty"); + if (getNumResults() != getBasis().size()) + return emitOpError("should return an index for each basis element"); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt index 5616e80d79fb0..e98c935b3f36e 100644 --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRAffineDialect LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRDialectUtils MLIRIR MLIRLoopLikeInterface MLIRMemRefDialect diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp new file mode 100644 index 0000000000000..c162aa2f2d058 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -0,0 +1,63 @@ +//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to expand affine index ops into one or more more +// fundamental operations. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +/// Lowers `affine.delinearize_index` into a sequence of division and remainder +/// operations. +struct LowerDelinearizeIndexOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, + PatternRewriter &rewriter) const override { + FailureOr> multiIndex = + delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), + llvm::to_vector(op.getBasis())); + if (failed(multiIndex)) + return failure(); + rewriter.replaceOp(op, *multiIndex); + return success(); + } +}; + +class ExpandAffineIndexOpsPass + : public AffineExpandIndexOpsBase { +public: + ExpandAffineIndexOpsPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + populateAffineExpandIndexOpsPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +std::unique_ptr mlir::createAffineExpandIndexOpsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index 1a2b2dbb17b80..4601a11bf2894 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRAffineTransforms AffineDataCopyGeneration.cpp + AffineExpandIndexOps.cpp AffineLoopInvariantCodeMotion.cpp AffineLoopNormalize.cpp AffineParallelize.cpp diff --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt index 3be71bd357982..fb26df43b688e 100644 --- a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRAffineUtils MLIRAffineDialect MLIRAffineAnalysis MLIRAnalysis + MLIRArithmeticUtils MLIRMemRefDialect MLIRTransformUtils ) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index ae949b62a5279..66a0e3640aba6 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExprVisitor.h" @@ -1821,3 +1822,52 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, return newMemRefType; } + +DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) { + DivModValue result; + AffineExpr d0, d1; + bindDims(b.getContext(), d0, d1); + result.quotient = + makeComposedAffineApply(b, loc, d0.floorDiv(d1), {lhs, rhs}); + result.remainder = makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs}); + return result; +} + +/// Create IR that computes the product of all elements in the set. +static FailureOr getIndexProduct(OpBuilder &b, Location loc, + ArrayRef set) { + if (set.empty()) + return failure(); + OpFoldResult result = set[0]; + AffineExpr s0, s1; + bindSymbols(b.getContext(), s0, s1); + for (unsigned i = 1, e = set.size(); i < e; i++) + result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]}); + return result; +} + +FailureOr> mlir::delinearizeIndex(OpBuilder &b, Location loc, + Value linearIndex, + ArrayRef dimSizes) { + unsigned numDims = dimSizes.size(); + + SmallVector divisors; + for (unsigned i = 1; i < numDims; i++) { + ArrayRef slice = dimSizes.drop_front(i); + FailureOr prod = getIndexProduct(b, loc, slice); + if (failed(prod)) + return failure(); + divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); + } + + SmallVector results; + results.reserve(divisors.size() + 1); + Value residual = linearIndex; + for (Value divisor : divisors) { + DivModValue divMod = getDivMod(b, loc, residual, divisor); + results.push_back(divMod.quotient); + residual = divMod.remainder; + } + results.push_back(residual); + return results; +} diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir new file mode 100644 index 0000000000000..70b7f397ad4fe --- /dev/null +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)> + +// CHECK-LABEL: @static_basis +// CHECK-SAME: (%[[IDX:.+]]: index) +// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]] +// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]] +// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]] +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @static_basis(%linear_index: index) -> (index, index, index) { + %b0 = arith.constant 16 : index + %b1 = arith.constant 224 : index + %b2 = arith.constant 224 : index + %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)> + +// CHECK-LABEL: @dynamic_basis +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] : +// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] : +// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @dynamic_basis(%linear_index: index, %src: memref) -> (index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %b0 = memref.dim %src, %c0 : memref + %b1 = memref.dim %src, %c1 : memref + %b2 = memref.dim %src, %c2 : memref + %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir index 03f4f124ddefa..866aa4062f3a6 100644 --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -485,3 +485,19 @@ func.func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) { } return } + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { + // expected-error@+1 {{'affine.delinearize_index' op should return an index for each basis element}} + %1 = affine.delinearize_index %idx into (%basis0, %basis1) : index + return +} + +// ----- + +func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { + // expected-error@+1 {{'affine.delinearize_index' op basis should not be empty}} + affine.delinearize_index %idx into () : index + return +} diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir index ad6f3651c1b21..df10163d59822 100644 --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -260,3 +260,12 @@ func.func @affine_for_multiple_yield(%buffer: memref<1024xf32>) -> (f32, f32) { // CHECK-NEXT: %[[res2:.*]] = arith.addf %{{.*}}, %[[iter_arg2]] : f32 // CHECK-NEXT: affine.yield %[[res1]], %[[res2]] : f32, f32 // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @delinearize +func.func @delinearize(%linear_idx: index, %basis0: index, %basis1 :index) -> (index, index) { + // CHECK: affine.delinearize_index %{{.+}} into (%{{.+}}, %{{.+}}) : index, index + %1:2 = affine.delinearize_index %linear_idx into (%basis0, %basis1) : index, index + return %1#0, %1#1 : index, index +}