Skip to content

Commit

Permalink
Make shape.is_broadcastable/shape.cstr_broadcastable nary
Browse files Browse the repository at this point in the history
This corresponds with the previous work to make shape.broadcast nary.
Additionally, simplify the ConvertShapeConstraints pass. It now doesn't
lower an implicit shape.is_broadcastable. This is still the same in
combination with shape-to-standard when the 2 passes are used in either
order.

Differential Revision: https://reviews.llvm.org/D96401
  • Loading branch information
tpopp committed Feb 15, 2021
1 parent e8b9da7 commit 3842d4b
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 218 deletions.
63 changes: 49 additions & 14 deletions mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Expand Up @@ -190,11 +190,12 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
let assemblyFormat = "$input attr-dict `:` type($input)";
}

def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
let summary = "Determines if 2 shapes can be successfully broadcasted";
def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
[Commutative, InferTypeOpInterface]> {
let summary = "Determines if 2+ shapes can be successfully broadcasted";
let description = [{
Given two input shapes or extent tensors, return a predicate specifying if
they are broadcastable. This broadcastable follows the same logic as what
Given multiple input shapes or extent tensors, return a predicate specifying
if they are broadcastable. This broadcastable follows the same logic as what
shape.broadcast documents.

Concretely, shape.is_broadcastable returning true implies that
Expand All @@ -209,11 +210,28 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
```
}];

let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs);
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs I1:$result);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
let builders = [
OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
[{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
];
let extraClassDeclaration = [{
// TODO: This should really be automatic. Figure out how to not need this defined.
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
/*width=*/1));
return success();
};
}];

let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
let verifier = [{ return ::verify(*this); }];

}

def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
Expand Down Expand Up @@ -692,11 +710,12 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}

def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
let summary = "Determines if 2 shapes can be successfully broadcasted";
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
[Commutative, InferTypeOpInterface]> {
let summary = "Determines if 2+ shapes can be successfully broadcasted";
let description = [{
Given two input shapes or extent tensors, return a witness specifying if
they are broadcastable. This broadcastable follows the same logic as what
Given input shapes or extent tensors, return a witness specifying if they
are broadcastable. This broadcastable follows the same logic as what
shape.broadcast documents.

"cstr" operations represent runtime assertions.
Expand All @@ -708,14 +727,30 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
```
}];

let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs);
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs Shape_WitnessType:$result);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";

let builders = [
OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
[{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
];

let extraClassDeclaration = [{
// TODO: This should really be automatic. Figure out how to not need this defined.
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
return success();
};
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
let verifier = [{ return ::verify(*this); }];
}

def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
Expand Down
73 changes: 2 additions & 71 deletions mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
Expand Up @@ -19,77 +19,8 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace {
class ConvertCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
if (op.getType().isa<shape::ShapeType>() ||
op.lhs().getType().isa<shape::ShapeType>() ||
op.rhs().getType().isa<shape::ShapeType>()) {
return rewriter.notifyMatchFailure(
op, "cannot convert error-propagating shapes");
}

auto loc = op.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);

// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Value lesserRank =
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());

Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);

// Generate code to compare the shapes extent by extent, and emit errors for
// non-broadcast-compatible shapes.
// Two extents are broadcast-compatible if
// 1. they are both equal, or
// 2. at least one of them is 1.

rewriter.create<scf::ForOp>(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
loc, greaterRankOperand, ValueRange{iv});
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, lesserRankOperand, ValueRange{ivShifted});

Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
Value extentsAgree =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
lesserRankOperandExtent);
auto broadcastIsValid =
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
lesserRankOperandExtentIsOne));
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
b.create<scf::YieldOp>(loc);
});

rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
#include "ShapeToStandard.cpp.inc"
} // namespace

namespace {
Expand All @@ -107,7 +38,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {

void mlir::populateConvertShapeConstraintsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ConvertCstrBroadcastableOp>(ctx);
patterns.insert<CstrBroadcastableToRequire>(ctx);
patterns.insert<ConvertCstrRequireOp>(ctx);
}

Expand Down
119 changes: 70 additions & 49 deletions mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Expand Up @@ -237,63 +237,84 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
IsBroadcastableOp::Adaptor transformed(operands);
if (transformed.lhs().getType().isa<ShapeType>() ||
transformed.rhs().getType().isa<ShapeType>())
if (!llvm::all_of(op.shapes(),
[](Value v) { return !v.getType().isa<ShapeType>(); }))
return failure();

auto loc = op.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
ImplicitLocOpBuilder lb(loc, rewriter);
Value zero = lb.create<ConstantIndexOp>(0);
Value one = lb.create<ConstantIndexOp>(1);
Type indexTy = lb.getIndexType();

// Save all the ranks for bounds checking. Because this is a tensor
// representing the shape extents, the rank is the extent of the only
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
return lb.create<DimOp>(v, zero);
}));

// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
}

// Calculate the difference of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
return lb.create<SubIOp>(indexTy, maxRank, v);
}));

// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Value lesserRank =
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
Type i1Ty = rewriter.getI1Type();
Value init =
Value trueVal =
rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));

// Determine if all overlapping extents are broadcastable.
auto reduceResult = rewriter.create<ForOp>(
loc, rankDiff, greaterRank, one, ValueRange{init},
auto reduceResult = lb.create<ForOp>(
loc, zero, maxRank, one, ValueRange{trueVal},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
loc, greaterRankOperand, ValueRange{iv});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
Value extentsAreEqual =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
lesserRankOperandExtent);
Value broadcastableExtents = b.create<AndOp>(
loc, iterArgs[0],
b.create<OrOp>(loc,
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
lesserRankOperandExtentIsOne),
extentsAreEqual));
b.create<scf::YieldOp>(loc, broadcastableExtents);
// Find a non-1 dim, if it exists. Note that the first part of this
// could reuse the Broadcast lowering entirely, but we redo the work
// here to make optimizations easier between the two loops.
Value broadcastedDim = getBroadcastedDim(
ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);

Value broadcastable = iterArgs[0];
for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
Value shape, rankDiff;
std::tie(shape, rankDiff) = tup;
Value outOfBounds =
b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
broadcastable =
b.create<IfOp>(
loc, TypeRange{i1Ty}, outOfBounds,
[&](OpBuilder &b, Location loc) {
// Non existent dimensions are always broadcastable
b.create<scf::YieldOp>(loc, broadcastable);
},
[&](OpBuilder &b, Location loc) {
// Every value needs to be either 1, or the same non-1
// value to be broadcastable in this dim.
Value operandDimension =
b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value dimensionExtent = b.create<tensor::ExtractOp>(
loc, shape, ValueRange{operandDimension});

Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
dimensionExtent, one);
Value equalBroadcasted =
b.create<CmpIOp>(loc, CmpIPredicate::eq,
dimensionExtent, broadcastedDim);
Value result = b.create<AndOp>(
loc, broadcastable,
b.create<OrOp>(loc, equalOne, equalBroadcasted));
b.create<scf::YieldOp>(loc, result);
})
.getResult(0);
}

b.create<scf::YieldOp>(loc, broadcastable);
});

rewriter.replaceOp(op, reduceResult.results().front());
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
Expand Up @@ -19,9 +19,9 @@ def BroadcastableStringAttr : NativeCodeCall<[{
$_builder.getStringAttr("required broadcastable shapes")
}]>;

def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS),
def CstrBroadcastableToRequire : Pat<(Shape_CstrBroadcastableOp $shapes),
(Shape_CstrRequireOp
(Shape_IsBroadcastableOp $LHS, $RHS),
(Shape_IsBroadcastableOp $shapes),
(BroadcastableStringAttr))>;

#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD
26 changes: 24 additions & 2 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Expand Up @@ -491,6 +491,10 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
}

OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// TODO: Add folding for the nary case
if (operands.size() != 2)
return nullptr;

// Both operands are not needed if one is a scalar.
if (operands[0] &&
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
Expand All @@ -512,9 +516,9 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// Lastly, see if folding can be completed based on what constraints are known
// on the input shapes.
SmallVector<int64_t, 6> lhsShape, rhsShape;
if (failed(getShapeVec(lhs(), lhsShape)))
if (failed(getShapeVec(shapes()[0], lhsShape)))
return nullptr;
if (failed(getShapeVec(rhs(), rhsShape)))
if (failed(getShapeVec(shapes()[1], rhsShape)))
return nullptr;

if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
Expand All @@ -525,6 +529,13 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}

static LogicalResult verify(CstrBroadcastableOp op) {
// Ensure that AssumingAllOp contains at least one operand
if (op.getNumOperands() < 2)
return op.emitOpError("required at least 2 input shapes");
return success();
}

//===----------------------------------------------------------------------===//
// CstrEqOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -723,6 +734,17 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
}
}

//===----------------------------------------------------------------------===//
// IsBroadcastableOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(IsBroadcastableOp op) {
// Ensure that AssumingAllOp contains at least one operand
if (op.getNumOperands() < 2)
return op.emitOpError("required at least 2 input shapes");
return success();
}

//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 3842d4b

Please sign in to comment.