diff --git a/3rdparty/tvm b/3rdparty/tvm index 8e2f4bf39..5ec6171dd 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8e2f4bf391ef4a4c48f73a0e05a31b84047c16d9 +Subproject commit 5ec6171dd779b3bb80634fc950bd32b4bca12659 diff --git a/benchmark/operators/benchmark_bitblas_matmul.py b/benchmark/operators/benchmark_bitblas_matmul.py index fb3927bdb..9a40db4ad 100644 --- a/benchmark/operators/benchmark_bitblas_matmul.py +++ b/benchmark/operators/benchmark_bitblas_matmul.py @@ -1,141 +1,67 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +import bitblas from bitblas.utils.target_detector import auto_detect_nvidia_target from bitblas import Matmul, MatmulConfig import argparse +import json - -# Initialize the parser -parser = argparse.ArgumentParser( - description="Benchmark BitBLAS int4 on a specific target." -) - -# Add arguments to the parser -parser.add_argument( - "--target", - type=str, - default=auto_detect_nvidia_target(), - help="Specify the target device for benchmarking." -) -parser.add_argument( - "--group_size", - type=int, - default=None, - help="Group size for grouped quantization." -) -parser.add_argument( - "--A_dtype", - type=str, - default="float16", - choices=["float16", "float32", "float64", "int32", "int8"], # Assuming these are the valid choices - help="Data type of activation A." -) -parser.add_argument( - "--W_dtype", - type=str, - default="int4", - choices=["float16", "float32", "float64", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1"], # Assuming these are the valid choices - help="Data type of weight W." -) -parser.add_argument( - "--accum_dtype", - type=str, - default="float16", - choices=["float16", "int32"], # Assuming these are the valid choices - help="Data type for accumulation." -) -parser.add_argument( - "--out_dtype", - type=str, - default="float16", - choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices - help="Data type for output." -) -parser.add_argument( - "--layout", - type=str, - default="nt", - choices=["nt", "nn"], # Assuming these are the valid choices - help="Matrix layout, 'nt' for non-transpose A and transpose W." -) -parser.add_argument( - "--with_bias", - action="store_true", - help="Include bias in the benchmark." -) -parser.add_argument( - "--with_scaling", - action="store_true", - help="Include scaling factor in the quantization." -) -parser.add_argument( - "--with_zeros", - action="store_true", - help="Include zeros in the quantization." -) -parser.add_argument( - "--zeros_mode", - type=str, - default=None, - choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable - help="Specify the mode for calculating zeros." -) - -# Parse the arguments -args = parser.parse_args() - -# Assign arguments to variables -target = args.target -group_size = args.group_size -A_dtype = args.A_dtype -W_dtype = args.W_dtype -accum_dtype = args.accum_dtype -out_dtype = args.out_dtype -layout = args.layout -with_bias = args.with_bias -group_size = args.group_size -with_scaling = args.with_scaling -with_zeros = args.with_zeros -zeros_mode = args.zeros_mode - -test_shapes = [ - # square test - (MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # BLOOM-176B - (MatmulConfig, Matmul, (1, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # OPT-65B - (MatmulConfig, Matmul, (1, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # LLAMA-70B/65B - (MatmulConfig, Matmul, (1, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (1, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - - # square test - (MatmulConfig, Matmul, (16384, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # BLOOM-176B - (MatmulConfig, Matmul, (8192, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # OPT-65B - (MatmulConfig, Matmul, (8192, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - # # LLAMA-70B/65B - (MatmulConfig, Matmul, (8192, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), - (MatmulConfig, Matmul, (8192, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), -] +bitblas.set_log_level("DEBUG") +# Initialize the parser +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking.") + +parser.add_argument( + "--backend", + type=str, + default="tir", + choices=["tir", "tl"], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros.") + +parser.add_argument("--verbose", type=bool, default=True, help="Enable verbose logging.") + +# [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode] +default_test_shapes = json.dumps([ + # ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] + [ + "MatmulConfig", "Matmul", + [ + 16384, 16384, 16384, "float16", "float16", "float16", "float16", "nt", False, None, + False, False, None + ] + ] +]) + +parser.add_argument( + "--test_shapes", + type=str, + default=default_test_shapes, + help="JSON string defining test shapes. Example format: '[[\"MatmulConfig\", \"Matmul\", [1,16384,16384,\"float16\",\"int4\",\"float16\",\"float16\",\"nt\",false,null,false,false,null]]]'" +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +backend = args.backend +verbose = args.verbose + +parsed_test_shapes = json.loads(args.test_shapes) +name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul} + +test_shapes = [] +for item in parsed_test_shapes: + config_class_name, operator_class_name, input_args = item + config_class = name_to_class[config_class_name] + operator_class = name_to_class[operator_class_name] + test_shapes.append((config_class, operator_class, tuple(input_args))) benchmark_sets = [] benchmark_sets.extend(test_shapes) @@ -145,12 +71,17 @@ benchmark_results = {} for config, operator, input_args in benchmark_sets: config = config(*input_args) - matmul = operator(config, target=target, enable_tuning=True) - kernel_latency = matmul.profile_latency() - if matmul.input_transform is not None: - kernel_latency += matmul.ladder_permutate_a.profile_latency() - - print("Time cost is: {:.3f} ms".format(kernel_latency)) + print(f"Running benchmark for {operator.__name__} with config: {config}") + op_inst = operator(config, target=target, enable_tuning=True, backend=backend) + kernel_latency = op_inst.profile_latency() + if op_inst.input_transform is not None: + kernel_latency += op_inst.ladder_permutate_a.profile_latency() + + print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency)) + + if verbose: + print(op_inst.scheduled_ir_module) + print(op_inst.get_source()) profile_config = { f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { @@ -160,7 +91,7 @@ benchmark_results.update(profile_config) -# Define headers for the table +# Define headers for the table headers = [ "PrimFunc", "Input Arguments", @@ -174,7 +105,9 @@ input_args = "-".join(args[1:]) col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0]) col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1])) - col_widths[2] = max(max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, col_widths[2]) + col_widths[2] = max( + max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, + col_widths[2]) break for i, header in enumerate(headers): @@ -193,4 +126,4 @@ input_args, f"{values['BitBLAS_top20_latency']:.3f} ms", ] - print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) + "\n") + print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)])) diff --git a/bitblas/base/arch/__init__.py b/bitblas/base/arch/__init__.py index c7d7af31b..dd931f617 100644 --- a/bitblas/base/arch/__init__.py +++ b/bitblas/base/arch/__init__.py @@ -1,15 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .arch_base import TileDevice -from .cuda import * -from .cpu import * -from .cdna import * +from .cuda import CUDA +from .cpu import CPU +from .cdna import CDNA from typing import Union +from tvm.target import Target -def get_arch(target: Union[str, tvm.target.Target] = "cuda") -> TileDevice: +def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: if isinstance(target, str): - target = tvm.target.Target(target) + target = Target(target) if target.kind.name == "cuda": return CUDA(target) @@ -27,57 +28,14 @@ def auto_infer_current_arch() -> TileDevice: return get_arch("cuda") -def is_cpu_arch(arch: TileDevice) -> bool: - return isinstance(arch, CPU) - - -def is_cuda_arch(arch: TileDevice) -> bool: - return isinstance(arch, CUDA) - - -def is_ampere_arch(arch: TileDevice) -> bool: - conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) - return all(conditions) - - -def is_volta_arch(arch: TileDevice) -> bool: - conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 70) - conditions.append(arch.sm_version < 80) - return all(conditions) - - -def is_cdna_arch(arch: TileDevice) -> bool: - return isinstance(arch, CDNA) - - -def has_mma_support(arch: TileDevice) -> bool: - conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80) - return all(conditions) - - -def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: - volta_tensorcore_supported = [ - ("float16", "float32"), - ("float16", "float16"), - ] - ampere_tensorcore_supported = [ - ("float16", "float32"), - ("float16", "float16"), - ("int8", "int32"), - ("int4", "int32"), - ("int2", "int32"), - ("int1", "int32"), - ] - - if is_volta_arch(arch): - return (in_dtype, accum_dtype) in volta_tensorcore_supported - elif is_ampere_arch(arch): - return (in_dtype, accum_dtype) in ampere_tensorcore_supported - else: - raise ValueError(f"Unsupported architecture: {arch}") +from .cpu import is_cpu_arch # noqa: F401 +from .cuda import ( + is_cuda_arch, # noqa: F401 + is_volta_arch, # noqa: F401 + is_ampere_arch, # noqa: F401 + is_ada_arch, # noqa: F401 + is_hopper_arch, # noqa: F401 + is_tensorcore_supported_precision, # noqa: F401 + has_mma_support, # noqa: F401 +) +from .cdna import is_cdna_arch # noqa: F401 diff --git a/bitblas/base/arch/cdna.py b/bitblas/base/arch/cdna.py index f6805dc98..cb49041db 100644 --- a/bitblas/base/arch/cdna.py +++ b/bitblas/base/arch/cdna.py @@ -7,6 +7,10 @@ from typing import List, Union +def is_cdna_arch(arch: TileDevice) -> bool: + return isinstance(arch, CDNA) + + class CDNA(TileDevice): def __init__(self, target: Union[Target, str]): diff --git a/bitblas/base/arch/cpu.py b/bitblas/base/arch/cpu.py index 65592cc7d..09a6391f1 100644 --- a/bitblas/base/arch/cpu.py +++ b/bitblas/base/arch/cpu.py @@ -6,6 +6,10 @@ from .arch_base import TileDevice +def is_cpu_arch(arch: TileDevice) -> bool: + return isinstance(arch, CPU) + + # For LLVM Backend, we do not provide the detailed information of the CPU # As the LLVM backend do not required tuning, just maintain the consistency class CPU(TileDevice): diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index 29c65e4a4..5e8730d67 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm.target import Target from .arch_base import TileDevice -from typing import List, Dict, Union +from typing import List, Union def check_sm_version(arch: str) -> int: @@ -12,17 +12,96 @@ def check_sm_version(arch: str) -> int: return int(sm_version) if sm_version.isdigit() else -1 +def is_cuda_arch(arch: TileDevice) -> bool: + return isinstance(arch, CUDA) + + +def is_volta_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 70) + conditions.append(arch.sm_version < 80) + return all(conditions) + + +def is_ampere_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) + return all(conditions) + + +def is_ada_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version == 89) + return all(conditions) + + +def is_hopper_arch(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version == 90) + return all(conditions) + + +def has_mma_support(arch: TileDevice) -> bool: + conditions = [True] + conditions.append(is_cuda_arch(arch)) + conditions.append(arch.sm_version >= 80) + return all(conditions) + + +volta_tensorcore_supported = [ + ("float16", "float32"), + ("float16", "float16"), +] +ampere_tensorcore_supported = [ + ("bfloat16", "float32"), + ("float16", "float32"), + ("float16", "float16"), + ("int8", "int32"), + ("int4", "int32"), + ("int2", "int32"), + ("int1", "int32"), +] +ada_tensorcore_supported = [ + ("bfloat16", "float32"), + ("float16", "float32"), + ("float16", "float16"), + ("int8", "int32"), + ("e5m2_float8", "float32"), + ("e4m3_float8", "float32"), +] +hopper_tensorcore_supported = ada_tensorcore_supported + + +# TODO(lei): we should consider the dtype of the input a and b +# instead of assuming both a and b share the same dtype. +# As the tensorcore may supports e4m3_float8 * e5m2_float8 +def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: + + if is_volta_arch(arch): + return (in_dtype, accum_dtype) in volta_tensorcore_supported + elif is_ampere_arch(arch): + return (in_dtype, accum_dtype) in ampere_tensorcore_supported + elif is_ada_arch(arch): + return (in_dtype, accum_dtype) in ada_tensorcore_supported + elif is_hopper_arch(arch): + return (in_dtype, accum_dtype) in hopper_tensorcore_supported + else: + raise ValueError(f"Unsupported architecture: {arch}") + + class TensorInstruction(object): def __init__( self, name: str, - intrin_group: Dict, shape: List[int], ): self.name: str = name - self.intrin_group: Dict = intrin_group - # only maintain the shape of M and N + # only hold the shape of M and N self.shape: List[int] = shape @@ -58,13 +137,11 @@ def __init__(self, target: Union[Target, str]): self.available_tensor_instructions: List[TensorInstruction] = None def get_avaliable_tensorintrin_shapes(self): - from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group, get_mma_intrin_group - self.available_tensor_instructions = ( - TensorInstruction("mma", get_mma_intrin_group, [16, 16]), - TensorInstruction("wmma", get_wmma_intrin_group, [16, 16]), + TensorInstruction("mma", [16, 16]), + TensorInstruction("wmma", [16, 16]), ) return [t.shape for t in self.available_tensor_instructions] def __repr__(self): - return f"CUDA({self.target})" \ No newline at end of file + return f"CUDA({self.target})" diff --git a/bitblas/base/base_scheduler.py b/bitblas/base/base_scheduler.py index 37b75785a..d901a4192 100644 --- a/bitblas/base/base_scheduler.py +++ b/bitblas/base/base_scheduler.py @@ -113,6 +113,9 @@ def apply_config( ) -> PrimFunc: pass + def get_hint_type(self) -> str: + raise NotImplementedError("Get Hint type is not implemented") + def serialize_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]: # Convert Roller Hints to TileLang Hints raise NotImplementedError("Serialization of hints to configs is not implemented") diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 722095657..ef2dc8587 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -328,9 +328,9 @@ def _score(node, thread): # small is better # TODO: This is a dummy mul which avoid reusing some shared memory. # Should be removed in the future. if td.smem_cost > (self.arch.smem_cap): - info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ + debug_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ " use dynamic shared memory." - logger.info(info_message) + logger.debug(debug_message) codegen_dict.shared_scope = "shared.dyn" codegen_dict.shared_scope = "shared.dyn" diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 3a5b6a2e8..92822d1bd 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -222,7 +222,7 @@ def apply_and_build_parallel(func, arch, num_repeats=3, max_workers=10, - timeout=30, + timeout=60, data_distribution="uniform") -> CompileResult: cpresults = [] diff --git a/bitblas/gpu/__init__.py b/bitblas/gpu/__init__.py index 72192f60e..750c42047 100644 --- a/bitblas/gpu/__init__.py +++ b/bitblas/gpu/__init__.py @@ -4,6 +4,7 @@ GPU-generic schedule rules. For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead """ + from .fallback import Fallback # noqa: F401 from .element_wise import ElementWise # noqa: F401 from .gemv import GEMV # noqa: F401 diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 1f596ef9a..d43d95ffa 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -18,6 +18,7 @@ ) from tvm.target.target import Target from tvm.tir.stmt_functor import pre_order_visit +from bitblas.base.arch import get_arch, is_tensorcore_supported_precision import logging logger = logging.getLogger(__name__) @@ -527,8 +528,6 @@ def get_tensorized_func_and_tags( skip_normalize: bool = False, allow_gemv: bool = False, ) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ @@ -648,18 +647,9 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: - # TODO(lei): we should consider the dtype of the input a and b - # instead of assuming both a and b share the same dtype. - # As the tensorcore may supports e4m3_float8 * e5m2_float8 in_dtype, out_dtype = get_in_out_dtypes(block_stmt) - try: - _ = get_mma_intrin_group( - a_dtype=in_dtype, - b_dtype=in_dtype, - out_dtype=out_dtype, - ) - except Exception: - logger.debug("Cannot find the corresponding mma intrin group") + if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): + logger.debug("The input and output dtype is not supported by tensorcore") return func, None # reindex and transform functions @@ -697,7 +687,7 @@ def check_last_trait(region: List[Range]): def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + from bitblas.tl.mma_layout import ( # pylint: disable=import-outside-toplevel ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, ) diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index cbd5a4e3f..190360c8f 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -24,6 +24,8 @@ class GemvFineGrainSIMTScheduler(MatmulSIMTBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "GemvFineGrainSIMTScheduler" + def __init__(self): super().__init__() @@ -56,6 +58,9 @@ def __repr__(self): f"reduce_thread: {self.reduce_thread}, " "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 0881c65f6..f81dd3d1d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -146,8 +146,13 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.matmul_int4_fine_grain_scheduler, self.matmul_int4_weight_propagation_scheduler, ]: - if isinstance(hint, scheduler.TLHint): - return scheduler + try: + scheduler_hint_type = scheduler.get_hint_type() + if scheduler_hint_type == hint.hint_type: + return scheduler + except NotImplementedError as e: + raise ValueError(f"get_hint_type() is not implemented for {type(scheduler)}") from e + raise ValueError(f"Unsupported hint type: {type(hint)}") def with_default_config(self, arch: Optional[TileDevice] = None) -> PrimFunc: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 782c9fdb7..9b60d547d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -74,6 +74,8 @@ class MatmulFineGrainSIMTScheduler(MatmulSIMTBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulFineGrainSIMTScheduler" + def __init__(self): super().__init__() @@ -119,6 +121,9 @@ def __repr__(self): f"chunk: {self.chunk}" "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 653ad894c..0d178e36a 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -59,6 +59,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache @@ -85,6 +87,8 @@ class MatmulBlockScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): + hint_type = "MatmulBlockScheduler" + def __init__(self): super().__init__() @@ -161,6 +165,9 @@ def get_configs_sm80(self): configs = [{**c, 'num_stages': num_stages} for c in configs] return configs + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: @@ -274,6 +281,8 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulFineGrainScheduler" + def __init__(self): super().__init__() @@ -332,6 +341,9 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: @@ -555,6 +567,9 @@ class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler): # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + class TLHint(MatmulFineGrainScheduler.TLHint): + hint_type: str = "MatmulWeightPropagationScheduler" + def apply_config( self, block_row_warps=2, @@ -573,6 +588,7 @@ def apply_config( trans_A, trans_B = self.trans_A, self.trans_B in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype with_bias = self.with_bias + input_transform_kind, weight_transform_kind = self.input_transform_kind, self.weight_transform_kind # Calculate the micro size per warp using a helper function micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) @@ -581,21 +597,28 @@ def apply_config( block_N = block_col_warps * warp_col_tiles block_K = chunk - # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if in_dtype == "float16" else 16 - - can_swizzle_a = block_K * DataType(in_dtype).bits == 512 - apply_pad_a = not can_swizzle_a - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) # Define the shapes of matrices and shared memory buffers - A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) C_shape = (M, N) Bias_shape = (N,) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -634,7 +657,8 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, - transform_kind_b=self.weight_transform_kind, + transform_kind_a=input_transform_kind, + transform_kind_b=weight_transform_kind, ) cache_write_required = self.check_require_cache() @@ -663,8 +687,8 @@ def main( # Apply memory layout optimizations T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), + A_shared: make_swizzle_layout(A_shared, is_smooth=is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_smooth=is_b_smooth), }) T.use_swizzle(panel_size=10, enable=enable_rasterization) @@ -675,9 +699,17 @@ def main( # Main matrix multiplication pipeline with multiple stages for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load A matrix into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B matrix into shared memory for j, k, jj, kk in T.Parallel( @@ -753,17 +785,22 @@ def main( return self.post_process(main) - def __post_init__(self): - # Validate the matrix transpose settings - assert self.trans_A is False, "Currently only support Matrix A not transposed" - assert self.trans_B is True, "Currently only support Matrix B transposed" + @property + def is_a_smooth(self): + return self.input_transform_kind > TransformKind.NonTransform - return + @property + def is_b_smooth(self): + return self.weight_transform_kind > TransformKind.NonTransform @dataclass class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): + @dataclass + class TLHint(MatmulFineGrainScheduler.TLHint): + hint_type: str = "MatmulINT4FineGrainScheduler" + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -901,7 +938,7 @@ def main( # Apply memory layout optimizations T.annotate_layout({ A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared, is_smooth=True), + B_shared: make_swizzle_layout(B_shared), }) # Optional rasterization for L2 locality enhancement @@ -972,6 +1009,9 @@ def __post_init__(self): @dataclass class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): + class TLHint(MatmulWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4WeightPropagationScheduler" + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -1041,16 +1081,23 @@ def apply_config( block_N = block_col_warps * warp_col_tiles block_K = chunk - # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if storage_dtype == "float16" else 16 + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth - can_swizzle_a = block_K * DataType(storage_dtype).bits == 512 - apply_pad_a = not can_swizzle_a + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) # Define the shapes of matrices and shared memory buffers - A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -1089,6 +1136,7 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, + transform_kind_a=self.input_transform_kind, transform_kind_b=self.weight_transform_kind, ) @@ -1116,8 +1164,8 @@ def main( # Apply memory layout optimizations T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - # B_shared: make_swizzle_layout(B_shared), + A_shared: make_swizzle_layout(A_shared, is_smooth=is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_smooth=is_b_smooth), }) T.use_swizzle(panel_size=10, enable=enable_rasterization) @@ -1128,9 +1176,17 @@ def main( # Main matrix multiplication pipeline with multiple stages for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load A matrix into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B matrix into shared memory for j, k, jj, kk in T.Parallel( @@ -1398,17 +1454,11 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( block_N = block_col_warps * warp_col_tiles block_K = chunk - # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if in_dtype == "float16" else 16 - - can_swizzle_a = block_K * DataType(in_dtype).bits == 512 - apply_pad_a = not can_swizzle_a - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + A_shared_shape = (block_M, block_K) B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py index b14a236ae..447e7e47b 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py @@ -31,6 +31,8 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulFineGrainScheduler" + def __init__(self): super().__init__() diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/base.py b/bitblas/ops/general_matmul/tilelang/dequantize/base.py index efe5627de..3d27f0703 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/base.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/base.py @@ -76,8 +76,6 @@ def __post_init__(self): # Validate the matrix transpose settings assert (self.trans_A is False), "Currently only support Matrix A not transposed" assert (self.trans_B is True), "Currently only support Matrix B transposed" - assert (self.input_transform_kind == TransformKind.NonTransform - ), "Currently only support NonTransform for input" # Legalize group_size if self.with_scaling and self.group_size == -1: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 37216ecfc..0d838661c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -26,6 +26,8 @@ class GemvDequantizeSIMTScheduler(MatmulDequantizeSIMTBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "GemvDequantizeSIMTScheduler" + def __init__(self): super().__init__() @@ -58,6 +60,9 @@ def __repr__(self): f"reduce_thread: {self.reduce_thread}, " "}") + def get_hint_type(self): + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 5a690b970..9716ac075 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -153,10 +153,21 @@ def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: for scheduler in [ + self.gemv_dequantize_simt_scheduler, + self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, + self.matmul_dequantize_fine_grained_scheduler, + self.matmul_dequantize_weight_propagation_scheduler, + self.matmul_int4_dequantize_fine_grain_scheduler, + self.matmul_int4_dequantize_weight_propagation_scheduler, ]: - if isinstance(hint, scheduler.TLHint): - return scheduler + try: + scheduler_hint_type = scheduler.get_hint_type() + if scheduler_hint_type == hint.hint_type: + return scheduler + except NotImplementedError as e: + raise ValueError(f"get_hint_type() is not implemented for {type(scheduler)}") from e + raise ValueError(f"Unsupported hint type: {type(hint)}") def with_default_config(self, arch: Optional[TileDevice] = None) -> PrimFunc: @@ -242,5 +253,15 @@ def is_dynamic(self) -> bool: M, N, K = self.M, self.N, self.K return ((not isinstance(M, int)) or (not isinstance(N, int)) or (not isinstance(K, int))) + def __post_init__(self): + # Validate the matrix transpose settings + assert (self.trans_A is False), "Currently only support Matrix A not transposed" + assert (self.trans_B is True), "Currently only support Matrix B transposed" + + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) + return + __all__ = ["MatmulDequantizeScheduler"] diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 54a0c54e7..4bdb26f6d 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -443,6 +443,8 @@ class MatmulDequantizeSIMTScheduler(MatmulDequantizeSIMTBaseScheduler): class TLHint(BaseTLHint): + hint_type = "MatmulDequantizeSIMTScheduler" + def __init__(self): super().__init__() @@ -488,6 +490,9 @@ def __repr__(self): f"chunk: {self.chunk}" "}") + def get_hint_type(self) -> str: + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index 85188a4b6..9f0ac8165 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -10,8 +10,7 @@ from bitblas.base.utils import get_roller_hints_from_func from dataclasses import dataclass from bitblas.ops.general_matmul.tirscript import ( - matmul_dequantize_select_implementation, -) + matmul_dequantize_select_implementation,) from bitblas.base.operator_common import QuantizationMemoryStage from bitblas.tl.base_hint import BaseTLHint from bitblas.quantization import ( @@ -99,23 +98,17 @@ def naive_cast_dequant(x): return x.astype(in_dtype) if with_zeros and zeros_mode == "quantized": - dequant_func = _tir_packed_to_unsigned_convert_with_zeros( - storage_type, storage_nbit - ) + dequant_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) elif source_format == "uint": if num_bits == 8: # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant else: - dequant_func = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit - ) + dequant_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) elif source_format == "int": if num_bits == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - dequant_func = _tir_packed_int_to_int_convert( - storage_type, storage_nbit - ) + dequant_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) elif num_bits == 8: # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant @@ -198,16 +191,15 @@ def _normal_dequant_impl( vi = index // stride_k vj = index % stride_k dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit - )( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) else: raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") @@ -230,19 +222,14 @@ def _normal_dequant_impl( compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) - * scale_local[v // group_size] - ) + ) * scale_local[v // group_size]) elif zeros_mode == "original": - dequant_weight_local[v] = ( - self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - - zeros_local[v // group_size] - ) * scale_local[v // group_size] + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) - zeros_local[v // group_size]) * scale_local[v // group_size] elif zeros_mode == "rescale": dequant_weight_local[v] = ( self._decode_func( @@ -250,20 +237,15 @@ def _normal_dequant_impl( compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) - * scale_local[v // group_size] - - zeros_local[v // group_size] - ) + ) * scale_local[v // group_size] - zeros_local[v // group_size]) elif zeros_mode == "quantized": - dequant_weight_local[v] = ( - self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - zero=dequant_qzeros_local[v // group_size], - dtype=in_dtype, - ) - ) * scale_local[v // group_size] + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros_local[v // group_size], + dtype=in_dtype, + )) * scale_local[v // group_size] else: raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") @@ -343,16 +325,15 @@ def _normal_fast_dequant_impl( vi = index // stride_k vj = index % stride_k dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit - )( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) else: raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") @@ -440,24 +421,10 @@ def dequantize( threads, ) else: - return self._normal_dequant( - compressed_weight_local, - scale_local, - zeros_local, - dequant_qzeros_local, - dequant_weight_local, - scale_buffer, - zeros_buffer, - qzeros_buffer, - local_size, - pid_n, - tx, - k, - i, - stride_n, - stride_k, - threads - ) + return self._normal_dequant(compressed_weight_local, scale_local, zeros_local, + dequant_qzeros_local, dequant_weight_local, scale_buffer, + zeros_buffer, qzeros_buffer, local_size, pid_n, tx, k, i, + stride_n, stride_k, threads) @property def num_elems_per_byte(self): @@ -478,6 +445,7 @@ class MatmulDequantizeBlockScheduler(MatmulDequantizeBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): + hint_type: str = "MatmulDequantizeBlockScheduler" def __init__(self): super().__init__() @@ -521,16 +489,17 @@ def get_config_params(self): } def __repr__(self): - return ( - "{" - f"block_M={self.block_M}," - f"block_N={self.block_N}," - f"block_K={self.block_K}," - f"num_stages={self.num_stages}," - f"threads={self.threads}," - f"enable_rasterization={self.enable_rasterization}" - "}" - ) + return ("{" + f"block_M={self.block_M}," + f"block_N={self.block_N}," + f"block_K={self.block_K}," + f"num_stages={self.num_stages}," + f"threads={self.threads}," + f"enable_rasterization={self.enable_rasterization}" + "}") + + def get_hint_type(self) -> str: + return self.TLHint.hint_type def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] @@ -574,9 +543,7 @@ def apply_config( M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K - assert isinstance(N, int) and isinstance( - K, int - ), "Do not support dynamic N and K Currently" + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" trans_A, trans_B = self.trans_A, self.trans_B @@ -648,18 +615,17 @@ def apply_config( @T.prim_func def general_shared_dequant_matmul( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - LUT: T.Buffer(LUT_shape, in_dtype), - Scale: T.Buffer(Scale_shape, in_dtype), - Qzeros: T.Buffer(Qzeros_shape, storage_dtype), - Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), - Bias: T.Buffer(Bias_shape, in_dtype), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + LUT: T.Buffer(LUT_shape, in_dtype), + Scale: T.Buffer(Scale_shape, in_dtype), + Qzeros: T.Buffer(Qzeros_shape, storage_dtype), + Zeros: T.Buffer(Zeros_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), + Bias: T.Buffer(Bias_shape, in_dtype), ): with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads - ) as (bx, by): + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_local([local_size_compressed], storage_dtype) @@ -667,9 +633,7 @@ def general_shared_dequant_matmul( zeros_local = T.alloc_local([local_zeros_size], in_dtype) dequant_qzeros_local = T.alloc_local([local_qzeros_size], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, in_dtype - ) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) C_shared = T.alloc_shared([block_M, block_N], out_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -685,18 +649,12 @@ def general_shared_dequant_matmul( T.copy(A[by * block_M, ko * block_K], A_shared) T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) - for i in T.serial( - block_N - * block_K - // num_elems_per_byte - // (threads * local_size_compressed) - ): + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = ( - i * threads * local_size_compressed - + tx * local_size_compressed - + v - ) + i * threads * local_size_compressed + tx * local_size_compressed + + v) vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index f2f462926..a0242f99b 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -21,14 +21,6 @@ MatmulDequantizeBaseScheduler, # noqa: F401 ) from bitblas.tl.base_hint import BaseTLHint -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_packed_to_fp4_to_f16, - _tir_u8_to_f8_e4m3_to_f16, - _tir_packed_to_unsigned_convert_with_zeros, -) # GPU warp configuration for NVIDIA GPUs warp_size = 32 @@ -50,6 +42,8 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): class TLHint(BaseTLHint): + hint_type: str = "MatmulDequantizeFineGrainedScheduler" + def __init__(self): super().__init__() @@ -108,6 +102,9 @@ def __repr__(self): f"enable_rasterization={self.enable_rasterization}" "}") + def get_hint_type(self) -> str: + return self.TLHint.hint_type + def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: @@ -403,6 +400,9 @@ def general_dequant_matmul( @dataclass class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedScheduler): + class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): + hint_type: str = "MatmulINT4DequantizeFineGrainedScheduler" + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index c7c452821..453a76924 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -38,6 +38,9 @@ class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedSche # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): + hint_type: str = "MatmulDequantizeWeightPropagationScheduler" + def apply_config( self, block_row_warps: Optional[int] = None, @@ -59,12 +62,14 @@ def apply_config( N, K = self.N, self.K assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" trans_A, trans_B = self.trans_A, self.trans_B + input_transform_kind = self.input_transform_kind weight_transform_kind = self.weight_transform_kind assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - assert (weight_transform_kind == TransformKind.LDMatrixTransform - ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" + assert ( + weight_transform_kind == TransformKind.LDMatrixTransform + ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" in_dtype, out_dtype, accum_dtype = ( self.in_dtype, @@ -101,7 +106,6 @@ def apply_config( if group_size == -1: group_size = K - A_shape = (M, K) B_shape = ( N // micro_size_y, K // micro_size_k, @@ -115,7 +119,21 @@ def apply_config( C_shape = (M, N) Bias_shape = (N,) - A_shared_shape = (block_M, block_K) + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -164,6 +182,7 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, + transform_kind_a=input_transform_kind, transform_kind_b=weight_transform_kind, num_elems_per_byte=num_elems_per_byte, ) @@ -200,7 +219,8 @@ def general_dequant_matmul( tx = T.thread_binding(0, threads, thread="threadIdx.x") T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), + A_shared: make_swizzle_layout(A_shared, is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_b_smooth), }) T.use_swizzle(10, enable=enable_rasterization) @@ -211,7 +231,17 @@ def general_dequant_matmul( for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_shared) + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * vec_load_qb)): @@ -530,19 +560,13 @@ def _normal_fast_dequant_impl( else: T.call_extern( func_name, - T.address_of( - compressed_weight_local[ - j * local_size // num_elems_per_byte - ] - ), + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), T.address_of(dequant_weight_local[j * local_size]), T.address_of(scale_buffer[remapped_i, remapped_j]), - T.address_of( - qzeros_buffer[ - qzeros_remapped_i, - (qzeros_remapped_j // num_elems_per_byte), - ] - ), + T.address_of(qzeros_buffer[ + qzeros_remapped_i, + (qzeros_remapped_j // num_elems_per_byte), + ]), local_size * grouped_k, local_size // num_elems_per_byte, qzeros_remapped_j % num_elems_per_byte, @@ -593,10 +617,21 @@ def get_param_indices( return new_indices + @property + def is_a_smooth(self): + return self.input_transform_kind > TransformKind.NonTransform + + @property + def is_b_smooth(self): + return self.weight_transform_kind > TransformKind.NonTransform + @dataclass class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): + class TLHint(MatmulDequantizeWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4DequantizeWeightPropagationScheduler" + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" M = self.M @@ -678,8 +713,9 @@ def apply_config( assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - assert (weight_transform_kind == TransformKind.LDMatrixTransform - ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" + assert ( + weight_transform_kind == TransformKind.LDMatrixTransform + ), f"Dequantize only implement for LDMatrixTransform currently, got {weight_transform_kind}" in_dtype, out_dtype, accum_dtype = ( self.in_dtype, @@ -718,7 +754,21 @@ def apply_config( if group_size == -1: group_size = K - A_shape = (M, K) + is_a_smooth = self.is_a_smooth + is_b_smooth = self.is_b_smooth + + if is_a_smooth: + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + A_shared_shape = ( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ) + else: + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + B_shape = ( N // micro_size_y, K // micro_size_k, @@ -731,7 +781,6 @@ def apply_config( micro_size_y, micro_size_k, ) - A_shared_shape = (block_M, block_K) B_shared_shape = ( block_N // micro_size_y, block_K // micro_size_k, @@ -810,7 +859,8 @@ def general_dequant_matmul( tx = T.thread_binding(0, threads, thread="threadIdx.x") T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), + A_shared: make_swizzle_layout(A_shared, is_a_smooth), + B_shared: make_swizzle_layout(B_shared, is_b_smooth), }) T.use_swizzle(10, enable=enable_rasterization) @@ -821,7 +871,17 @@ def general_dequant_matmul( for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_shared) + if is_a_smooth: + for i, k, ii, kk in T.Parallel( + block_M // micro_size_x, + block_K // micro_size_k, + micro_size_x, + micro_size_k, + ): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), + ko * (block_K // micro_size_k), ii, kk] + else: + T.copy(A[by * block_M, ko * block_K], A_shared) # Load B into shared memory # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index efba5773b..8df3b2a17 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -383,7 +383,7 @@ def hardware_aware_finetune( def get_profile_tensors(self, dynamic_symbolic_constraints: Optional[Dict] = None): if dynamic_symbolic_constraints is None: dynamic_symbolic_constraints = {} - func = self.prim_func + func = self.prim_func or retrieve_func_from_module(self.scheduled_ir_module) device = self.arch.device def var_warpper(v): diff --git a/bitblas/relax/transform/apply_fast_tuning.py b/bitblas/relax/transform/apply_fast_tuning.py index 035c93d0d..00ccf67f3 100644 --- a/bitblas/relax/transform/apply_fast_tuning.py +++ b/bitblas/relax/transform/apply_fast_tuning.py @@ -137,10 +137,12 @@ def transform_module( # pylint: disable=missing-function-docstring updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) continue - if check_func_with_dynamic(func): + specalized_function = func.with_attr("global_symbol", g_var.name_hint) + + if check_func_with_dynamic(specalized_function): dispatch_mod = fast_tune_with_dynamic_range( - func, + specalized_function, target=target, topk=self.topk, parallel_build=self.parallel_build, @@ -161,7 +163,7 @@ def transform_module( # pylint: disable=missing-function-docstring else: # otherwise is static shape analysis _, best = fast_tune( - func, + specalized_function, target=target, topk=self.topk, parallel_build=self.parallel_build, diff --git a/bitblas/tl/base_hint.py b/bitblas/tl/base_hint.py index 24c9bba2f..752ea4145 100644 --- a/bitblas/tl/base_hint.py +++ b/bitblas/tl/base_hint.py @@ -8,6 +8,9 @@ # Base class for Tensor Layout Hints that defines the interface and common functionality for derived classes. class BaseTLHint(ABC): + # hint identifier + hint_type: str = "base" + # Constructor for the BaseTLHint class, takes variable arguments (*args and **kwargs) to allow flexibility. def __init__(self, *args, **kwargs): # Calls the superclass constructor (useful in complex inheritance hierarchies). diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index 5dab4ba64..443da90eb 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -30,6 +30,18 @@ def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id): return row, col +def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): row = 8 * (local_id % 4 // 2) + (thread_id // 4) col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) diff --git a/bitblas/tl/mma_macro_generator.py b/bitblas/tl/mma_macro_generator.py index 8e238e0f4..f28233911 100644 --- a/bitblas/tl/mma_macro_generator.py +++ b/bitblas/tl/mma_macro_generator.py @@ -397,10 +397,112 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): else: raise ValueError("Unsupported transform_kind_b") - if self.transform_kind_a != TransformKind.NonTransform: - raise ValueError("TransformKind A is not supported yet") + assert transform_kind_a in [0, 1, 2, 3], "Input transform stage should be 0, 1, 2, or 3" + assert transform_kind_b in [0, 1, 2, 3], "Weight transform stage should be 0, 1, 2, or 3" - assert transform_kind_b in [0, 3], "Currently only support 0 and 3" + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + a_dtype = self.a_dtype + a_transposed = self.a_transposed + transform_kind_a = self.transform_kind_a + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_bindings, + rk=0, + ): + stride = A_shared_buf.shape[-1] + tx, _, warp_m = self.extract_thread_binding(thread_bindings) + if transform_kind_a == TransformKind.NonTransform: + for i in T.serial(warp_rows): + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_buf[ + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ]), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + elif transform_kind_a == TransformKind.InterWarpTransform: + for i in T.serial(warp_rows): + # Assign B_shared_elem + ri, rj = ( + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_x, + (rj) // micro_size_k, + (ri) % micro_size_x, + (rj) % micro_size_k, + ) + args = (ni, nj, nii, njj) if transform_kind_a > 0 else (ri, rj) + A_shared_elem = A_shared_buf[args] + + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_elem), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), + ) + elif transform_kind_a == TransformKind.IntraWarpTransform: + for i in T.serial(warp_rows): + # Assign B_shared_elem + ri, rj = ( + warp_m * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_x, + (rj) // micro_size_k, + (ri) % micro_size_x, + (rj) % micro_size_k, + ) + A_shared_elem = A_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of(A_shared_elem), + tx * local_size_a, + ) + elif transform_kind_a == TransformKind.LDMatrixTransform: + for j in T.serial(warp_rows): + for local_id in T.vectorized(local_size_a): + # Assign A_shared_elem + ri, rj = ( + warp_m * warp_rows + j, + rk * (chunk // micro_size_k) + ki, + ) + rii, rjj = (tx * local_size_a + + local_id) // micro_size_k, (tx * local_size_a + local_id) % ( + micro_size_k) + A_local_buf[j * local_size_a + local_id] = (A_shared_buf[ri, rj, rii, rjj]) + else: + raise ValueError("Unsupported TransformKind for Input A") + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): warp_col_tiles = self.warp_col_tiles @@ -425,7 +527,26 @@ def _warp_ldmatrix_b( stride = B_shared_buf.shape[-1] tx, warp_n, _ = self.extract_thread_binding(thread_bindings) - if transform_kind_b < TransformKind.LDMatrixTransform: + if transform_kind_b == TransformKind.NonTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + B_shared_elem = B_shared_buf[ri, rj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), + ) + elif transform_kind_b == TransformKind.InterWarpTransform: for j in T.serial(warp_cols): # Assign B_shared_elem ri, rj = ( @@ -438,8 +559,7 @@ def _warp_ldmatrix_b( (ri) % micro_size_y, (rj) % micro_size_k, ) - args = (ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj) - B_shared_elem = B_shared_buf[args] + B_shared_elem = B_shared_buf[ni, nj, nii, njj] T.ptx_ldmatrix( b_dtype, @@ -451,7 +571,32 @@ def _warp_ldmatrix_b( T.address_of(B_shared_elem), get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), ) - else: + elif transform_kind_b == TransformKind.IntraWarpTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + warp_n * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_y, + (rj) // micro_size_k, + (ri) % micro_size_y, + (rj) % micro_size_k, + ) + B_shared_elem = B_shared_buf[ni, nj, nii, njj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + tx * local_size_b, + ) + elif transform_kind_b == TransformKind.LDMatrixTransform: local_size_dequantize = local_size_b // num_elems_per_byte for j in T.serial(warp_cols): for local_id in T.vectorized(local_size_dequantize): @@ -466,6 +611,8 @@ def _warp_ldmatrix_b( micro_size_k // num_elems_per_byte) B_local_buf[j * local_size_dequantize + local_id] = ( B_shared_buf[ri, rj, rii, rjj]) + else: + raise ValueError("Unsupported TransformKind for Input B") return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index e58aaf500..a275e91f6 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -72,7 +72,7 @@ def apply_and_build_parallel(scheduler, arch, num_repeats=3, max_workers=10, - timeout=30, + timeout=60, data_distribution="uniform") -> CompileResult: cpresults = [] diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 660107647..38f49786e 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -11,7 +11,6 @@ from bitblas import set_log_level import numpy as np -print("bitblas. path is ", bitblas.__path__) np.random.seed(0) set_log_level(logging.DEBUG) diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index 821e3aa25..d343d8078 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -885,6 +885,208 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) +@simplify_prim_func +def tl_matmul_with_ladder_input_weight_transform( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + transform_a, + transform_b, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + assert transform_a > 0, "transform_a should be greater than 0" + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + + warp_rows = 1 + warp_cols = 1 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M // micro_size_x, K // micro_size_k, micro_size_x, micro_size_k) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M // micro_size_x, block_K // micro_size_k, micro_size_x, micro_size_k) + B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_a=transform_a, + transform_kind_b=transform_b) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k, ii, kk in T.Parallel(block_M // micro_size_x, block_K // micro_size_k, + micro_size_x, micro_size_k): + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main + + +def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_dtype, out_dtype, + accum_dtype, transform_a, + transform_b): + matmul = tl_matmul_with_ladder_input_weight_transform(M, N, K, in_dtype, out_dtype, accum_dtype, + transform_a, transform_b) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config_A = bitblas.ops.LadderPermutateConfig( + M=M, + N=K, + datatype=in_dtype, + storage_dtype=in_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=transform_a, + ) + + ladder_permutate_a = bitblas.ops.LadderPermutate(ladder_permutate_config_A) + + ladder_permutate_config_B = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + datatype=in_dtype, + storage_dtype=in_dtype, + propagate_kind="B", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate_b = bitblas.ops.LadderPermutate(ladder_permutate_config_B) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + LA = ladder_permutate_a(A.cpu()).cuda() + LB = ladder_permutate_b(B.cpu()).cuda() + + mod(LA, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + # print("Ref C: ", ref_c) + # print("C: ", C) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + def test_assert_tl_matmul(): assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") @@ -908,5 +1110,10 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): 256, 1024, 512, "float16", "float16", "float16", 3) +def test_assert_tl_matmul_with_ladder_input_weight_transform(): + assert_tl_matmul_with_ladder_input_weight_transform_correctness(256, 256, 256, "float16", + "float16", "float16", 2, 3) + + if __name__ == "__main__": bitblas.testing.main()