Skip to content

Commit

Permalink
[StableHLO][NFC] Factor out and clean up pointwise patterns (#13018)
Browse files Browse the repository at this point in the history
This is primarily to improve compilation times when iterating on the
rest of StableHLO to Linalg conversion patterns. These pointwise
patterns create a large number of class template instantiations and take
>20s to compile on my machine.

Also clean up the code:
-  Use free cast functions
-  Make the conversion patterns final
-  Prefer static/free helper functions

Issue: #12678
  • Loading branch information
kuhar authored and jpienaar committed May 1, 2023
1 parent 251866b commit 7f85b92
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_compiler_cc_library(
"LegalizeToLinalgUtils.cpp",
"Passes.cpp",
"StableHLOToLinalg.cpp",
"StableHLOToLinalgPointwise.cpp",
"TypeConversion.cpp",
],
hdrs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ iree_cc_library(
"LegalizeToLinalgUtils.cpp"
"Passes.cpp"
"StableHLOToLinalg.cpp"
"StableHLOToLinalgPointwise.cpp"
"TypeConversion.cpp"
DEPS
::PassHeaders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
: getEmptyTensor(b, loc, resultType, sizes);
}

Value coerceTensorShape(OpBuilder& builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType) {
return builder.createOrFold<tensor::CastOp>(
loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()),
value);
}

Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
OpBuilder* b) {
// Apply for semi-ring operations that lower to elaborate code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type,
Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
Operation* op, ValueRange operands);

/// Ensures a tensor has the same shape (not including the element type) as
/// another.
Value coerceTensorShape(OpBuilder& builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType);

/// Sparsifies a (block of) operation(s) that cannot be handled directly
/// by the sparse compiler but has well-known semi-ring semantics.
///
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context,
RewritePatternSet* patterns,
bool enablePrimitiveOps);

//===----------------------------------------------------------------------===//
// Fine-grained patterns used by the implementation.
//===----------------------------------------------------------------------===//
namespace detail {
/// Populates the patterns that convert from StableHLO to Linalg on tensors.
void populatePointwiseStableHloToLinalgConversionPatterns(
MLIRContext* context, TypeConverter& typeConverter,
RewritePatternSet* patterns, bool enablePrimitiveOps);

} // namespace detail

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_REWRITERS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ Value extractIndexFromTensor(OpBuilder& builder, Location loc, Value tensor,
loc, builder.getIndexType(), extracted);
}

/// Ensures a tensor has the same shape (not including the element type) as
/// another.
Value coerceTensorShape(OpBuilder& builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType) {
return builder.createOrFold<tensor::CastOp>(
loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()),
value);
}

//===----------------------------------------------------------------------===//
// stablehlo.RngOp conversion patterns.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2909,110 +2900,6 @@ class DotGeneralOpConversion
}
};

/// Converts a HLO operation to a linalg.map op that contains the corresponding
/// scalar operations.
template <typename OpTy>
class PointwiseToLinalgMapConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
int64_t maxRank = getMaxRank(adaptor);

// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}

// Find result type, if on tensors.
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

// Check result type compatibility.
if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != maxRank ||
!(resultTy->getElementType().isSignlessIntOrFloat() ||
resultTy->getElementType().isa<ComplexType>())) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}

// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure();

// Find input/output values and types.
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());

// Mapped inputs are cast to the same shape as the init tensor.
// Values from scalar inputs are extracted and used directly in the block.
SmallVector<Value> mappedInputs;
SmallVector<Value> scalarInputs;
for (Value input : adaptor.getOperands()) {
if (getRank(input) == maxRank) {
mappedInputs.push_back(coerceTensorShape(
rewriter, loc, cast<TypedValue<ShapedType>>(input),
emptyTensor.getType()));
scalarInputs.push_back(nullptr);
} else {
scalarInputs.push_back(rewriter.create<tensor::ExtractOp>(loc, input));
}
}

auto mapOp = rewriter.create<linalg::MapOp>(
loc, mappedInputs, emptyTensor,
[&](OpBuilder& b, Location loc, ValueRange args) {
Value innerResult = stablehlo::StableHloOpToStdScalarOp::mapOp(
op, getElementTypeOrSelf(emptyTensor),
interleaveScalarAndBlockArgs(scalarInputs, args), &b);

b.create<linalg::YieldOp>(loc, innerResult);
},
linalg::getPrunedAttributeList(op));

rewriter.replaceOp(op, mapOp->getResults());
return success();
}

protected:
int64_t getRank(Value v) const {
return v.getType().cast<ShapedType>().getRank();
}

int64_t getMaxRank(typename OpTy::Adaptor adaptor) const {
int64_t maxRank = 0;
for (auto operand : adaptor.getOperands()) {
maxRank = std::max(maxRank, getRank(operand));
}
return maxRank;
}

// Inserts block arguments in places where scalar inputs have a nullptr.
SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
ValueRange blockArgs) const {
SmallVector<Value> result;
auto argsIter = blockArgs.begin();
for (Value scalarInput : scalarInputs) {
if (scalarInput) {
result.push_back(scalarInput);
} else {
result.push_back(*argsIter);
++argsIter;
}
}
return result;
}
};

class SetDimensionSizeConverter
: public OpConversionPattern<stablehlo::SetDimensionSizeOp> {
public:
Expand Down Expand Up @@ -3108,6 +2995,9 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context,
SelectAndScatterNoOverlapConverter,
ReduceRegionReturnOpConversion>(typeConverter, context);

detail::populatePointwiseStableHloToLinalgConversionPatterns(
context, typeConverter, patterns, enablePrimitiveOps);

if (enablePrimitiveOps) {
patterns->add<
BroadcastInDimOpToBroadcastConverter,
Expand All @@ -3116,52 +3006,6 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context,
IotaToMapConverter<stablehlo::IotaOp>,
IotaToMapConverter<stablehlo::DynamicIotaOp>,
MapOpToMapConverter,
PointwiseToLinalgMapConverter<stablehlo::AbsOp>,
PointwiseToLinalgMapConverter<stablehlo::AddOp>,
PointwiseToLinalgMapConverter<stablehlo::AndOp>,
PointwiseToLinalgMapConverter<stablehlo::Atan2Op>,
PointwiseToLinalgMapConverter<stablehlo::BitcastConvertOp>,
PointwiseToLinalgMapConverter<stablehlo::CbrtOp>,
PointwiseToLinalgMapConverter<stablehlo::CeilOp>,
PointwiseToLinalgMapConverter<stablehlo::ClampOp>,
PointwiseToLinalgMapConverter<stablehlo::ClzOp>,
PointwiseToLinalgMapConverter<stablehlo::CompareOp>,
PointwiseToLinalgMapConverter<stablehlo::ComplexOp>,
PointwiseToLinalgMapConverter<stablehlo::ConvertOp>,
PointwiseToLinalgMapConverter<stablehlo::CosineOp>,
PointwiseToLinalgMapConverter<stablehlo::DivOp>,
PointwiseToLinalgMapConverter<stablehlo::ExpOp>,
PointwiseToLinalgMapConverter<stablehlo::Expm1Op>,
PointwiseToLinalgMapConverter<stablehlo::FloorOp>,
PointwiseToLinalgMapConverter<stablehlo::ImagOp>,
PointwiseToLinalgMapConverter<stablehlo::IsFiniteOp>,
PointwiseToLinalgMapConverter<stablehlo::Log1pOp>,
PointwiseToLinalgMapConverter<stablehlo::LogOp>,
PointwiseToLinalgMapConverter<stablehlo::LogisticOp>,
PointwiseToLinalgMapConverter<stablehlo::MaxOp>,
PointwiseToLinalgMapConverter<stablehlo::MinOp>,
PointwiseToLinalgMapConverter<stablehlo::MulOp>,
PointwiseToLinalgMapConverter<stablehlo::NegOp>,
PointwiseToLinalgMapConverter<stablehlo::NotOp>,
PointwiseToLinalgMapConverter<stablehlo::OrOp>,
PointwiseToLinalgMapConverter<stablehlo::PopulationCountOp>,
PointwiseToLinalgMapConverter<stablehlo::PowOp>,
PointwiseToLinalgMapConverter<stablehlo::RealOp>,
PointwiseToLinalgMapConverter<stablehlo::ReducePrecisionOp>,
PointwiseToLinalgMapConverter<stablehlo::RemOp>,
PointwiseToLinalgMapConverter<stablehlo::RoundNearestEvenOp>,
PointwiseToLinalgMapConverter<stablehlo::RoundOp>,
PointwiseToLinalgMapConverter<stablehlo::RsqrtOp>,
PointwiseToLinalgMapConverter<stablehlo::SelectOp>,
PointwiseToLinalgMapConverter<stablehlo::ShiftLeftOp>,
PointwiseToLinalgMapConverter<stablehlo::ShiftRightArithmeticOp>,
PointwiseToLinalgMapConverter<stablehlo::ShiftRightLogicalOp>,
PointwiseToLinalgMapConverter<stablehlo::SignOp>,
PointwiseToLinalgMapConverter<stablehlo::SineOp>,
PointwiseToLinalgMapConverter<stablehlo::SqrtOp>,
PointwiseToLinalgMapConverter<stablehlo::SubtractOp>,
PointwiseToLinalgMapConverter<stablehlo::TanhOp>,
PointwiseToLinalgMapConverter<stablehlo::XorOp>,
ReduceOpToReduceConverter,
TransposeOpToTransposeConverter
>(typeConverter, context);
Expand All @@ -3173,52 +3017,6 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context,
HloBroadcastInDimConverter,
HloDynamicBroadcastInDimConverter,
MapOpToGenericConverter,
PointwiseToLinalgConverter<stablehlo::AbsOp>,
PointwiseToLinalgConverter<stablehlo::AddOp>,
PointwiseToLinalgConverter<stablehlo::AndOp>,
PointwiseToLinalgConverter<stablehlo::Atan2Op>,
PointwiseToLinalgConverter<stablehlo::BitcastConvertOp>,
PointwiseToLinalgConverter<stablehlo::CbrtOp>,
PointwiseToLinalgConverter<stablehlo::CeilOp>,
PointwiseToLinalgConverter<stablehlo::ClampOp>,
PointwiseToLinalgConverter<stablehlo::ClzOp>,
PointwiseToLinalgConverter<stablehlo::CompareOp>,
PointwiseToLinalgConverter<stablehlo::ComplexOp>,
PointwiseToLinalgConverter<stablehlo::ConvertOp>,
PointwiseToLinalgConverter<stablehlo::CosineOp>,
PointwiseToLinalgConverter<stablehlo::DivOp>,
PointwiseToLinalgConverter<stablehlo::ExpOp>,
PointwiseToLinalgConverter<stablehlo::Expm1Op>,
PointwiseToLinalgConverter<stablehlo::FloorOp>,
PointwiseToLinalgConverter<stablehlo::ImagOp>,
PointwiseToLinalgConverter<stablehlo::IsFiniteOp>,
PointwiseToLinalgConverter<stablehlo::Log1pOp>,
PointwiseToLinalgConverter<stablehlo::LogOp>,
PointwiseToLinalgConverter<stablehlo::LogisticOp>,
PointwiseToLinalgConverter<stablehlo::MaxOp>,
PointwiseToLinalgConverter<stablehlo::MinOp>,
PointwiseToLinalgConverter<stablehlo::MulOp>,
PointwiseToLinalgConverter<stablehlo::NegOp>,
PointwiseToLinalgConverter<stablehlo::NotOp>,
PointwiseToLinalgConverter<stablehlo::OrOp>,
PointwiseToLinalgConverter<stablehlo::PopulationCountOp>,
PointwiseToLinalgConverter<stablehlo::PowOp>,
PointwiseToLinalgConverter<stablehlo::RealOp>,
PointwiseToLinalgConverter<stablehlo::ReducePrecisionOp>,
PointwiseToLinalgConverter<stablehlo::RemOp>,
PointwiseToLinalgConverter<stablehlo::RoundNearestEvenOp>,
PointwiseToLinalgConverter<stablehlo::RoundOp>,
PointwiseToLinalgConverter<stablehlo::RsqrtOp>,
PointwiseToLinalgConverter<stablehlo::SelectOp>,
PointwiseToLinalgConverter<stablehlo::ShiftLeftOp>,
PointwiseToLinalgConverter<stablehlo::ShiftRightArithmeticOp>,
PointwiseToLinalgConverter<stablehlo::ShiftRightLogicalOp>,
PointwiseToLinalgConverter<stablehlo::SignOp>,
PointwiseToLinalgConverter<stablehlo::SineOp>,
PointwiseToLinalgConverter<stablehlo::SqrtOp>,
PointwiseToLinalgConverter<stablehlo::SubtractOp>,
PointwiseToLinalgConverter<stablehlo::TanhOp>,
PointwiseToLinalgConverter<stablehlo::XorOp>,
ReduceOpToGenericConverter,
TransposeConverter<stablehlo::TransposeOp>
>(typeConverter, context);
Expand Down
Loading

0 comments on commit 7f85b92

Please sign in to comment.