-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[mlir][rocdl] add gfx950 smfmac instructions to rocdl dialect #171737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Eric Feng <Eric.Feng@amd.com>
|
@llvm/pr-subscribers-mlir-amdgpu @llvm/pr-subscribers-mlir Author: Eric Feng (efric) ChangesFull diff: https://github.com/llvm/llvm-project/pull/171737.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..fe8a854cd1321 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+// New in gfx950.
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..745fea8e38955 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -288,7 +288,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
- %arg9 : vector<16xi32>) -> vector<4 x f32> {
+ %arg9 : vector<16xi32>,
+ %arg10 : vector<8 x f16>,
+ %arg11 : vector<16 x f16>,
+ %arg12 : vector<8 x bf16>,
+ %arg13 : vector<16 x bf16>,
+ %arg14 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -362,6 +367,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+ i32, i32, i32) -> vector<4xi32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+ i32, i32, i32) -> vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
llvm.return %r0 : vector<4 x f32>
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..868597fba92a6 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -528,7 +528,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
- %arg9 : vector<16xi32>) -> vector<4 x f32> {
+ %arg9 : vector<16xi32>,
+ %arg10 : vector<8 x f16>,
+ %arg11 : vector<16 x f16>,
+ %arg12 : vector<8 x bf16>,
+ %arg13 : vector<16 x bf16>,
+ %arg14 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -598,12 +603,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
-
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+ i32, i32, i32) -> vector<4xi32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+ i32, i32, i32) -> vector<16xi32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
llvm.return %r0 : vector<4 x f32>
}
|
|
@llvm/pr-subscribers-mlir-llvm Author: Eric Feng (efric) ChangesFull diff: https://github.com/llvm/llvm-project/pull/171737.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..fe8a854cd1321 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+// New in gfx950.
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..745fea8e38955 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -288,7 +288,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
- %arg9 : vector<16xi32>) -> vector<4 x f32> {
+ %arg9 : vector<16xi32>,
+ %arg10 : vector<8 x f16>,
+ %arg11 : vector<16 x f16>,
+ %arg12 : vector<8 x bf16>,
+ %arg13 : vector<16 x bf16>,
+ %arg14 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -362,6 +367,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
+ // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+ i32, i32, i32) -> vector<4xi32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+ i32, i32, i32) -> vector<16xi32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
llvm.return %r0 : vector<4 x f32>
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..868597fba92a6 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -528,7 +528,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
%arg6 : vector<8 x i16>,
%arg7 : vector<2xi32>,
%arg8 : vector<4xi32>,
- %arg9 : vector<16xi32>) -> vector<4 x f32> {
+ %arg9 : vector<16xi32>,
+ %arg10 : vector<8 x f16>,
+ %arg11 : vector<16 x f16>,
+ %arg12 : vector<8 x bf16>,
+ %arg13 : vector<16 x bf16>,
+ %arg14 : vector<8 x i32>) -> vector<4 x f32> {
%csti32 = llvm.mlir.constant(42 : i32) : i32
// CHECK-LABEL: rocdl.smfmac
@@ -598,12 +603,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
-
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
i32, i32, i32) -> vector<16xf32>
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+ (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+ %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+ i32, i32, i32) -> vector<4xi32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+ i32, i32, i32) -> vector<4xf32>
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+ %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+ i32, i32, i32) -> vector<16xi32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+ %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+ (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+ i32, i32, i32) -> vector<16xf32>
+
llvm.return %r0 : vector<4 x f32>
}
|
Signed-off-by: Eric Feng <Eric.Feng@amd.com>
|
Thanks for reviewing. Can you help me merge this as I don't have the permissions. |
|
Let's wait for @krzysz00 and let him merge if it look OK |
|
I'll check tomorrow if the signature isn't different for the sparse ones |
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so it looks like the MfmaIntrOp class is doing that thing WMMA was doing until recently where we're just declaring varargs, which means, for example, that immargs aren't being represented as attributes.
Fixing that should happen but is off-topic for this PR. Therefore, approved,.
…71737) Signed-off-by: Eric Feng <Eric.Feng@amd.com>
No description provided.