diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 9fa3ec1fc4b21..e308d601b1c6b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -1291,6 +1291,38 @@ def ROCDL_CvtScaleF32PkFp4F32Op : }]; } +//===----------------------------------------------------------------------===// +// MED3 operations +//===----------------------------------------------------------------------===// + +def ROCDL_Med3F16Op : ROCDL_ConcreteNonMemIntrOp<"med3.f16", [Pure], 1>, + Arguments<(ins F16:$src0, + F16:$src1, + F16:$src2)> { + let results = (outs F16:$res); + let summary = "Median of three half-precision float values"; + let assemblyFormat = [{ + $src0 `,` $src1 `,` $src2 attr-dict `:` `(` type($src0) `,` type($src1) `,` type($src2) `)` `->` type($res) + }]; + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_fmed3, {$src0, $src1, $src2}, {moduleTranslation.convertType(op.getSrc0().getType())}); + }]; +} + +def ROCDL_Med3F32Op : ROCDL_ConcreteNonMemIntrOp<"med3.f32", [Pure], 1>, + Arguments<(ins F32:$src0, + F32:$src1, + F32:$src2)> { + let results = (outs F32:$res); + let summary = "Median of three single-precision float values"; + let assemblyFormat = [{ + $src0 `,` $src1 `,` $src2 attr-dict `:` `(` type($src0) `,` type($src1) `,` type($src2) `)` `->` type($res) + }]; + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_fmed3, {$src0, $src1, $src2}, {moduleTranslation.convertType(op.getSrc0().getType())}); + }]; +} + //===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index a464358250c38..d4871b67c8724 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1298,6 +1298,20 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 { llvm.return %ret : i32 } +llvm.func @test_med3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 { + // CHECK-LABEL: define half @test_med3_f16(half %0, half %1, half %2) + %0 = rocdl.med3.f16 %arg0, %arg1, %arg2 : (f16, f16, f16) -> f16 + llvm.return %0 : f16 + // CHECK: call half @llvm.amdgcn.fmed3.f16(half %0, half %1, half %2) +} + +llvm.func @test_med3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 { + // CHECK-LABEL: define float @test_med3_f32(float %0, float %1, float %2) + %0 = rocdl.med3.f32 %arg0, %arg1, %arg2 : (f32, f32, f32) -> f32 + llvm.return %0 : f32 + // CHECK: call float @llvm.amdgcn.fmed3.f32(float %0, float %1, float %2) +} + // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" } // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024" // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"