diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index 00ab3ed76e278..5f7940b9da28b 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -43,6 +43,11 @@ func.func @rsqrt(%arg: complex) -> complex { func.return %sqrt : complex } +func.func @conj(%arg: complex) -> complex { + %conj = complex.conj %arg : complex + func.return %conj : complex +} + // %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] func.func @test_binary(%input: tensor>, %func: (complex, complex) -> complex) { @@ -216,5 +221,36 @@ func.func @entry() { call @test_unary(%rsqrt_test_cast, %rsqrt_func) : (tensor>, (complex) -> complex) -> () + // complex.conj test + %conj_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: -1.0 + // CHECK-NEXT: 1.0 + (-1.0, 1.0), + // CHECK-NEXT: -1.0 + // CHECK-NEXT: -1.0 + (0.0, 0.0), + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + (0.0, 1.0), + // CHECK-NEXT: 0 + // CHECK-NEXT: -1.0 + (1.0, -1.0), + // CHECK-NEXT: 1.0 + // CHECK-NEXT: -1.0 + (1.0, 0.0), + // CHECK-NEXT: 1.0 + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 1.0 + // CHECK-NEXT: -1.0 + ]> : tensor<7xcomplex> + %conj_test_cast = tensor.cast %conj_test + : tensor<7xcomplex> to tensor> + + %conj_func = func.constant @conj : (complex) -> complex + call @test_unary(%conj_test_cast, %conj_func) + : (tensor>, (complex) -> complex) -> () + func.return }