diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 790281a1f..f35ac030a 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -63,7 +63,7 @@ def __init__( self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y - self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps) + self.threads = threads def _initialize_k_dim(self, a_dtype="float16"): self.k_dim = 256 // DataType(a_dtype).bits @@ -91,17 +91,6 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): self.micro_size_y = n_dim self.micro_size_k = k_dim - def _initialize_thread_axis(self, - threads=128, - warp_size=32, - block_row_warps=2, - block_col_warps=2): - self.threads = threads - # thread_bindings = T.env_thread("threadIdx.x") - # self.tx = thread_bindings % warp_size - # self.ty = (thread_bindings // warp_size) % block_row_warps - # self.tz = thread_bindings // (warp_size * block_row_warps) - @staticmethod @T.macro def MMA(inst, A_local_buf, B_local_buf, C_local_buf): @@ -209,3 +198,32 @@ def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row, col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + j * inst.local_size_out + local_id] + + # Allow GEMM from shared memory to local memory + @staticmethod + @T.macro + def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings): + A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size), + inst.a_dtype, + scope="local") + B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size), + inst.b_dtype, + scope="local") + for ki in T.serial(0, (inst.block_K // inst.micro_size_k)): + inst.LDMATRIX_A( + inst, + A_local_buf, + A_shared_buf, + ki, + thread_bindings=thread_bindings, + ) + + inst.LDMATRIX_B( + inst, + B_local_buf, + B_shared_buf, + ki, + thread_bindings=thread_bindings, + ) + + inst.MMA(inst, A_local_buf, B_local_buf, C_local_buf)