diff --git a/bitblas/base/common_schedules.py b/bitblas/base/common_schedules.py index 7d528c70a..d4ac7346b 100644 --- a/bitblas/base/common_schedules.py +++ b/bitblas/base/common_schedules.py @@ -22,7 +22,7 @@ from typing import Callable, List from tvm import tir - +from bitblas.utils import retrieve_func_from_module from .analysis import BlockInfo @@ -74,7 +74,7 @@ def get_output_blocks( """ # collect arguments buffer - func = sch.mod["main"] + func = retrieve_func_from_module(sch.mod) args = list(func.buffer_map.values()) output_blocks = [] diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 60560120e..e6e427b5a 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -21,7 +21,12 @@ import tempfile import itertools from tvm.ir.supply import GlobalVarSupply -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils import ( + tensor_replace_dp4a, + tensor_remove_make_int4, + tensor_remove_make_int2, + retrieve_func_from_module, +) from bitblas.utils.tensor_adapter import ( np_float2np_bf16,) import logging @@ -58,18 +63,28 @@ def __init__(self, config, sch, mod: Module): self.time_evaluator = None def profile(self, data_distribution="uniform"): - func = self.sch.mod["main"] + func = retrieve_func_from_module(self.sch.mod) device = self.config.arch.device profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution) latency = self.time_evaluator(*profile_tensors).mean * 1e3 return latency -def get_roller_hints_from_func(func: tir.PrimFunc, +def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, allow_gemv: bool = False) -> Optional[List[Hint]]: + func = None + if isinstance(func_or_module, tir.PrimFunc): + func = func_or_module + elif isinstance(func_or_module, IRModule): + func = retrieve_func_from_module(func_or_module) + else: + raise ValueError("Not supported type: ", type(func_or_module)) + + assert func is not None, "The function should not be None" + if tensorcore_only: try: tensorized_func, tags = get_tensorized_func_and_tags( diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 351dd3739..1f596ef9a 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -431,9 +431,8 @@ def is_dequantize(block: BlockRV) -> bool: has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) if not has_uint_input: return False - if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): - return False - return True + return not (len(block_stmt.writes) != 1 or + "float" not in str(block_stmt.writes[0].buffer.dtype)) dequantize_blocks = [block for block in blocks if is_dequantize(block)] return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None @@ -552,9 +551,7 @@ def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: len( collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) > 0) - if not all(conditions): - return False - return True + return all(conditions) # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) def check_sm_version(arch: str) -> int: @@ -677,14 +674,20 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) # 16 for 16 bits tensor core while 32 for 8bits tensorcore. - minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 + minimal_tensorize_spatial_threshold = 16 + minimal_tensorize_reduce_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 # the batch dimension is not taken into consideration. - extent = block_stmt.iter_vars[1].dom.extent - if isinstance(extent, tir.expr.IntImm) and (extent.value < (1 if allow_gemv else - minimal_tensorize_threshold)): - return func, None - for item_var in block_stmt.iter_vars[2:]: + for item_var in block_stmt.iter_vars[1:]: extent = item_var.dom.extent + iter_type = item_var.iter_type + + if iter_type is IterVar.DataPar: + minimal_tensorize_threshold = minimal_tensorize_spatial_threshold + elif iter_type is IterVar.CommReduce: + minimal_tensorize_threshold = minimal_tensorize_reduce_threshold + else: + raise ValueError(f"Unknown IterVar type {iter_type}") + if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): return func, None tags = analysis_tensorcore_tags(sch, main_block, target) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 7833865b8..e0d40e6bd 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -141,7 +141,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): ) roller_hints = get_roller_hints_from_func( - ir_module["main"], + ir_module, arch, topk, tensorcore_only=True, @@ -340,7 +340,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): ) roller_hints = get_roller_hints_from_func( - ir_module["main"], + ir_module, arch, topk, tensorcore_only=True, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py index efdfd58ea..8e55b0231 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py @@ -32,11 +32,17 @@ class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + M = self.M K = self.K // 2 # 2xint4 should be packed into one single int8 # Simple TIR Compute Expression storage_dtype = "int8" + + # This is a hack to utilize tensor core + if isinstance(M, int) and M < 16: + M = 16 + ir_module = matmul_select_implementation( - M=self.M, + M=M, N=self.N, K=K, in_dtype=storage_dtype, @@ -46,7 +52,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): ) roller_hints = get_roller_hints_from_func( - ir_module["main"], + ir_module, arch, topk, tensorcore_only=True, @@ -230,6 +236,47 @@ def __post_init__(self): @dataclass class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + M = self.M + K = self.K // 2 # 2xint4 should be packed into one single int8 + # Simple TIR Compute Expression + storage_dtype = "int8" + + # This is a hack to utilize tensor core + if isinstance(M, int) and M < 16: + M = 16 + + ir_module = matmul_select_implementation( + M=M, + N=self.N, + K=K, + in_dtype=storage_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + propagate_b=self.weight_transform_kind) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + def serialze_hints_to_configs(hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + return serialze_hints_to_configs(roller_hints) + def apply_config( self, block_row_warps=2, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py index 4d45e1204..036ace634 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -136,7 +136,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): ) roller_hints = get_roller_hints_from_func( - ir_module["main"], + ir_module, arch, topk, tensorcore_only=True, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index 942a66a90..060d52b1e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -154,7 +154,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): zeros_mode=self.zeros_mode) roller_hints = get_roller_hints_from_func( - ir_module["main"], + ir_module, arch, topk, tensorcore_only=True, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py index e7fb80d24..100ab0a31 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py @@ -32,9 +32,15 @@ class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedSchedu def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + M = self.M K = self.K // 2 # 2xint4 should be packed into one single int8 storage_dtype = "int8" num_bits = self.num_bits * 2 + + # This is a hack to utilize tensor core + if isinstance(M, int) and M < 16: + M = 16 + # INT4XINT2 is equal to int8xint4 with reduced shape # Simple TIR Compute Expression ir_module = matmul_dequantize_select_implementation( @@ -56,7 +62,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): zeros_mode=self.zeros_mode) roller_hints = get_roller_hints_from_func( - ir_module["main"], + ir_module, arch, topk, tensorcore_only=True, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py index 153e1f64a..fef550bd7 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py @@ -3,18 +3,23 @@ from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T -from typing import Optional +from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 make_swizzle_layout, # noqa: F401 index_to_coordinates, # noqa: F401 ) +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint from bitblas.tl.macro_generator import ( INT4TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) from bitblas.ops.common import TransformKind # noqa: F401 from dataclasses import dataclass +from bitblas.base.utils import get_roller_hints_from_func from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group +from bitblas.ops.general_matmul.tirscript import ( + matmul_dequantize_select_implementation,) from bitblas.ops.general_matmul.tilelang.dequantize.ladder_weight_transform_tensorcore import ( MatmulDequantizeWeightPropagationScheduler,) @@ -25,6 +30,60 @@ @dataclass class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + M = self.M + K = self.K // 2 # 2xint4 should be packed into one single int8 + storage_dtype = "int8" + num_bits = self.num_bits * 2 + + # This is a hack to utilize tensor core + if isinstance(M, int) and M < 16: + M = 16 + + # INT4XINT2 is equal to int8xint4 with reduced shape + # Simple TIR Compute Expression + ir_module = matmul_dequantize_select_implementation( + M=M, + N=self.N, + K=K, + in_dtype=storage_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + bit=num_bits, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + zeros_mode=self.zeros_mode) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + for hint in roller_hints: + print(hint) + + def serialze_hints_to_configs(hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + return serialze_hints_to_configs(roller_hints) + def apply_config( self, block_row_warps: Optional[int] = None, diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 1b22491f5..94ea042cd 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -21,6 +21,7 @@ from bitblas.builder.wrapper import TIRWrapper, TLWrapper from bitblas.builder.lib_generator import LibraryGenerator from bitblas.common import MAX_ERROR_MESSAGE_LENGTH +from bitblas.utils import retrieve_func_from_module from dataclasses import dataclass import logging import re @@ -317,6 +318,7 @@ def apply_fast_tuning( elif self.is_tilelang_backend(): # Finetune the schedule tuning_configs = self.get_tl_tuning_config(topk=topk) + assert len(tuning_configs) > 0, "No tuning config found for this operator." _, best = tl_apply_and_build( func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=parallel_build) # Return the best Config as Hint @@ -368,11 +370,9 @@ def hardware_aware_finetune( assert ( len(scheduled_mod.get_global_vars()) == 1 ), "The optimized module should only have one global variable for default schedule." - assert ( - "main" in scheduled_mod - ), "The optimized module should have a function named 'main' for default schedule." default_kernal_name = self.kernel_name_generator.generate(best_hint) - func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + func = retrieve_func_from_module(scheduled_mod).with_attr("global_symbol", + default_kernal_name) scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) self._update_optimized_mod(scheduled_ir_module) @@ -465,9 +465,6 @@ def forward(self, *args): def __call__(self, *args: Any) -> Any: return self.forward(*args) - def update_func(self, func: PrimFunc): - self.ir_module["main"] = func - def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None): if rt_mod is not None: self.rt_mod = rt_mod diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 6747d0632..cdbc74a7d 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -14,9 +14,13 @@ from bitblas.base.utils import get_dummy_input_arrays from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils import ( + tensor_replace_dp4a, + tensor_remove_make_int4, + tensor_remove_make_int2, + retrieve_func_from_module, +) from bitblas.common import MAX_ERROR_MESSAGE_LENGTH - import logging import tempfile @@ -52,7 +56,7 @@ def __init__(self, config, sch, mod: Module): self.time_evaluator = None def profile(self, data_distribution="uniform"): - func = self.sch.mod["main"] + func = retrieve_func_from_module(self.sch.mod) device = self.config.arch.device profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution) latency = self.time_evaluator(*profile_tensors).mean * 1e3 diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py index 2ba3cd5f5..e94b178aa 100644 --- a/bitblas/utils/__init__.py +++ b/bitblas/utils/__init__.py @@ -8,6 +8,8 @@ import subprocess from bitblas.common import BITBLAS_DEFAULT_CACHE_PATH +from tvm import IRModule +from tvm.tir import PrimFunc def get_commit_id(): @@ -21,3 +23,12 @@ def get_commit_id(): def get_default_cache_path(): return BITBLAS_DEFAULT_CACHE_PATH + + +def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: + if not isinstance(ir_module, IRModule): + raise ValueError("Not supported type: ", type(ir_module)) + assert len(ir_module.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + func = list(ir_module.functions.values())[0] + return func