diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 8b687a7f29bef..29001e26eaaaf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -985,7 +985,6 @@ class ScaleArgInfo { //===---------------------------------------------------------------------===// // Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics //===---------------------------------------------------------------------===// - foreach smallT = [ ScaleArgInfo, ScaleArgInfo, @@ -996,6 +995,8 @@ foreach smallT = [ ScaleArgInfo, ScaleArgInfo, ] in { + + // Up-scaling def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op : ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name, [Pure], 1, [2], ["scaleSel"]>, @@ -1010,13 +1011,30 @@ foreach smallT = [ attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res) }]; } + + // Down-scaling + def ROCDL_CvtScaleF32Pk8 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk8." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name ; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `:` type($res) + }]; + } } // foreach largeT } // foreach smallTOp //===---------------------------------------------------------------------===// // Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics //===---------------------------------------------------------------------===// - foreach smallT = [ ScaleArgInfo, ScaleArgInfo diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 0bad151570029..6134695e9ced6 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1068,6 +1068,38 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) { // ----- +// CHECK-LABEL: rocdl.cvt.scalef32.pk8 +llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %scale: f32) { + + // CHECK: rocdl.cvt.scalef32.pk8.fp8.f32 + %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.bf8.f32 + %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.fp4.f32 + %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.pk8.fp8.f16 + %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.bf8.f16 + %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.fp4.f16 + %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.pk8.fp8.bf16 + %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.bf8.bf16 + %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.fp4.bf16 + %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32 + + llvm.return +} + +// ----- + // CHECK-LABEL: rocdl.cvt.scale.pk16 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index e043a8c533d05..00ee6b795c43a 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1340,6 +1340,34 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) { llvm.return } +// CHECK-LABEL: rocdl.cvt.scalef32.pk8 +// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>, %v8xbf16: vector<8xbf16>, %scale: f32) { + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f32(<8 x float> %[[V8F32]], float %[[SCALE]]) + %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f32(<8 x float> %[[V8F32]], float %[[SCALE]]) + %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f32(<8 x float> %[[V8F32]], float %[[SCALE]]) + %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f16(<8 x half> %[[V8F16]], float %[[SCALE]]) + %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f16(<8 x half> %[[V8F16]], float %[[SCALE]]) + %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f16(<8 x half> %[[V8F16]], float %[[SCALE]]) + %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]]) + %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]]) + %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]]) + %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32 + + llvm.return +} + // CHECK-LABEL: @rocdl.cvt.scale.pk16 // CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]]) llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {