Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal viable support for tf32 on the block pointer path #1172

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/tutorials/09-experimental-block-pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def matmul(a, b, res_dtype):
# Still we can test our matrix multiplication with block pointers against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32)]:
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32),
(torch.float32, torch.float32)]:
if dtype.is_floating_point:
a = torch.randn((512, 512), device='xpu', dtype=dtype)
b = torch.randn((512, 512), device='xpu', dtype=dtype)
Expand All @@ -243,8 +244,9 @@ def matmul(a, b, res_dtype):

# Note: the torch.matmul and Triton implementations uses different
# algorithms so we need to adjust tolerance.
atol = 4e-2 if dtype == torch.float32 else 1e-4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, we can double confirm with other teams(kernel library, igc, etc) that used tf32 gemm previously.
to make sure this is as expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would be good to know. I'd need to raise the bounds again in #1211.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the rational to increase atol vs rtol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing specific, I just played around with the parameters. The maximum relative error is 146 (triton: 0.0146, torch: 0.0001). I think the underlying problem might be that the reference computation is not done with TF32 precision; still investigating how to enable that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increasing rtol is better because the torch.allclose comparison is:

∣input−other∣≤atol+rtol×∣other∣

So increasing atol affects comparisons regardless of the value.
If we can force torch to use TF32 precision that would be ideal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found out how to set the TF32 mode and hence was able to drop the changes to the tolerances.

rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol):
if torch.allclose(triton_output, torch_output, atol=atol, rtol=rtol):
print("✅ Triton and Torch match")
else:
exit("❌ Triton and Torch differ")
34 changes: 33 additions & 1 deletion test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s
// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --split-input-file | FileCheck %s

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, vector<8xi32>)
Expand Down Expand Up @@ -108,3 +108,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
tt.return
}
}

// -----

// COM: Checks the correct lowering of the A operand load for TF32, i.e. using 4xi32 and vnni=false.

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @matmul_kernel_with_block_pointers_tf32(
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [512 : i32, 1 : i32, 1 : i32]} {
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%c0_i64 = arith.constant 0 : i64
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x8xf32>>
%1 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf32>>
// CHECK: [[ELEM_SIZE:%.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[NUM_BLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: {{%.*}} = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v4i32({{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, [[ELEM_SIZE]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[NUM_BLOCKS]], [[TRANSPOSE]], [[VNNI]], {{%.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32>
%2 = tt.load %0 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[ELEM_SIZE:%.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[NUM_BLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VAL_60:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i32({{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, [[ELEM_SIZE]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[NUM_BLOCKS]], [[TRANSPOSE]], [[VNNI]], {{%.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32>
%3 = tt.load %1 {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
tt.return
}
}
60 changes: 58 additions & 2 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// COM: Test transformation for int8 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers
// CHECK-LABEL: @matmul_kernel_with_block_pointers_int8
#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32) {
tt.func public @matmul_kernel_with_block_pointers_int8(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : i32
// CHECK-DAG: [[C32:%.*]] = arith.constant 32 : i32
%cst = arith.constant dense<0> : tensor<8x32xi32, #warp>
Expand Down Expand Up @@ -223,3 +223,59 @@ tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i8> {tt.divisib
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xi32, #warp>>
tt.return
}

// -----

// COM: Test transformation for tf32 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers_tf32
#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: i32) {
// CHECK: [[TZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<8x16xf32>
%cst = arith.constant dense<0.000000e+00> : tensor<8x32xf32, #warp>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
// CHECK-COUNT-4: {{.*}} = tt.make_tensor_ptr %arg0
// CHECK-COUNT-8: {{.*}} = tt.make_tensor_ptr %arg1
%tptr_a = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%tptr_b = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
// CHECK: [[LOOP_RES:%.*]]:14 = scf.for {{.*}} = {{.*}} to {{.*}} step {{.*}} iter_args([[ITER_1:%.*]] = [[TZERO]], [[ITER_2:%.*]] = [[TZERO]], {{.*}})
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
%35:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %tptr_a, %arg12 = %tptr_b) -> (tensor<8x32xf32, #warp>, !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>, !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>) : i32 {
// CHECK: [[LD_A1:%.*]] = tt.load %arg[[#first_ptr:]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A2:%.*]] = tt.load %arg[[#first_ptr+1]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A3:%.*]] = tt.load %arg[[#first_ptr+2]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A4:%.*]] = tt.load %arg[[#first_ptr+3]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_B1:%.*]] = tt.load %arg[[#first_ptr+4]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B2:%.*]] = tt.load %arg[[#first_ptr+5]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B3:%.*]] = tt.load %arg[[#first_ptr+6]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B4:%.*]] = tt.load %arg[[#first_ptr+7]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B5:%.*]] = tt.load %arg[[#first_ptr+8]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B6:%.*]] = tt.load %arg[[#first_ptr+9]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B7:%.*]] = tt.load %arg[[#first_ptr+10]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B8:%.*]] = tt.load %arg[[#first_ptr+11]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
%46 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%47 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
// CHECK: [[DOT_1:%.*]] = tt.dot [[LD_A1]], [[LD_B1]], [[ITER_1]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_2:%.*]] = tt.dot [[LD_A2]], [[LD_B2]], [[DOT_1]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_3:%.*]] = tt.dot [[LD_A3]], [[LD_B3]], [[DOT_2]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_4:%.*]] = tt.dot [[LD_A4]], [[LD_B4]], [[DOT_3]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_5:%.*]] = tt.dot [[LD_A1]], [[LD_B5]], [[ITER_2]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_6:%.*]] = tt.dot [[LD_A2]], [[LD_B6]], [[DOT_5]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_7:%.*]] = tt.dot [[LD_A3]], [[LD_B7]], [[DOT_6]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_8:%.*]] = tt.dot [[LD_A4]], [[LD_B8]], [[DOT_7]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
%48 = tt.dot %46, %47, %arg10, inputPrecision = tf32 : tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<8x32xf32, #warp>
// CHECK-COUNT-12: {{.*}} = tt.advance
%49 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%50 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
scf.yield %48, %49, %50 : tensor<8x32xf32, #warp>, !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>, !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
} {triton_gpu.workload = 3 : i32}
// CHECK: [[TPTR_C1:%.*]] = tt.make_tensor_ptr %arg2,
// CHECK: [[TPTR_C2:%.*]] = tt.make_tensor_ptr %arg2,
%tptr_c = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf32, #warp>>
// CHECK: tt.store [[TPTR_C1:%.*]], [[LOOP_RES]]#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: tt.store [[TPTR_C2:%.*]], [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}
40 changes: 24 additions & 16 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ using namespace mlir::triton::gpu::intel;
namespace {

VectorType getVectorType(RankedTensorType tensorType, Type elemType) {
unsigned ratio =
elemType.getIntOrFloatBitWidth() / tensorType.getElementTypeBitWidth();
unsigned num = (tensorType.getNumElements() / 16) / ratio;
// Determine a vector type of the given `elemType` that covers 1/16 of
// `tensorType`, i.e. the amout of data a single subgroup lane will work on.
size_t tensorSize =
tensorType.getNumElements() * tensorType.getElementTypeBitWidth();
size_t num = (tensorSize / 16) / elemType.getIntOrFloatBitWidth();
return vec_ty(elemType, num);
};

Expand Down Expand Up @@ -120,11 +122,13 @@ class LoadStorePrefetchOpConversion
assert(tensorType.getRank() <= 2 &&
"only support 1d/2d load/store/prefetch for now");

unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth();
Type elemType = tensorType.getElementType();
unsigned dataSize = elemType.getIntOrFloatBitWidth();
unsigned blockHeight = tensorType.getShape()[0];
unsigned blockWidth = tensorType.getShape()[1];
assert((blockWidth == 16 || blockWidth == 32 || blockWidth == 64) &&
"only support 16/32/64 block");
assert((blockWidth == 8 || blockWidth == 16 || blockWidth == 32 ||
blockWidth == 64) &&
"only support 8/16/32/64 block");
auto idxAttr = op->template getAttrOfType<mlir::IntegerAttr>("DotIdx");
unsigned vBlks = 1;
if (dataSize == 16) {
Expand Down Expand Up @@ -175,10 +179,11 @@ class LoadStorePrefetchOpConversion
unsigned idx = idxAttr.getInt();
Type resType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
bool isDword = idx == 1 || elemType == f32_ty;
Type vectorType =
getVectorType(cast<RankedTensorType>(op.getResult().getType()),
idx == 0 ? i16_ty : i32_ty);
bool vnni = (idx == 1) && dataSize <= 32;
isDword ? i32_ty : i16_ty);
bool vnni = (idx == 1) && dataSize < 32;
auto load = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
loc, vectorType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY,
dataSize, blockWidth, blockHeight, vBlks, false /*transpose*/, vnni);
Expand Down Expand Up @@ -219,12 +224,14 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
LogicalResult
matchAndRewrite(DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType {
auto encodePrecision =
[&](Type type, InputPrecisionAttr attr) -> TritonGEN::PrecisionType {
FMarno marked this conversation as resolved.
Show resolved Hide resolved
if (type == bf16_ty)
return TritonGEN::PrecisionType::BF16;
else if (type == f16_ty)
return TritonGEN::PrecisionType::FP16;
else if (type == rewriter.getTF32Type())
else if (type == f32_ty && attr &&
attr.getValue() == InputPrecision::TF32)
return TritonGEN::PrecisionType::TF32;
else if (type.isInteger(8)) {
if (type.isUnsignedInteger())
Expand All @@ -236,18 +243,19 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
return TritonGEN::PrecisionType::UNUSED;
};

TritonGEN::PrecisionType precATy =
encodePrecision(op.getA().getType().getElementType());
TritonGEN::PrecisionType precBTy =
encodePrecision(op.getB().getType().getElementType());
TritonGEN::PrecisionType precATy = encodePrecision(
op.getA().getType().getElementType(), op.getInputPrecisionAttr());
TritonGEN::PrecisionType precBTy = encodePrecision(
op.getB().getType().getElementType(), op.getInputPrecisionAttr());
auto precA =
TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precATy);
auto precB =
TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precBTy);

Location loc = op.getLoc();
Type typeA =
getVectorType(cast<RankedTensorType>(op.getA().getType()), i16_ty);
Type typeA = getVectorType(
cast<RankedTensorType>(op.getA().getType()),
precATy == TritonGEN::PrecisionType::TF32 ? i32_ty : i16_ty);
Value castA = bitcast(adaptor.getA(), typeA);
VectorType typeB =
getVectorType(cast<RankedTensorType>(op.getB().getType()), i32_ty);
Expand Down