Skip to content

Commit

Permalink
[StableHLO][NFC] Factor out dot prod lowering patterns (#13023)
Browse files Browse the repository at this point in the history
This is primarily to reduce build times and length of the main stablehlo
to linalg conversion pass file.

Also clean up the code:
-  Use `LogicalResult` for verification success/failure
-  Use free cast functions
-  Make patterns `final`
-  Replace some local `auto` with the actual type where not obvious

Issue: #12678
  • Loading branch information
kuhar authored Apr 11, 2023
1 parent 8b2f0b4 commit 4cdf0cc
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 283 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_compiler_cc_library(
"LegalizeToLinalgUtils.cpp",
"Passes.cpp",
"StableHLOToLinalg.cpp",
"StableHLOToLinalgDotProd.cpp",
"StableHLOToLinalgPointwise.cpp",
"TypeConversion.cpp",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ iree_cc_library(
"LegalizeToLinalgUtils.cpp"
"Passes.cpp"
"StableHLOToLinalg.cpp"
"StableHLOToLinalgDotProd.cpp"
"StableHLOToLinalgPointwise.cpp"
"TypeConversion.cpp"
DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,29 @@ Value coerceTensorShape(OpBuilder& builder, Location loc,
value);
}

LogicalResult verifyHloOpBufferOrTensorSemantics(Operation* op) {
auto isRankedTensor = [](Value val) {
return isa<RankedTensorType>(val.getType());
};
if (!llvm::all_of(op->getOperands(), isRankedTensor)) return failure();
return success(llvm::all_of(op->getResults(), isRankedTensor));
}

Value fillTensorWithZeros(OpBuilder& builder, Location loc, Value tensor) {
auto type = cast<ShapedType>(tensor.getType());
Value zero;
// Complex numbers are a special case.
if (auto complexType = type.getElementType().dyn_cast<ComplexType>()) {
auto zeroElement = builder.getZeroAttr(complexType.getElementType());
auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement});
zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr);
} else {
auto zeroAttr = builder.getZeroAttr(type.getElementType());
zero = builder.create<arith::ConstantOp>(loc, zeroAttr);
}
return builder.create<linalg::FillOp>(loc, zero, tensor).result();
}

Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
OpBuilder* b) {
// Apply for semi-ring operations that lower to elaborate code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
Value coerceTensorShape(OpBuilder& builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType);

/// Verifies |op|'s semantics by checking if all operands and results have
/// ranged tensor types.
LogicalResult verifyHloOpBufferOrTensorSemantics(Operation* op);

/// Fills |tensor| with a zero constant of the matching type. Returns the new
/// value.
Value fillTensorWithZeros(OpBuilder& builder, Location loc, Value tensor);

/// Sparsifies a (block of) operation(s) that cannot be handled directly
/// by the sparse compiler but has well-known semi-ring semantics.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context,
// Fine-grained patterns used by the implementation.
//===----------------------------------------------------------------------===//
namespace detail {
/// Populates the patterns that convert from StableHLO to Linalg on tensors.
/// Populates the patterns that convert from elementwise StableHLO ops to Linalg
/// on tensors.
void populatePointwiseStableHloToLinalgConversionPatterns(
MLIRContext* context, TypeConverter& typeConverter,
RewritePatternSet* patterns, bool enablePrimitiveOps);

/// Populates the patterns that convert from dot product StableHLO ops to Linalg
/// on tensors.
void populateStableHloDotProdToLinalgConversionPatterns(
MLIRContext* context, TypeConverter& typeConverter,
RewritePatternSet* patterns);

} // namespace detail

} // namespace mlir::iree_compiler::stablehlo
Expand Down
Loading

0 comments on commit 4cdf0cc

Please sign in to comment.