diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 4f81a3874650d..3f3ec7b59eb3d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -299,6 +299,7 @@ class UnaryFn: square = UnaryFnType("square") tanh = UnaryFnType("tanh") erf = UnaryFnType("erf") + conj = UnaryFnType("conj") class BinaryFnType: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 254458a978828..10f1083b11758 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: def _unary_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.ExpOp(x).result + if _is_complex_type(x.type): + return complex.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") def _unary_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.LogOp(x).result + if _is_complex_type(x.type): + return complex.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") def _unary_abs(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.AbsFOp(x).result + if _is_complex_type(x.type): + return complex.AbsOp(x).result raise NotImplementedError("Unsupported 'abs' operand: {x}") def _unary_ceil(self, x: Value) -> Value: @@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value: return complex.NegOp(x).result raise NotImplementedError("Unsupported 'negf' operand: {x}") + def _unary_conj(self, x: Value) -> Value: + if _is_complex_type(x.type): + return complex.ConjOp(x).result + raise NotImplementedError("Unsupported 'conj' operand: {x}") + def _binary_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.AddFOp(lhs, rhs).result diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py index f8e034fb0e48b..d23c48daebad7 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)): @linalg_structured_op -def elemwise_unary_poly( +def elemwise_unary_poly_cast( I=TensorDef(T), O=TensorDef(U, output=True), fun=UnaryFnAttrDef(default=UnaryFn.exp), @@ -39,6 +39,14 @@ def elemwise_unary_poly( O[None] = fun(cast(U, I[None])) +@linalg_structured_op +def elemwise_unary_poly( + I=TensorDef(T), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), +): + O[None] = fun(I[None]) + @linalg_structured_op(op_name="custom_op_name") def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)): O[D.n] = I[D.n] @@ -84,6 +92,17 @@ def test_i32_index(init_result): def test_f32_elemwise_exp(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) + # CHECK-LABEL: @test_c32_elemwise_exp + # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) + # CHECK-NEXT: %[[EXP:.+]] = complex.exp %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[EXP]] : complex + # CHECK-NEXT: -> tensor<4x16xcomplex> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) + ) + def test_c32_elemwise_exp(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) + # CHECK-LABEL: @test_f32_elemwise_log # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32 @@ -95,10 +114,21 @@ def test_f32_elemwise_exp(input, init_result): def test_f32_elemwise_log(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) + # CHECK-LABEL: @test_c32_elemwise_log + # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) + # CHECK-NEXT: %[[LOG:.+]] = complex.log %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[LOG]] : complex + # CHECK-NEXT: -> tensor<4x16xcomplex> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) + ) + def test_c32_elemwise_log(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) + # CHECK-LABEL: @test_f32_elemwise_abs # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: %[[ABS:.+]] = math.absf %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[ABS]] : f32 # CHECK-NEXT: -> tensor<4x16xf32> @func.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) @@ -106,10 +136,21 @@ def test_f32_elemwise_log(input, init_result): def test_f32_elemwise_abs(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) + # CHECK-LABEL: @test_c32_elemwise_abs + # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[ABS:.+]] = complex.abs %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[ABS]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32) + ) + def test_c32_elemwise_abs(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) + # CHECK-LABEL: @test_f32_elemwise_ceil # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: %[[CEIL:.+]] = math.ceil %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[CEIL]] : f32 # CHECK-NEXT: -> tensor<4x16xf32> @func.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) @@ -119,8 +160,8 @@ def test_f32_elemwise_ceil(input, init_result): # CHECK-LABEL: @test_f32_elemwise_floor # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: %[[FLOOR:.+]] = math.floor %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[FLOOR]] : f32 # CHECK-NEXT: -> tensor<4x16xf32> @func.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) @@ -130,8 +171,8 @@ def test_f32_elemwise_floor(input, init_result): # CHECK-LABEL: @test_f32_elemwise_neg # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: %[[NEG:.+]] = arith.negf %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[NEG]] : f32 # CHECK-NEXT: -> tensor<4x16xf32> @func.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) @@ -141,8 +182,8 @@ def test_f32_elemwise_neg(input, init_result): # CHECK-LABEL: @test_c32_elemwise_neg # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) - # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex - # CHECK-NEXT: linalg.yield %[[EXP]] : complex + # CHECK-NEXT: %[[NEG:.+]] = complex.neg %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[NEG]] : complex # CHECK-NEXT: -> tensor<4x16xcomplex> @func.FuncOp.from_py_func( RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) @@ -150,6 +191,19 @@ def test_f32_elemwise_neg(input, init_result): def test_c32_elemwise_neg(input, init_result): return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) + # CHECK-LABEL: @test_c32_elemwise_conj + # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) + # CHECK-NEXT: %[[CONJ:.+]] = complex.conj %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[CONJ]] : complex + # CHECK-NEXT: -> tensor<4x16xcomplex> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) + ) + def test_c32_elemwise_conj(input, init_result): + return elemwise_unary_poly( + input, outs=[init_result], fun=UnaryFn.conj, cast=None + ) + # Just check that we don't assert out on name mismatch. # CHECK-LABEL: @test_non_default_op_name @func.FuncOp.from_py_func(