diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 135d1e4007d49..ad5212fcfda45 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -230,11 +230,22 @@ class ROCDL_SpecialIdRegisterOp : // ROCDL vector types definitions //===----------------------------------------------------------------------===// +class ROCDL_NamedType { + string typeName = name; +} + class ROCDL_ConcreteVector : FixedVectorOfLengthAndType<[length], [elem]>, BuildableType< "::mlir::VectorType::get({" # length # "} ," - # elem.builderCall # ")">; + # elem.builderCall # ")">, + ROCDL_NamedType<"vector<" # length # "x" + # !tolower(!cast(elem)) # ">">; + +class ROCDL_Scalar : + Type, + BuildableType, + ROCDL_NamedType(elem))>; def ROCDL_V2I16Type : ROCDL_ConcreteVector; def ROCDL_V2F16Type : ROCDL_ConcreteVector; @@ -925,7 +936,7 @@ def ROCDL_IglpOpt : ROCDL_ConcreteNonMemIntrOp<"iglp.opt", [], 0, [0], ["variant //===---------------------------------------------------------------------===// // Xdlops intrinsics -class ROCDL_Mfma_IntrOp : +class ROCDL_Mfma_IntrOp : ROCDL_IntrOp, Arguments<(ins ABType:$a, @@ -945,19 +956,10 @@ class ROCDL_Mfma_IntrOp : Example: ```mlir - // MFMA with f32 inputs and 32-wide f32 accumulator. - %r0 = rocdl.mfma.f32.32x32x1f32 %a0, %b0, %c0, 0, 0, 0 : - (f32, f32, vector<32xf32>) -> vector<32xf32> - - // MFMA with i8 inputs and 32-wide i32 accumulator. - %r1 = rocdl.mfma.i32.32x32x4i8 %a1, %a1, %c1, 0, 0, 0 : - (i32, i32, vector<32xi32>) -> vector<32xi32> - - // MFMA with bf16 inputs and 32-wide f32 accumulator. - %r2 = rocdl.mfma.f32.32x32x2bf16 %a2, %a2, %c0, 0, 0, 0 : - (vector<2xi16>, vector<2xi16>, vector<32xf32>) -> vector<32xf32> - ``` - }]; + %r0 = }] # mnemonic # [{ %a0, %b0, %c0, 0, 0, 0 : (}] # ABType.typeName + # [{, }] # ABType.typeName # [{, }] # CDType.typeName # [{) -> }] + # CDType.typeName # [{ + ```}]; } class ROCDL_Mfma_Scale_IntrOp : @@ -1038,21 +1040,21 @@ class ROCDL_Smfmac_IntrOp } // Available on all CDNA. -def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32", /*Type AB=*/F32, /*Type CD=*/ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32", F32, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32", F32, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32", F32, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32", F32, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32", /*Type AB=*/ROCDL_Scalar, /*Type CD=*/ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32", ROCDL_Scalar, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; -def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8", I32, ROCDL_ConcreteVector>; -def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8", I32, ROCDL_ConcreteVector>; -def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8", I32, ROCDL_ConcreteVector>; -def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8", I32, ROCDL_ConcreteVector>; -def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8", I32, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8", ROCDL_Scalar, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; @@ -1066,21 +1068,21 @@ def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k", def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; // Note: in gfx94x, unlike in gfx90a, the f64 xdlops use the "blgp" argument as // a NEG bitfield. See IntrinsicsAMDGPU.td for more info. -def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64", F64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64", F64, F64>; +def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64", ROCDL_Scalar, ROCDL_Scalar>; // New in gfx94x. -def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8", I64, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8", ROCDL_Scalar, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8", I64, ROCDL_ConcreteVector>; -def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8", I64, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8", ROCDL_Scalar, ROCDL_ConcreteVector>; +def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8", ROCDL_Scalar, ROCDL_ConcreteVector>; // New in gfx950. def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16", ROCDL_ConcreteVector, ROCDL_ConcreteVector>; def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8", ROCDL_ConcreteVector, ROCDL_ConcreteVector>;