Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[StableHLO][NFC] Move pointwise to linalg conversion to source file #13044

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,92 +100,6 @@ bool allOperandsAreScalarTensors(Operation* op);
/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation* op);

/// Converts a HLO operation to a linalg.generic op that contains the
/// corresponding scalar operations.
template <typename OpTy>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// Find maximum rank / number of loops.
auto getRank = [](Value v) {
return v.getType().cast<ShapedType>().getRank();
};
auto isScalar = [&](Value v) { return getRank(v) == 0; };
auto it = llvm::find_if_not(adaptor.getOperands(), isScalar);
Value maxRankArg =
it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front();
int64_t nloops = getRank(maxRankArg);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == nloops;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be os same rank or scalar.");
}

// Find result type, if on tensors.
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

// Check result type compatibility.
if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops ||
!(resultTy->getElementType().isSignlessIntOrFloat() ||
resultTy->getElementType().isa<ComplexType>())) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}

if (allOperandsAreScalarTensors(op) && isInBodyOfLinalgOps(op))
return failure();

// Find input/output values and types.
ValueRange inputs = adaptor.getOperands();
Value output =
getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());

// Create indexing maps.
AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops);
SmallVector<AffineMap, 4> maps;
for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap);
maps.push_back(idMap);

// Build `linalg.generic` op.
bool failed = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps,
getNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
ValueRange args) {
Type innerResultTy = getElementTypeOrSelf(output);
auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, innerResultTy, argvec, &rewriter);
if (innerResult == nullptr) {
failed = true;
} else {
innerResult = postSparsify(op, semiring, innerResult, &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
}
},
linalg::getPrunedAttributeList(op));
if (failed) return failure();

rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "iree/compiler/InputConversion/StableHLO/TypeConversion.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand All @@ -22,6 +23,16 @@ namespace stablehlo = mlir::stablehlo;

int64_t getRank(Value v) { return cast<ShapedType>(v.getType()).getRank(); }

int64_t getMaxRank(ValueRange operands) {
int64_t maxRank = 0;
for (Value operand : operands) {
maxRank = std::max(maxRank, getRank(operand));
}
return maxRank;
}

bool isScalar(Value v) { return getRank(v) == 0; }

/// Inserts block arguments in places where scalar inputs have a nullptr.
SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
ValueRange blockArgs) {
Expand All @@ -38,6 +49,47 @@ SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
return result;
}

struct PointwiseConversionInfo {
int64_t maxOperandRank = 0;
ShapedType resultType;
};

/// Checks the preconditions for conversion of pointwise HLO ops to linalg.
/// Returns the max operand rank and the result type on success.
FailureOr<PointwiseConversionInfo> checkOperandsAndResults(
Operation* op, ValueRange operands, TypeConverter& typeConverter,
ConversionPatternRewriter& rewriter) {
int64_t maxRank = getMaxRank(operands);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(operands, [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}

// Find result type, if on tensors.
auto resultTy = dyn_cast_or_null<ShapedType>(
typeConverter.convertType(op->getResultTypes().front()));

// Check result type compatibility.
if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank ||
!(resultTy.getElementType().isSignlessIntOrFloat() ||
isa<ComplexType>(resultTy.getElementType()))) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}

// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure();

return PointwiseConversionInfo{maxRank, resultTy};
}

/// Converts a HLO operation to a linalg.map op that contains the corresponding
/// scalar operations.
template <typename OpTy>
Expand All @@ -48,34 +100,15 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern<OpTy> {
LogicalResult matchAndRewrite(
OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto loc = op.getLoc();
int64_t maxRank = getMaxRank(adaptor);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}

// Find result type, if on tensors.
auto resultTy = dyn_cast_or_null<ShapedType>(
this->typeConverter->convertType(op->getResultTypes().front()));

// Check result type compatibility.
if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank ||
!(resultTy.getElementType().isSignlessIntOrFloat() ||
isa<ComplexType>(resultTy.getElementType()))) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
auto conversionInfo = checkOperandsAndResults(
op, adaptor.getOperands(), *this->typeConverter, rewriter);
if (failed(conversionInfo)) {
return failure();
}

// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure();
int64_t maxRank = conversionInfo->maxOperandRank;
ShapedType resultTy = conversionInfo->resultType;
Location loc = op.getLoc();

// Find input/output values and types.
Value emptyTensor =
Expand Down Expand Up @@ -110,13 +143,64 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern<OpTy> {
rewriter.replaceOp(op, mapOp->getResults());
return success();
}
};

/// Converts a HLO operation to a linalg.generic op that contains the
/// corresponding scalar operations.
template <typename OpTy>
struct PointwiseToLinalgConverter final : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;

static int64_t getMaxRank(OpAdaptor adaptor) {
int64_t maxRank = 0;
for (Value operand : adaptor.getOperands()) {
maxRank = std::max(maxRank, getRank(operand));
LogicalResult matchAndRewrite(
OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto conversionInfo = checkOperandsAndResults(
op, adaptor.getOperands(), *this->typeConverter, rewriter);
if (failed(conversionInfo)) {
return failure();
}
return maxRank;

int64_t maxRank = conversionInfo->maxOperandRank;
ShapedType resultTy = conversionInfo->resultType;
Location loc = op.getLoc();

// Find input/output values and types.
ValueRange inputs = adaptor.getOperands();
Value output =
getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands());

// Create indexing maps.
AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank);
SmallVector<AffineMap, 4> maps;
for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap);
maps.push_back(idMap);

// Build `linalg.generic` op.
bool failed = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps,
getNParallelLoopsAttrs(maxRank),
[&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
ValueRange args) {
Type innerResultTy = getElementTypeOrSelf(output);
auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, innerResultTy, argvec, &rewriter);
if (!innerResult) {
failed = true;
} else {
innerResult = postSparsify(op, semiring, innerResult, &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
}
},
linalg::getPrunedAttributeList(op));
if (failed) return failure();

rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
} // namespace
Expand Down