Skip to content

Commit

Permalink
Make std.divis and std.diviu support ElementsAttr folding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 282434465
  • Loading branch information
benvanik authored and tensorflower-gardener committed Nov 25, 2019
1 parent f87b2fd commit 38d7870
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 33 deletions.
44 changes: 20 additions & 24 deletions mlir/lib/Dialect/StandardOps/Ops.cpp
Expand Up @@ -1391,19 +1391,16 @@ void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");

auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return {};

// Don't fold if it requires division by zero.
if (rhs.getValue().isNullValue())
return {};

// Don't fold if it would overflow.
bool overflow;
auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow);
return overflow ? IntegerAttr() : IntegerAttr::get(lhs.getType(), result);
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
}
return a.sdiv_ov(b, overflowOrDiv0);
});
return overflowOrDiv0 ? Attribute() : result;
}

//===----------------------------------------------------------------------===//
Expand All @@ -1413,17 +1410,16 @@ OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");

auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return {};

// Don't fold if it requires division by zero.
auto rhsValue = rhs.getValue();
if (rhsValue.isNullValue())
return {};

return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhsValue));
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (div0 || !b) {
div0 = true;
return a;
}
return a.udiv(b);
});
return div0 ? Attribute() : result;
}

// ---------------------------------------------------------------------------
Expand Down
80 changes: 71 additions & 9 deletions mlir/test/Transforms/constant-fold.mlir
Expand Up @@ -213,40 +213,102 @@ func @mulf_splat_tensor() -> tensor<4xf32> {
// -----

// CHECK-LABEL: func @simple_divis
func @simple_divis() -> (i32, i32) {
func @simple_divis() -> (i32, i32, i32) {
// CHECK-DAG: [[C0:%.+]] = constant 0
%z = constant 0 : i32
// CHECK-DAG: [[C6:%.+]] = constant 6
%0 = constant 6 : i32
%1 = constant 2 : i32

// CHECK-NEXT:[[C3:%.+]] = constant 3 : i32
// CHECK-NEXT: [[C3:%.+]] = constant 3 : i32
%2 = divis %0, %1 : i32

%3 = constant -2 : i32

// CHECK-NEXT: [[CM3:%.+]] = constant -3 : i32
%4 = divis %0, %3 : i32

// CHECK-NEXT: return [[C3]], [[CM3]]
return %2, %4 : i32, i32
// CHECK-NEXT: [[XZ:%.+]] = divis [[C6]], [[C0]]
%5 = divis %0, %z : i32

// CHECK-NEXT: return [[C3]], [[CM3]], [[XZ]]
return %2, %4, %5 : i32, i32, i32
}

// -----

// CHECK-LABEL: func @divis_splat_tensor
func @divis_splat_tensor() -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK-DAG: [[C0:%.+]] = constant dense<0>
%z = constant dense<0> : tensor<4xi32>
// CHECK-DAG: [[C6:%.+]] = constant dense<6>
%0 = constant dense<6> : tensor<4xi32>
%1 = constant dense<2> : tensor<4xi32>

// CHECK-NEXT: [[C3:%.+]] = constant dense<3> : tensor<4xi32>
%2 = divis %0, %1 : tensor<4xi32>

%3 = constant dense<-2> : tensor<4xi32>

// CHECK-NEXT: [[CM3:%.+]] = constant dense<-3> : tensor<4xi32>
%4 = divis %0, %3 : tensor<4xi32>

// CHECK-NEXT: [[XZ:%.+]] = divis [[C6]], [[C0]]
%5 = divis %0, %z : tensor<4xi32>

// CHECK-NEXT: return [[C3]], [[CM3]], [[XZ]]
return %2, %4, %5 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}

// -----

// CHECK-LABEL: func @simple_diviu
func @simple_diviu() -> (i32, i32) {
func @simple_diviu() -> (i32, i32, i32) {
%z = constant 0 : i32
// CHECK-DAG: [[C6:%.+]] = constant 6
%0 = constant 6 : i32
%1 = constant 2 : i32

// CHECK-NEXT:[[C3:%.+]] = constant 3 : i32
// CHECK-DAG: [[C3:%.+]] = constant 3 : i32
%2 = diviu %0, %1 : i32

%3 = constant -2 : i32

// Unsigned division interprets -2 as 2^32-2, so the result is 0.
// CHECK-NEXT:[[C0:%.+]] = constant 0 : i32
// CHECK-DAG: [[C0:%.+]] = constant 0 : i32
%4 = diviu %0, %3 : i32

// CHECK-NEXT: return [[C3]], [[C0]]
return %2, %4 : i32, i32
// CHECK-NEXT: [[XZ:%.+]] = diviu [[C6]], [[C0]]
%5 = diviu %0, %z : i32

// CHECK-NEXT: return [[C3]], [[C0]], [[XZ]]
return %2, %4, %5 : i32, i32, i32
}


// -----

// CHECK-LABEL: func @diviu_splat_tensor
func @diviu_splat_tensor() -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%z = constant dense<0> : tensor<4xi32>
// CHECK-DAG: [[C6:%.+]] = constant dense<6>
%0 = constant dense<6> : tensor<4xi32>
%1 = constant dense<2> : tensor<4xi32>

// CHECK-DAG: [[C3:%.+]] = constant dense<3> : tensor<4xi32>
%2 = diviu %0, %1 : tensor<4xi32>

%3 = constant dense<-2> : tensor<4xi32>

// Unsigned division interprets -2 as 2^32-2, so the result is 0.
// CHECK-DAG: [[C0:%.+]] = constant dense<0> : tensor<4xi32>
%4 = diviu %0, %3 : tensor<4xi32>

// CHECK-NEXT: [[XZ:%.+]] = diviu [[C6]], [[C0]]
%5 = diviu %0, %z : tensor<4xi32>

// CHECK-NEXT: return [[C3]], [[C0]], [[XZ]]
return %2, %4, %5 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}

// -----
Expand Down

0 comments on commit 38d7870

Please sign in to comment.