Skip to content

Commit

Permalink
Add interpreter for RoundNearestEvenOp (#1423)
Browse files Browse the repository at this point in the history
Here are the constraints for the RoundNearestEvenOp:
```
(I1) operand is a tensor of floating-point type.
(C1) `operand` and `result` have the same type.
```
I1 and C1 are covered by the ODS, so no additional tests are added.

closes #1111
  • Loading branch information
ghpvnist committed Apr 19, 2023
1 parent 30fbc8c commit f2666c0
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 5 deletions.
4 changes: 3 additions & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4520,10 +4520,12 @@ specification.

```mlir
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_round_nearest_even.mlir)

### rsqrt

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ one of the following tracking labels.
| rng | yes | yes | yes | yes | no |
| rng_bit_generator | yes | revisit | infeasible | yes | no |
| round_nearest_afz | yes | yes | yes | yes | no |
| round_nearest_even | yes | yes | yes | yes | no |
| round_nearest_even | yes | yes | yes | yes | yes |
| rsqrt | yes | yes | yes | yes | yes |
| scatter | yes | revisit | yes | no | no |
| select | yes | yes | yes | yes | yes |
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ def StableHLO_RoundOp: StableHLO_UnaryElementwiseOp<"round_nearest_afz",
}

def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_even",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> {
[Pure, HLO_CompatibleOperandsAndResultType /*round_nearest_even_c1*/],
HLO_FpTensor /*round_nearest_even_i1*/> { /*round_nearest_even_c1*/
let summary = "RoundNearestEven operation";
let description = [{
Performs element-wise rounding towards the nearest integer, breaking ties
Expand All @@ -551,7 +552,7 @@ def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_ev

Example:
```mlir
%result = stablehlo.round_nearest_even %operand : tensor<5xf32>
%result = stablehlo.round_nearest_even %operand : tensor<5xf64>
```
}];
}
Expand Down
10 changes: 10 additions & 0 deletions stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,16 @@ Element rem(const Element &e1, const Element &e2) {
});
}

Element roundNearestEven(const Element &el) {
auto type = el.getType();
if (!isSupportedFloatType(type))
report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(type).c_str()));
auto val = el.getFloatValue();
val.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
return Element(type, val);
}

Element rsqrt(const Element &el) {
return mapWithUpcastToDouble(
el, [](double e) { return 1.0 / std::sqrt(e); },
Expand Down
4 changes: 4 additions & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ Element real(const Element &e);
/// Returns the remainder for two Element objects.
Element rem(const Element &e1, const Element &e2);

/// Returns the value rounded to nearest integer, breaking ties towards the
/// even, of Element object.
Element roundNearestEven(const Element &el);

/// Returns reverse square root of Element object.
Element rsqrt(const Element &e);

Expand Down
13 changes: 13 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ SmallVector<Tensor> eval(
Tensor runtimeResult =
evalReverseOp(runtimeOperand, dimensions, reverseOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto roundNearestEvenOp = dyn_cast<RoundNearestEvenOp>(op)) {
Tensor runtimeOperand = scope.find(roundNearestEvenOp.getOperand());
Tensor runtimeResult =
evalRoundNearestEvenOp(runtimeOperand, roundNearestEvenOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto rsqrtOp = dyn_cast<RsqrtOp>(op)) {
Tensor runtimeOperand = scope.find(rsqrtOp.getOperand());
Tensor runtimeResult = evalRsqrtOp(runtimeOperand, rsqrtOp.getType());
Expand Down Expand Up @@ -726,6 +731,14 @@ Tensor evalReverseOp(const Tensor &operand, Axes dimensions,
return result;
}

Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
++resultIt)
result.set(*resultIt, roundNearestEven(operand.get(*resultIt)));
return result;
}

Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
Expand Down
1 change: 1 addition & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Tensor evalRemOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Tensor evalReshapeOp(const Tensor &operand, ShapedType resultType);
Tensor evalReverseOp(const Tensor &operand, Axes dimensions,
ShapedType resultType);
Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType);
Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType);
Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue,
const Tensor &onFalse, ShapedType resultType);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
8 changes: 8 additions & 0 deletions stablehlo/tests/interpret_round_nearest_even.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @round_nearest_even_op_test_f64() {
%operand = stablehlo.constant dense<[-2.5, 0.4, 0.5, 0.6, 2.5]> : tensor<5xf64>
%result = stablehlo.round_nearest_even %operand : tensor<5xf64>
check.expect_almost_eq_const %result, dense<[-2.0, 0.0, 0.0, 1.0, 2.0]> : tensor<5xf64>
func.return
}

0 comments on commit f2666c0

Please sign in to comment.