Skip to content

Commit

Permalink
Enable TF32 test.
Browse files Browse the repository at this point in the history
Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
  • Loading branch information
jopperm committed May 30, 2024
1 parent aca33d2 commit 3b09e6d
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 21 deletions.
6 changes: 4 additions & 2 deletions python/tutorials/09-experimental-block-pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def matmul(a, b, res_dtype):
# Still we can test our matrix multiplication with block pointers against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32)]:
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32),
(torch.float32, torch.float32)]:
if dtype.is_floating_point:
a = torch.randn((512, 512), device='xpu', dtype=dtype)
b = torch.randn((512, 512), device='xpu', dtype=dtype)
Expand All @@ -243,8 +244,9 @@ def matmul(a, b, res_dtype):

# Note: the torch.matmul and Triton implementations uses different
# algorithms so we need to adjust tolerance.
atol = 4e-2 if dtype == torch.float32 else 1e-4
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol):
if torch.allclose(triton_output, torch_output, atol=atol, rtol=rtol):
print("✅ Triton and Torch match")
else:
exit("❌ Triton and Torch differ")
34 changes: 33 additions & 1 deletion test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s
// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --split-input-file | FileCheck %s

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, vector<8xi32>)
Expand Down Expand Up @@ -108,3 +108,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
tt.return
}
}

// -----

// COM: Checks the correct lowering of the A operand load for TF32, i.e. using 4xi32 and vnni=false.

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @matmul_kernel_with_block_pointers_tf32(
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [512 : i32, 1 : i32, 1 : i32]} {
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%c0_i64 = arith.constant 0 : i64
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x8xf32>>
%1 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf32>>
// CHECK: [[ELEM_SIZE:%.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[NUM_BLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: {{%.*}} = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v4i32({{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, [[ELEM_SIZE]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[NUM_BLOCKS]], [[TRANSPOSE]], [[VNNI]], {{%.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32>
%2 = tt.load %0 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[ELEM_SIZE:%.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: [[NUM_BLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
// CHECK: [[VAL_60:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i32({{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, [[ELEM_SIZE]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[NUM_BLOCKS]], [[TRANSPOSE]], [[VNNI]], {{%.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32>
%3 = tt.load %1 {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
tt.return
}
}
60 changes: 58 additions & 2 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// COM: Test transformation for int8 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers
// CHECK-LABEL: @matmul_kernel_with_block_pointers_int8
#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32) {
tt.func public @matmul_kernel_with_block_pointers_int8(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : i32
// CHECK-DAG: [[C32:%.*]] = arith.constant 32 : i32
%cst = arith.constant dense<0> : tensor<8x32xi32, #warp>
Expand Down Expand Up @@ -223,3 +223,59 @@ tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i8> {tt.divisib
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xi32, #warp>>
tt.return
}

// -----

// COM: Test transformation for tf32 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers_tf32
#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: i32) {
// CHECK: [[TZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<8x16xf32>
%cst = arith.constant dense<0.000000e+00> : tensor<8x32xf32, #warp>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
// CHECK-COUNT-4: {{.*}} = tt.make_tensor_ptr %arg0
// CHECK-COUNT-8: {{.*}} = tt.make_tensor_ptr %arg1
%tptr_a = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%tptr_b = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
// CHECK: [[LOOP_RES:%.*]]:14 = scf.for {{.*}} = {{.*}} to {{.*}} step {{.*}} iter_args([[ITER_1:%.*]] = [[TZERO]], [[ITER_2:%.*]] = [[TZERO]], {{.*}})
%35:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %tptr_a, %arg12 = %tptr_b) -> (tensor<8x32xf32, #warp>, !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>, !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>) : i32 {
// CHECK: [[LD_A1:%.*]] = tt.load %arg[[#first_ptr:]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A2:%.*]] = tt.load %arg[[#first_ptr+1]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A3:%.*]] = tt.load %arg[[#first_ptr+2]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_A4:%.*]] = tt.load %arg[[#first_ptr+3]] {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x8xf32>>
// CHECK: [[LD_B1:%.*]] = tt.load %arg[[#first_ptr+4]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B2:%.*]] = tt.load %arg[[#first_ptr+5]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B3:%.*]] = tt.load %arg[[#first_ptr+6]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B4:%.*]] = tt.load %arg[[#first_ptr+7]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B5:%.*]] = tt.load %arg[[#first_ptr+8]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B6:%.*]] = tt.load %arg[[#first_ptr+9]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B7:%.*]] = tt.load %arg[[#first_ptr+10]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: [[LD_B8:%.*]] = tt.load %arg[[#first_ptr+11]] {DotIdx = 1 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
%46 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%47 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
// CHECK: [[DOT_1:%.*]] = tt.dot [[LD_A1]], [[LD_B1]], [[ITER_1]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_2:%.*]] = tt.dot [[LD_A2]], [[LD_B2]], [[DOT_1]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_3:%.*]] = tt.dot [[LD_A3]], [[LD_B3]], [[DOT_2]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_4:%.*]] = tt.dot [[LD_A4]], [[LD_B4]], [[DOT_3]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_5:%.*]] = tt.dot [[LD_A1]], [[LD_B5]], [[ITER_2]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_6:%.*]] = tt.dot [[LD_A2]], [[LD_B6]], [[DOT_5]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_7:%.*]] = tt.dot [[LD_A3]], [[LD_B7]], [[DOT_6]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
// CHECK: [[DOT_8:%.*]] = tt.dot [[LD_A4]], [[LD_B8]], [[DOT_7]], inputPrecision = tf32 : tensor<8x8xf32> * tensor<8x16xf32> -> tensor<8x16xf32>
%48 = tt.dot %46, %47, %arg10, inputPrecision = tf32 : tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<8x32xf32, #warp>
// CHECK-COUNT-12: {{.*}} = tt.advance
%49 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%50 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
scf.yield %48, %49, %50 : tensor<8x32xf32, #warp>, !tt.ptr<tensor<8x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>, !tt.ptr<tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
} {triton_gpu.workload = 3 : i32}
// CHECK: [[TPTR_C1:%.*]] = tt.make_tensor_ptr %arg2,
// CHECK: [[TPTR_C2:%.*]] = tt.make_tensor_ptr %arg2,
%tptr_c = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x32xf32, #warp>>
// CHECK: tt.store [[TPTR_C1:%.*]], [[LOOP_RES]]#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
// CHECK: tt.store [[TPTR_C2:%.*]], [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf32>>
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}
40 changes: 24 additions & 16 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ using namespace mlir::triton::gpu::intel;
namespace {

VectorType getVectorType(RankedTensorType tensorType, Type elemType) {
unsigned ratio =
elemType.getIntOrFloatBitWidth() / tensorType.getElementTypeBitWidth();
unsigned num = (tensorType.getNumElements() / 16) / ratio;
// Determine a vector type of the given `elemType` that covers 1/16 of
// `tensorType`, i.e. the amout of data a single subgroup lane will work on.
size_t tensorSize =
tensorType.getNumElements() * tensorType.getElementTypeBitWidth();
size_t num = (tensorSize / 16) / elemType.getIntOrFloatBitWidth();
return vec_ty(elemType, num);
};

Expand Down Expand Up @@ -120,11 +122,13 @@ class LoadStorePrefetchOpConversion
assert(tensorType.getRank() <= 2 &&
"only support 1d/2d load/store/prefetch for now");

unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth();
Type elemType = tensorType.getElementType();
unsigned dataSize = elemType.getIntOrFloatBitWidth();
unsigned blockHeight = tensorType.getShape()[0];
unsigned blockWidth = tensorType.getShape()[1];
assert((blockWidth == 16 || blockWidth == 32 || blockWidth == 64) &&
"only support 16/32/64 block");
assert((blockWidth == 8 || blockWidth == 16 || blockWidth == 32 ||
blockWidth == 64) &&
"only support 8/16/32/64 block");
auto idxAttr = op->template getAttrOfType<mlir::IntegerAttr>("DotIdx");
unsigned vBlks = 1;
if (dataSize == 16) {
Expand Down Expand Up @@ -175,10 +179,11 @@ class LoadStorePrefetchOpConversion
unsigned idx = idxAttr.getInt();
Type resType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
bool isDword = idx == 1 || elemType == f32_ty;
Type vectorType =
getVectorType(cast<RankedTensorType>(op.getResult().getType()),
idx == 0 ? i16_ty : i32_ty);
bool vnni = (idx == 1) && dataSize <= 32;
isDword ? i32_ty : i16_ty);
bool vnni = (idx == 1) && dataSize < 32;
auto load = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
loc, vectorType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY,
dataSize, blockWidth, blockHeight, vBlks, false /*transpose*/, vnni);
Expand Down Expand Up @@ -219,12 +224,14 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
LogicalResult
matchAndRewrite(DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType {
auto encodePrecision =
[&](Type type, InputPrecisionAttr attr) -> TritonGEN::PrecisionType {
if (type == bf16_ty)
return TritonGEN::PrecisionType::BF16;
else if (type == f16_ty)
return TritonGEN::PrecisionType::FP16;
else if (type == rewriter.getTF32Type())
else if (type == f32_ty && attr &&
attr.getValue() == InputPrecision::TF32)
return TritonGEN::PrecisionType::TF32;
else if (type.isInteger(8)) {
if (type.isUnsignedInteger())
Expand All @@ -236,18 +243,19 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
return TritonGEN::PrecisionType::UNUSED;
};

TritonGEN::PrecisionType precATy =
encodePrecision(op.getA().getType().getElementType());
TritonGEN::PrecisionType precBTy =
encodePrecision(op.getB().getType().getElementType());
TritonGEN::PrecisionType precATy = encodePrecision(
op.getA().getType().getElementType(), op.getInputPrecisionAttr());
TritonGEN::PrecisionType precBTy = encodePrecision(
op.getB().getType().getElementType(), op.getInputPrecisionAttr());
auto precA =
TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precATy);
auto precB =
TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precBTy);

Location loc = op.getLoc();
Type typeA =
getVectorType(cast<RankedTensorType>(op.getA().getType()), i16_ty);
Type typeA = getVectorType(
cast<RankedTensorType>(op.getA().getType()),
precATy == TritonGEN::PrecisionType::TF32 ? i32_ty : i16_ty);
Value castA = bitcast(adaptor.getA(), typeA);
VectorType typeB =
getVectorType(cast<RankedTensorType>(op.getB().getType()), i32_ty);
Expand Down

0 comments on commit 3b09e6d

Please sign in to comment.