diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index a384273ba30eb..ad66d19cbf3e3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -1885,6 +1885,33 @@ def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res", }]; } +//===----------------------------------------------------------------------===// +// Math operations +//===----------------------------------------------------------------------===// + +class ROCDL_Math_IntrOp traits = [Pure]> : + ROCDL_IntrOp, + Arguments<(ins LLVM_AnyFloat:$arg)> { + let results = (outs LLVM_AnyFloat:$res); + let description = [{ + Note: In the general case, prefer the conventional `arith`, `math`, or `llvm` ops over this. + Use this ROCDL-specific operation only when you fully understand its implication and + when it is strictly necessary. This op is usually chosen when a small loss in precision is + acceptable in exchange for higher execution speed. + }]; + let assemblyFormat = + "$arg qualified(type($arg)) attr-dict `->` qualified(type($res))"; +} + +def ROCDLTanh : ROCDL_Math_IntrOp<"tanh">; +def ROCDLSin : ROCDL_Math_IntrOp<"sin">; +def ROCDLCos : ROCDL_Math_IntrOp<"cos">; +def ROCDLRcp : ROCDL_Math_IntrOp<"rcp">; +def ROCDLExp : ROCDL_Math_IntrOp<"exp">; +def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">; +def ROCDLLog : ROCDL_Math_IntrOp<"log">; +def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">; + //===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 27bf4163b9b7e..e03ba80d1807b 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -49,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4 llvm.return %0 : vector<4xf16> } +func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) { + // CHECK-LABEL: rocdl.math.ops + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16 + %tanh0 = rocdl.tanh %a f32 -> f32 + %tanh1 = rocdl.tanh %b f16 -> f16 + %tanh2 = rocdl.tanh %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16 + %sin0 = rocdl.sin %a f32 -> f32 + %sin1 = rocdl.sin %b f16 -> f16 + %sin2 = rocdl.sin %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16 + %cos0 = rocdl.cos %a f32 -> f32 + %cos1 = rocdl.cos %b f16 -> f16 + %cos2 = rocdl.cos %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16 + %rcp0 = rocdl.rcp %a f32 -> f32 + %rcp1 = rocdl.rcp %b f16 -> f16 + %rcp2 = rocdl.rcp %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16 + %exp2_0 = rocdl.exp2 %a f32 -> f32 + %exp2_1 = rocdl.exp2 %b f16 -> f16 + %exp2_2 = rocdl.exp2 %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16 + %log0 = rocdl.log %a f32 -> f32 + %log1 = rocdl.log %b f16 -> f16 + %log2 = rocdl.log %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16 + %sqrt0 = rocdl.sqrt %a f32 -> f32 + %sqrt1 = rocdl.sqrt %b f16 -> f16 + %sqrt2 = rocdl.sqrt %c bf16 -> bf16 + llvm.return +} + func.func @rocdl.barrier() { // CHECK: rocdl.barrier rocdl.barrier diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 86b69812787b8..c6764cfb7db0d 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -61,6 +61,59 @@ llvm.func @kernel_func_workgroups() llvm.return } +llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) { + // CHECK-LABEL: kernel_math_ops + // CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}}) + %tanh0 = rocdl.tanh %a f32 -> f32 + %tanh1 = rocdl.tanh %b f16 -> f16 + %tanh2 = rocdl.tanh %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}}) + %sin0 = rocdl.sin %a f32 -> f32 + %sin1 = rocdl.sin %b f16 -> f16 + %sin2 = rocdl.sin %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}}) + %cos0 = rocdl.cos %a f32 -> f32 + %cos1 = rocdl.cos %b f16 -> f16 + %cos2 = rocdl.cos %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}}) + %rcp0 = rocdl.rcp %a f32 -> f32 + %rcp1 = rocdl.rcp %b f16 -> f16 + %rcp2 = rocdl.rcp %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}}) + %exp2_0 = rocdl.exp2 %a f32 -> f32 + %exp2_1 = rocdl.exp2 %b f16 -> f16 + %exp2_2 = rocdl.exp2 %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}}) + %log0 = rocdl.log %a f32 -> f32 + %log1 = rocdl.log %b f16 -> f16 + %log2 = rocdl.log %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}}) + %sqrt0 = rocdl.sqrt %a f32 -> f32 + %sqrt1 = rocdl.sqrt %b f16 -> f16 + %sqrt2 = rocdl.sqrt %c bf16 -> bf16 + llvm.return +} + llvm.func @known_block_sizes() attributes {rocdl.kernel, rocdl.flat_work_group_size = "128,128",