From 4f5d0ff81b6a149e8b90b9050079a71d81793896 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 7 Oct 2025 08:53:40 +0000 Subject: [PATCH] [ROCDL] Added rocdl.cvt.scale.sr.pk8 ops --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 18 +++++++++++ mlir/test/Dialect/LLVMIR/rocdl.mlir | 33 ++++++++++++++++++++ mlir/test/Target/LLVMIR/rocdl.mlir | 33 ++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 29001e26eaaaf..db1b7e3af62fd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -1029,6 +1029,24 @@ foreach smallT = [ attr-dict $src `,` $scale `:` type($res) }]; } + + + def ROCDL_CvtScaleF32SrPk8 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk8." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name # " with stochastic rounding"; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so and apply stochastic rounding. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $seed `,` $scale `:` type($res) + }]; + } } // foreach largeT } // foreach smallTOp diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 6134695e9ced6..a88b59aeb61b2 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, // ----- +// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8 +llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %seed: i32, + %scale: f32) { + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32 + %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32 + %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32 + %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16 + %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16 + %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16 + %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16 + %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16 + %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16 + %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32 + + llvm.return +} + +// ----- + // CHECK-LABEL: rocdl.cvt.scale.pk16 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 00ee6b795c43a..1c0c2eba002aa 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16> llvm.return } +// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8 +// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %seed: i32, + %scale: f32) { + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32 + + llvm.return +} + + // CHECK-LABEL: @rocdl.cvt.scale.pk16 // CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]]) llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {