diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index 1efa7ef84ff59..2c0dad6382009 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -49,6 +49,10 @@ ForOp getForInductionVarOwner(Value val); /// value is not an induction variable, then return nullptr. ParallelOp getParallelForInductionVarOwner(Value val); +/// Returns the ForeachThreadOp parent of an thread index variable. +/// If the provided value is not a thread index variable, then return nullptr. +ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val); + /// Return true if ops a and b (or their ancestors) are in mutually exclusive /// regions/blocks of an IfOp. // TODO: Consider moving this functionality to RegionBranchOpInterface. diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h index 7e775c5e90621..462d6b5c42412 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h @@ -20,6 +20,7 @@ namespace mlir { class AffineMap; struct LogicalResult; class Operation; +class OpFoldResult; class RewriterBase; class Value; class ValueRange; @@ -32,8 +33,8 @@ class IfOp; /// step size via the last parameter. The function should return `success` in /// that case. If the first parameter is not an iteration variable, return /// `failure`. -using LoopMatcherFn = - function_ref; +using LoopMatcherFn = function_ref; /// Try to canonicalize an min/max operations in the context of for `loops` with /// a known range. diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 012499f7dad38..878ddc60cee70 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1194,6 +1194,15 @@ PerformConcurrentlyOp ForeachThreadOp::getTerminator() { return cast(getBody()->getTerminator()); } +ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { + auto tidxArg = val.dyn_cast(); + if (!tidxArg) + return ForeachThreadOp(); + assert(tidxArg.getOwner() && "unlinked block argument"); + auto *containingOp = tidxArg.getOwner()->getParentOp(); + return dyn_cast(containingOp); +} + //===----------------------------------------------------------------------===// // ParallelInsertSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 0f511af14811d..eda6bc6e1cf8b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -138,7 +138,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern { unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(forOp, resultNumber)) return failure(); - rewriter.updateRootInPlace(dimOp, [&](){ + rewriter.updateRootInPlace(dimOp, [&]() { dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]); }); return success(); @@ -153,7 +153,8 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { + auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub, + OpFoldResult &step) { if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { lb = forOp.getLowerBound(); ub = forOp.getUpperBound(); @@ -171,6 +172,18 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { } return failure(); } + if (scf::ForeachThreadOp foreachThreadOp = + scf::getForeachThreadOpThreadIndexOwner(iv)) { + for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) { + if (foreachThreadOp.getThreadIndices()[idx] == iv) { + lb = OpBuilder(iv.getContext()).getIndexAttr(0); + ub = foreachThreadOp.getNumThreads()[idx]; + step = OpBuilder(iv.getContext()).getIndexAttr(1); + return success(); + } + } + return failure(); + } return failure(); }; diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp index 6c28cc3d83d87..958b5a2757148 100644 --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -201,7 +201,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map, static LogicalResult addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv, - Value lb, Value ub, Value step, + OpFoldResult lb, OpFoldResult ub, OpFoldResult step, RewriterBase &rewriter) { // IntegerPolyhedron does not support semi-affine expressions. // Therefore, only constant step values are supported. @@ -210,8 +210,12 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv, return failure(); unsigned dimIv = constraints.appendDimId(iv); - unsigned dimLb = constraints.appendDimId(lb); - unsigned dimUb = constraints.appendDimId(ub); + auto lbv = lb.dyn_cast(); + unsigned dimLb = + lbv ? constraints.appendDimId(lbv) : constraints.appendDimId(/*num=*/1); + auto ubv = ub.dyn_cast(); + unsigned dimUb = + ubv ? constraints.appendDimId(ubv) : constraints.appendDimId(/*num=*/1); // If loop lower/upper bounds are constant: Add EQ constraint. Optional lbInt = getConstantIntValue(lb); @@ -276,7 +280,7 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, // If `operand` is an iteration variable: Find corresponding loop // bounds and step. Value iv = operand; - Value lb, ub, step; + OpFoldResult lb, ub, step; if (failed(loopMatcher(operand, lb, ub, step))) continue; allIvs.insert(iv); diff --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir new file mode 100644 index 0000000000000..b65d0c7049ab6 --- /dev/null +++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s + +func.func @reduce() -> tensor<128xf32> { + %c2 = arith.constant 2 : index + %cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32> + %cst_0 = arith.constant -0.000000e+00 : f32 + %0 = linalg.init_tensor [128, 384] : tensor<128x384xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32> + %2 = linalg.init_tensor [128] : tensor<128xf32> + %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32> + %4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { + %7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0) + %8 = affine.max affine_map<(d0) -> (0, d0)>(%7) + %9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0) + %10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0) + + // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32> + // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32> + %11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor + %12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor + + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) { + %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor) outs(%12 : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %14 = arith.addf %arg1, %arg2 : f32 + linalg.yield %14 : f32 + } -> tensor + + // TODO: canonicalize this cast away. + // CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor + // CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor into tensor<128xf32> + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor into tensor<128xf32> + } + } + return %4 : tensor<128xf32> +}