Skip to content

Commit

Permalink
[StableHLO] Add missing pointwise op tests (#13066)
Browse files Browse the repository at this point in the history
Add tests for `stablehlo.not` and `stablehlo.complex`. These were not
present in the equivalent mlir-hlo tests.

Issue: #12678
  • Loading branch information
kuhar committed Apr 13, 2023
1 parent d904673 commit e670ee5
Showing 1 changed file with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1445,4 +1445,45 @@ func.func @reduce_precision(%arg0: tensor<1x2x3x4xf32>)
-> tensor<1x2x3x4xf32> {
%0 = "stablehlo.reduce_precision"(%arg0) {exponent_bits=3:i32, mantissa_bits=3:i32} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
}
}

// -----

// CHECK-LABEL: func @integer_not
// CHECK-SAME: (%[[ARG:.+]]: tensor<2x2xi32>)
// CHECK-PRIMITIVE-LABEL: func @integer_not
// CHECK-PRIMITIVE-SAME: (%[[ARG:.+]]: tensor<2x2xi32>)
func.func @integer_not(%arg: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[CST_N1:.+]] = arith.constant -1 : i32
// CHECK: linalg.generic
// CHECK: (%[[IN:.+]]: i32, %{{.+}}: i32)
// CHECK: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32
// CHECK: linalg.yield %[[V_NOT]] : i32
// CHECK-PRIMITIVE: %[[CST_N1:.+]] = arith.constant -1 : i32
// CHECK-PRIMITIVE: linalg.map
// CHECK-PRIMITIVE: (%[[IN:.+]]: i32)
// CHECK-PRIMITIVE: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32
// CHECK-PRIMITIVE: linalg.yield %[[V_NOT]] : i32
%0 = "stablehlo.not"(%arg) : (tensor<2x2xi32>) -> tensor<2x2xi32>
func.return %0 : tensor<2x2xi32>
}

// -----

// CHECK-LABEL: func @float_complex
// CHECK-SAME: (%[[LHS:.+]]: tensor<2x2xf32>, %[[RHS:.+]]: tensor<2x2xf32>)
// CHECK-PRIMITIVE-LABEL: func @float_complex
// CHECK-PRIMITIVE-SAME: (%[[LHS:.+]]: tensor<2x2xf32>, %[[RHS:.+]]: tensor<2x2xf32>)
func.func @float_complex(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>> {
// CHECK: %[[INIT]] = tensor.empty() : tensor<2x2xcomplex<f32>>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK: (%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %{{.+}}: complex<f32>
// CHECK: %[[RES:.+]] = complex.create %[[IN0]], %[[IN1]] : complex<f32>
// CHECK: linalg.yield %[[RES]] : complex<f32>
// CHECK-PRIMITIVE: %[[INIT]] = tensor.empty() : tensor<2x2xcomplex<f32>>
// CHECK-PRIMITIVE: linalg.map { complex.create } ins(%[[LHS]], %[[RHS]] : tensor<2x2xf32>, tensor<2x2xf32>)
// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor<2x2xcomplex<f32>>)
%0 = "stablehlo.complex"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
func.return %0 : tensor<2x2xcomplex<f32>>
}

0 comments on commit e670ee5

Please sign in to comment.