diff --git a/3rdparty/tvm b/3rdparty/tvm index 180359a9a..8847ba9a6 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 180359a9ab7a04e6b3ad472a12295d9d3056fa95 +Subproject commit 8847ba9a6562b08b77d0223a33601f34d8100404 diff --git a/bitblas/tl/mfma_macro_generator.py b/bitblas/tl/mfma_macro_generator.py index 2fe8c0061..cd95ed794 100644 --- a/bitblas/tl/mfma_macro_generator.py +++ b/bitblas/tl/mfma_macro_generator.py @@ -2,8 +2,9 @@ # Licensed under the MIT License. import tvm.tl.language as T - +from typing import Tuple from tvm import DataType +from tvm.tir import PrimExpr from tvm.runtime import convert from .utils import ( mfma_store_index_map,) @@ -142,12 +143,16 @@ def get_ldmatrix_index_map(self, is_b=False): return index_map, reverse_index_map - def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): + def extract_thread_binding(self, thread_id) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: WARP_SIZE = self.WARP_SIZE block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps + return thread_id % WARP_SIZE, (thread_id // WARP_SIZE) % block_col_warps, ( + thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps + + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): warp_row_tiles = self.warp_row_tiles - warp_cols = self.warp_cols + warp_rows = self.warp_rows chunk = self.chunk micro_size_x = self.micro_size_x micro_size_k = self.micro_size_k @@ -164,17 +169,16 @@ def _warp_ldmatrix_a( thread_bindings, rk=0, ): - tx = thread_bindings % WARP_SIZE - tz = (thread_bindings // (WARP_SIZE * block_col_warps)) % block_row_warps + tx, _, tz = self.extract_thread_binding(thread_bindings) if is_transposed: - for i in T.serial(warp_cols): + for i in T.serial(warp_rows): for local_id in T.vectorized(local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * micro_size_k, tz * warp_row_tiles + i * micro_size_x) A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] else: - for i in T.serial(warp_cols): + for i in T.serial(warp_rows): for local_id in T.vectorized(local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (tz * warp_row_tiles + i * micro_size_x, @@ -184,9 +188,6 @@ def _warp_ldmatrix_a( return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): - - WARP_SIZE = self.WARP_SIZE - block_col_warps = self.block_col_warps warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -205,8 +206,7 @@ def _warp_ldmatrix_b( thread_bindings, rk=0, ): - tx = thread_bindings % WARP_SIZE - ty = (thread_bindings // WARP_SIZE) % block_col_warps + tx, ty, _ = self.extract_thread_binding(thread_bindings) if is_transposed: for j in T.serial(warp_cols): @@ -263,7 +263,6 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): return _warp_mma(A_local_buf, B_local_buf, C_local_buf) def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): - WARP_SIZE = self.WARP_SIZE block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps warp_rows = self.warp_rows @@ -281,23 +280,17 @@ def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): # equal to the warp_size @T.macro def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings): - tx = thread_bindings % WARP_SIZE - ty = (thread_bindings // WARP_SIZE) % block_row_warps - tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps - + tx, ty, tz = self.extract_thread_binding(thread_bindings) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.serial(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) - C_buf[ty * warp_rows + i, tz * warp_cols + j, row, + C_buf[tz * warp_rows + i, ty * warp_cols + j, row, col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings): - tx = thread_bindings % WARP_SIZE - ty = (thread_bindings // WARP_SIZE) % block_row_warps - tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps - + tx, ty, tz = self.extract_thread_binding(thread_bindings) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.serial(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id))