diff --git a/3rdparty/tvm b/3rdparty/tvm index a12155db8..e52254920 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a12155db8eb818f54f27c5124ad8332918cae0ec +Subproject commit e52254920e8ba1719e7c4f68dd684fd8ede79623 diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index e0cd7da66..1b083eafb 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -6,7 +6,7 @@ from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, - make_swizzle_layout, + make_mma_swizzle_layout as make_swizzle_layout, ) from bitblas.tl.mma_macro_generator import ( diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py index 087fc9f11..9de81d29d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py @@ -7,7 +7,7 @@ from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, - make_swizzle_layout, + make_mma_swizzle_layout as make_swizzle_layout, ) from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( MatmulFineGrainScheduler, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index 9127c7ae4..fff815e7d 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -6,7 +6,7 @@ from typing import Optional, List, Literal from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 - make_swizzle_layout, # noqa: F401 + make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 ) from bitblas.tl.mma_macro_generator import ( diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py index 54303167e..4bcd75cbe 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py @@ -6,7 +6,7 @@ from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 - make_swizzle_layout, # noqa: F401 + make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 index_to_coordinates, # noqa: F401 ) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index d51766cec..afd8849fc 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -6,7 +6,7 @@ from typing import Optional from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 - make_swizzle_layout, # noqa: F401 + make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 ) from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler from bitblas.tl.mma_macro_generator import ( diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py index ace3052da..45cd948c9 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py @@ -6,7 +6,7 @@ from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 - make_swizzle_layout, # noqa: F401 + make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 index_to_coordinates, # noqa: F401 ) from bitblas.base.arch import TileDevice diff --git a/bitblas/tl/__init__.py b/bitblas/tl/__init__.py index 3103fbf89..fe8f61522 100644 --- a/bitblas/tl/__init__.py +++ b/bitblas/tl/__init__.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. from .utils import ( - get_swizzle_layout, # noqa: F401 mma_store_index_map, # noqa: F401 get_ldmatrix_offset, # noqa: F401 ) diff --git a/bitblas/tl/mfma_layout.py b/bitblas/tl/mfma_layout.py index 2ff9be948..79e75e438 100644 --- a/bitblas/tl/mfma_layout.py +++ b/bitblas/tl/mfma_layout.py @@ -1,5 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +from tvm import DataType +import tvm.tl.language as T from tvm.runtime import convert @@ -71,3 +74,52 @@ def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id): i = thread_id % 16 j = local_id + (thread_id // 16) * 4 return i, j + + +def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id): + i = thread_id % 16 + j = (thread_id // 16) * 8 + local_id + return i, j + + +def shared_16x32_to_local_64x8_layout_A(i, j): + thread_id = i + 16 * (j // 8) + local = (j % 8) + return thread_id, local + + +def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 8 + j = thread_id % 16 + return i, j + + +def shared_16x32_to_local_64x8_layout_B(i, j): + thread_id = j + (i // 8) * 16 + local = (i % 8) + return thread_id, local + + +def make_mfma_swizzle_layout(shared_buf, vecSize=8): + dtype = shared_buf.dtype + shape = shared_buf.shape + + numBanks = 32 + bankBitWidth = 32 + SIMDWidth = 16 + + innerDimLength = shape[-1] + typeWidthInBit = DataType(dtype).bits + + elemsPerOneBanksRow = (numBanks * bankBitWidth) // typeWidthInBit + perPhase = max(1, elemsPerOneBanksRow // innerDimLength) + maxPhase = min(SIMDWidth // perPhase, innerDimLength // vecSize) + + def transform(row, col): + phase = (row // perPhase) % maxPhase + colOffSwizzled = ((col // vecSize) ^ phase) * vecSize + colOffOrdered = col % vecSize + colOff = colOffSwizzled + colOffOrdered + return row, colOff + + return T.Layout(shape, transform) diff --git a/bitblas/tl/mfma_macro_generator.py b/bitblas/tl/mfma_macro_generator.py index c68fce701..0148cd8bf 100644 --- a/bitblas/tl/mfma_macro_generator.py +++ b/bitblas/tl/mfma_macro_generator.py @@ -6,6 +6,7 @@ from tvm import DataType from tvm.tir import PrimExpr from tvm.runtime import convert +from typing import Optional from .utils import ( mfma_store_index_map,) @@ -30,22 +31,29 @@ class MatrixCoreIntrinEmitter(object): "e5m2_float8": "e5m2", } + # k_pack represents the number of elements in a vectorized instruction + # Detail information can be found in the triton documentation + # https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419 + k_pack = 1 + # Represent the thread binding in the form of (tx, warp_n, warp_m) is_m_first = False def __init__( self, - a_dtype="float16", - b_dtype="float16", - accum_dtype="float16", - a_transposed=False, - b_transposed=False, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=8, - warp_col_tiles=8, - chunk=16, - reduce_k=1, - num_elems_per_byte=1, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + k_pack: Optional[int] = None, + is_m_first: Optional[bool] = False, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -63,6 +71,9 @@ def __init__( self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) self._initialize_mfma_prefix(self.k_dim) self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self._initialize_k_pack(k_pack) + self._initialize_is_m_first(is_m_first) + self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k @@ -113,19 +124,31 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): self.micro_size_y = n_dim self.micro_size_k = k_dim + def _initialize_k_pack(self, k_pack: Optional[int] = None): + if k_pack is not None: + self.k_pack = k_pack + + def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): + if is_m_first is not None: + self.is_m_first = is_m_first + def get_ldmatrix_index_map(self, is_b=False): from .mfma_layout import ( shared_16x4_to_local_64x1_layout_A, shared_4x16_to_local_64x1_layout_B, shared_16x16_to_local_64x4_layout_A, shared_16x16_to_local_64x4_layout_B, + shared_16x32_to_local_64x8_layout_A, + shared_16x32_to_local_64x8_layout_B, thread_id_shared_access_64x1_to_16x4_layout_A, thread_id_shared_access_64x1_to_4x16_layout_B, thread_id_shared_access_64x4_to_16x16_layout_A, thread_id_shared_access_64x4_to_16x16_layout_B, + thread_id_shared_access_64x8_to_16x32_layout_A, + thread_id_shared_access_64x8_to_16x32_layout_B, ) - k_dim = self.k_dim + k_dim = self.k_dim * self.k_pack transposed = self.a_transposed if not is_b else self.b_transposed if k_dim == 4: index_map = shared_16x4_to_local_64x1_layout_A @@ -140,6 +163,13 @@ def get_ldmatrix_index_map(self, is_b=False): if is_b: index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + elif k_dim == 32: + index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A + reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A + + if is_b: + index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B + reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B else: raise ValueError("k_dim must be 4 or 16 currently") @@ -181,6 +211,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): micro_size_x = self.micro_size_x micro_size_k = self.micro_size_k local_size_a = self.local_size_a + k_pack = self.k_pack is_transposed = self.a_transposed _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) @@ -196,18 +227,20 @@ def _warp_ldmatrix_a( tx, _, warp_m = self.extract_thread_binding(thread_bindings) if is_transposed: for i in T.serial(warp_rows): - for local_id in T.vectorized(local_size_a): + for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * micro_size_k, warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, + r + col] else: for i in T.serial(warp_rows): - for local_id in T.vectorized(local_size_a): + for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k) - A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, + r + col] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) @@ -218,6 +251,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): micro_size_y = self.micro_size_y micro_size_k = self.micro_size_k local_size_b = self.local_size_b + k_pack = self.k_pack is_transposed = self.b_transposed _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) @@ -234,22 +268,24 @@ def _warp_ldmatrix_b( if is_transposed: for j in T.serial(warp_cols): - for local_id in T.vectorized(local_size_b): + for local_id in T.vectorized(k_pack * local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = ( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) - B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, + r + col] else: for j in T.serial(warp_cols): - for local_id in T.vectorized(local_size_b): + for local_id in T.vectorized(k_pack * local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = ( rk * chunk + ki * micro_size_k, warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, + r + col] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) @@ -259,6 +295,7 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf): local_size_a = self.local_size_a local_size_b = self.local_size_b local_size_out = self.local_size_out + k_pack = self.k_pack mfma_suffix = self.mfma_suffix a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}" @@ -267,7 +304,7 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf): @T.macro def _warp_mma(A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(warp_rows, warp_cols): + for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): T.tvm_mfma( mfma_suffix, "row", @@ -276,9 +313,9 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): compute_b_dtype, compute_out_dtype, B_local_buf.data, - (j * local_size_b) // local_size_b, + ((j * k_pack + kp) * local_size_b) // local_size_b, A_local_buf.data, - (i * local_size_a) // local_size_a, + ((i * k_pack + kp) * local_size_a) // local_size_a, C_local_buf.data, (i * warp_cols * local_size_out + j * local_size_out) // local_size_out, dtype=compute_out_dtype, diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index 719885be5..5dab4ba64 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import Union +from tvm import arith, DataType +import tvm.tl.language as T + def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): row = thread_id % 16 @@ -42,3 +46,74 @@ def shared_16x32_to_mma_32x16_smoothlayout(i, j): def shared_32x16_to_mma_32x16_smoothlayout(i, j): return (i * 2 + j // 16, j % 16) + + +def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): + ana = arith.Analyzer() + BANK_SIZE_BYTES = 128 + if isinstance(dtype, str): + dtype = DataType(dtype) + col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( + BANK_SIZE_BYTES // dtype.bits) + # use transaction bits to support diverse dtype. + # for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits + # for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits + coalescent_bits = dtype.bits * row_size + # permutation on 4 banks, each bank has 32 bits + bank_elems = BANK_SIZE_BYTES // dtype.bits + new_col_idx_outer = None + + if coalescent_bits % 1024 == 0: + # Use 8 * 8 permuted layout + # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read + # Every row below corresponds to 32 banks + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 + # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 + # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 + # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 + # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 + # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 + row_idx_sub = row_idx % bank_elems + new_col_idx_outer = col_idx_outer ^ row_idx_sub + else: + assert coalescent_bits % 512 == 0 + # Use 8 * 4 permuted layout + # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read + # Every row below corresponds to 16 banks + # 0 1 2 3 ==> 0 1 2 3 + # 0 1 2 3 ==> 0 1 2 3 + # 0 1 2 3 ==> 1 0 3 2 + # 0 1 2 3 ==> 1 0 3 2 + # 0 1 2 3 ==> 2 3 0 1 + # 0 1 2 3 ==> 2 3 0 1 + # 0 1 2 3 ==> 3 2 1 0 + # 0 1 2 3 ==> 3 2 1 0 + # View with 8 elements per row: + # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 + # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 + # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 + # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 + row_idx_sub = row_idx % bank_elems + # Interleave elems per byte + interleave_elems = 32 // dtype.bits + new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) + + assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" + return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) + + +def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if is_smooth or not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index e11dab3e7..1142228f2 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from tvm import arith from tvm import DataType -import tvm.tl.language as T -from typing import Union, Literal +from typing import Literal from .mma_layout import ( ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, @@ -14,61 +12,9 @@ ) from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) - -def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): - ana = arith.Analyzer() - BANK_SIZE_BYTES = 128 - if isinstance(dtype, str): - dtype = DataType(dtype) - col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( - BANK_SIZE_BYTES // dtype.bits) - # use transaction bits to support diverse dtype. - # for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits - # for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits - coalescent_bits = dtype.bits * row_size - # permutation on 4 banks, each bank has 32 bits - bank_elems = BANK_SIZE_BYTES // dtype.bits - new_col_idx_outer = None - - if coalescent_bits % 1024 == 0: - # Use 8 * 8 permuted layout - # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read - # Every row below corresponds to 32 banks - # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 - # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 - # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 - # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 - # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 - # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 - # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 - # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 - row_idx_sub = row_idx % bank_elems - new_col_idx_outer = col_idx_outer ^ row_idx_sub - else: - assert coalescent_bits % 512 == 0 - # Use 8 * 4 permuted layout - # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read - # Every row below corresponds to 16 banks - # 0 1 2 3 ==> 0 1 2 3 - # 0 1 2 3 ==> 0 1 2 3 - # 0 1 2 3 ==> 1 0 3 2 - # 0 1 2 3 ==> 1 0 3 2 - # 0 1 2 3 ==> 2 3 0 1 - # 0 1 2 3 ==> 2 3 0 1 - # 0 1 2 3 ==> 3 2 1 0 - # 0 1 2 3 ==> 3 2 1 0 - # View with 8 elements per row: - # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 - # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 - # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 - # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 - row_idx_sub = row_idx % bank_elems - # Interleave elems per byte - interleave_elems = 32 // dtype.bits - new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) - - assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" - return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) +from .mma_layout import get_swizzle_layout # noqa: F401 +from .mma_layout import make_mma_swizzle_layout # noqa: F401 +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 # the original implementation and insight is from the following code snippet @@ -125,21 +71,6 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]): return micro_size_x, micro_size_y, micro_size_k -def make_swizzle_layout(shared_buf, is_smooth: bool = False): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if is_smooth or not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - def index_to_coordinates(index, shape): ''' General Implementation of: diff --git a/integration/BitNet/int4_kernel/tl_int4xint2.py b/integration/BitNet/int4_kernel/tl_int4xint2.py index 7fc2fc7a9..afc5842b1 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2.py @@ -8,7 +8,7 @@ from tvm import DataType from tvm import tl as TL import tvm.tl.language as T -from bitblas.tl.utils import (make_swizzle_layout, index_to_coordinates) +from bitblas.tl.utils import (make_mma_swizzle_layout as make_swizzle_layout, index_to_coordinates) from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s from bitblas.tl.mma_macro_generator import ( diff --git a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py index c7c80a3f1..88cd45e6c 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py @@ -8,7 +8,7 @@ from tvm import DataType from tvm import tl as TL import tvm.tl.language as T -from bitblas.tl.utils import make_swizzle_layout, index_to_coordinates +from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout, index_to_coordinates from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s diff --git a/integration/BitNet/int4_kernel/tl_int4xint4.py b/integration/BitNet/int4_kernel/tl_int4xint4.py index e3bc20649..8da5cbb7c 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint4.py +++ b/integration/BitNet/int4_kernel/tl_int4xint4.py @@ -7,7 +7,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import ( - make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout,) from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter,) diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py new file mode 100644 index 000000000..f281f8eb0 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + k_pack=1, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + vec_size = 4 * k_pack + import tvm.tl.language as T + + @T.prim_func + def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( + (M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared, coalesced_width=vec_size) + else: + T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size) + else: + T.copy(B[k * block_K, bx * block_N], B_shared, coalesced_width=vec_size) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B, k_pack=k_pack) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, + k_pack=1, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + k_pack=k_pack, + ) + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_f16f32f32_nt(): + run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32) + run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2) + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index dd63274e2..66958134c 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -8,7 +8,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert -from bitblas.tl.utils import (make_swizzle_layout) +from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index ee93d33b0..3dd5e11da 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -9,7 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import ( - make_swizzle_layout,) + make_mma_swizzle_layout as make_swizzle_layout,) from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py new file mode 100644 index 000000000..1c4696555 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -0,0 +1,212 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import make_mfma_swizzle_layout as make_swizzle_layout +from bitblas.tl.mfma_macro_generator import ( + MatrixCoreIntrinEmitter,) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + block_row_warps = 1 + block_col_warps = 1 + warp_row_tiles = 16 + warp_col_tiles = 16 + chunk = 32 + shared_scope = "shared.dyn" + cache_write_shared = False + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=0): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local) + + # Perform STMatrix + if cache_write_shared: + mfma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + mfma_emitter.stmatrix( + C_local, + C, + thread_bindings=thread_bindings, + pid_m=by, + pid_n=bx, + ) + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + + if in_dtype == "int8": + A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@bitblas.testing.requires_rocm +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py similarity index 100% rename from testing/python/tilelang/test_tilelang_macro_gemm.py rename to testing/python/tilelang/test_tilelang_mma_macro_gemm.py