Skip to content

Commit

Permalink
[mlir][linalg] Improve implementation of hoist padding.
Browse files Browse the repository at this point in the history
Instead of relying on adhoc bounds calculations, use a projection-based
implementation. This simplifies the implementation and finds more static
constant sizes than previously/

Differential Revision: https://reviews.llvm.org/D106054
  • Loading branch information
nicolasvasilache committed Jul 15, 2021
1 parent 5024fe9 commit 01bdb0f
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 125 deletions.
67 changes: 67 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Analysis/ConstraintsSet.h
@@ -0,0 +1,67 @@
//===- ConstraintsSet.h - Extensions for FlatAffineConstraints --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Linalg-specific constraints set extensions.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_

#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/AffineMap.h"

namespace mlir {
class ValueRange;

/// Linalg-specific constraints set extensions.
class ConstraintsSet : public FlatAffineConstraints {
public:
ConstraintsSet() : FlatAffineConstraints() {}

/// Assuming `val` is defined by `val = affine.min map (operands)`, introduce
/// all the constraints `val >= expr_i(operands)`, where expr_i are all the
/// results of `map`.
// This API avoids taking a dependence on the AffineMinOp definition.
LogicalResult composeMin(Value val, AffineMap map, ValueRange operands) {
return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/true);
}

/// Assuming `val` is defined by `val = affine.max map (operands)`, introduce
/// all the constraints `val <= expr_i(operands)`, where expr_i are all the
/// results of `map`.
// This API avoids taking a dependence on the AffineMaxOp definition.
LogicalResult composeMax(Value val, AffineMap map, ValueRange operands) {
return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/false);
}

/// Assuming `val` is defined by `val = affine.apply map (operands)`, call
/// composeMap.
// This API avoids taking a dependence on the AffineMApplyOp definition.
LogicalResult composeAffineApply(Value val, AffineMap map,
ValueRange operands);

/// Asserts the identifier `id` is in the constraints set and returns it.
unsigned lookupPos(Value id) const;

/// If v is not in the constraint set, insert it as a dim or symbol depending
/// on `asDim`.
/// Return success if v is of dim id type when `asDim` is true and of symbol
/// id type when `asDim` is false.
/// Return failure otherwise.
LogicalResult ensureIdOfType(Value v, bool asDim);

private:
/// Implementation detail for composeMin/Max.
LogicalResult composeMinOrMaxMapAndOperands(Value val, AffineMap map,
ValueRange operands, bool min);
};

} // namespace mlir

#endif // MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt
@@ -1,12 +1,15 @@
add_mlir_dialect_library(MLIRLinalgAnalysis
ConstraintsSet.cpp
DependenceAnalysis.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg

LINK_LIBS PUBLIC
MLIRAnalysis
MLIRIR
MLIRLinalg
MLIRLoopAnalysis
MLIRMemRef
MLIRStandard
)
87 changes: 87 additions & 0 deletions mlir/lib/Dialect/Linalg/Analysis/ConstraintsSet.cpp
@@ -0,0 +1,87 @@
//===- ConstraintsSet.cpp - Extensions for FlatAffineConstraints ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Linalg-specific constraints set extensions.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/IR/AffineMap.h"

using namespace mlir;

unsigned ConstraintsSet::lookupPos(Value id) const {
unsigned pos;
if (!findId(id, &pos)) {
llvm::errs() << "Lookup failed: " << id << "\n";
llvm_unreachable("Lookup failed");
}
return pos;
}

LogicalResult ConstraintsSet::ensureIdOfType(Value v, bool asDim) {
if (!containsId(v)) {
if (asDim)
addDimId(getNumDimIds(), v);
else
addSymbolId(getNumSymbolIds(), v);
return success();
}
unsigned pos = lookupPos(v);
return success((asDim && pos < getNumDimIds()) ||
(!asDim && getNumDimIds() <= pos &&
pos < getNumDimIds() + getNumSymbolIds()));
}

LogicalResult ConstraintsSet::composeAffineApply(Value val, AffineMap map,
ValueRange operands) {
AffineValueMap avm(map, operands, val);
return composeMap(&avm);
}

LogicalResult ConstraintsSet::composeMinOrMaxMapAndOperands(Value val,
AffineMap map,
ValueRange operands,
bool min) {
ConstraintsSet localCst;
std::vector<SmallVector<int64_t, 8>> flatExprs;
if (failed(getFlattenedAffineExprs(map, &flatExprs, &localCst)))
return failure();
assert(flatExprs.size() == map.getNumResults() &&
"incorrect number of flattened expressiosn");

// Local vars on a per-need basis.
if (localCst.getNumLocalIds() != 0)
return failure();

// Add one inequality for each result connecting `val` to the other ids in
// `operands`. For instance, uf the expression is:
// `16 * i0 + i1` and
// `min` is true
// add:
// -d_val + 16 * i0 + i1 >= 0.
for (const auto &flatExpr : flatExprs) {
assert(flatExpr.size() >= operands.size() + 1);
SmallVector<int64_t, 8> ineq(getNumCols(), 0);
for (unsigned i = 0, e = operands.size(); i < e; i++)
ineq[lookupPos(operands[i])] = min ? flatExpr[i] : -flatExpr[i];

// Set the coefficient for `d_val`.
ineq[lookupPos(val)] = min ? -1 : 1;

// Set the constant term (upper bound in flatExpr is exclusive).
ineq[getNumCols() - 1] = min ? flatExpr[flatExpr.size() - 1] - 1
: -flatExpr[flatExpr.size() - 1];

// Add the inequality connecting the result of the map to the rest.
addInequality(ineq);
}

return success();
}

0 comments on commit 01bdb0f

Please sign in to comment.