diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index d2168c850..90fab86d0 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -6,7 +6,7 @@ from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np -from typing import List, Tuple, Optional, Dict, Union, Literal +from typing import List, Tuple, Optional, Dict, Union, Literal, Callable from tvm import tir, IRModule from tvm.runtime import Module from tvm.tir import Schedule @@ -455,13 +455,13 @@ def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[ def create_dispatch_mod(g_var: str, original_func: tir.PrimFunc, - specialized_funcs: List[tir.PrimFunc]) -> IRModule: + specialized_funcs: List[tir.PrimFunc], function_symbols) -> IRModule: dispatch_mod: IRModule = tvm.IRModule() g_var_supply = GlobalVarSupply(dispatch_mod) refactored_funcs = [] - for func in specialized_funcs: + for f_var, func in zip(function_symbols, specialized_funcs): params, buffers_to_declare = collect_buffers_to_declare(func) - global_symbol, device_func = refactor_specialized_func(g_var, func, params, + global_symbol, device_func = refactor_specialized_func(f_var, func, params, buffers_to_declare) global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) dispatch_mod[global_symbol] = device_func @@ -478,6 +478,7 @@ def fast_tune_with_dynamic_range( parallel_build: bool = True, global_symbol: Optional[str] = None, dynamic_range: Optional[Dict[str, List[int]]] = None, + kernel_name_generator: Optional[Callable] = None, ) -> IRModule: if dynamic_range is None: dynamic_range = {} @@ -517,12 +518,30 @@ def fast_tune_with_dynamic_range( # Convert the Cartesian product to a list of dictionaries specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list] + function_symbols: List[str] = [] specilized_tuned_funcs: List[tir.PrimFunc] = [] for item in specialize_items: func = func.with_attr("opt_shapes", item) _, best = fast_tune(func, target, topk, parallel_build) if best is None: return None - specilized_tuned_funcs.append(best.sch.mod["main"]) - - return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs) + specialized_func = best.sch.mod["main"] + function_symbol = global_symbol + if kernel_name_generator is not None: + scheduled_mod = best.sch.mod + best_hint = best.config + assert len(scheduled_mod.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + assert "main" in scheduled_mod, ( + "The optimized module should have a function named 'main' for default schedule.") + default_kernal_name = kernel_name_generator.generate(best_hint) + specialized_func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + function_symbol = default_kernal_name + + function_symbols.append(function_symbol) + specilized_tuned_funcs.append(specialized_func) + + assert global_symbol is not None, "The global_symbol should not be None" + assert len(function_symbols) == len(specilized_tuned_funcs), ( + "The length of global_symbols should be equal to the length of specilized_tuned_funcs") + return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs, function_symbols) diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index 59d63298b..f39c7cfab 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -13,6 +13,22 @@ logger = logging.getLogger(__name__) +PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ + cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {}); +""" + +PREDEF_INIT_FUNC = """ +extern "C" void init() {{ + {} +}} +""" + +PREDEF_HOST_FUNC = """ +extern "C" void call({}) {{ +{} +}} +""" + class TIRCUDASourceWrapper(object): _TYPE_MAP = { @@ -77,16 +93,11 @@ def get_cuda_init_func(self): call_str = """""" # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call if self.dynamic_smem_buf is not None: - call_str = """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(self.function_name, self.dynamic_smem_buf) + call_str = ( + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, + self.dynamic_smem_buf)) # Format the initialization function using the call_str - init_funcs = """ - extern "C" void init() {{ - {} - }} - """.format(call_str) + init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs def update_lib_code(self, code: str): @@ -162,18 +173,19 @@ def legalize_c(p): call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, smem_str, call_args) # Create the host function wrapper for the CUDA kernel - host_func = """ - extern "C" void call({}) {{ - {} - }} - """.format(def_args, call_str) + host_func = PREDEF_HOST_FUNC.format(def_args, call_str) # Combine the source, initialization function, and host function to form the complete library code lib_code = self.source + init_func + host_func return lib_code @property def prim_func(self): - return self.mod["main"] + if len(self.mod.get_global_vars()) == 1: + return self.mod[self.mod.get_global_vars()[0]] + elif "main" in self.mod: + return self.mod["main"] + else: + raise ValueError("Unable to determine primary function.") class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper): @@ -188,16 +200,10 @@ def get_cuda_init_func(self): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): if dynamic_smem_buf is not None: # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(function_name, dynamic_smem_buf) + call_str += ( + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf)) # Define the init function that will set the attributes for each kernel - init_funcs = """ -extern "C" void init() {{ - {} -}} - """.format(call_str) + init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs def create_dispatch_func(self, code, function_informations): @@ -278,8 +284,8 @@ def legalize_c(p): (symbolic,) = list(dynamic_symbolic_set) range_str = opt_shapes[symbolic] if last_range == 0: - call_str = "if ({} == 0) return; \n".format(symbolic,) - call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + call_str = " if ({} == 0) return; \n".format(symbolic,) + call_str += " if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( symbolic, range_str, function_name, @@ -289,7 +295,7 @@ def legalize_c(p): call_args, ) else: - call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + call_str = " else if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( symbolic, range_str, function_name, @@ -299,18 +305,13 @@ def legalize_c(p): call_args, ) if last_range == num_items - 1: - call_str += ( - "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - function_name, grid_str, block_str, smem_str, call_args)) + call_str += (" else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) last_range += 1 _call_str += call_str # Wrap the kernel dispatch logic in an external C function - host_func = """ -extern "C" void call({}) {{ - {} -}} - """.format(def_args, _call_str) + host_func = PREDEF_HOST_FUNC.format(def_args, _call_str) return host_func def parse_source_information(self): @@ -381,10 +382,6 @@ def compare_map_objects(map_obj): lib_code = self.source + init_func + host_func return lib_code - @property - def prim_func(self): - return self.mod["main"] - class TIRWrapper(BaseWrapper): diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 6c5ea1ebe..597b2a34f 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -108,8 +108,8 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): # For writing optimized.py file optimized_file_path = os.path.join(config_path, "optimized.py") with open(optimized_file_path, "w") as optimized_file: - if op_inst.optimized_func is not None: - optimized_file.write(op_inst.optimized_func.script(show_meta=False)) + if op_inst.optimized_mod is not None: + optimized_file.write(op_inst.optimized_mod.script(show_meta=False)) if op_inst.libpath is not None: # copy lib name to the same directory as the artifact srcpath = op_inst.srcpath diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 2945996df..dfd22e6e8 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -6,8 +6,9 @@ from functools import reduce from enum import IntEnum from bitblas.base.arch.cuda import CUDA +from bitblas.base.roller.hint import Hint from typing import Any, Literal, Optional, Tuple, Union -from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU +from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU, BaseKernelNameGenerator from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation from .tirscript.matmul_impl import select_implementation as consistent_implementation from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 @@ -226,6 +227,85 @@ def __post_init__(self): object.__setattr__(self, "storage_dtype", self.W_dtype) +class MatmulKernelNameGenerator(BaseKernelNameGenerator): + + KERNEL_PREFIX = "matmul" + + @staticmethod + def serialize_hint(hint: Optional[Hint] = None) -> str: + if hint is None: + return "default" + else: + if hint.use_tc: + hint_prefix = "tc" + BM, BN = hint.block + WM, WN = hint.warp + BK = hint.rstep[-1] + reduce_k = hint.block_reduction_depth + pipeline_stage = hint.pipeline_stage + hint_name = f"{hint_prefix}x{BM}x{BN}x{BK}w{WM}x{WN}" + if reduce_k is not None and reduce_k > 1: + hint_name += f"xr{reduce_k}" + if pipeline_stage > 1: + hint_name += f"xp{pipeline_stage}" + return hint_name + else: + hint_prefix = "simt" + # do not annotate for simt currently + return hint_prefix + + @staticmethod + def simplify_dtype(dtype: str) -> str: + if dtype == "float32": + return "f32" + elif dtype == "float16": + return "f16" + elif dtype == "bfloat16": + return "bf16" + elif dtype.startswith("int"): + return f"i{dtype[3:]}" + elif dtype.startswith("uint"): + return f"u{dtype[4:]}" + return dtype + + def generate(self, hint=None) -> str: + config = self.config + kernel_name = self.KERNEL_PREFIX + shape_str = f"n{self.config.N}k{self.config.K}" + if isinstance(config.M, int): + shape_str = f"m{config.M}" + shape_str + + A_dtype = self.simplify_dtype(config.A_dtype) + W_dtype = self.simplify_dtype(config.W_dtype) + + precision_str = (f"{A_dtype}x{W_dtype}") + kernel_name = "_".join([kernel_name, shape_str, precision_str]) + + # if config.with_scaling: + # kernel_name += "Scale" + + # if config.with_zeros: + # if config.zeros_mode == "original": + # kernel_name += "OriginalZeros" + # elif config.zeros_mode == "rescale": + # precision_str += "RescaleZeros" + # elif config.zeros_mode == "quantized": + # precision_str += "QuantizedZeros" + # else: + # raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}") + + # if config.propagate_a is not TransformKind.NonTransform: + # kernel_name += f"_pa{config.propagate_a.value}" + # if config.propagate_b is not TransformKind.NonTransform: + # kernel_name += f"_pb{config.propagate_b.value}" + + kernel_name = "_".join([kernel_name, self.serialize_hint(hint)]) + return kernel_name + + def is_valid_config(self, config: OperatorConfig) -> bool: + return isinstance(config, MatmulConfig) + + class Matmul(Operator): # TODO(lei): This should be improved into a general datatype class. @@ -350,6 +430,9 @@ def dispatch_tir(self, # output data type self.torch_output_dtype = getattr(torch, self.out_dtype) + def get_kernel_name_generator(self): + return MatmulKernelNameGenerator(self.config) + def _alloc_workspace(self): return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index d09ee6dac..65ad06679 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -38,7 +38,7 @@ def __init__( target = self.target if target.kind.name == "cuda": - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + self.optimized_mod = self.apply_default_schedule(self.prim_func_mod, target) if enable_tuning: self.hardware_aware_finetune() if not from_database: diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index f6fa4cca0..a94da9969 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -8,11 +8,12 @@ from tvm.contrib.dlpack import to_pytorch_func import bitblas import ctypes -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Tuple import numpy as np from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.arch import get_arch +from bitblas.base.roller.hint import Hint from bitblas.builder.wrapper import TIRWrapper from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass @@ -21,6 +22,16 @@ logger = logging.getLogger(__name__) +APPLY_SCHEDULE_FAILED_MESSAGE = ("Failed to apply default schedule for operator {} " + "With target {} and hint {}. \n" + "The error message: {} " + "Please perform hardware-aware tuning manually.") + +BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE = ("Failed to build runtime library for operator {} " + "With target {} and hint {}. \n" + "The error message: {} " + "Please perform hardware-aware tuning manually.") + class TransformKind(IntEnum): NonTransform = 0 @@ -35,6 +46,24 @@ class OperatorConfig: pass +class BaseKernelNameGenerator(ABC): + """Optional class for generating kernel names based on the config and hint""" + + def __init__(self, config: OperatorConfig): + assert self.is_valid_config(config), (f"Invalid config for {self.__class__.__name__}: " + f"{config}") + self.config = config + + @abstractmethod + def is_valid_config(self, config: OperatorConfig): + pass + + @abstractmethod + def generate(self, hint: Hint = None) -> str: + '''Generate the kernel name based on the config and hint''' + pass + + class Operator(ABC): def __init__(self, name, config: OperatorConfig, target: Target = None): @@ -44,24 +73,30 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.config = config self.target = target self.prim_func_mod = self._select_implementation() - self.optimized_func = None + self.optimized_mod = None self.rt_mod = None self.time_evaluator = None self.arch = get_arch(target) if target else None self.dynamic_range = None self.pass_context: Dict = {} self.num_args = len(self.prim_func.params) - self.function_handle = None self.num_output_args: int = ( 1 # todo(lei): should be analyzed from the prim_func. ) + self.kernel_name_generator: Optional[BaseKernelNameGenerator] = ( + self.get_kernel_name_generator()) self.lib_generator = LibraryGenerator(self.arch) self.wrapper = TIRWrapper(self.arch) self.lib = None - def get_source(self, target: Target = None) -> str: + def get_kernel_name_generator(self) -> Optional[BaseKernelNameGenerator]: + return None + + def get_source(self, target: Optional[Target] = None, kenrel_only=False) -> str: if target is None: target = self.target + if self.lib_generator.lib_code is not None and not kenrel_only: + return self.lib_generator.lib_code if self.rt_mod is None: self._build_runtime_module(target) return self.rt_mod.imported_modules[0].get_source() if self.rt_mod else None @@ -88,7 +123,7 @@ def _build_runtime_module(self, target: Target): # Check if the platform is CUDA and we have an optimized function if self.arch.platform == "CUDA": - if self.optimized_func is None: + if self.optimized_mod is None: return None @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) @@ -96,17 +131,17 @@ def tvm_callback_cuda_postproc(code, _): return self.post_process(code) try: - # Use a specific TVM pass context for CUDA platforms with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, **self.pass_context }): - rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) + rt_mod = tvm.build(self.optimized_mod, target=target) except Exception: # noqa: F841 logger.debug( - "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" - ) + BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(self.__class__.__name__, target, + "optimized", + "Failed to build optimized module")) else: # For non-CUDA platforms or when no optimized function is available, build with the primary function rt_mod = tvm.build(self.prim_func, target=target, name=self.name) @@ -117,14 +152,14 @@ def tvm_callback_cuda_postproc(code, _): # Initialize a time evaluator with the built module, specifying the device and the number of runs self.time_evaluator = rt_mod.time_evaluator( rt_mod.entry_name, self.arch.device, number=10) - self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if self.arch.platform == "CUDA": try: is_dynamic = ( - self.dynamic_range is not None and len(self.optimized_func.functions) > 1) - self.wrapper.assign_optimized_module(self.optimized_func) - wrapped_source = self.wrapper.wrap(self.get_source(target), is_dynamic) + self.dynamic_range is not None and len(self.optimized_mod.functions) > 1) + self.wrapper.assign_optimized_module(self.optimized_mod) + wrapped_source = self.wrapper.wrap( + self.get_source(target, kenrel_only=True), is_dynamic) self.lib_generator.update_lib_code(wrapped_source) self.lib_generator.compile_lib() self.lib = self.lib_generator.load_lib() @@ -153,14 +188,25 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule return optimized_mod return None + def _update_optimized_mod(self, optimized_mod: IRModule): + self.optimized_mod = optimized_mod + def _build_default_module(self, target: Target): try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None + scheduled_mod = self.apply_default_schedule(self.prim_func_mod, target) + assert len(scheduled_mod.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + assert "main" in scheduled_mod, ( + "The optimized module should have a function named 'main' for default schedule.") + default_kernal_name = self.kernel_name_generator.generate() + func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + optimized_mod = tvm.IRModule({default_kernal_name: func}) + self._update_optimized_mod(optimized_mod) + except Exception as apply_schedule_error: + self.optimized_mod = None logger.warning( - "[BitBLAS][Warning] Apply default schedule failed. Please perform hardware-aware tuning manually." - ) + APPLY_SCHEDULE_FAILED_MESSAGE.format(self.__class__.__name__, target, "default", + apply_schedule_error)) self._build_runtime_module(target) @@ -171,12 +217,13 @@ def apply_fast_tuning(self, func: PrimFunc, target: Target, topk: int = 20, - parallel_build=True) -> IRModule: + parallel_build=True) -> Tuple[IRModule, Hint]: _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) - if best is not None: - return best.sch.mod + # annotate the best pass context + # TODO(lei): actually we should remove this by enable pass through + # annotation in the func's attribute. self.pass_context = best.config.pass_context - return None + return ((best.sch.mod, best.config) if best is not None else (None, None)) def apply_fast_tuning_with_dynamic_range( self, @@ -186,25 +233,39 @@ def apply_fast_tuning_with_dynamic_range( dynamic_range: Dict[str, List[int]] = None, ): optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) + func, + target, + topk=topk, + parallel_build=True, + dynamic_range=dynamic_range, + kernel_name_generator=self.kernel_name_generator) if optimized_mod is not None: return optimized_mod return None def hardware_aware_finetune(self, topk: int = 20, - target: tvm.target.Target = None, + target: Optional[tvm.target.Target] = None, parallel_build=True): if target is None: target = self.target dynamic_range = self.dynamic_range func = self.prim_func if dynamic_range is not None: - self.optimized_func = self.apply_fast_tuning_with_dynamic_range( + self.optimized_mod = self.apply_fast_tuning_with_dynamic_range( func, target, topk, dynamic_range) else: - self.optimized_func = self.apply_fast_tuning( + scheduled_mod, best_hint = self.apply_fast_tuning( func, target, topk, parallel_build=parallel_build) + assert len(scheduled_mod.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + assert "main" in scheduled_mod, ( + "The optimized module should have a function named 'main' for default schedule.") + default_kernal_name = self.kernel_name_generator.generate(best_hint) + func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + optimized_mod = tvm.IRModule({default_kernal_name: func}) + self._update_optimized_mod(optimized_mod) + self._build_runtime_module(self.target) def get_profile_tensors(self, dynamic_symbolic_constraints: Optional[Dict] = None): @@ -315,7 +376,6 @@ def update_func(self, func: PrimFunc): def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.rt_mod = rt_mod self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) - self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if srcpath is not None: assert self.lib_generator is not None, "lib_generator is not initialized" @@ -336,7 +396,12 @@ def _select_implementation(self) -> IRModule: @property def prim_func(self): - return self.prim_func_mod["main"] + if len(self.prim_func_mod.get_global_vars()) == 1: + return self.prim_func_mod[self.prim_func_mod.get_global_vars()[0]] + elif "main" in self.prim_func_mod: + return self.prim_func_mod["main"] + else: + raise ValueError("Unable to determine primary function.") @property def srcpath(self): diff --git a/bitblas/wrapper/__init__.py b/bitblas/wrapper/__init__.py deleted file mode 100644 index 1d87f8020..000000000 --- a/bitblas/wrapper/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .general import CUDASourceWrapper, CUDASourceWrapperWithDynamic # noqa: F401 diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py deleted file mode 100644 index 4e7c65c2c..000000000 --- a/bitblas/wrapper/general.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from typing import Optional, List, Dict, Union -from tvm import IRModule -from bitblas import TileDevice -from tvm.runtime import ndarray -from bitblas.utils import match_global_kernel -import re -import ctypes -import os -import tempfile -import subprocess -import logging -from tvm.driver import lower -from tvm.target import Target - -logger = logging.getLogger(__name__) - -_TYPE_MAP = { - "float32": "float", - "float16": "half", - "bfloat16": "__nv_bfloat16", - "e4m3_float8": "__nv_fp8_e4m3", - "e5m2_float8": "__nv_fp8_e5m2", - "float64": "double", - "int64": "int64_t", - "int32": "int", - "uint32": "unsigned int", - "bool": "int8_t", - "int8": "int8_t", - "uint8": "uint8_t", - "int16": "int16_t", - "uchar": "uint8_t", -} - - -def get_annotated_device_mod(mod: IRModule, target: Target): - """ - Lower the given IRModule and create a device module for the specified target. - - Parameters: - - mod: The input IRModule. - - target: The compilation target. - - Returns: - - A device module ready for execution. - """ - input_mod = lower(mod) - target_input_mod = {target: input_mod} - annotated_mods = {} - runtime = None - target_host = None - for tgt, mod in target_input_mod.items(): - if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or " - "Target when inputs is dict.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule, " - "or dict of str to IRModule.") - annotated_mods[tgt] = mod.with_attr("runtime", runtime) - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - if not target_host: - for tar, _ in annotated_mods.items(): - device_type = ndarray.device(tar.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type: - target_host = tar - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - for target, mod in annotated_mods.items(): - mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") - device_mod_passes = tvm.get_global_func("driver.device_mod_passes") - mod = mixed_mod_passes(mod, target)(mod) - device_mod = device_mod_passes(mod, target)(mod) - return device_mod - - -def get_thread_block_information(mod: IRModule): - """ - Extracts the thread block and grid dimensions for the reduction block within a given IRModule. - - Parameters: - - mod: The input IRModule from which to extract thread block and grid information. - - Returns: - A tuple containing two lists: - - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). - - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). - """ - - # Initialize the schedule from the IRModule - sch = tvm.tir.Schedule(mod) - - # Get the root block and its child blocks - root_block = sch.get_block("root") - child_blocks = sch.get_child_blocks(root_block) - - # Initialize default block and grid dimensions (1, 1, 1) - block_dims, grid_dims = [1, 1, 1], [1, 1, 1] - - for block in child_blocks: - # Get the loops surrounding the main block - loops = sch.get_loops(block) - - # Iterate over each loop to extract thread and block bindings - for loop in loops: - stmt = sch.get(loop) - thread_binding = stmt.thread_binding - extent = int(stmt.extent) - - # Skip loops without thread binding - if thread_binding: - if "threadIdx" in thread_binding.thread_tag: - block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent - elif "blockIdx" in thread_binding.thread_tag: - grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent - - return block_dims, grid_dims - - -class CUDASourceWrapper(object): - - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - self.mod = optimized_mod - self.arch = arch - self.source = source - self.function_name: Optional[str] = None - self.dynamic_smem_buf: Optional[int] = None - self.block_info: Union[List[int], Dict] = [1, 1, 1] - self.grid_info: Union[List[int], Dict] = [1, 1, 1] - self.parse_source_information() - self.srcpath: Optional[str] = None - self.libpath: Optional[str] = None - self.lib_code: Optional[str] = self.update_lib_code(source) - - def load_lib(self): - return ctypes.CDLL(self.libpath) - - def remove_lib(self): - if self.libpath: - os.remove(self.libpath) - self.libpath = None - - def compile_lib(self, timeout: float = None): - arch = self.arch - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) - compute_version = arch.compute_capability - libpath = src.name.replace(".cu", ".so") - - command = [ - "nvcc", - "-std=c++17", - "-Xcudafe", - "--diag_suppress=177", - "--compiler-options", - "'-fPIC'", - "-lineinfo", - "--shared", - src.name, - "-lcuda", - "-gencode", - f"arch=compute_{compute_version},code=sm_{compute_version}", - "-o", - libpath, - ] - src.write(self.lib_code) - src.flush() - try: - ret = subprocess.run(command, timeout=timeout) - except subprocess.TimeoutExpired: - logger.warning(f"Compilation Timeout! {command}") - return None - if ret.returncode != 0: - logger.warning(f"Compilation Failed! {command}") - return None - self.srcpath = src.name - self.libpath = libpath - - def parse_source_information(self): - device_mod = get_annotated_device_mod(self.mod, self.arch.target) - assert (len(device_mod.functions) == 1 - ), "Only support one function in the module for static shape kernel." - for g_var, func in device_mod.functions.items(): - self.function_name = g_var.name_hint - attrs = func.attrs - if "dyn_shared_memory_buf" in attrs: - self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) - if "thread_extent" in attrs: - thread_extent = attrs["thread_extent"] - for tag, extent in thread_extent.items(): - if "threadIdx" in tag: - self.block_info["xyz".index(tag[-1])] = extent - elif "blockIdx" in tag: - self.grid_info["xyz".index(tag[-1])] = extent - - def get_dynamic_symbolic_set(self, prim_func): - # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set = set() - for param in prim_func.params: - buffer = prim_func.buffer_map[param] - for dim in buffer.shape: - if isinstance(dim, tvm.tir.Var): - dynamic_symbolic_set.add(dim.name) - return dynamic_symbolic_set - - def get_cuda_init_func(self): - # Initialize an empty string for the CUDA function call - call_str = """""" - # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call - if self.dynamic_smem_buf is not None: - call_str = """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(self.function_name, self.dynamic_smem_buf) - # Format the initialization function using the call_str - init_funcs = """ - extern "C" void init() {{ - {} - }} - """.format(call_str) - return init_funcs - - def update_lib_code(self, code: str): - # Update the library code with the given code string - self.lib_code = code - # Find the index of the global kernel function in the code - index = match_global_kernel(code) - # Extract the declaration of the function starting from the found index - declaration = code[index:].split(";")[0] - - function_name = self.function_name - # Get the CUDA initialization function - init_func = self.get_cuda_init_func() - - # Locate the opening brace of the function to insert arguments - index = code.index("{", index) - function_args = [] - # Populate the function arguments from the primary function's parameters and buffers - for param in self.prim_func.params: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", - }) - - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - # Add dynamic symbolic parameters as integers to the function arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) - - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) - # Format the function arguments for declaration - def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - - def func_call_args(s, function_args): - # Extract the function call arguments matching the function definition - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for match in matches: - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - return call_args - - call_args = ", ".join(func_call_args(declaration, function_args)) - block_info, grid_info = self.block_info, self.grid_info - - def legalize_c(p): - # Convert TIR expressions to legal C expressions - # Directly convert to string since the special case handling - # does not alter the string representation for `tvm.tir.Var` and `IntImm`. - # Replace Python's floor division operator with C's division operator - if isinstance(p, tvm.tir.IntImm): - p = int(p) - return str(p).replace("//", "/") - - # Prepare the block and grid dimensions for the CUDA kernel launch - block_str = "dim3({}, {}, {})".format( - legalize_c(block_info[0]), - legalize_c(block_info[1]), - legalize_c(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) - # Determine the shared memory size, defaulting to 0 if not specified - smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf - # Format the CUDA kernel launch string - if len(dynamic_symbolic_set) != 0: - call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) - else: - call_str = "" - call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, - smem_str, call_args) - # Create the host function wrapper for the CUDA kernel - host_func = """ - extern "C" void call({}) {{ - {} - }} - """.format(def_args, call_str) - # Combine the source, initialization function, and host function to form the complete library code - lib_code = self.source + init_func + host_func - return lib_code - - @property - def prim_func(self): - return self.mod["main"] - - -class CUDASourceWrapperWithDynamic(CUDASourceWrapper): - - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - super().__init__(optimized_mod, source, arch) - - def get_cuda_init_func(self): - # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory - call_str = """""" - # Iterate over functions and their dynamic shared memory requirements - for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): - if dynamic_smem_buf is not None: - # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(function_name, dynamic_smem_buf) - # Define the init function that will set the attributes for each kernel - init_funcs = """ -extern "C" void init() {{ - {} -}} - """.format(call_str) - return init_funcs - - def create_dispatch_func(self, code, function_informations): - # Extract the set of dynamic symbolic names used in the primary function - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - - # Find the location of the global kernel function in the code - index = match_global_kernel(code) - - # Analyze the function declaration to prepare for argument extraction - dummy_declaration = code[index:].split(";")[0] - - function_name = self.function_name - - # Identify the start of the function body to insert arguments - index = code.index("{", index) - function_args = [] - # Collect function arguments based on primary function's parameters and buffer mappings - for param in self.prim_func.params: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", - }) - # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) - - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) - - # Format the argument definitions for function declaration - def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - - def func_call_args(s: str, function_args): - # Extract and clean the function call arguments to match the declaration - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for match in matches: - match = re.sub(r"\d+", "", match) # Remove numbers - match = re.sub(r"_", "", match) # Remove underscores - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - return call_args - - call_args = ", ".join(func_call_args(dummy_declaration, function_args)) - - def legalize_c(p): - # Convert TIR expressions to legal C expressions - # Directly convert to string since the special case handling - # does not alter the string representation for `tvm.tir.Var` and `IntImm`. - # Replace Python's floor division operator with C's division operator - if isinstance(p, tvm.tir.IntImm): - p = int(p) - return str(p).replace("//", "/") - - last_range = 0 - num_items = len(function_informations) - _call_str = """""" - for function_name, info in function_informations.items(): - # Prepare block and grid configurations for kernel launches - block_info, grid_info = info["block_info"], info["grid_info"] - block_str = "dim3({}, {}, {})".format( - legalize_c(block_info[0]), - legalize_c(block_info[1]), - legalize_c(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - legalize_c(grid_info[0]), - legalize_c(grid_info[1]), - legalize_c(grid_info[2]), - ) - # Handle dynamic shared memory specification - smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) - opt_shapes = info["opt_shapes"] - # Generate conditional kernel launch code based on dynamic symbolic ranges - (symbolic,) = list(dynamic_symbolic_set) - range_str = opt_shapes[symbolic] - if last_range == 0: - call_str = "if ({} == 0) return; \n".format(symbolic,) - call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - symbolic, - range_str, - function_name, - grid_str, - block_str, - smem_str, - call_args, - ) - else: - call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - symbolic, - range_str, - function_name, - grid_str, - block_str, - smem_str, - call_args, - ) - if last_range == num_items - 1: - call_str += ( - "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - function_name, grid_str, block_str, smem_str, call_args)) - last_range += 1 - _call_str += call_str - - # Wrap the kernel dispatch logic in an external C function - host_func = """ -extern "C" void call({}) {{ - {} -}} - """.format(def_args, _call_str) - return host_func - - def parse_source_information(self): - # Parse device module to extract execution configurations for each function - device_mod = get_annotated_device_mod(self.mod, self.arch.target) - block_info_map = {} - grid_info_map = {} - dynamic_smem_buf_map = {} - for g_var, func in device_mod.functions.items(): - # Default block and grid configurations - block_info = [1, 1, 1] - grid_info = [1, 1, 1] - function_name = g_var.name_hint - attrs = func.attrs - dynamic_smem_buf = None - if "dyn_shared_memory_buf" in attrs: - dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) - if "thread_extent" in attrs: - # Extract block and grid sizes from thread extents - thread_extent = attrs["thread_extent"] - for tag, extent in thread_extent.items(): - if "threadIdx" in tag: - block_info["xyz".index(tag[-1])] = extent - elif "blockIdx" in tag: - grid_info["xyz".index(tag[-1])] = extent - # Map the extracted configurations to each function - block_info_map[function_name] = block_info - grid_info_map[function_name] = grid_info - dynamic_smem_buf_map[function_name] = dynamic_smem_buf - # Store the mappings for use in code generation - self.block_info = block_info_map - self.grid_info = grid_info_map - self.dynamic_smem_buf = dynamic_smem_buf_map - - def update_lib_code(self, code: str): - # Organize function information for code generation - function_informations = {} - for g_var, func in self.mod.functions.items(): - if g_var.name_hint == "main": - continue - function_name = g_var.name_hint - attrs = func.attrs - assert "opt_shapes" in attrs - opt_shapes = attrs["opt_shapes"] - function_informations[function_name] = { - "function_name": function_name, - "opt_shapes": opt_shapes, - "block_info": self.block_info[function_name], - "grid_info": self.grid_info[function_name], - "dynamic_smem_buf": self.dynamic_smem_buf[function_name], - } - - def compare_map_objects(map_obj): - comparable_representation = list(map_obj.values()) - return comparable_representation - - function_informations = dict( - sorted( - function_informations.items(), - key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) - - self.lib_code = code - - # Generate the initialization and dispatch functions - init_func = self.get_cuda_init_func() - host_func = self.create_dispatch_func(code, function_informations) - # Concatenate source code with generated code segments - lib_code = self.source + init_func + host_func - return lib_code - - @property - def prim_func(self): - return self.mod["main"] diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index 22c134b12..f65ce8066 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -39,8 +39,9 @@ def matmul_backend_code_wrap( ) matmul = Matmul(config=matmul_config, enable_tuning=False) backend = TIRWrapper(arch=matmul.arch) - backend.assign_optimized_module(matmul.optimized_func) - wrapped_code = backend.wrap(matmul.get_source(), is_dynamic=isinstance(M, list)) + backend.assign_optimized_module(matmul.optimized_mod) + is_dynamic = (matmul.dynamic_range is not None and len(matmul.optimized_mod.functions) > 1) + wrapped_code = backend.wrap(matmul.get_source(kenrel_only=True), is_dynamic=is_dynamic) assert "void call" in wrapped_code