Skip to content

Commit

Permalink
[StableHLO][NFC] Move pointwise to linalg conversion to source file
Browse files Browse the repository at this point in the history
This moves the pointwise StableHLO op to `linalg.generic` conversion out
of the common header and to the matching source file.

In addition, unify the common precondition checks code used by both
pointwise patterns and pull it out to a non-template function. This is
to reduce the amount of duplicated code. This reduces the compilation
time of this source file from 28s to 22s on my machine.

Also clean up the moved pattern.

Issue: iree-org#12678
  • Loading branch information
kuhar committed Apr 12, 2023
1 parent 9466b5e commit 11dbbec
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,92 +100,6 @@ bool allOperandsAreScalarTensors(Operation* op);
/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation* op);

/// Converts a HLO operation to a linalg.generic op that contains the
/// corresponding scalar operations.
template <typename OpTy>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// Find maximum rank / number of loops.
auto getRank = [](Value v) {
return v.getType().cast<ShapedType>().getRank();
};
auto isScalar = [&](Value v) { return getRank(v) == 0; };
auto it = llvm::find_if_not(adaptor.getOperands(), isScalar);
Value maxRankArg =
it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front();
int64_t nloops = getRank(maxRankArg);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == nloops;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be os same rank or scalar.");
}

// Find result type, if on tensors.
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

// Check result type compatibility.
if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops ||
!(resultTy->getElementType().isSignlessIntOrFloat() ||
resultTy->getElementType().isa<ComplexType>())) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}

if (allOperandsAreScalarTensors(op) && isInBodyOfLinalgOps(op))
return failure();

// Find input/output values and types.
ValueRange inputs = adaptor.getOperands();
Value output =
getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());

// Create indexing maps.
AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops);
SmallVector<AffineMap, 4> maps;
for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap);
maps.push_back(idMap);

// Build `linalg.generic` op.
bool failed = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps,
getNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
ValueRange args) {
Type innerResultTy = getElementTypeOrSelf(output);
auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, innerResultTy, argvec, &rewriter);
if (innerResult == nullptr) {
failed = true;
} else {
innerResult = postSparsify(op, semiring, innerResult, &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
}
},
linalg::getPrunedAttributeList(op));
if (failed) return failure();

rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "iree/compiler/InputConversion/StableHLO/TypeConversion.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand All @@ -22,6 +23,16 @@ namespace stablehlo = mlir::stablehlo;

int64_t getRank(Value v) { return cast<ShapedType>(v.getType()).getRank(); }

int64_t getMaxRank(ValueRange operands) {
int64_t maxRank = 0;
for (Value operand : operands) {
maxRank = std::max(maxRank, getRank(operand));
}
return maxRank;
}

bool isScalar(Value v) { return getRank(v) == 0; }

/// Inserts block arguments in places where scalar inputs have a nullptr.
SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
ValueRange blockArgs) {
Expand All @@ -38,6 +49,47 @@ SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
return result;
}

struct PointwiseConversionInfo {
int64_t maxOperandRank = 0;
ShapedType resultType;
};

/// Checks the preconditions for conversion of pointwise HLO ops to linalg.
/// Returns the max operand rank and the result type on success.
FailureOr<PointwiseConversionInfo> checkOperandsAndResults(
Operation* op, ValueRange operands, TypeConverter& typeConverter,
ConversionPatternRewriter& rewriter) {
int64_t maxRank = getMaxRank(operands);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(operands, [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}

// Find result type, if on tensors.
auto resultTy = dyn_cast_or_null<ShapedType>(
typeConverter.convertType(op->getResultTypes().front()));

// Check result type compatibility.
if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank ||
!(resultTy.getElementType().isSignlessIntOrFloat() ||
isa<ComplexType>(resultTy.getElementType()))) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}

// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure();

return PointwiseConversionInfo{maxRank, resultTy};
}

/// Converts a HLO operation to a linalg.map op that contains the corresponding
/// scalar operations.
template <typename OpTy>
Expand All @@ -48,34 +100,15 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto loc = op.getLoc();
int64_t maxRank = getMaxRank(adaptor);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}

// Find result type, if on tensors.
auto resultTy = dyn_cast_or_null<ShapedType>(
this->typeConverter->convertType(op->getResultTypes().front()));

// Check result type compatibility.
if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank ||
!(resultTy.getElementType().isSignlessIntOrFloat() ||
isa<ComplexType>(resultTy.getElementType()))) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
auto conversionInfo = checkOperandsAndResults(
op, adaptor.getOperands(), *this->typeConverter, rewriter);
if (failed(conversionInfo)) {
return failure();
}

// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure();
int64_t maxRank = conversionInfo->maxOperandRank;
ShapedType resultTy = conversionInfo->resultType;
Location loc = op.getLoc();

// Find input/output values and types.
Value emptyTensor =
Expand Down Expand Up @@ -110,13 +143,64 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern<OpTy> {
rewriter.replaceOp(op, mapOp->getResults());
return success();
}
};

/// Converts a HLO operation to a linalg.generic op that contains the
/// corresponding scalar operations.
template <typename OpTy>
struct PointwiseToLinalgConverter final : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;

static int64_t getMaxRank(OpAdaptor adaptor) {
int64_t maxRank = 0;
for (Value operand : adaptor.getOperands()) {
maxRank = std::max(maxRank, getRank(operand));
LogicalResult matchAndRewrite(
OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto conversionInfo = checkOperandsAndResults(
op, adaptor.getOperands(), *this->typeConverter, rewriter);
if (failed(conversionInfo)) {
return failure();
}
return maxRank;

int64_t maxRank = conversionInfo->maxOperandRank;
ShapedType resultTy = conversionInfo->resultType;
Location loc = op.getLoc();

// Find input/output values and types.
ValueRange inputs = adaptor.getOperands();
Value output =
getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands());

// Create indexing maps.
AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank);
SmallVector<AffineMap, 4> maps;
for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap);
maps.push_back(idMap);

// Build `linalg.generic` op.
bool failed = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps,
getNParallelLoopsAttrs(maxRank),
[&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
ValueRange args) {
Type innerResultTy = getElementTypeOrSelf(output);
auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, innerResultTy, argvec, &rewriter);
if (!innerResult) {
failed = true;
} else {
innerResult = postSparsify(op, semiring, innerResult, &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
}
},
linalg::getPrunedAttributeList(op));
if (failed) return failure();

rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
} // namespace
Expand Down

0 comments on commit 11dbbec

Please sign in to comment.