diff --git a/3rdparty/tvm b/3rdparty/tvm index 32c5c790b..39b2ba2fc 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 32c5c790baffe5fa605de52e70640ce67b30f4e6 +Subproject commit 39b2ba2fc24bf2ad441ef7b418c537c2814b21e2 diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index f35ac030a..b1422cb0e 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -97,7 +97,7 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf): for i, j in T.grid(inst.warp_rows, inst.warp_cols): T.ptx_mma( inst.accum_dtype, - "m16n8k16", + inst.mma_prefix, "row", "col", inst.a_dtype_abbrv, @@ -114,7 +114,7 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf): T.ptx_mma( inst.accum_dtype, - "m16n8k16", + inst.mma_prefix, "row", "col", inst.a_dtype_abbrv, @@ -142,11 +142,10 @@ def LDMATRIX_A( stride = inst.chunk tx = thread_bindings % inst.WARP_SIZE ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps - # self.ty = (thread_bindings // warp_size) % block_row_warps - # self.tz = thread_bindings // (warp_size * block_row_warps) + for i in T.serial(inst.warp_rows): T.ptx_ldmatrix( - "float16", + inst.a_dtype, T.bool(False), 4, ".b16", @@ -154,7 +153,7 @@ def LDMATRIX_A( i * inst.local_size_a, T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x, ki * inst.micro_size_k,]), - get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False), + get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed), ) @staticmethod @@ -171,7 +170,7 @@ def LDMATRIX_B( tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) for j in T.serial(inst.warp_cols): T.ptx_ldmatrix( - "float16", + inst.b_dtype, T.bool(False), # TODO(lei): should be optimized 4, ".b16", @@ -179,7 +178,7 @@ def LDMATRIX_B( j * inst.local_size_b, T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y, ki * inst.micro_size_k,]), - get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True), + get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, inst.b_transposed), ) # STS @@ -203,13 +202,14 @@ def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): @staticmethod @T.macro def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings): - A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size), + # TODO(lei): alloc_buffer within the macro is not supported yet. + A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size_a), inst.a_dtype, scope="local") - B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size), + B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size_b), inst.b_dtype, scope="local") - for ki in T.serial(0, (inst.block_K // inst.micro_size_k)): + for ki in T.serial(0, (inst.chunk // inst.micro_size_k)): inst.LDMATRIX_A( inst, A_local_buf, diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index d0df62cfa..4910bdc4c 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -19,7 +19,7 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): # permutation on 4 banks, each bank has 32 bits bank_elems = BANK_SIZE_BYTES // dtype.bits new_col_idx_outer = None - print(f"coalescent_bits: {coalescent_bits}") + if coalescent_bits % 1024 == 0: # Use 8 * 8 permuted layout # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read diff --git a/docs/ExtendOperatorsWithDSL.md b/docs/ExtendOperatorsWithDSL.md index 279cc2490..8c717b43e 100644 --- a/docs/ExtendOperatorsWithDSL.md +++ b/docs/ExtendOperatorsWithDSL.md @@ -1,5 +1,6 @@ ### Using BitBLAS from DSL ```python +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.base.arch import CUDA from bitblas.base.utils import apply_and_build