diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 9fa3ec1fc4b21..1f3974846a5ef 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -835,10 +835,17 @@ class ROCDL_ConcreteVector : def ROCDL_V2I16Type : ROCDL_ConcreteVector; def ROCDL_V2F16Type : ROCDL_ConcreteVector; +def ROCDL_V2I32Type : ROCDL_ConcreteVector; def ROCDL_V2BF16Type : ROCDL_ConcreteVector; def ROCDL_V2F32Type : ROCDL_ConcreteVector; +def ROCDL_V3I32Type : ROCDL_ConcreteVector; def ROCDL_V6I32Type : ROCDL_ConcreteVector; def ROCDL_V8I32Type : ROCDL_ConcreteVector; +def ROCDL_V8BF16Type : ROCDL_ConcreteVector; +def ROCDL_V8F16Type : ROCDL_ConcreteVector; +def ROCDL_V8F32Type : ROCDL_ConcreteVector; +def ROCDL_V16BF16Type : ROCDL_ConcreteVector; +def ROCDL_V16F16Type : ROCDL_ConcreteVector; def ROCDL_V16F32Type : ROCDL_ConcreteVector; def ROCDL_V32F16Type : ROCDL_ConcreteVector; def ROCDL_V32BF16Type : ROCDL_ConcreteVector; @@ -975,6 +982,68 @@ class ScaleArgInfo { string nameForOp = typeName; } +//===---------------------------------------------------------------------===// +// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics +//===---------------------------------------------------------------------===// + +foreach smallT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo +] in { + foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo, + ] in { + def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name, + [Pure], 1, [2], ["scaleSel"]>, + Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> { + + let summary = "Scales 8 " # smallT.name # " and converts them to 8 " # largeT.name # "."; + let description = [{ + Available on gfx1250+. + }]; + let results = (outs largeT.type:$res); + let assemblyFormat = [{ + attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res) + }]; + } + } // foreach largeT +} // foreach smallTOp + +//===---------------------------------------------------------------------===// +// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics +//===---------------------------------------------------------------------===// + +foreach smallT = [ + ScaleArgInfo, + ScaleArgInfo +] in { + foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo, + ] in { + def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name, + [Pure], 1, [2], ["scaleSel"]>, + Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> { + + let summary = "Scales 16 " # smallT.name # " and converts them to 16 " # largeT.name # "."; + let description = [{ + Available on gfx1250+. + }]; + let results = (outs largeT.type:$res); + let assemblyFormat = [{ + attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res) + }]; + + } + } // foreach largeT +} // foreach smallTOp + //===---------------------------------------------------------------------===// // Scaled 32x6-bit float float conversion intrinsics //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 782ef4e154440..959bb35302b20 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1025,6 +1025,57 @@ llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> { // ----- +// CHECK-LABEL: rocdl.cvt.scale.pk8 +llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) { + + // CHECK: rocdl.cvt.scale.pk8.f16.fp4 + %0 = rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 + %1 = rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16> + // CHECK: rocdl.cvt.scale.pk8.f32.fp4 + %2 = rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32> + + // CHECK: rocdl.cvt.scale.pk8.f16.fp8 + %3 = rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 + %4 = rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16> + // CHECK: rocdl.cvt.scale.pk8.f32.fp8 + %5 = rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32> + + // CHECK: rocdl.cvt.scale.pk8.f16.bf8 + %6 = rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16> + // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 + %7 = rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16> + // CHECK: rocdl.cvt.scale.pk8.f32.bf8 + %8 = rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32> + + llvm.return +} + +// ----- + +// CHECK-LABEL: rocdl.cvt.scale.pk16 +llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { + + // CHECK: rocdl.cvt.scale.pk16.f16.fp6 + %0 = rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16> + // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 + %1 = rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16> + // CHECK: rocdl.cvt.scale.pk16.f32.fp6 + %2 = rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32> + + // CHECK: rocdl.cvt.scale.pk16.f16.bf6 + %3 = rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16> + // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 + %4 = rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16> + // CHECK: rocdl.cvt.scale.pk16.f32.bf6 + %5 = rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32> + + llvm.return +} + +// ----- + // expected-error@below {{attribute attached to unexpected op}} func.func private @expected_llvm_func() attributes { rocdl.kernel } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index a464358250c38..bebd1b4317b2f 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1298,6 +1298,54 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 { llvm.return %ret : i32 } +// CHECK-LABEL: rocdl.cvt.scale.pk8 +// CHECK-SAME:(i32 %[[I32:.+]], <2 x i32> %[[V2I32:.+]], i32 %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) { + + // CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0) + %0 = rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16> + // CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0) + %1 = rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16> + // CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0) + %2 = rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32> + + // CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0) + %3 = rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16> + // CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0) + %4 = rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16> + // CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0) + %5 = rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32> + + // CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0) + %6 = rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16> + // CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0) + %7 = rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16> + // CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0) + %8 = rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32> + + llvm.return +} + +// CHECK-LABEL: @rocdl.cvt.scale.pk16 +// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { + + // CHECK: call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0) + %0 = rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16> + // CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0) + %1 = rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16> + // CHECK: call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0) + %2 = rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32> + // CHECK: call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0) + %3 = rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16> + // CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0) + %4 = rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16> + // CHECK: call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0) + %5 = rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32> + + llvm.return +} + // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" } // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024" // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"