From ca42750dbd172d4cdf6bedac76a2231016cad528 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 8 Dec 2024 18:05:54 +0000 Subject: [PATCH 01/17] fix for relax --- bitblas/relax/transform/apply_fast_tuning.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bitblas/relax/transform/apply_fast_tuning.py b/bitblas/relax/transform/apply_fast_tuning.py index 035c93d0d..f7eb7eb03 100644 --- a/bitblas/relax/transform/apply_fast_tuning.py +++ b/bitblas/relax/transform/apply_fast_tuning.py @@ -136,11 +136,13 @@ def transform_module( # pylint: disable=missing-function-docstring trace.apply_to_schedule(sch, remove_postproc=False) updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) continue + + specalized_function = func.with_attr("global_symbol", g_var.name_hint) - if check_func_with_dynamic(func): + if check_func_with_dynamic(specalized_function): dispatch_mod = fast_tune_with_dynamic_range( - func, + specalized_function, target=target, topk=self.topk, parallel_build=self.parallel_build, @@ -161,7 +163,7 @@ def transform_module( # pylint: disable=missing-function-docstring else: # otherwise is static shape analysis _, best = fast_tune( - func, + specalized_function, target=target, topk=self.topk, parallel_build=self.parallel_build, From 58fa7bfc049524d779e574342c49cd2d45e6accf Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 8 Dec 2024 18:06:48 +0000 Subject: [PATCH 02/17] lint fix --- bitblas/relax/transform/apply_fast_tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/relax/transform/apply_fast_tuning.py b/bitblas/relax/transform/apply_fast_tuning.py index f7eb7eb03..00ccf67f3 100644 --- a/bitblas/relax/transform/apply_fast_tuning.py +++ b/bitblas/relax/transform/apply_fast_tuning.py @@ -136,7 +136,7 @@ def transform_module( # pylint: disable=missing-function-docstring trace.apply_to_schedule(sch, remove_postproc=False) updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) continue - + specalized_function = func.with_attr("global_symbol", g_var.name_hint) if check_func_with_dynamic(specalized_function): From 8275513bdbf3d349630efe9ffd340d11d6b08b37 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 10 Dec 2024 17:25:53 +0000 Subject: [PATCH 03/17] save import bitblas time --- bitblas/base/arch/__init__.py | 76 +++++-------------- bitblas/base/arch/cdna.py | 4 + bitblas/base/arch/cpu.py | 4 + bitblas/base/arch/cuda.py | 93 +++++++++++++++++++++--- bitblas/base/roller/policy/tensorcore.py | 4 +- bitblas/gpu/matmul_analysis.py | 18 +---- bitblas/tl/mma_layout.py | 12 +++ 7 files changed, 127 insertions(+), 84 deletions(-) diff --git a/bitblas/base/arch/__init__.py b/bitblas/base/arch/__init__.py index c7d7af31b..dd931f617 100644 --- a/bitblas/base/arch/__init__.py +++ b/bitblas/base/arch/__init__.py @@ -1,15 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .arch_base import TileDevice -from .cuda import * -from .cpu import * -from .cdna import * +from .cuda import CUDA +from .cpu import CPU +from .cdna import CDNA from typing import Union +from tvm.target import Target -def get_arch(target: Union[str, tvm.target.Target] = "cuda") -> TileDevice: +def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: if isinstance(target, str): - target = tvm.target.Target(target) + target = Target(target) if target.kind.name == "cuda": return CUDA(target) @@ -27,57 +28,14 @@ def auto_infer_current_arch() -> TileDevice: return get_arch("cuda") -def is_cpu_arch(arch: TileDevice) -> bool: - return isinstance(arch, CPU) - - -def is_cuda_arch(arch: TileDevice) -> bool: - return isinstance(arch, CUDA) - - -def is_ampere_arch(arch: TileDevice) -> bool: - conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) - return all(conditions) - - -def is_volta_arch(arch: TileDevice) -> bool: - conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 70) - conditions.append(arch.sm_version < 80) - return all(conditions) - - -def is_cdna_arch(arch: TileDevice) -> bool: - return isinstance(arch, CDNA) - - -def has_mma_support(arch: TileDevice) -> bool: - conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80) - return all(conditions) - - -def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: - volta_tensorcore_supported = [ - ("float16", "float32"), - ("float16", "float16"), - ] - ampere_tensorcore_supported = [ - ("float16", "float32"), - ("float16", "float16"), - ("int8", "int32"), - ("int4", "int32"), - ("int2", "int32"), - ("int1", "int32"), - ] - - if is_volta_arch(arch): - return (in_dtype, accum_dtype) in volta_tensorcore_supported - elif is_ampere_arch(arch): - return (in_dtype, accum_dtype) in ampere_tensorcore_supported - else: - raise ValueError(f"Unsupported architecture: {arch}") +from .cpu import is_cpu_arch # noqa: F401 +from .cuda import ( + is_cuda_arch, # noqa: F401 + is_volta_arch, # noqa: F401 + is_ampere_arch, # noqa: F401 + is_ada_arch, # noqa: F401 + is_hopper_arch, # noqa: F401 + is_tensorcore_supported_precision, # noqa: F401 + has_mma_support, # noqa: F401 +) +from .cdna import is_cdna_arch # noqa: F401 diff --git a/bitblas/base/arch/cdna.py b/bitblas/base/arch/cdna.py index f6805dc98..cb49041db 100644 --- a/bitblas/base/arch/cdna.py +++ b/bitblas/base/arch/cdna.py @@ -7,6 +7,10 @@ from typing import List, Union +def is_cdna_arch(arch: TileDevice) -> bool: + return isinstance(arch, CDNA) + + class CDNA(TileDevice): def __init__(self, target: Union[Target, str]): diff --git a/bitblas/base/arch/cpu.py b/bitblas/base/arch/cpu.py index 65592cc7d..09a6391f1 100644 --- a/bitblas/base/arch/cpu.py +++ b/bitblas/base/arch/cpu.py @@ -6,6 +6,10 @@ from .arch_base import TileDevice +def is_cpu_arch(arch: TileDevice) -> bool: + return isinstance(arch, CPU) + + # For LLVM Backend, we do not provide the detailed information of the CPU # As the LLVM backend do not required tuning, just maintain the consistency class CPU(TileDevice): diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index 29c65e4a4..fb8c2bbf1 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm.target import Target from .arch_base import TileDevice -from typing import List, Dict, Union +from typing import List, Union def check_sm_version(arch: str) -> int: @@ -12,17 +12,94 @@ def check_sm_version(arch: str) -> int: return int(sm_version) if sm_version.isdigit() else -1 +def is_cuda_arch(arch: TileDevice) -> bool: + return isinstance(arch, CUDA) + + +def is_volta_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 70) + conditions.append(arch.sm_version < 80) + return all(conditions) + + +def is_ampere_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) + return all(conditions) + + +def is_ada_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version == 89) + return all(conditions) + + +def is_hopper_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version == 90) + return all(conditions) + + +def has_mma_support(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 80) + return all(conditions) + + +volta_tensorcore_supported = [ + ("float16", "float32"), + ("float16", "float16"), +] +ampere_tensorcore_supported = [ + ("float16", "float32"), + ("float16", "float16"), + ("int8", "int32"), + ("int4", "int32"), + ("int2", "int32"), + ("int1", "int32"), +] +ada_tensorcore_supported = [ + ("float16", "float32"), + ("float16", "float16"), + ("int8", "int32"), + ("e5m2_float8", "float32"), + ("e4m3_float8", "float32"), +] +hopper_tensorcore_supported = ada_tensorcore_supported + + +# TODO(lei): we should consider the dtype of the input a and b +# instead of assuming both a and b share the same dtype. +# As the tensorcore may supports e4m3_float8 * e5m2_float8 +def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: + + if is_volta_arch(arch): + return (in_dtype, accum_dtype) in volta_tensorcore_supported + elif is_ampere_arch(arch): + return (in_dtype, accum_dtype) in ampere_tensorcore_supported + elif is_ada_arch(arch): + return (in_dtype, accum_dtype) in ada_tensorcore_supported + elif is_hopper_arch(arch): + return (in_dtype, accum_dtype) in hopper_tensorcore_supported + else: + raise ValueError(f"Unsupported architecture: {arch}") + + class TensorInstruction(object): def __init__( self, name: str, - intrin_group: Dict, shape: List[int], ): self.name: str = name - self.intrin_group: Dict = intrin_group - # only maintain the shape of M and N + # only hold the shape of M and N self.shape: List[int] = shape @@ -58,13 +135,11 @@ def __init__(self, target: Union[Target, str]): self.available_tensor_instructions: List[TensorInstruction] = None def get_avaliable_tensorintrin_shapes(self): - from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group, get_mma_intrin_group - self.available_tensor_instructions = ( - TensorInstruction("mma", get_mma_intrin_group, [16, 16]), - TensorInstruction("wmma", get_wmma_intrin_group, [16, 16]), + TensorInstruction("mma", [16, 16]), + TensorInstruction("wmma", [16, 16]), ) return [t.shape for t in self.available_tensor_instructions] def __repr__(self): - return f"CUDA({self.target})" \ No newline at end of file + return f"CUDA({self.target})" diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 722095657..ef2dc8587 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -328,9 +328,9 @@ def _score(node, thread): # small is better # TODO: This is a dummy mul which avoid reusing some shared memory. # Should be removed in the future. if td.smem_cost > (self.arch.smem_cap): - info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ + debug_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ " use dynamic shared memory." - logger.info(info_message) + logger.debug(debug_message) codegen_dict.shared_scope = "shared.dyn" codegen_dict.shared_scope = "shared.dyn" diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 1f596ef9a..d43d95ffa 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -18,6 +18,7 @@ ) from tvm.target.target import Target from tvm.tir.stmt_functor import pre_order_visit +from bitblas.base.arch import get_arch, is_tensorcore_supported_precision import logging logger = logging.getLogger(__name__) @@ -527,8 +528,6 @@ def get_tensorized_func_and_tags( skip_normalize: bool = False, allow_gemv: bool = False, ) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ @@ -648,18 +647,9 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: - # TODO(lei): we should consider the dtype of the input a and b - # instead of assuming both a and b share the same dtype. - # As the tensorcore may supports e4m3_float8 * e5m2_float8 in_dtype, out_dtype = get_in_out_dtypes(block_stmt) - try: - _ = get_mma_intrin_group( - a_dtype=in_dtype, - b_dtype=in_dtype, - out_dtype=out_dtype, - ) - except Exception: - logger.debug("Cannot find the corresponding mma intrin group") + if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): + logger.debug("The input and output dtype is not supported by tensorcore") return func, None # reindex and transform functions @@ -697,7 +687,7 @@ def check_last_trait(region: List[Range]): def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, ) diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index 5dab4ba64..443da90eb 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -30,6 +30,18 @@ def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id): return row, col +def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): row = 8 * (local_id % 4 // 2) + (thread_id // 4) col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) From fb7de9baf84d65b0efe735bd1e2aa1e2ea77a8d8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 10 Dec 2024 18:08:24 +0000 Subject: [PATCH 04/17] bug fix for tl backend --- .../operators/benchmark_bitblas_matmul.py | 268 +++++++++--------- .../tilelang/dequantize/matmul_dequantize.py | 8 +- 2 files changed, 138 insertions(+), 138 deletions(-) diff --git a/benchmark/operators/benchmark_bitblas_matmul.py b/benchmark/operators/benchmark_bitblas_matmul.py index fb3927bdb..965c3640d 100644 --- a/benchmark/operators/benchmark_bitblas_matmul.py +++ b/benchmark/operators/benchmark_bitblas_matmul.py @@ -4,138 +4,132 @@ from bitblas.utils.target_detector import auto_detect_nvidia_target from bitblas import Matmul, MatmulConfig import argparse +import json - -# Initialize the parser -parser = argparse.ArgumentParser( - description="Benchmark BitBLAS int4 on a specific target." -) - -# Add arguments to the parser -parser.add_argument( - "--target", - type=str, - default=auto_detect_nvidia_target(), - help="Specify the target device for benchmarking." -) -parser.add_argument( - "--group_size", - type=int, - default=None, - help="Group size for grouped quantization." -) -parser.add_argument( - "--A_dtype", - type=str, - default="float16", - choices=["float16", "float32", "float64", "int32", "int8"], # Assuming these are the valid choices - help="Data type of activation A." -) -parser.add_argument( - "--W_dtype", - type=str, - default="int4", - choices=["float16", "float32", "float64", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1"], # Assuming these are the valid choices - help="Data type of weight W." -) -parser.add_argument( - "--accum_dtype", - type=str, - default="float16", - choices=["float16", "int32"], # Assuming these are the valid choices - help="Data type for accumulation." -) -parser.add_argument( - "--out_dtype", - type=str, - default="float16", - choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices - help="Data type for output." -) -parser.add_argument( - "--layout", - type=str, - default="nt", - choices=["nt", "nn"], # Assuming these are the valid choices - help="Matrix layout, 'nt' for non-transpose A and transpose W." -) -parser.add_argument( - "--with_bias", - action="store_true", - help="Include bias in the benchmark." -) -parser.add_argument( - "--with_scaling", - action="store_true", - help="Include scaling factor in the quantization." -) -parser.add_argument( - "--with_zeros", - action="store_true", - help="Include zeros in the quantization." -) -parser.add_argument( - "--zeros_mode", - type=str, - default=None, - choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable - help="Specify the mode for calculating zeros." -) - -# Parse the arguments -args = parser.parse_args() - -# Assign arguments to variables -target = args.target -group_size = args.group_size -A_dtype = args.A_dtype -W_dtype = args.W_dtype -accum_dtype = args.accum_dtype -out_dtype = args.out_dtype -layout = args.layout -with_bias = args.with_bias -group_size = args.group_size -with_scaling = args.with_scaling -with_zeros = args.with_zeros -zeros_mode = args.zeros_mode - -test_shapes = [ - # square test - (MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # BLOOM-176B - (MatmulConfig, Matmul, (1, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # OPT-65B - (MatmulConfig, Matmul, (1, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # LLAMA-70B/65B - (MatmulConfig, Matmul, (1, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - - # square test - (MatmulConfig, Matmul, (16384, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # BLOOM-176B - (MatmulConfig, Matmul, (8192, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # OPT-65B - (MatmulConfig, Matmul, (8192, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # LLAMA-70B/65B - (MatmulConfig, Matmul, (8192, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), -] +# Initialize the parser +parser = argparse.ArgumentParser( + description="Benchmark BitBLAS int4 on a specific target." +) + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking." +) +parser.add_argument( + "--group_size", + type=int, + default=None, + help="Group size for grouped quantization." +) +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "int8"], + help="Data type of activation A." +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=["float16", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1"], + help="Data type of weight W." +) +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], # Assuming these are the valid choices + help="Data type for accumulation." +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices + help="Data type for output." +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], # Assuming these are the valid choices + help="Matrix layout, 'nt' for non-transpose A and transpose W." +) +parser.add_argument( + "--with_bias", + action="store_true", + help="Include bias in the benchmark." +) +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization." +) +parser.add_argument( + "--with_zeros", + action="store_true", + help="Include zeros in the quantization." +) +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros." +) + +parser.add_argument( + "--backend", + type=str, + default="tl", + choices=["tir", "tl"], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros." +) + +default_test_shapes = json.dumps([ + ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] +]) + +parser.add_argument( + "--test_shapes", + type=str, + default=default_test_shapes, + help="JSON string defining test shapes. Example format: '[[\"MatmulConfig\", \"Matmul\", [1,16384,16384,\"float16\",\"int4\",\"float16\",\"float16\",\"nt\",false,null,false,false,null]]]'" +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +group_size = args.group_size +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode +backend = args.backend + +parsed_test_shapes = json.loads(args.test_shapes) +name_to_class = { + "MatmulConfig": MatmulConfig, + "Matmul": Matmul +} + +test_shapes = [] +for item in parsed_test_shapes: + config_class_name, operator_class_name, input_args = item + config_class = name_to_class[config_class_name] + operator_class = name_to_class[operator_class_name] + test_shapes.append((config_class, operator_class, tuple(input_args))) benchmark_sets = [] benchmark_sets.extend(test_shapes) @@ -145,11 +139,11 @@ benchmark_results = {} for config, operator, input_args in benchmark_sets: config = config(*input_args) - matmul = operator(config, target=target, enable_tuning=True) - kernel_latency = matmul.profile_latency() - if matmul.input_transform is not None: - kernel_latency += matmul.ladder_permutate_a.profile_latency() - + op_inst = operator(config, target=target, enable_tuning=True, backend=backend) + kernel_latency = op_inst.profile_latency() + if op_inst.input_transform is not None: + kernel_latency += op_inst.ladder_permutate_a.profile_latency() + print("Time cost is: {:.3f} ms".format(kernel_latency)) profile_config = { @@ -160,7 +154,7 @@ benchmark_results.update(profile_config) -# Define headers for the table +# Define headers for the table headers = [ "PrimFunc", "Input Arguments", diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 5a690b970..85189b22c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -153,7 +153,13 @@ def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: for scheduler in [ - self.matmul_dequantize_block_scheduler, + self.gemv_dequantize_simt_scheduler, + self.matmul_dequantize_simt_scheduler, + self.matmul_dequantize_block_scheduler, + self.matmul_dequantize_fine_grained_scheduler, + self.matmul_dequantize_weight_propagation_scheduler, + self.matmul_int4_dequantize_fine_grain_scheduler, + self.matmul_int4_dequantize_weight_propagation_scheduler, ]: if isinstance(hint, scheduler.TLHint): return scheduler From 02cf643197db59b75dff6116e238e4e8d45052b6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 11 Dec 2024 18:02:48 +0000 Subject: [PATCH 05/17] support input transform_kind --- .../operators/benchmark_bitblas_matmul.py | 107 ++------- bitblas/base/utils.py | 5 + bitblas/gpu/__init__.py | 1 + .../tilelang/dequantize/matmul_dequantize.py | 26 ++- bitblas/tl/mma_macro_generator.py | 161 +++++++++++++- .../tilelang/test_tilelang_mma_macro_gemm.py | 207 ++++++++++++++++++ 6 files changed, 404 insertions(+), 103 deletions(-) diff --git a/benchmark/operators/benchmark_bitblas_matmul.py b/benchmark/operators/benchmark_bitblas_matmul.py index 965c3640d..ab87e2ca6 100644 --- a/benchmark/operators/benchmark_bitblas_matmul.py +++ b/benchmark/operators/benchmark_bitblas_matmul.py @@ -1,97 +1,36 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +import bitblas from bitblas.utils.target_detector import auto_detect_nvidia_target from bitblas import Matmul, MatmulConfig import argparse import json +bitblas.set_log_level("DEBUG") # Initialize the parser -parser = argparse.ArgumentParser( - description="Benchmark BitBLAS int4 on a specific target." -) +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") # Add arguments to the parser parser.add_argument( "--target", type=str, default=auto_detect_nvidia_target(), - help="Specify the target device for benchmarking." -) -parser.add_argument( - "--group_size", - type=int, - default=None, - help="Group size for grouped quantization." -) -parser.add_argument( - "--A_dtype", - type=str, - default="float16", - choices=["float16", "int8"], - help="Data type of activation A." -) -parser.add_argument( - "--W_dtype", - type=str, - default="int4", - choices=["float16", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1"], - help="Data type of weight W." -) -parser.add_argument( - "--accum_dtype", - type=str, - default="float16", - choices=["float16", "int32"], # Assuming these are the valid choices - help="Data type for accumulation." -) -parser.add_argument( - "--out_dtype", - type=str, - default="float16", - choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices - help="Data type for output." -) -parser.add_argument( - "--layout", - type=str, - default="nt", - choices=["nt", "nn"], # Assuming these are the valid choices - help="Matrix layout, 'nt' for non-transpose A and transpose W." -) -parser.add_argument( - "--with_bias", - action="store_true", - help="Include bias in the benchmark." -) -parser.add_argument( - "--with_scaling", - action="store_true", - help="Include scaling factor in the quantization." -) -parser.add_argument( - "--with_zeros", - action="store_true", - help="Include zeros in the quantization." -) -parser.add_argument( - "--zeros_mode", - type=str, - default=None, - choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable - help="Specify the mode for calculating zeros." -) + help="Specify the target device for benchmarking.") parser.add_argument( "--backend", type=str, - default="tl", + default="tir", choices=["tir", "tl"], # Replace with actual modes if applicable - help="Specify the mode for calculating zeros." -) + help="Specify the mode for calculating zeros.") +# [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode] default_test_shapes = json.dumps([ - ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] + # ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] + [ + "MatmulConfig", "Matmul", + [1, 16384, 16384, "int8", "int8", "int32", "int32", "nt", False, None, False, False, None] + ] ]) parser.add_argument( @@ -106,23 +45,10 @@ # Assign arguments to variables target = args.target -group_size = args.group_size -A_dtype = args.A_dtype -W_dtype = args.W_dtype -accum_dtype = args.accum_dtype -out_dtype = args.out_dtype -layout = args.layout -with_bias = args.with_bias -with_scaling = args.with_scaling -with_zeros = args.with_zeros -zeros_mode = args.zeros_mode backend = args.backend parsed_test_shapes = json.loads(args.test_shapes) -name_to_class = { - "MatmulConfig": MatmulConfig, - "Matmul": Matmul -} +name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul} test_shapes = [] for item in parsed_test_shapes: @@ -139,12 +65,13 @@ benchmark_results = {} for config, operator, input_args in benchmark_sets: config = config(*input_args) + print(f"Running benchmark for {operator.__name__} with config: {config}") op_inst = operator(config, target=target, enable_tuning=True, backend=backend) kernel_latency = op_inst.profile_latency() if op_inst.input_transform is not None: kernel_latency += op_inst.ladder_permutate_a.profile_latency() - print("Time cost is: {:.3f} ms".format(kernel_latency)) + print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency)) profile_config = { f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { @@ -168,7 +95,9 @@ input_args = "-".join(args[1:]) col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0]) col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1])) - col_widths[2] = max(max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, col_widths[2]) + col_widths[2] = max( + max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, + col_widths[2]) break for i, header in enumerate(headers): diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 3a5b6a2e8..011ce6720 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -262,6 +262,11 @@ def tvm_callback_cuda_postproc(code, _): code = tensor_remove_make_int2(code) return code + # assume index_map to be registered + from tvm.tir.tensor_intrin.cuda import ( + get_mma_intrin_group, # noqa: F401 + ) + with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, diff --git a/bitblas/gpu/__init__.py b/bitblas/gpu/__init__.py index 72192f60e..750c42047 100644 --- a/bitblas/gpu/__init__.py +++ b/bitblas/gpu/__init__.py @@ -4,6 +4,7 @@ GPU-generic schedule rules. For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead """ + from .fallback import Fallback # noqa: F401 from .element_wise import ElementWise # noqa: F401 from .gemv import GEMV # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 85189b22c..78896180e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -153,13 +153,13 @@ def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: for scheduler in [ - self.gemv_dequantize_simt_scheduler, - self.matmul_dequantize_simt_scheduler, - self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.gemv_dequantize_simt_scheduler, + self.matmul_dequantize_simt_scheduler, + self.matmul_dequantize_block_scheduler, + self.matmul_dequantize_fine_grained_scheduler, + self.matmul_dequantize_weight_propagation_scheduler, + self.matmul_int4_dequantize_fine_grain_scheduler, + self.matmul_int4_dequantize_weight_propagation_scheduler, ]: if isinstance(hint, scheduler.TLHint): return scheduler @@ -248,5 +248,17 @@ def is_dynamic(self) -> bool: M, N, K = self.M, self.N, self.K return ((not isinstance(M, int)) or (not isinstance(N, int)) or (not isinstance(K, int))) + def __post_init__(self): + # Validate the matrix transpose settings + assert (self.trans_A is False), "Currently only support Matrix A not transposed" + assert (self.trans_B is True), "Currently only support Matrix B transposed" + assert (self.input_transform_kind == TransformKind.NonTransform + ), "Currently only support NonTransform for input" + + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) + return + __all__ = ["MatmulDequantizeScheduler"] diff --git a/bitblas/tl/mma_macro_generator.py b/bitblas/tl/mma_macro_generator.py index 8e238e0f4..f28233911 100644 --- a/bitblas/tl/mma_macro_generator.py +++ b/bitblas/tl/mma_macro_generator.py @@ -397,10 +397,112 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): else: raise ValueError("Unsupported transform_kind_b") - if self.transform_kind_a != TransformKind.NonTransform: - raise ValueError("TransformKind A is not supported yet") + assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3" + assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3" - assert transform_kind_b in [0, 3], "Currently only support 0 and 3" + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_dtype = self.a_dtype + a_transposed = self.a_transposed + transform_kind_a = self.transform_kind_a + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_bindings, + rk=0, + ): + stride = A_shared_buf.shape[-1] + tx, _, warp_m = self.extract_thread_binding(thread_bindings) + if transform_kind_a == TransformKind.NonTransform: + for i in T.serial(warp_rows): + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_buf[ + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ]), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + elif transform_kind_a == TransformKind.InterWarpTransform: + for i in T.serial(warp_rows): + # Assign B_shared_elem + ri, rj = ( + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_x, + (rj) // micro_size_k, + (ri) % micro_size_x, + (rj) % micro_size_k, + ) + args = (ni, nj, nii, njj) if transform_kind_a > 0 else (ri, rj) + A_shared_elem = A_shared_buf[args] + + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + elif transform_kind_a == TransformKind.IntraWarpTransform: + for i in T.serial(warp_rows): + # Assign B_shared_elem + ri, rj = ( + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_x, + (rj) // micro_size_k, + (ri) % micro_size_x, + (rj) % micro_size_k, + ) + A_shared_elem = A_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_elem), + tx * local_size_a, + ) + elif transform_kind_a == TransformKind.LDMatrixTransform: + for j in T.serial(warp_rows): + for local_id in T.vectorized(local_size_a): + # Assign A_shared_elem + ri, rj = ( + warp_m * warp_rows + j, + rk * (chunk // micro_size_k) + ki, + ) + rii, rjj = (tx * local_size_a + + local_id) // micro_size_k, (tx * local_size_a + local_id) % ( + micro_size_k) + A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj]) + else: + raise ValueError("Unsupported TransformKind for Input A") + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): warp_col_tiles = self.warp_col_tiles @@ -425,7 +527,26 @@ def _warp_ldmatrix_b( stride = B_shared_buf.shape[-1] tx, warp_n, _ = self.extract_thread_binding(thread_bindings) - if transform_kind_b < TransformKind.LDMatrixTransform: + if transform_kind_b == TransformKind.NonTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + B_shared_elem = B_shared_buf[ri, rj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + ) + elif transform_kind_b == TransformKind.InterWarpTransform: for j in T.serial(warp_cols): # Assign B_shared_elem ri, rj = ( @@ -438,8 +559,7 @@ def _warp_ldmatrix_b( (ri) % micro_size_y, (rj) % micro_size_k, ) - args = (ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj) - B_shared_elem = B_shared_buf[args] + B_shared_elem = B_shared_buf[ni, nj, nii, njj] T.ptx_ldmatrix( b_dtype, @@ -451,7 +571,32 @@ def _warp_ldmatrix_b( T.address_of(B_shared_elem), get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), ) - else: + elif transform_kind_b == TransformKind.IntraWarpTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_y, + (rj) // micro_size_k, + (ri) % micro_size_y, + (rj) % micro_size_k, + ) + B_shared_elem = B_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + tx * local_size_b, + ) + elif transform_kind_b == TransformKind.LDMatrixTransform: local_size_dequantize = local_size_b // num_elems_per_byte for j in T.serial(warp_cols): for local_id in T.vectorized(local_size_dequantize): @@ -466,6 +611,8 @@ def _warp_ldmatrix_b( micro_size_k // num_elems_per_byte) B_local_buf[j * local_size_dequantize + local_id] = ( B_shared_buf[ri, rj, rii, rjj]) + else: + raise ValueError("Unsupported TransformKind for Input B") return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index 821e3aa25..d343d8078 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -885,6 +885,208 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) +@simplify_prim_func +def tl_matmul_with_ladder_input_weight_transform( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + transform_a, + transform_b, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + assert transform_a > 0, "transform_a should be greater than 0" + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + + warp_rows = 1 + warp_cols = 1 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M // micro_size_x, block_K // micro_size_k, micro_size_x, micro_size_k) + B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_a=transform_a, + transform_kind_b=transform_b) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k, ii, kk in T.Parallel(block_M // micro_size_x, block_K // micro_size_k, + micro_size_x, micro_size_k): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main + + +def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_dtype, out_dtype, + accum_dtype, transform_a, + transform_b): + matmul = tl_matmul_with_ladder_input_weight_transform(M, N, K, in_dtype, out_dtype, accum_dtype, + transform_a, transform_b) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config_A = bitblas.ops.LadderPermutateConfig( + M=M, + N=K, + datatype=in_dtype, + storage_dtype=in_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=transform_a, + ) + + ladder_permutate_a = bitblas.ops.LadderPermutate(ladder_permutate_config_A) + + ladder_permutate_config_B = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + datatype=in_dtype, + storage_dtype=in_dtype, + propagate_kind="B", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate_b = bitblas.ops.LadderPermutate(ladder_permutate_config_B) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + LA = ladder_permutate_a(A.cpu()).cuda() + LB = ladder_permutate_b(B.cpu()).cuda() + + mod(LA, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + # print("Ref C: ", ref_c) + # print("C: ", C) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + def test_assert_tl_matmul(): assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") @@ -908,5 +1110,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): 256, 1024, 512, "float16", "float16", "float16", 3) +def test_assert_tl_matmul_with_ladder_input_weight_transform(): + assert_tl_matmul_with_ladder_input_weight_transform_correctness(256, 256, 256, "float16", + "float16", "float16", 2, 3) + + if __name__ == "__main__": bitblas.testing.main() From 65fb3b459be7960933c3d45c9dd6afa46bc3eb10 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 11 Dec 2024 20:46:14 +0000 Subject: [PATCH 06/17] hint identifier --- bitblas/base/base_scheduler.py | 3 + bitblas/base/utils.py | 2 +- .../tilelang/dense/gemv_simt.py | 5 + .../general_matmul/tilelang/dense/matmul.py | 10 +- .../tilelang/dense/matmul_simt.py | 5 + .../tilelang/dense/matmul_tensorcore.py | 116 +++++++++++++----- .../tilelang/dense/matmul_wmma.py | 2 + ...atmul_dequantize_tensorcore_finegrained.py | 3 + ..._dequantize_tensorcore_weight_transform.py | 104 ++++++++++++---- bitblas/ops/operator.py | 2 +- bitblas/tl/base_hint.py | 3 + bitblas/tl/tuner.py | 2 +- 12 files changed, 201 insertions(+), 56 deletions(-) diff --git a/bitblas/base/base_scheduler.py b/bitblas/base/base_scheduler.py index 37b75785a..c798402e0 100644 --- a/bitblas/base/base_scheduler.py +++ b/bitblas/base/base_scheduler.py @@ -113,6 +113,9 @@ def apply_config( ) -> PrimFunc: pass + def get_hint_type(self) -> str: + raise NotImplementedError("Get Hint type is not implemented") + def serialize_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]: # Convert Roller Hints to TileLang Hints raise NotImplementedError("Serialization of hints to configs is not implemented") diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 011ce6720..b957be59e 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -222,7 +222,7 @@ def apply_and_build_parallel(func, arch, num_repeats=3, max_workers=10, - timeout=30, + timeout=60, data_distribution="uniform") -> CompileResult: cpresults = [] diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index cbd5a4e3f..cdc29026a 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -23,6 +23,8 @@ class GemvFineGrainSIMTScheduler(MatmulSIMTBaseScheduler): reduce_thread: int = 16 class TLHint(BaseTLHint): + + hint_type: str = "GemvFineGrainSIMTScheduler" def __init__(self): super().__init__() @@ -56,6 +58,9 @@ def __repr__(self): f"reduce_thread: {self.reduce_thread}, " "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 0881c65f6..7ff325b1c 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -146,8 +146,14 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.matmul_int4_fine_grain_scheduler, self.matmul_int4_weight_propagation_scheduler, ]: - if isinstance(hint, scheduler.TLHint): - return scheduler + try: + scheduler_hint_type = scheduler.get_hint_type() + if scheduler_hint_type == hint.hint_type: + return scheduler + except NotImplementedError: + raise ValueError( + f"get_hint_type() is not implemented for {type(scheduler)}") + raise ValueError(f"Unsupported hint type: {type(hint)}") def with_default_config(self, arch: Optional[TileDevice] = None) -> PrimFunc: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 782c9fdb7..9b60d547d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -74,6 +74,8 @@ class MatmulFineGrainSIMTScheduler(MatmulSIMTBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulFineGrainSIMTScheduler" + def __init__(self): super().__init__() @@ -119,6 +121,9 @@ def __repr__(self): f"chunk: {self.chunk}" "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 653ad894c..317f2b2a0 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -84,6 +84,8 @@ class MatmulBlockScheduler(MatmulBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): + + hint_type = "MatmulBlockScheduler" def __init__(self): super().__init__() @@ -161,6 +163,9 @@ def get_configs_sm80(self): configs = [{**c, 'num_stages': num_stages} for c in configs] return configs + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: @@ -274,6 +279,8 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulFineGrainScheduler" + def __init__(self): super().__init__() @@ -332,6 +339,9 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: @@ -555,6 +565,9 @@ class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler): # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + class TLHint(MatmulFineGrainScheduler.TLHint): + hint_type: str = "MatmulWeightPropagationScheduler" + def apply_config( self, block_row_warps=2, @@ -573,6 +586,7 @@ def apply_config( trans_A, trans_B = self.trans_A, self.trans_B in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype with_bias = self.with_bias + input_transform_kind, weight_transform_kind = self.input_transform_kind, self.weight_transform_kind # Calculate the micro size per warp using a helper function micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) @@ -581,21 +595,28 @@ def apply_config( block_N = block_col_warps * warp_col_tiles block_K = chunk - # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if in_dtype == "float16" else 16 - - can_swizzle_a = block_K * DataType(in_dtype).bits == 512 - apply_pad_a = not can_swizzle_a - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) # Define the shapes of matrices and shared memory buffers - A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) C_shape = (M, N) Bias_shape = (N,) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -634,7 +655,8 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, - transform_kind_b=self.weight_transform_kind, + transform_kind_a=input_transform_kind, + transform_kind_b=weight_transform_kind, ) cache_write_required = self.check_require_cache() @@ -663,8 +685,8 @@ def main( # Apply memory layout optimizations T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), + A_shared: make_swizzle_layout(A_shared, is_smooth=is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_smooth=is_b_smooth), }) T.use_swizzle(panel_size=10, enable=enable_rasterization) @@ -675,9 +697,16 @@ def main( # Main matrix multiplication pipeline with multiple stages for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load A matrix into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B matrix into shared memory for j, k, jj, kk in T.Parallel( @@ -753,10 +782,19 @@ def main( return self.post_process(main) + @property + def is_a_smooth(self): + return self.input_transform_kind > TransformKind.NonTransform + + @property + def is_b_smooth(self): + return self.weight_transform_kind > TransformKind.NonTransform + def __post_init__(self): # Validate the matrix transpose settings assert self.trans_A is False, "Currently only support Matrix A not transposed" assert self.trans_B is True, "Currently only support Matrix B transposed" + assert self.weight_transform_kind > TransformKind.NonTransform, "Weight Transform Kind is required" return @@ -764,6 +802,10 @@ def __post_init__(self): @dataclass class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): + @dataclass + class TLHint(MatmulFineGrainScheduler.TLHint): + hint_type: str = "MatmulINT4FineGrainScheduler" + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -972,6 +1014,9 @@ def __post_init__(self): @dataclass class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): + class TLHint(MatmulWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4WeightPropagationScheduler" + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -1047,10 +1092,23 @@ def apply_config( can_swizzle_a = block_K * DataType(storage_dtype).bits == 512 apply_pad_a = not can_swizzle_a + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + # Define the shapes of matrices and shared memory buffers - A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -1089,6 +1147,7 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, + transform_kind_a=self.input_transform_kind, transform_kind_b=self.weight_transform_kind, ) @@ -1116,8 +1175,8 @@ def main( # Apply memory layout optimizations T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - # B_shared: make_swizzle_layout(B_shared), + A_shared: make_swizzle_layout(A_shared, is_smooth=is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_smooth=is_b_smooth), }) T.use_swizzle(panel_size=10, enable=enable_rasterization) @@ -1128,9 +1187,16 @@ def main( # Main matrix multiplication pipeline with multiple stages for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load A matrix into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B matrix into shared memory for j, k, jj, kk in T.Parallel( @@ -1398,17 +1464,11 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( block_N = block_col_warps * warp_col_tiles block_K = chunk - # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if in_dtype == "float16" else 16 - - can_swizzle_a = block_K * DataType(in_dtype).bits == 512 - apply_pad_a = not can_swizzle_a - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + A_shared_shape = (block_M, block_K) B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py index b14a236ae..447e7e47b 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py @@ -31,6 +31,8 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulFineGrainScheduler" + def __init__(self): super().__init__() diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index f2f462926..da649477f 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -403,6 +403,9 @@ def general_dequant_matmul( @dataclass class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedScheduler): + class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): + pass + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index c7c452821..31e5929a2 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -38,6 +38,9 @@ class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedSche # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): + pass + def apply_config( self, block_row_warps: Optional[int] = None, @@ -59,12 +62,14 @@ def apply_config( N, K = self.N, self.K assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" trans_A, trans_B = self.trans_A, self.trans_B + input_transform_kind = self.input_transform_kind weight_transform_kind = self.weight_transform_kind assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - assert (weight_transform_kind == TransformKind.LDMatrixTransform - ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" + assert ( + weight_transform_kind == TransformKind.LDMatrixTransform + ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" in_dtype, out_dtype, accum_dtype = ( self.in_dtype, @@ -101,7 +106,6 @@ def apply_config( if group_size == -1: group_size = K - A_shape = (M, K) B_shape = ( N // micro_size_y, K // micro_size_k, @@ -115,7 +119,21 @@ def apply_config( C_shape = (M, N) Bias_shape = (N,) - A_shared_shape = (block_M, block_K) + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -164,6 +182,7 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, + transform_kind_a=input_transform_kind, transform_kind_b=weight_transform_kind, num_elems_per_byte=num_elems_per_byte, ) @@ -200,7 +219,8 @@ def general_dequant_matmul( tx = T.thread_binding(0, threads, thread="threadIdx.x") T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), + A_shared: make_swizzle_layout(A_shared, is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_b_smooth), }) T.use_swizzle(10, enable=enable_rasterization) @@ -211,7 +231,16 @@ def general_dequant_matmul( for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_shared) + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * vec_load_qb)): @@ -530,19 +559,13 @@ def _normal_fast_dequant_impl( else: T.call_extern( func_name, - T.address_of( - compressed_weight_local[ - j * local_size // num_elems_per_byte - ] - ), + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), T.address_of(dequant_weight_local[j * local_size]), T.address_of(scale_buffer[remapped_i, remapped_j]), - T.address_of( - qzeros_buffer[ - qzeros_remapped_i, - (qzeros_remapped_j // num_elems_per_byte), - ] - ), + T.address_of(qzeros_buffer[ + qzeros_remapped_i, + (qzeros_remapped_j // num_elems_per_byte), + ]), local_size * grouped_k, local_size // num_elems_per_byte, qzeros_remapped_j % num_elems_per_byte, @@ -593,10 +616,21 @@ def get_param_indices( return new_indices + @property + def is_a_smooth(self): + return self.input_transform_kind > TransformKind.NonTransform + + @property + def is_b_smooth(self): + return self.weight_transform_kind > TransformKind.NonTransform + @dataclass class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): + class TLHint(MatmulDequantizeWeightPropagationScheduler.TLHint): + pass + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -678,8 +712,9 @@ def apply_config( assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - assert (weight_transform_kind == TransformKind.LDMatrixTransform - ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" + assert ( + weight_transform_kind == TransformKind.LDMatrixTransform + ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" in_dtype, out_dtype, accum_dtype = ( self.in_dtype, @@ -718,7 +753,21 @@ def apply_config( if group_size == -1: group_size = K - A_shape = (M, K) + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + B_shape = ( N // micro_size_y, K // micro_size_k, @@ -731,7 +780,6 @@ def apply_config( micro_size_y, micro_size_k, ) - A_shared_shape = (block_M, block_K) B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -810,7 +858,8 @@ def general_dequant_matmul( tx = T.thread_binding(0, threads, thread="threadIdx.x") T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), + A_shared: make_swizzle_layout(A_shared, is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_b_smooth), }) T.use_swizzle(10, enable=enable_rasterization) @@ -821,7 +870,16 @@ def general_dequant_matmul( for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_shared) + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B into shared memory # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index efba5773b..856bc809a 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -383,7 +383,7 @@ def hardware_aware_finetune( def get_profile_tensors(self, dynamic_symbolic_constraints: Optional[Dict] = None): if dynamic_symbolic_constraints is None: dynamic_symbolic_constraints = {} - func = self.prim_func + func = retrieve_func_from_module(self.scheduled_ir_module) device = self.arch.device def var_warpper(v): diff --git a/bitblas/tl/base_hint.py b/bitblas/tl/base_hint.py index 24c9bba2f..752ea4145 100644 --- a/bitblas/tl/base_hint.py +++ b/bitblas/tl/base_hint.py @@ -8,6 +8,9 @@ # Base class for Tensor Layout Hints that defines the interface and common functionality for derived classes. class BaseTLHint(ABC): + # hint identifier + hint_type: str = "base" + # Constructor for the BaseTLHint class, takes variable arguments (*args and **kwargs) to allow flexibility. def __init__(self, *args, **kwargs): # Calls the superclass constructor (useful in complex inheritance hierarchies). diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index e58aaf500..a275e91f6 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -72,7 +72,7 @@ def apply_and_build_parallel(scheduler, arch, num_repeats=3, max_workers=10, - timeout=30, + timeout=60, data_distribution="uniform") -> CompileResult: cpresults = [] From ad7bc1c105e63b3e8c0d7aa9527dc10174a2ccc9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 11 Dec 2024 20:49:41 +0000 Subject: [PATCH 07/17] annotate hint type for dequantize --- .../tilelang/dequantize/gemv_dequantize_simt.py | 5 +++++ .../tilelang/dequantize/matmul_dequantize.py | 10 ++++++++-- .../tilelang/dequantize/matmul_dequantize_simt.py | 5 +++++ .../dequantize/matmul_dequantize_tensorcore.py | 4 ++++ .../matmul_dequantize_tensorcore_finegrained.py | 7 ++++++- .../matmul_dequantize_tensorcore_weight_transform.py | 4 ++-- 6 files changed, 30 insertions(+), 5 deletions(-) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 37216ecfc..54e33cd4c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -25,6 +25,8 @@ class GemvDequantizeSIMTScheduler(MatmulDequantizeSIMTBaseScheduler): reduce_thread: int = 32 class TLHint(BaseTLHint): + + hint_type: str = "GemvDequantizeSIMTScheduler" def __init__(self): super().__init__() @@ -58,6 +60,9 @@ def __repr__(self): f"reduce_thread: {self.reduce_thread}, " "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 78896180e..310d59fc8 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -161,8 +161,14 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.matmul_int4_dequantize_fine_grain_scheduler, self.matmul_int4_dequantize_weight_propagation_scheduler, ]: - if isinstance(hint, scheduler.TLHint): - return scheduler + try: + scheduler_hint_type = scheduler.get_hint_type() + if scheduler_hint_type == hint.hint_type: + return scheduler + except NotImplementedError: + raise ValueError( + f"get_hint_type() is not implemented for {type(scheduler)}") + raise ValueError(f"Unsupported hint type: {type(hint)}") def with_default_config(self, arch: Optional[TileDevice] = None) -> PrimFunc: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 54a0c54e7..2cd0143c2 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -442,6 +442,8 @@ class MatmulDequantizeSIMTScheduler(MatmulDequantizeSIMTBaseScheduler): chunk: int = 16 # Usually determines the K-dimension split size class TLHint(BaseTLHint): + + hint_type = "MatmulDequantizeSIMTScheduler" def __init__(self): super().__init__() @@ -488,6 +490,9 @@ def __repr__(self): f"chunk: {self.chunk}" "}") + def get_hint_type(self) -> str: + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index 85188a4b6..bd834d8f6 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -478,6 +478,7 @@ class MatmulDequantizeBlockScheduler(MatmulDequantizeBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): + hint_type: str = "MatmulDequantizeBlockScheduler" def __init__(self): super().__init__() @@ -532,6 +533,9 @@ def __repr__(self): "}" ) + def get_hint_type(self) -> str: + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index da649477f..f7714460d 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -49,6 +49,8 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): + + hint_type: str = "MatmulDequantizeFineGrainedScheduler" def __init__(self): super().__init__() @@ -108,6 +110,9 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") + def get_hint_type(self) -> str: + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: @@ -404,7 +409,7 @@ def general_dequant_matmul( class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedScheduler): class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): - pass + hint_type: str = "MatmulINT4DequantizeFineGrainedScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index 31e5929a2..1ed2d1fd3 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -39,7 +39,7 @@ class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedSche weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): - pass + hint_type: str = "MatmulDequantizeWeightPropagationScheduler" def apply_config( self, @@ -629,7 +629,7 @@ def is_b_smooth(self): class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): class TLHint(MatmulDequantizeWeightPropagationScheduler.TLHint): - pass + hint_type: str = "MatmulINT4DequantizeWeightPropagationScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" From d635713ca6c65f44b9a54ca8345af4b13443ad3a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 12 Dec 2024 08:49:34 +0000 Subject: [PATCH 08/17] enhance swizzling --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 8e2f4bf39..0f7ee3db4 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8e2f4bf391ef4a4c48f73a0e05a31b84047c16d9 +Subproject commit 0f7ee3db4ccd616d57fc66b86f31a275952dd371 From a3e97dedc3182ee1fd1631eaf3671618a6c11d93 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 12 Dec 2024 10:39:20 +0000 Subject: [PATCH 09/17] Enhance for hardware aware tuning --- 3rdparty/tvm | 2 +- .../operators/benchmark_bitblas_matmul.py | 14 +- bitblas/base/base_scheduler.py | 2 +- bitblas/base/utils.py | 5 - .../tilelang/dense/gemv_simt.py | 2 +- .../general_matmul/tilelang/dense/matmul.py | 5 +- .../tilelang/dense/matmul_tensorcore.py | 20 +- .../tilelang/dequantize/base.py | 2 - .../dequantize/gemv_dequantize_simt.py | 4 +- .../tilelang/dequantize/matmul_dequantize.py | 7 +- .../dequantize/matmul_dequantize_simt.py | 2 +- .../matmul_dequantize_tensorcore.py | 174 +++++++----------- ...atmul_dequantize_tensorcore_finegrained.py | 10 +- ..._dequantize_tensorcore_weight_transform.py | 6 +- 14 files changed, 98 insertions(+), 157 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 0f7ee3db4..5ec6171dd 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0f7ee3db4ccd616d57fc66b86f31a275952dd371 +Subproject commit 5ec6171dd779b3bb80634fc950bd32b4bca12659 diff --git a/benchmark/operators/benchmark_bitblas_matmul.py b/benchmark/operators/benchmark_bitblas_matmul.py index ab87e2ca6..9a40db4ad 100644 --- a/benchmark/operators/benchmark_bitblas_matmul.py +++ b/benchmark/operators/benchmark_bitblas_matmul.py @@ -24,12 +24,17 @@ choices=["tir", "tl"], # Replace with actual modes if applicable help="Specify the mode for calculating zeros.") +parser.add_argument("--verbose", type=bool, default=True, help="Enable verbose logging.") + # [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode] default_test_shapes = json.dumps([ # ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] [ "MatmulConfig", "Matmul", - [1, 16384, 16384, "int8", "int8", "int32", "int32", "nt", False, None, False, False, None] + [ + 16384, 16384, 16384, "float16", "float16", "float16", "float16", "nt", False, None, + False, False, None + ] ] ]) @@ -46,6 +51,7 @@ # Assign arguments to variables target = args.target backend = args.backend +verbose = args.verbose parsed_test_shapes = json.loads(args.test_shapes) name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul} @@ -73,6 +79,10 @@ print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency)) + if verbose: + print(op_inst.scheduled_ir_module) + print(op_inst.get_source()) + profile_config = { f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { "BitBLAS_top20_latency": kernel_latency, @@ -116,4 +126,4 @@ input_args, f"{values['BitBLAS_top20_latency']:.3f} ms", ] - print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) + "\n") + print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)])) diff --git a/bitblas/base/base_scheduler.py b/bitblas/base/base_scheduler.py index c798402e0..d901a4192 100644 --- a/bitblas/base/base_scheduler.py +++ b/bitblas/base/base_scheduler.py @@ -115,7 +115,7 @@ def apply_config( def get_hint_type(self) -> str: raise NotImplementedError("Get Hint type is not implemented") - + def serialize_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]: # Convert Roller Hints to TileLang Hints raise NotImplementedError("Serialization of hints to configs is not implemented") diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index b957be59e..92822d1bd 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -262,11 +262,6 @@ def tvm_callback_cuda_postproc(code, _): code = tensor_remove_make_int2(code) return code - # assume index_map to be registered - from tvm.tir.tensor_intrin.cuda import ( - get_mma_intrin_group, # noqa: F401 - ) - with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index cdc29026a..190360c8f 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -23,7 +23,7 @@ class GemvFineGrainSIMTScheduler(MatmulSIMTBaseScheduler): reduce_thread: int = 16 class TLHint(BaseTLHint): - + hint_type: str = "GemvFineGrainSIMTScheduler" def __init__(self): diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 7ff325b1c..f81dd3d1d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -150,9 +150,8 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: scheduler_hint_type = scheduler.get_hint_type() if scheduler_hint_type == hint.hint_type: return scheduler - except NotImplementedError: - raise ValueError( - f"get_hint_type() is not implemented for {type(scheduler)}") + except NotImplementedError as e: + raise ValueError(f"get_hint_type() is not implemented for {type(scheduler)}") from e raise ValueError(f"Unsupported hint type: {type(hint)}") diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 317f2b2a0..86cfbb22b 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -84,7 +84,7 @@ class MatmulBlockScheduler(MatmulBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): - + hint_type = "MatmulBlockScheduler" def __init__(self): @@ -704,7 +704,8 @@ def main( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) @@ -790,14 +791,6 @@ def is_a_smooth(self): def is_b_smooth(self): return self.weight_transform_kind > TransformKind.NonTransform - def __post_init__(self): - # Validate the matrix transpose settings - assert self.trans_A is False, "Currently only support Matrix A not transposed" - assert self.trans_B is True, "Currently only support Matrix B transposed" - assert self.weight_transform_kind > TransformKind.NonTransform, "Weight Transform Kind is required" - - return - @dataclass class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): @@ -943,7 +936,7 @@ def main( # Apply memory layout optimizations T.annotate_layout({ A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared, is_smooth=True), + B_shared: make_swizzle_layout(B_shared), }) # Optional rasterization for L2 locality enhancement @@ -1016,7 +1009,7 @@ class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): class TLHint(MatmulWeightPropagationScheduler.TLHint): hint_type: str = "MatmulINT4WeightPropagationScheduler" - + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -1194,7 +1187,8 @@ def main( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/base.py b/bitblas/ops/general_matmul/tilelang/dequantize/base.py index efe5627de..3d27f0703 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/base.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/base.py @@ -76,8 +76,6 @@ def __post_init__(self): # Validate the matrix transpose settings assert (self.trans_A is False), "Currently only support Matrix A not transposed" assert (self.trans_B is True), "Currently only support Matrix B transposed" - assert (self.input_transform_kind == TransformKind.NonTransform - ), "Currently only support NonTransform for input" # Legalize group_size if self.with_scaling and self.group_size == -1: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 54e33cd4c..0d838661c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -25,7 +25,7 @@ class GemvDequantizeSIMTScheduler(MatmulDequantizeSIMTBaseScheduler): reduce_thread: int = 32 class TLHint(BaseTLHint): - + hint_type: str = "GemvDequantizeSIMTScheduler" def __init__(self): @@ -62,7 +62,7 @@ def __repr__(self): def get_hint_type(self): return self.TLHint.hint_type - + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 310d59fc8..9716ac075 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -165,9 +165,8 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: scheduler_hint_type = scheduler.get_hint_type() if scheduler_hint_type == hint.hint_type: return scheduler - except NotImplementedError: - raise ValueError( - f"get_hint_type() is not implemented for {type(scheduler)}") + except NotImplementedError as e: + raise ValueError(f"get_hint_type() is not implemented for {type(scheduler)}") from e raise ValueError(f"Unsupported hint type: {type(hint)}") @@ -258,8 +257,6 @@ def __post_init__(self): # Validate the matrix transpose settings assert (self.trans_A is False), "Currently only support Matrix A not transposed" assert (self.trans_B is True), "Currently only support Matrix B transposed" - assert (self.input_transform_kind == TransformKind.NonTransform - ), "Currently only support NonTransform for input" # Legalize group_size if self.with_scaling and self.group_size == -1: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 2cd0143c2..4bdb26f6d 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -442,7 +442,7 @@ class MatmulDequantizeSIMTScheduler(MatmulDequantizeSIMTBaseScheduler): chunk: int = 16 # Usually determines the K-dimension split size class TLHint(BaseTLHint): - + hint_type = "MatmulDequantizeSIMTScheduler" def __init__(self): diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index bd834d8f6..9f0ac8165 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -10,8 +10,7 @@ from bitblas.base.utils import get_roller_hints_from_func from dataclasses import dataclass from bitblas.ops.general_matmul.tirscript import ( - matmul_dequantize_select_implementation, -) + matmul_dequantize_select_implementation,) from bitblas.base.operator_common import QuantizationMemoryStage from bitblas.tl.base_hint import BaseTLHint from bitblas.quantization import ( @@ -99,23 +98,17 @@ def naive_cast_dequant(x): return x.astype(in_dtype) if with_zeros and zeros_mode == "quantized": - dequant_func = _tir_packed_to_unsigned_convert_with_zeros( - storage_type, storage_nbit - ) + dequant_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) elif source_format == "uint": if num_bits == 8: # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant else: - dequant_func = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit - ) + dequant_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) elif source_format == "int": if num_bits == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - dequant_func = _tir_packed_int_to_int_convert( - storage_type, storage_nbit - ) + dequant_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) elif num_bits == 8: # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant @@ -198,16 +191,15 @@ def _normal_dequant_impl( vi = index // stride_k vj = index % stride_k dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit - )( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) else: raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") @@ -230,19 +222,14 @@ def _normal_dequant_impl( compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) - * scale_local[v // group_size] - ) + ) * scale_local[v // group_size]) elif zeros_mode == "original": - dequant_weight_local[v] = ( - self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - - zeros_local[v // group_size] - ) * scale_local[v // group_size] + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) - zeros_local[v // group_size]) * scale_local[v // group_size] elif zeros_mode == "rescale": dequant_weight_local[v] = ( self._decode_func( @@ -250,20 +237,15 @@ def _normal_dequant_impl( compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) - * scale_local[v // group_size] - - zeros_local[v // group_size] - ) + ) * scale_local[v // group_size] - zeros_local[v // group_size]) elif zeros_mode == "quantized": - dequant_weight_local[v] = ( - self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - zero=dequant_qzeros_local[v // group_size], - dtype=in_dtype, - ) - ) * scale_local[v // group_size] + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros_local[v // group_size], + dtype=in_dtype, + )) * scale_local[v // group_size] else: raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") @@ -343,16 +325,15 @@ def _normal_fast_dequant_impl( vi = index // stride_k vj = index % stride_k dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit - )( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) else: raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") @@ -440,24 +421,10 @@ def dequantize( threads, ) else: - return self._normal_dequant( - compressed_weight_local, - scale_local, - zeros_local, - dequant_qzeros_local, - dequant_weight_local, - scale_buffer, - zeros_buffer, - qzeros_buffer, - local_size, - pid_n, - tx, - k, - i, - stride_n, - stride_k, - threads - ) + return self._normal_dequant(compressed_weight_local, scale_local, zeros_local, + dequant_qzeros_local, dequant_weight_local, scale_buffer, + zeros_buffer, qzeros_buffer, local_size, pid_n, tx, k, i, + stride_n, stride_k, threads) @property def num_elems_per_byte(self): @@ -522,16 +489,14 @@ def get_config_params(self): } def __repr__(self): - return ( - "{" - f"block_M={self.block_M}," - f"block_N={self.block_N}," - f"block_K={self.block_K}," - f"num_stages={self.num_stages}," - f"threads={self.threads}," - f"enable_rasterization={self.enable_rasterization}" - "}" - ) + return ("{" + f"block_M={self.block_M}," + f"block_N={self.block_N}," + f"block_K={self.block_K}," + f"num_stages={self.num_stages}," + f"threads={self.threads}," + f"enable_rasterization={self.enable_rasterization}" + "}") def get_hint_type(self) -> str: return self.TLHint.hint_type @@ -578,9 +543,7 @@ def apply_config( M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K - assert isinstance(N, int) and isinstance( - K, int - ), "Do not support dynamic N and K Currently" + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" trans_A, trans_B = self.trans_A, self.trans_B @@ -652,18 +615,17 @@ def apply_config( @T.prim_func def general_shared_dequant_matmul( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - LUT: T.Buffer(LUT_shape, in_dtype), - Scale: T.Buffer(Scale_shape, in_dtype), - Qzeros: T.Buffer(Qzeros_shape, storage_dtype), - Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), - Bias: T.Buffer(Bias_shape, in_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + LUT: T.Buffer(LUT_shape, in_dtype), + Scale: T.Buffer(Scale_shape, in_dtype), + Qzeros: T.Buffer(Qzeros_shape, storage_dtype), + Zeros: T.Buffer(Zeros_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), + Bias: T.Buffer(Bias_shape, in_dtype), ): with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads - ) as (bx, by): + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_local([local_size_compressed], storage_dtype) @@ -671,9 +633,7 @@ def general_shared_dequant_matmul( zeros_local = T.alloc_local([local_zeros_size], in_dtype) dequant_qzeros_local = T.alloc_local([local_qzeros_size], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, in_dtype - ) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) C_shared = T.alloc_shared([block_M, block_N], out_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -689,18 +649,12 @@ def general_shared_dequant_matmul( T.copy(A[by * block_M, ko * block_K], A_shared) T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) - for i in T.serial( - block_N - * block_K - // num_elems_per_byte - // (threads * local_size_compressed) - ): + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = ( - i * threads * local_size_compressed - + tx * local_size_compressed - + v - ) + i * threads * local_size_compressed + tx * local_size_compressed + + v) vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index f7714460d..a0242f99b 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -21,14 +21,6 @@ MatmulDequantizeBaseScheduler, # noqa: F401 ) from bitblas.tl.base_hint import BaseTLHint -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_packed_to_fp4_to_f16, - _tir_u8_to_f8_e4m3_to_f16, - _tir_packed_to_unsigned_convert_with_zeros, -) # GPU warp configuration for NVIDIA GPUs warp_size = 32 @@ -49,7 +41,7 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): - + hint_type: str = "MatmulDequantizeFineGrainedScheduler" def __init__(self): diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index 1ed2d1fd3..453a76924 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -238,7 +238,8 @@ def general_dequant_matmul( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) @@ -877,7 +878,8 @@ def general_dequant_matmul( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) From bdbc685190ace2b477a42834a91ca8839607a311 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 12 Dec 2024 16:07:12 +0000 Subject: [PATCH 10/17] test fix --- bitblas/base/arch/cuda.py | 2 ++ bitblas/ops/operator.py | 2 +- testing/python/operators/test_general_matmul_tile_schedule.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index fb8c2bbf1..5e8730d67 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -57,6 +57,7 @@ def has_mma_support(arch: TileDevice) -> bool: ("float16", "float16"), ] ampere_tensorcore_supported = [ + ("bfloat16", "float32"), ("float16", "float32"), ("float16", "float16"), ("int8", "int32"), @@ -65,6 +66,7 @@ def has_mma_support(arch: TileDevice) -> bool: ("int1", "int32"), ] ada_tensorcore_supported = [ + ("bfloat16", "float32"), ("float16", "float32"), ("float16", "float16"), ("int8", "int32"), diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 856bc809a..8df3b2a17 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -383,7 +383,7 @@ def hardware_aware_finetune( def get_profile_tensors(self, dynamic_symbolic_constraints: Optional[Dict] = None): if dynamic_symbolic_constraints is None: dynamic_symbolic_constraints = {} - func = retrieve_func_from_module(self.scheduled_ir_module) + func = self.prim_func or retrieve_func_from_module(self.scheduled_ir_module) device = self.arch.device def var_warpper(v): diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 660107647..38f49786e 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -11,7 +11,6 @@ from bitblas import set_log_level import numpy as np -print("bitblas. path is ", bitblas.__path__) np.random.seed(0) set_log_level(logging.DEBUG) From e30b64fbec98e14ca6181966dbdbee04aa9cc955 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 13 Dec 2024 16:48:08 +0000 Subject: [PATCH 11/17] remove pad factor --- .../general_matmul/tilelang/dense/matmul_tensorcore.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 86cfbb22b..0d178e36a 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -59,6 +59,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache @@ -1079,12 +1081,6 @@ def apply_config( block_N = block_col_warps * warp_col_tiles block_K = chunk - # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if storage_dtype == "float16" else 16 - - can_swizzle_a = block_K * DataType(storage_dtype).bits == 512 - apply_pad_a = not can_swizzle_a - is_a_smooth = self.is_a_smooth is_b_smooth = self.is_b_smooth @@ -1098,7 +1094,7 @@ def apply_config( ) else: A_shape = (M, K) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + A_shared_shape = (block_M, block_K) # Define the shapes of matrices and shared memory buffers B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) From 3b2646ac029c032df6f7fe3f7d4380123cd21fb3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 13 Dec 2024 16:52:28 +0000 Subject: [PATCH 12/17] introduce legalize dyanmic pass --- 3rdparty/tvm | 2 +- bitblas/testing/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 5ec6171dd..e95d30b76 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 5ec6171dd779b3bb80634fc950bd32b4bca12659 +Subproject commit e95d30b769419135c6ee7dbe7b04bf88ad55ab79 diff --git a/bitblas/testing/__init__.py b/bitblas/testing/__init__.py index b8442adcf..a147503f3 100644 --- a/bitblas/testing/__init__.py +++ b/bitblas/testing/__init__.py @@ -84,7 +84,7 @@ def torch_assert_close(tensor_a, if num_mismatched > max_allowed_mismatched: raise AssertionError( f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} " - f"({max_mismatched_ratio * 100:.2f}% allowed). " + f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%). " f"Greatest absolute difference: {diff.max().item()}, " f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") else: From 9462884515b8acf5631001a8ec7d76242824344c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 16 Dec 2024 04:31:01 +0000 Subject: [PATCH 13/17] update 3rdparty --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index e95d30b76..c0d4413cb 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e95d30b769419135c6ee7dbe7b04bf88ad55ab79 +Subproject commit c0d4413cba049a90cdc6ac6f5fb182c69b5bfb66 From d662748bf32843b93ecaaea9fba1de8413e2f491 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 16 Dec 2024 06:20:50 +0000 Subject: [PATCH 14/17] testfix --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c0d4413cb..4307584d7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c0d4413cba049a90cdc6ac6f5fb182c69b5bfb66 +Subproject commit 4307584d77fa920a1b3536bd2bc6f6d566406860 From 8c05d7b4a96018e6f762d3f71d33a432f8539a85 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 16 Dec 2024 06:29:31 +0000 Subject: [PATCH 15/17] test code commit --- bitblas/base/roller/hint.py | 4 + .../tilelang/dense/matmul_simt.py | 2 + .../tilelang/dense/matmul_tensorcore.py | 8 +- .../dequantize/matmul_dequantize_simt.py | 2 + .../matmul_dequantize_tensorcore.py | 2 + ...atmul_dequantize_tensorcore_finegrained.py | 117 +++++++++++++++--- ..._dequantize_tensorcore_weight_transform.py | 111 +++++++++++------ 7 files changed, 188 insertions(+), 58 deletions(-) diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py index 36a1fb7a0..8bb0f624c 100644 --- a/bitblas/base/roller/hint.py +++ b/bitblas/base/roller/hint.py @@ -173,6 +173,10 @@ def __init__(self) -> None: # Config for block reduction self.block_reduction_depth = None # type: int + # TL Specific + # Split-K factor for SM waste optimization + self.split_k_factor: int = 1 + # Experimental self._raxis_order = [] self._step = [] diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 9b60d547d..c80d10fbb 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -47,6 +47,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 0d178e36a..d448e61bc 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -706,8 +706,8 @@ def main( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) @@ -1183,8 +1183,8 @@ def main( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 4bdb26f6d..8fcb53f7f 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -66,6 +66,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index 9f0ac8165..4c91bc144 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -68,6 +68,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index a0242f99b..61e34f08a 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -12,7 +12,7 @@ from bitblas.ops.general_matmul.tirscript import ( matmul_dequantize_select_implementation,) from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitter) -from bitblas.base.arch import TileDevice +from bitblas.base.arch import TileDevice, is_cuda_arch from bitblas.base.roller.hint import Hint from bitblas.base.roller.rasterization import NoRasterization from bitblas.base.utils import get_roller_hints_from_func @@ -37,8 +37,9 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): chunk: int = 32 # Usually determines the K-dimension split size # Other Optimization Parameters - num_stages: int = 2 + num_stages: int = 0 enable_rasterization: bool = False # Enhance L2 Locality + split_k_factor: int = 1 # Split-K factor for SM waste optimization class TLHint(BaseTLHint): @@ -76,6 +77,7 @@ def from_roller_hint(cls, hint: Hint): tl_hint.chunk = chunk tl_hint.num_stages = num_stages tl_hint.enable_rasterization = enable_rasterization + tl_hint.split_k_factor = hint.split_k_factor return tl_hint @@ -88,6 +90,7 @@ def get_config_params(self): "chunk": self.chunk, "num_stages": self.num_stages, "enable_rasterization": self.enable_rasterization, + "split_k_factor": self.split_k_factor, } def __repr__(self): @@ -99,7 +102,8 @@ def __repr__(self): f"block_K={self.chunk}," f"threads={self.block_row_warps * self.block_col_warps * warp_size}," f"num_stages={self.num_stages}," - f"enable_rasterization={self.enable_rasterization}" + f"enable_rasterization={self.enable_rasterization}," + f"split_k_factor={self.split_k_factor}" "}") def get_hint_type(self) -> str: @@ -108,7 +112,61 @@ def get_hint_type(self) -> str: def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: + # Extract static shape dimensions for matrix multiplication + M, N, K = self.M, self.N, self.K + + # Determine if the shapes are statically defined (not dynamic) + is_static_shape = isinstance(M, int) and isinstance(N, int) and isinstance(K, int) + + # Check if the architecture is CUDA-based + arch_is_cuda = is_cuda_arch(self.arch) + + # If the architecture is CUDA and we have a static shape, proceed with optimization + if arch_is_cuda and is_static_shape: + sm_waste_threshold = 5e-2 # Allow at most 5% SM waste + num_sms = self.arch.compute_max_core # Get the maximum number of streaming multiprocessors + + # Compute block sizes based on the configuration + block_M = hint.block[0] # Block size in the M dimension + block_N = hint.block[1] # Block size in the N dimension + block_K = hint.rstep[0] # Block size in the K dimension + + # Calculate the grid dimensions in M and N directions + grid_m = M // block_M + grid_n = N // block_N + total_grids = grid_m * grid_n # Total number of grids + + # Initialize the split-k factor (used to distribute K-dimension work across blocks) + split_k_factor = 1 + + # Optimize the split-k factor to minimize SM waste + while True: + # Total grids after applying split-k + total_grids_split_k = total_grids * split_k_factor + + # Calculate the waste in SMs after split-k distribution + waste_sm_splitk = total_grids_split_k - (total_grids_split_k // + num_sms) * num_sms + waste_sm_splitk_ratio = waste_sm_splitk / total_grids_split_k + + # If the SM waste ratio is within the allowed threshold, stop optimization + if waste_sm_splitk_ratio <= sm_waste_threshold: + break + + # Double the split-k factor and check if the resulting K-dimension size is too large + expand_split_k = split_k_factor * 2 + if expand_split_k * block_K >= K: + break + + # Update the split-k factor for the next iteration + split_k_factor = expand_split_k + + # Note: The optimized split_k_factor can be stored or applied to the config if needed + hint.split_k_factor = split_k_factor + + # Convert the hint to a configuration object using the TLHint mapping config = self.TLHint.from_roller_hint(hint) + configs.append(config) return configs @@ -123,6 +181,7 @@ def with_default_config(self): num_stages = getattr(self, "num_stages", 2) enable_rasterization = getattr(self, "enable_rasterization", False) + split_k_factor = getattr(self, "split_k_factor", 1) return self.apply_config( block_row_warps=block_row_warps, @@ -132,6 +191,7 @@ def with_default_config(self): chunk=chunk, num_stages=num_stages, enable_rasterization=enable_rasterization, + split_k_factor=split_k_factor, ) def apply_config( @@ -143,6 +203,7 @@ def apply_config( chunk: Optional[int] = None, num_stages: Optional[int] = None, enable_rasterization=False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -204,6 +265,8 @@ def apply_config( Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) Bias_shape = (N,) + splitK = K // split_k_factor + A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @@ -253,7 +316,15 @@ def apply_config( chunk=chunk, ) - cache_write_required = self.check_require_cache() + enable_split_k = split_k_factor > 1 + + def check_require_cache(): + conditions = [False] + conditions.append(self.check_require_cache()) + conditions.append(enable_split_k) + return any(conditions) + + cache_write_required = check_require_cache() @T.prim_func def general_dequant_matmul( @@ -267,7 +338,8 @@ def general_dequant_matmul( Bias: T.Buffer(Bias_shape, in_dtype), ): with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor, + threads=threads) as (bx, by, bz): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) @@ -296,10 +368,13 @@ def general_dequant_matmul( T.clear(C_frag) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_shared) - T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy( + B[bx * block_N, + bz * (splitK // num_elems_per_byte) + ko * block_K // num_elems_per_byte], + B_shared) for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): @@ -359,6 +434,7 @@ def general_dequant_matmul( # Matrix multiplication on fragments mma_emitter.mma(A_frag, B_frag, C_frag) + if cache_write_required: # Store the result back to C shared memory mma_emitter.stmatrix( @@ -377,13 +453,24 @@ def general_dequant_matmul( ] += Bias[bx * block_N + j] # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + if enable_split_k: + for i, j in T.Parallel(block_M, block_N // 2): + T.atomic_addx2( + C[by * block_M + i, bx * block_N + j * 2], C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ]) + else: + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: # Store the result back to C global memory mma_emitter.stmatrix( diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index 453a76924..4e1a18697 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -50,6 +50,7 @@ def apply_config( chunk: Optional[int] = None, num_stages: Optional[int] = None, enable_rasterization=False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -140,7 +141,6 @@ def apply_config( micro_size_y, micro_size_k // num_elems_per_byte, ) - C_shared_shape = ( block_M // micro_size_x, block_N // micro_size_y, @@ -148,6 +148,8 @@ def apply_config( micro_size_y, ) + shared_scope = "shared" + import_source: Optional[str] = None func_name: str = "" if fast_decoding is True: @@ -187,12 +189,21 @@ def apply_config( num_elems_per_byte=num_elems_per_byte, ) + splitK = K // split_k_factor + enable_split_k = split_k_factor > 1 + + def check_require_cache(): + conditions = [False] + conditions.append(self.check_require_cache()) + conditions.append(enable_split_k) + return any(conditions) + + cache_write_required = check_require_cache() + vec_load_qb = 16 if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: vec_load_qb = block_N * block_K // num_elems_per_byte // threads - cache_write_required = self.check_require_cache() - @T.prim_func def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), @@ -205,10 +216,11 @@ def general_dequant_matmul( Bias: T.Buffer(Bias_shape, in_dtype), ): with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - C_shared = T.alloc_shared(C_shared_shape, out_dtype) + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor, + threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b // num_elems_per_byte), @@ -229,7 +241,7 @@ def general_dequant_matmul( T.clear(C_frag) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages): if is_a_smooth: for i, k, ii, kk in T.Parallel( @@ -238,28 +250,25 @@ def general_dequant_matmul( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, + kk] = A[by * (block_M // micro_size_x) + i, + bz * splitK + ko * (block_K // micro_size_k) + k, ii, + kk] else: - T.copy(A[by * block_M, ko * block_K], A_shared) - - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * vec_load_qb)): - for v in T.vectorized(0, vec_load_qb): - idx = i * threads * vec_load_qb + tx * vec_load_qb + v - vkk = idx % (micro_size_k // num_elems_per_byte) - vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y - vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( - block_K // micro_size_k) - vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // - (block_K // micro_size_k)) % ( - block_N // micro_size_y) - B_shared[vj, vk, vjj, vkk] = B[ - bx * (block_N // micro_size_y) + vj, - ko * (block_K // micro_size_k) + vk, - vjj, - vkk, - ] + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + (micro_size_k // num_elems_per_byte), + ): + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + bz * (splitK // micro_size_k) + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] # Perform the matrix multiplication on tensor core fragments for ki in T.serial(0, (block_K // micro_size_k)): @@ -336,14 +345,38 @@ def general_dequant_matmul( j % micro_size_y, ] += Bias[j] - # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + # Store results from shared memory to global memory + if enable_split_k: + # only for fp16 + if DataType(out_dtype).bits == 16: + for i, j in T.Parallel(block_M, block_N // 2): + m, n = by * block_M + i, bx * block_N + j * 2 + T.atomic_addx2( + C[m, n], C_shared[ + i // micro_size_x, + (j * 2) // micro_size_y, + i % micro_size_x, + (j * 2) % micro_size_y, + ]) + else: + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: mma_emitter.stmatrix( C_frag, @@ -878,8 +911,8 @@ def general_dequant_matmul( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) From cdd07532663389015dd62e06cd82460d0f0e1d14 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 16 Dec 2024 07:20:25 +0000 Subject: [PATCH 16/17] enhance typing and fix test for int4 dequantize gemm --- .../ops/general_flashatten/tilelang/flashatten.py | 4 ++-- .../tilelang/dense/matmul_tensorcore.py | 14 +++++++------- .../matmul_dequantize_tensorcore_finegrained.py | 7 +++++-- ...atmul_dequantize_tensorcore_weight_transform.py | 7 +++++-- .../operators/test_general_matmul_tilelang_impl.py | 4 ++-- .../test_general_matmul_tilelang_kernel.py | 14 +++++++------- 6 files changed, 28 insertions(+), 22 deletions(-) diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py index 9d76c6dfd..4470e5a51 100644 --- a/bitblas/ops/general_flashatten/tilelang/flashatten.py +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -60,7 +60,7 @@ def apply_config( block_N=64, num_stages=2, threads=128, - enable_rasterization=False, + enable_rasterization: bool =False, ): batch, heads, seq_len, dim = self.batch, self.heads, self.seq_len, self.dim trans_K = self.trans_K @@ -185,7 +185,7 @@ def flashatten_blocked( num_stages=2, threads=128, is_causal=False, - enable_rasterization=False, # Enhance L2 Locality + enable_rasterization: bool =False, # Enhance L2 Locality ): Q_shape = (batch, seq_len, heads, dim) if not trans_Q else (batch, dim, heads, seq_len) K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index d448e61bc..d465e9183 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -381,7 +381,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool =False, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -578,7 +578,7 @@ def apply_config( warp_col_tiles=32, chunk=16, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): M = self.maybe_dynamic(self.M, "m") @@ -850,7 +850,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool =False, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -1061,7 +1061,7 @@ def apply_config( warp_col_tiles=32, chunk=16, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): M = self.maybe_dynamic(self.M, "m") @@ -1264,7 +1264,7 @@ def matmul_blocked( accum_dtype="float16", num_stages=2, threads=128, - enable_rasterization=False, # Enhance L2 Locality + enable_rasterization: bool =False, # Enhance L2 Locality ): A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) @@ -1316,7 +1316,7 @@ def matmul_macro_tensorcore( warp_col_tiles, chunk, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): assert trans_A is False, "Currently only support Matrix A is not transposed" assert trans_B is True, "Currently only support Matrix B is transposed" @@ -1445,7 +1445,7 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( warp_col_tiles, chunk, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): assert trans_A is False, "Currently only support Matrix A is not transposed" assert trans_B is True, "Currently only support Matrix B is transposed" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index 61e34f08a..9656bc519 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -202,7 +202,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool =False, split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" @@ -550,7 +550,8 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool =False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -558,6 +559,8 @@ def apply_config( assert warp_col_tiles is not None, "warp_col_tiles is required" assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" + # unused variable + split_k_factor = split_k_factor M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index 4e1a18697..cff2800a8 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -49,7 +49,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool =False, split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" @@ -727,7 +727,8 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool =False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -735,6 +736,8 @@ def apply_config( assert warp_col_tiles is not None, "warp_col_tiles is required" assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" + # unused variable + split_k_factor = split_k_factor M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 5deaeaf41..38e9c1b97 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -86,7 +86,7 @@ def assert_matmul_macro_tensorcore_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = matmul_macro_tensorcore( M=M, @@ -144,7 +144,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = matmul_macro_tensorcore_weight_propagation_level_ldmatrix( M=M, diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 676543304..6e1f1e445 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -90,7 +90,7 @@ def assert_matmul_blocked_apply_config_correctness( accum_dtype="float16", num_stages=2, threads=128, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulBlockScheduler( M=M, @@ -196,7 +196,7 @@ def assert_matmul_fine_grained_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulFineGrainScheduler( @@ -316,7 +316,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulWeightPropagationScheduler( @@ -438,7 +438,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulINT4FineGrainScheduler( @@ -566,7 +566,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulINT4WeightPropagationScheduler( @@ -737,7 +737,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulINT4DequantizeFineGrainedScheduler( M=M, @@ -941,7 +941,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool =False, ): matmul = MatmulINT4DequantizeWeightPropagationScheduler( M=M, From b9c343c62f74752319d295f9c73fac6e058706d3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 16 Dec 2024 07:21:25 +0000 Subject: [PATCH 17/17] lint fix --- .../ops/general_flashatten/tilelang/flashatten.py | 4 ++-- .../tilelang/dense/matmul_tensorcore.py | 14 +++++++------- .../matmul_dequantize_tensorcore_finegrained.py | 4 ++-- ...atmul_dequantize_tensorcore_weight_transform.py | 4 ++-- .../operators/test_general_matmul_tilelang_impl.py | 4 ++-- .../test_general_matmul_tilelang_kernel.py | 14 +++++++------- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py index 4470e5a51..d2a5b2857 100644 --- a/bitblas/ops/general_flashatten/tilelang/flashatten.py +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -60,7 +60,7 @@ def apply_config( block_N=64, num_stages=2, threads=128, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): batch, heads, seq_len, dim = self.batch, self.heads, self.seq_len, self.dim trans_K = self.trans_K @@ -185,7 +185,7 @@ def flashatten_blocked( num_stages=2, threads=128, is_causal=False, - enable_rasterization: bool =False, # Enhance L2 Locality + enable_rasterization: bool = False, # Enhance L2 Locality ): Q_shape = (batch, seq_len, heads, dim) if not trans_Q else (batch, dim, heads, seq_len) K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index d465e9183..4e56a15f3 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -381,7 +381,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -578,7 +578,7 @@ def apply_config( warp_col_tiles=32, chunk=16, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): M = self.maybe_dynamic(self.M, "m") @@ -850,7 +850,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -1061,7 +1061,7 @@ def apply_config( warp_col_tiles=32, chunk=16, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): M = self.maybe_dynamic(self.M, "m") @@ -1264,7 +1264,7 @@ def matmul_blocked( accum_dtype="float16", num_stages=2, threads=128, - enable_rasterization: bool =False, # Enhance L2 Locality + enable_rasterization: bool = False, # Enhance L2 Locality ): A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) @@ -1316,7 +1316,7 @@ def matmul_macro_tensorcore( warp_col_tiles, chunk, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): assert trans_A is False, "Currently only support Matrix A is not transposed" assert trans_B is True, "Currently only support Matrix B is transposed" @@ -1445,7 +1445,7 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( warp_col_tiles, chunk, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): assert trans_A is False, "Currently only support Matrix A is not transposed" assert trans_B is True, "Currently only support Matrix B is transposed" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index 9656bc519..ebbdafcc6 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -202,7 +202,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization: bool =False, + enable_rasterization: bool = False, split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" @@ -550,7 +550,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization: bool =False, + enable_rasterization: bool = False, split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index cff2800a8..eb1b5c93e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -49,7 +49,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization: bool =False, + enable_rasterization: bool = False, split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" @@ -727,7 +727,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization: bool =False, + enable_rasterization: bool = False, split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 38e9c1b97..e412e2298 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -86,7 +86,7 @@ def assert_matmul_macro_tensorcore_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = matmul_macro_tensorcore( M=M, @@ -144,7 +144,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = matmul_macro_tensorcore_weight_propagation_level_ldmatrix( M=M, diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 6e1f1e445..e89701af8 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -90,7 +90,7 @@ def assert_matmul_blocked_apply_config_correctness( accum_dtype="float16", num_stages=2, threads=128, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulBlockScheduler( M=M, @@ -196,7 +196,7 @@ def assert_matmul_fine_grained_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulFineGrainScheduler( @@ -316,7 +316,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulWeightPropagationScheduler( @@ -438,7 +438,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulINT4FineGrainScheduler( @@ -566,7 +566,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulINT4WeightPropagationScheduler( @@ -737,7 +737,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulINT4DequantizeFineGrainedScheduler( M=M, @@ -941,7 +941,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization: bool =False, + enable_rasterization: bool = False, ): matmul = MatmulINT4DequantizeWeightPropagationScheduler( M=M,