Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal viable support for tf32 on the block pointer path #1172

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/tutorials/09-experimental-block-pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def matmul(a, b, res_dtype):
# Still we can test our matrix multiplication with block pointers against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32)]:
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32),
(torch.float32, torch.float32)]:
if dtype.is_floating_point:
a = torch.randn((512, 512), device='xpu', dtype=dtype)
b = torch.randn((512, 512), device='xpu', dtype=dtype)
Expand All @@ -232,6 +233,8 @@ def matmul(a, b, res_dtype):

triton_output = matmul(a, b, res_dtype)
if dtype.is_floating_point:
torch.xpu.set_fp32_math_mode(torch.xpu.utils.FP32MathMode.TF32 if dtype ==
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
torch.float32 else torch.xpu.utils.FP32MathMode.FP32)
torch_output = torch.matmul(a, b).to(res_dtype)
else:
# torch.matmul clamps values to input dtype; IPEX doesn't support int32 matmul
Expand Down
34 changes: 33 additions & 1 deletion test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s
// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --split-input-file | FileCheck %s

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, vector<8xi32>)
Expand Down Expand Up @@ -108,3 +108,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
tt.return
}
}

// -----

// COM: Checks the correct lowering of the A operand load for TF32, i.e. using 4xi32 and vnni=false.

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @matmul_kernel_with_block_pointers_tf32(
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [512 : i32, 1 : i32, 1 : i32]} {
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%c0_i64 = arith.constant 0 : i64
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x8xf32>>
%1 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf32>>
// CHECK: [[ELEM_SIZE:%.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[NUM_BLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: {{%.*}} = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v4i32({{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, [[ELEM_SIZE]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[NUM_BLOCKS]], [[TRANSPOSE]], [[VNNI]], {{%.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32>
%2 = tt.load %0 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[ELEM_SIZE:%.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[NUM_BLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VAL_60:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i32({{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, [[ELEM_SIZE]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[NUM_BLOCKS]], [[TRANSPOSE]], [[VNNI]], {{%.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32>
%3 = tt.load %1 {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
tt.return
}
}
60 changes: 58 additions & 2 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// COM: Test transformation for int8 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers
// CHECK-LABEL: @matmul_kernel_with_block_pointers_int8
#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32) {
tt.func public @matmul_kernel_with_block_pointers_int8(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : i32
// CHECK-DAG: [[C32:%.*]] = arith.constant 32 : i32
%cst = arith.constant dense<0> : tensor<8x32xi32, #warp>
Expand Down Expand Up @@ -223,3 +223,59 @@ tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i8> {tt.divisib
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xi32, #warp>>
tt.return
}

// -----

// COM: Test transformation for tf32 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers_tf32
#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: i32) {
// CHECK: [[TZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<8x16xf32>
%cst = arith.constant dense<0.000000e+00> : tensor<8x32xf32, #warp>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
// CHECK-COUNT-4: {{.*}} = tt.make_tensor_ptr %arg0
// CHECK-COUNT-8: {{.*}} = tt.make_tensor_ptr %arg1
%tptr_a = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%tptr_b = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
// CHECK: [[LOOP_RES:%.*]]:14 = scf.for {{.*}} = {{.*}} to {{.*}} step {{.*}} iter_args([[ITER_1:%.*]] = [[TZERO]], [[ITER_2:%.*]] = [[TZERO]], {{.*}})
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
%35:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %tptr_a, %arg12 = %tptr_b) -> (tensor<8x32xf32, #warp>, !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>, !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>) : i32 {
// CHECK: [[LD_A1:%.*]] = tt.load %arg[[#first_ptr:]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A2:%.*]] = tt.load %arg[[#first_ptr+1]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A3:%.*]] = tt.load %arg[[#first_ptr+2]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A4:%.*]] = tt.load %arg[[#first_ptr+3]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_B1:%.*]] = tt.load %arg[[#first_ptr+4]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B2:%.*]] = tt.load %arg[[#first_ptr+5]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B3:%.*]] = tt.load %arg[[#first_ptr+6]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B4:%.*]] = tt.load %arg[[#first_ptr+7]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B5:%.*]] = tt.load %arg[[#first_ptr+8]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B6:%.*]] = tt.load %arg[[#first_ptr+9]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B7:%.*]] = tt.load %arg[[#first_ptr+10]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B8:%.*]] = tt.load %arg[[#first_ptr+11]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
%46 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%47 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
// CHECK: [[DOT_1:%.*]] = tt.dot [[LD_A1]], [[LD_B1]], [[ITER_1]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_2:%.*]] = tt.dot [[LD_A2]], [[LD_B2]], [[DOT_1]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_3:%.*]] = tt.dot [[LD_A3]], [[LD_B3]], [[DOT_2]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_4:%.*]] = tt.dot [[LD_A4]], [[LD_B4]], [[DOT_3]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_5:%.*]] = tt.dot [[LD_A1]], [[LD_B5]], [[ITER_2]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_6:%.*]] = tt.dot [[LD_A2]], [[LD_B6]], [[DOT_5]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_7:%.*]] = tt.dot [[LD_A3]], [[LD_B7]], [[DOT_6]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_8:%.*]] = tt.dot [[LD_A4]], [[LD_B8]], [[DOT_7]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
%48 = tt.dot %46, %47, %arg10, inputPrecision = tf32 : tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<8x32xf32, #warp>
// CHECK-COUNT-12: {{.*}} = tt.advance
%49 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%50 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
scf.yield %48, %49, %50 : tensor<8x32xf32, #warp>, !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>, !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
} {triton_gpu.workload = 3 : i32}
// CHECK: [[TPTR_C1:%.*]] = tt.make_tensor_ptr %arg2,
// CHECK: [[TPTR_C2:%.*]] = tt.make_tensor_ptr %arg2,
%tptr_c = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf32, #warp>>
// CHECK: tt.store [[TPTR_C1:%.*]], [[LOOP_RES]]#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: tt.store [[TPTR_C2:%.*]], [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}
40 changes: 24 additions & 16 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ using namespace mlir::triton::gpu::intel;
namespace {

VectorType getVectorType(RankedTensorType tensorType, Type elemType) {
unsigned ratio =
elemType.getIntOrFloatBitWidth() / tensorType.getElementTypeBitWidth();
unsigned num = (tensorType.getNumElements() / 16) / ratio;
// Determine a vector type of the given `elemType` that covers 1/16 of
// `tensorType`, i.e. the amout of data a single subgroup lane will work on.
size_t tensorSize =
tensorType.getNumElements() * tensorType.getElementTypeBitWidth();
size_t num = (tensorSize / 16) / elemType.getIntOrFloatBitWidth();
return vec_ty(elemType, num);
};

Expand Down Expand Up @@ -120,11 +122,13 @@ class LoadStorePrefetchOpConversion
assert(tensorType.getRank() <= 2 &&
"only support 1d/2d load/store/prefetch for now");

unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth();
Type elemType = tensorType.getElementType();
unsigned dataSize = elemType.getIntOrFloatBitWidth();
unsigned blockHeight = tensorType.getShape()[0];
unsigned blockWidth = tensorType.getShape()[1];
assert((blockWidth == 16 || blockWidth == 32 || blockWidth == 64) &&
"only support 16/32/64 block");
assert((blockWidth == 8 || blockWidth == 16 || blockWidth == 32 ||
blockWidth == 64) &&
"only support 8/16/32/64 block");
auto idxAttr = op->template getAttrOfType<mlir::IntegerAttr>("DotIdx");
unsigned vBlks = 1;
if (dataSize == 16) {
Expand Down Expand Up @@ -175,10 +179,11 @@ class LoadStorePrefetchOpConversion
unsigned idx = idxAttr.getInt();
Type resType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
bool isDword = idx == 1 || elemType == f32_ty;
Type vectorType =
getVectorType(cast<RankedTensorType>(op.getResult().getType()),
idx == 0 ? i16_ty : i32_ty);
bool vnni = (idx == 1) && dataSize <= 32;
isDword ? i32_ty : i16_ty);
bool vnni = (idx == 1) && dataSize < 32;
auto load = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
loc, vectorType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY,
dataSize, blockWidth, blockHeight, vBlks, false /*transpose*/, vnni);
Expand Down Expand Up @@ -219,12 +224,14 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
LogicalResult
matchAndRewrite(DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType {
auto encodePrecision =
[&](Type type, InputPrecisionAttr attr) -> TritonGEN::PrecisionType {
FMarno marked this conversation as resolved.
Show resolved Hide resolved
if (type == bf16_ty)
return TritonGEN::PrecisionType::BF16;
else if (type == f16_ty)
return TritonGEN::PrecisionType::FP16;
else if (type == rewriter.getTF32Type())
else if (type == f32_ty && attr &&
attr.getValue() == InputPrecision::TF32)
return TritonGEN::PrecisionType::TF32;
else if (type.isInteger(8)) {
if (type.isUnsignedInteger())
Expand All @@ -236,18 +243,19 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
return TritonGEN::PrecisionType::UNUSED;
};

TritonGEN::PrecisionType precATy =
encodePrecision(op.getA().getType().getElementType());
TritonGEN::PrecisionType precBTy =
encodePrecision(op.getB().getType().getElementType());
TritonGEN::PrecisionType precATy = encodePrecision(
op.getA().getType().getElementType(), op.getInputPrecisionAttr());
TritonGEN::PrecisionType precBTy = encodePrecision(
op.getB().getType().getElementType(), op.getInputPrecisionAttr());
auto precA =
TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precATy);
auto precB =
TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precBTy);

Location loc = op.getLoc();
Type typeA =
getVectorType(cast<RankedTensorType>(op.getA().getType()), i16_ty);
Type typeA = getVectorType(
cast<RankedTensorType>(op.getA().getType()),
precATy == TritonGEN::PrecisionType::TF32 ? i32_ty : i16_ty);
Value castA = bitcast(adaptor.getA(), typeA);
VectorType typeB =
getVectorType(cast<RankedTensorType>(op.getB().getType()), i32_ty);
Expand Down