diff --git a/3rdparty/tvm b/3rdparty/tvm index e1c5b0897..71fe7ce82 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e1c5b089737e47a3849afa87df2432c13b633594 +Subproject commit 71fe7ce827396b98a3169343c3744e788a82566c diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 63433a52a..fd8ec43ae 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -472,3 +472,210 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + +class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = "int4" + b_dtype_abbrv = "int4" + accum_dtype = self.accum_dtype + accum_dtype_abbrv = accum_dtype + mma_prefix = "m16n8k32" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + ''' + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + ''' + + # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_a) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_b) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + +class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = "int4" + b_dtype_abbrv = "int4" + accum_dtype = self.accum_dtype + accum_dtype_abbrv = "int32" + mma_prefix = "m16n8k32" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + ''' + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + ''' + + # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_a) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_b) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 053dbe4d5..2c88bec64 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -120,12 +120,12 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]): return micro_size_x, micro_size_y, micro_size_k -def make_swizzle_layout(shared_buf): +def make_swizzle_layout(shared_buf, is_smooth: bool = False): dtype = shared_buf.dtype shape = shared_buf.shape can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: + if is_smooth or not can_swizzle: return T.Layout(shape, lambda *args: args) def transform_func(i, j): diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py new file mode 100644 index 000000000..37c210b91 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -0,0 +1,416 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import ( + make_swizzle_layout,) + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter, + INT4TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod(compressed_A, compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@simplify_prim_func +def tl_matmul_weight_only_transform( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + K = K // 2 + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + transform_b = 3 + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + LB = ladder_permutate(compressed_B.cpu()).cuda() + + mod(compressed_A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul_weight_only_transform(): + assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32") + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index cc4839568..4c4cf8f59 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -472,7 +472,9 @@ def tl_matmul_with_ladder_weight_only_transform( warp_size = 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -501,9 +503,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") T.annotate_layout({ @@ -667,7 +669,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( warp_size = 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -704,10 +708,10 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) - B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b // num_elems_per_byte), storage_dtype) + B_dequantize_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") rk = T.thread_binding(0, reduce_k, "threadIdx.y") @@ -765,15 +769,16 @@ def main( ) for j in T.serial(warp_cols): - local_size_b = mma_emitter.local_size_b - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), 8) + T.call_extern( + 'handle', 'decode_i4u_to_f16', + T.address_of(B_local[j * mma_emitter.local_size_b // + num_elems_per_byte]), + T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8) mma_emitter.mma(A_local, B_dequantize_local, C_local) if reduce_k > 1: - for n in T.serial(warp_rows * warp_cols * local_size): + for n in T.serial(warp_rows * warp_cols * local_size_c): T.attr( T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), "reduce_scope",