Skip to content

Commit

Permalink
Allow dynamic but ranked types in ops with SameOperandsAndResultShape…
Browse files Browse the repository at this point in the history
… and SameOperandsAndResultType traits

Currently SameOperandsAndResultShape trait allows operands to have tensor<*xf32> and tensor<2xf32> but doesn't allow tensor<?xf32> and tensor<10xf32>.

Also, use the updated shape compatibility helper function in TensorCastOp::areCastCompatible method.

PiperOrigin-RevId: 273658336
  • Loading branch information
smit-hinsu authored and tensorflower-gardener committed Oct 9, 2019
1 parent b3a6ae8 commit 85b4631
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 52 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/TypeUtilities.h
Expand Up @@ -52,6 +52,13 @@ SmallVector<Type, 10> getFlattenedTypes(TupleType t);
/// dialect and typeData.
bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData);

/// Returns success if the given two types have compatible shape. That is,
/// they are both scalars (not shaped), or they are both shaped types and at
/// least one is unranked or they have compatible dimensions. Dimensions are
/// compatible if at least one is dynamic or both are equal. The element type
/// does not matter.
LogicalResult verifyCompatibleShape(Type type1, Type type2);

//===----------------------------------------------------------------------===//
// Utility Iterators
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 1 addition & 18 deletions mlir/lib/Dialect/StandardOps/Ops.cpp
Expand Up @@ -2215,24 +2215,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
if (aT.getElementType() != bT.getElementType())
return false;

// If the either are unranked, then the cast is valid.
auto aRType = aT.dyn_cast<RankedTensorType>();
auto bRType = bT.dyn_cast<RankedTensorType>();
if (!aRType || !bRType)
return true;

// If they are both ranked, they have to have the same rank, and any specified
// dimensions must match.
if (aRType.getRank() != bRType.getRank())
return false;

for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) {
int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i);
if (aDim != -1 && bDim != -1 && aDim != bDim)
return false;
}

return true;
return succeeded(verifyCompatibleShape(aT, bT));
}

OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
Expand Down
33 changes: 8 additions & 25 deletions mlir/lib/IR/Operation.cpp
Expand Up @@ -748,33 +748,13 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
return success();
}

/// Returns success if the given two types have the same shape. That is,
/// they are both scalars (not shaped), or they are both shaped types and at
/// least one is unranked or they have the same shape. The element type does not
/// matter.
static LogicalResult verifyShapeMatch(Type type1, Type type2) {
auto sType1 = type1.dyn_cast<ShapedType>();
auto sType2 = type2.dyn_cast<ShapedType>();

// Either both or neither type should be shaped.
if (!sType1)
return success(!sType2);
if (!sType2)
return failure();

if (!sType1.hasRank() || !sType2.hasRank())
return success();

return success(sType1.getShape() == sType2.getShape());
}

LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();

auto type = op->getOperand(0)->getType();
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
if (failed(verifyShapeMatch(opType, type)))
if (failed(verifyCompatibleShape(opType, type)))
return op->emitOpError() << "requires the same shape for all operands";
}
return success();
Expand All @@ -787,12 +767,12 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {

auto type = op->getOperand(0)->getType();
for (auto resultType : op->getResultTypes()) {
if (failed(verifyShapeMatch(resultType, type)))
if (failed(verifyCompatibleShape(resultType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
}
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
if (failed(verifyShapeMatch(opType, type)))
if (failed(verifyCompatibleShape(opType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
}
Expand Down Expand Up @@ -843,13 +823,16 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
return failure();

auto type = op->getResult(0)->getType();
auto elementType = getElementTypeOrSelf(type);
for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
if (resultType != type)
if (getElementTypeOrSelf(resultType) != elementType ||
failed(verifyCompatibleShape(resultType, type)))
return op->emitOpError()
<< "requires the same type for all operands and results";
}
for (auto opType : op->getOperandTypes()) {
if (opType != type)
if (getElementTypeOrSelf(opType) != elementType ||
failed(verifyCompatibleShape(opType, type)))
return op->emitOpError()
<< "requires the same type for all operands and results";
}
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/IR/TypeUtilities.cpp
Expand Up @@ -61,6 +61,37 @@ bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
return false;
}

/// Returns success if the given two types have compatible shape. That is,
/// they are both scalars (not shaped), or they are both shaped types and at
/// least one is unranked or they have compatible dimensions. Dimensions are
/// compatible if at least one is dynamic or both are equal. The element type
/// does not matter.
LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
auto sType1 = type1.dyn_cast<ShapedType>();
auto sType2 = type2.dyn_cast<ShapedType>();

// Either both or neither type should be shaped.
if (!sType1)
return success(!sType2);
if (!sType2)
return failure();

if (!sType1.hasRank() || !sType2.hasRank())
return success();

if (sType1.getRank() != sType2.getRank())
return failure();

for (const auto &dims : llvm::zip(sType1.getShape(), sType2.getShape())) {
int64_t dim1 = std::get<0>(dims);
int64_t dim2 = std::get<1>(dims);
if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
dim1 != dim2)
return failure();
}
return success();
}

OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
: llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}

Expand Down
8 changes: 0 additions & 8 deletions mlir/test/IR/invalid-ops.mlir
Expand Up @@ -297,14 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {

// -----

func @func_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
^bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
// expected-error@+1 {{requires the same shape for all operands and results}}
%r = "std.select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
}

// -----

func @test_vector.transfer_read(memref<?x?xf32>) {
^bb0(%arg0: memref<?x?xf32>):
%c3 = constant 3 : index
Expand Down
21 changes: 20 additions & 1 deletion mlir/test/IR/traits.mlir
Expand Up @@ -113,10 +113,11 @@ func @failedSameOperandShape_no_operands() {
// -----

// CHECK: succeededSameOperandAndResultShape
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor<?xf32>) {
%0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "test.same_operand_and_result_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = "test.same_operand_and_result_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
%3 = "test.same_operand_and_result_shape"(%t1, %t1d) : (tensor<1xf32>, tensor<?xf32>) -> tensor<1xf32>
return
}

Expand All @@ -143,6 +144,24 @@ func @failedSameOperandAndResultShape_no_operands(%t1: tensor<1xf32>) {

// -----

// CHECK: succeededSameOperandAndResultType
func @succeededSameOperandAndResultType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor<?xf32>) {
%0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "test.same_operand_and_result_type"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = "test.same_operand_and_result_type"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
%3 = "test.same_operand_and_result_type"(%t1, %t1d) : (tensor<1xf32>, tensor<?xf32>) -> tensor<1xf32>
return
}

// -----

func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) {
// expected-error@+1 {{requires the same type for all operands and results}}
%0 = "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32>
}

// -----

func @failedHasParent_wrong_parent() {
"some.op"() ({
// expected-error@+1 {{'test.child' op expects parent op 'test.parent'}}
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/lib/TestDialect/TestOps.td
Expand Up @@ -257,6 +257,12 @@ def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
let results = (outs Variadic<AnyVectorOrTensor>);
}

def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
[SameOperandsAndResultType]> {
let arguments = (ins Variadic<AnyVectorOrTensor>);
let results = (outs Variadic<AnyVectorOrTensor>);
}

def ArgAndResHaveFixedElementTypesOp :
TEST_Op<"arg_and_res_have_fixed_element_types",
[PredOpTrait<"fixed type combination",
Expand Down

0 comments on commit 85b4631

Please sign in to comment.