diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index ef2dc8587..79ace6194 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -328,9 +328,8 @@ def _score(node, thread): # small is better # TODO: This is a dummy mul which avoid reusing some shared memory. # Should be removed in the future. if td.smem_cost > (self.arch.smem_cap): - debug_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ - " use dynamic shared memory." - logger.debug(debug_message) + # Tile Dict: {td.output_tile} Shared memory exceeds the static capacity + # use dynamic shared memory. codegen_dict.shared_scope = "shared.dyn" codegen_dict.shared_scope = "shared.dyn" diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index e3eba9883..85a75601a 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -19,9 +19,9 @@ class TLCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", "float16": "half_t", - "bfloat16": "__nv_bfloat16", - "e4m3_float8": "__nv_fp8_e4m3", - "e5m2_float8": "__nv_fp8_e5m2", + "bfloat16": "bfloat16_t", + "e4m3_float8": "float_e4m3_t", + "e5m2_float8": "float_e5m2_t", "float64": "double", "int64": "int64_t", "int32": "int", diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index b427a3c6f..a85be8d51 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -274,7 +274,7 @@ def forward(self, A, output=None): self.init_params() args = [A_void, *self.q_params] if output is None: - output = torch.empty( + output = torch.zeros( A.shape[:-1] + (self.out_features,), dtype=getattr(torch, self.bitblas_matmul.out_dtype), device=A.device) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 811a2a61e..c560afd0e 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -15,6 +15,7 @@ from .tilelang.dense import select_scheduler as consistent_scheduler from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils import retrieve_func_from_module from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass from ..ladder_permutate import LadderPermutate, LadderPermutateConfig @@ -350,7 +351,7 @@ def __init__( target: Optional[Union[str, Target]] = None, enable_tuning: bool = True, from_database: bool = False, - backend: str = "tir", + backend: str = "tl", ): # if from database, we should disable default schedule # to save compilation time @@ -370,8 +371,14 @@ def __init__( self.bit = bit # This is a hack to support the int4 and uint4 + # legalize the backend (hacky implementation) + # TODO(lei): In future release we should remove + # by implementing all the operators in the tl backend. if config.A_dtype in ["int4", "uint4"]: backend = "tl" + if source_format in ["nf"]: + backend = "tir" + super().__init__(name, config, target, backend) if source_format == "int" and self.with_zeros: @@ -383,13 +390,13 @@ def __init__( if target.kind.name not in ("cuda", "hip"): raise ValueError("Currently only support cuda and hip target") - self.dispatch_tir(target, from_database, source_format, enable_tuning) + self.dispatch(target, from_database, source_format, enable_tuning) - def dispatch_tir(self, - target: Target, - from_database: bool = False, - source_format: str = "uint", - enable_tuning: bool = True): + def dispatch(self, + target: Target, + from_database: bool = False, + source_format: str = "uint", + enable_tuning: bool = True): if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} @@ -638,7 +645,21 @@ def post_process(self, code: str) -> str: return code def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + prim_func = self.prim_func + + # retrieve from tilelang backend + if prim_func is None and self.scheduled_ir_module is not None: + prim_func = retrieve_func_from_module(self.scheduled_ir_module) + + if prim_func is None and self.is_tilelang_backend(): + # If from_database and from tilelang backend, we should construct a default module + self._update_optimized_mod(self.scheduler_with_default(self.scheduler)) + prim_func = retrieve_func_from_module(self.scheduled_ir_module) + + if prim_func is not None: + return [int(i) for i in prim_func.buffer_map[prim_func.params[1]].shape] + + raise ValueError("The weight shape is not available.") def transform_weight(self, weight, scale=None, zeros=None, bias=None): """ diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index 190360c8f..5891acb14 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -111,8 +111,8 @@ def apply_config( A_shape = (M, K) B_shape = (N, K) - C_shape = (M, N) Bias_shape = (N,) + C_shape = (M, N) dp4a_size = 4 use_dp4a = in_dtype == "int8" and accum_dtype == "int32" @@ -121,8 +121,8 @@ def apply_config( def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( @@ -186,5 +186,4 @@ def __post_init__(self): # Validate the matrix transpose settings assert self.trans_A is False, "Currently only support Matrix A not transposed" assert self.trans_B is True, "Currently only support Matrix B transposed" - assert self.with_bias is False, "Currently only support without bias" return diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index f81dd3d1d..8b197e569 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -242,7 +242,6 @@ def __post_init__(self): # Validate the matrix transpose settings assert self.trans_A is False, "Currently only support Matrix A not transposed" assert self.trans_B is True, "Currently only support Matrix B transposed" - assert self.with_bias is False, "Currently only support without bias" assert self.input_transform_kind == TransformKind.NonTransform, "Currently only support NonTransform for input" return diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index c80d10fbb..5c44daae3 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -57,7 +57,7 @@ def check_require_cache(self) -> bool: conditions: List[bool] = [] conditions.append(False) - # Bias Add should be done in shared memory + # Bias Add should be performed in shared memory conditions.append(with_bias) return any(conditions) # Always set to False Currently @@ -172,6 +172,8 @@ def apply_config( self.accum_dtype, ) + with_bias = self.with_bias + shared_scope = "shared.dyn" block_M = block_size_x * thread_row_tiles @@ -183,6 +185,7 @@ def apply_config( C_shape = (M, N) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) + Bias_shape = (N,) threads = thread_row_tiles * thread_col_tiles local_size_a = block_M // thread_row_tiles @@ -198,6 +201,7 @@ def apply_config( def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( @@ -249,14 +253,22 @@ def main( else: for dp4a_idx in T.serial(dp4a_size): C_local[i * local_size_b + j] += ( - A_local[i, mk * dp4a_size + dp4a_idx] * - B_local[j, mk * dp4a_size + dp4a_idx]) - - for i, j in T.grid(local_size_a, local_size_b): - C[ - by * block_M + warp_m * local_size_a + i, - bx * block_N + warp_n * local_size_b + j, - ] = C_local[i * local_size_b + j] + A_local[i, + mk * dp4a_size + dp4a_idx].astype(accum_dtype) * + B_local[j, + mk * dp4a_size + dp4a_idx].astype(accum_dtype)) + + if with_bias: + for i, j in T.grid(local_size_a, local_size_b): + C_local[i * local_size_b + j] += Bias[bx * block_N + warp_n * local_size_b + + j] + + for i in T.serial(local_size_a): + for j in T.vectorized(local_size_b): + C[ + by * block_M + warp_m * local_size_a + i, + bx * block_N + warp_n * local_size_b + j, + ] = C_local[i * local_size_b + j] return self.post_process(main) @@ -264,6 +276,5 @@ def __post_init__(self): # Validate the matrix transpose settings assert self.trans_A is False, "Currently only support Matrix A not transposed" assert self.trans_B is True, "Currently only support Matrix B transposed" - assert self.with_bias is False, "Currently only support without bias" return diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 4e56a15f3..f2bb5bd4d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -69,7 +69,7 @@ def check_require_cache(self) -> bool: conditions: List[bool] = [] conditions.append(False) - # Bias Add should be done in shared memory + # Bias Add should be performed in shared memory conditions.append(with_bias) return any(conditions) # Always set to False Currently @@ -227,8 +227,8 @@ def apply_config( def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): @@ -444,16 +444,15 @@ def apply_config( chunk=chunk, ) - # cache_write_required = self.check_require_cache() - cache_write_required = False + cache_write_required = self.check_require_cache() # Define the main kernel using the generated configuration @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( @@ -667,8 +666,8 @@ def apply_config( def main( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( @@ -867,6 +866,8 @@ def apply_config( in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype assert in_dtype == "int4", "Only support int4 input" assert accum_dtype == "int32", "Only support int32 accumulation" + with_bias = self.with_bias + assert not with_bias, "Currently do not support bias" storage_dtype = "int8" # Calculate the micro size per warp using a helper function @@ -879,6 +880,8 @@ def apply_config( # Define the shapes of matrices and shared memory buffers A_shape = (M, K) B_shape = (N, K) + Bias_shape = (N,) + C_shape = (M, N) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) C_shared_shape = ( @@ -918,7 +921,8 @@ def apply_config( def main( A: T.Buffer(A_shape, storage_dtype), B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), out_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 0d838661c..963a03b66 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -168,8 +168,8 @@ def main( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, n_partition), diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 8fcb53f7f..4ca802608 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -624,8 +624,8 @@ def general_shared_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index 4c91bc144..e0da07aa7 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -623,8 +623,8 @@ def general_shared_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index ebbdafcc6..61b539ee8 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -212,13 +212,14 @@ def apply_config( assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" + shared_scope = "shared.dyn" M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" trans_A, trans_B = self.trans_A, self.trans_B assert trans_A is False, "Dequantize only implement for trans_A=False currently" - assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + assert trans_B is True, "Dequantize only implement for trans_B=True currently" in_dtype, out_dtype, accum_dtype = ( self.in_dtype, @@ -334,16 +335,17 @@ def general_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor, threads=threads) as (bx, by, bz): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) - C_shared = T.alloc_shared(C_shared_shape, out_dtype) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) @@ -398,7 +400,7 @@ def general_dequant_matmul( local_size, bx, tx, - ko, + bz * T.ceildiv(splitK, block_K) + ko, i, block_N, block_K, @@ -443,25 +445,37 @@ def general_dequant_matmul( thread_bindings=tx, ) - if with_bias: - for i, j in T.Parallel(block_M, block_N): - C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] += Bias[bx * block_N + j] - - # Store results from shared memory to global memory - if enable_split_k: - for i, j in T.Parallel(block_M, block_N // 2): - T.atomic_addx2( - C[by * block_M + i, bx * block_N + j * 2], C_shared[ + if with_bias: # noqa: SIM102 + if bz == 0: # as bz is the k-dim, otherwise, bias will be added multiple times + for i, j in T.Parallel(block_M, block_N): + C_shared[ i // micro_size_x, j // micro_size_y, i % micro_size_x, j % micro_size_y, - ]) + ] += Bias[bx * block_N + j] + + # Store results from shared memory to global memory + if enable_split_k: + if DataType(out_dtype).bits == 16: + for i, j in T.Parallel(block_M, block_N // 2): + m, n = by * block_M + i, bx * block_N + j * 2 + T.atomic_addx2( + C[m, n], C_shared[ + i // micro_size_x, + (j * 2) // micro_size_y, + i % micro_size_x, + (j * 2) % micro_size_y, + ]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add( + C[by * block_M + i, bx * block_N + j], C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ]) else: for i, j in T.Parallel(block_M, block_N): C[by * block_M + i, bx * block_N + j] = C_shared[ diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index eb1b5c93e..dcf235d18 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -212,8 +212,8 @@ def general_dequant_matmul( Scale: T.Buffer(Scale_shape, in_dtype), Qzeros: T.Buffer(Qzeros_shape, storage_dtype), Zeros: T.Buffer(Zeros_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor, @@ -302,7 +302,7 @@ def general_dequant_matmul( bx, tx, mma_emitter, - ko, + bz * T.ceildiv(splitK, block_K) + ko, ki, block_N, block_K, @@ -319,7 +319,7 @@ def general_dequant_matmul( bx, tx, mma_emitter, - ko, + bz * T.ceildiv(splitK, block_K) + ko, ki, block_N, block_K, @@ -336,14 +336,15 @@ def general_dequant_matmul( thread_bindings=tx, ) - if with_bias: - for i, j in T.Parallel(block_M, block_N): - C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] += Bias[j] + if with_bias: # noqa: SIM102 + if bz == 0: # as bz is the k-dim, otherwise, bias will be added multiple times + for i, j in T.Parallel(block_M, block_N): + C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] += Bias[bx * block_N + j] # Store results from shared memory to global memory if enable_split_k: diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 8df3b2a17..ea99490c5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -236,7 +236,7 @@ def tvm_callback_hip_postproc(code, _): raise ValueError(f"Unsupported target: {self.arch}") return rt_mod - def scheduler_with_default(self, scheduler: BaseScheduler): + def scheduler_with_default(self, scheduler: BaseScheduler) -> Optional[IRModule]: scheduled_ir_module = IRModule.from_expr(scheduler.with_default_config()) if scheduled_ir_module is not None: self.ir_module = scheduled_ir_module @@ -501,7 +501,10 @@ def _select_scheduler(self) -> Optional[BaseScheduler]: raise NotImplementedError @property - def prim_func(self): + def prim_func(self) -> Optional[PrimFunc]: + if self.ir_module is None: + return None + if len(self.ir_module.get_global_vars()) == 1: return self.ir_module[self.ir_module.get_global_vars()[0]] elif "main" in self.ir_module: diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index c97ac8f0c..84666951a 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -38,7 +38,8 @@ def matmul_backend_code_wrap( out_dtype=out_dtype, with_bias=with_bias, ) - matmul = Matmul(config=matmul_config, enable_tuning=False) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + assert matmul.is_tir_backend(), "Backend should be TIR" backend = TIRWrapper(arch=matmul.arch) backend.assign_optimized_module(matmul.scheduled_ir_module) is_dynamic = ( diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 3bf32044b..286aff237 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -37,7 +37,8 @@ def correctness_consistent(m, in_features, out_features, bias): input_data = torch.randn(m, in_features, dtype=torch.float16).cuda() output_torch = linear_torch(input_data) output_bitblas = linear_bitblas(input_data) - + print(output_torch) + print(output_bitblas) bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) @@ -154,17 +155,15 @@ def correctness_weight_only_dequantize( with torch.no_grad(): output_bitblas = linear_bitblas(inputs[0]) - try: - rtol = 1e0 - atol = 1e0 - if zeros_mode == "original": - rtol = 1e2 - atol = 1e2 - torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol) - except AssertionError as e: - print(ref_result, output_bitblas) - print(f"Failed with {e}") - raise e + + rtol = 1e0 + atol = 1e0 + if zeros_mode == "original": + rtol = 1e2 + atol = 1e2 + print(output_bitblas) + print(ref_result) + torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol) def test_correctness_weight_only_dequantize(): diff --git a/testing/python/module/test_repack_from_gptq.py b/testing/python/module/test_repack_from_gptq.py index 3357bd336..332fe846b 100644 --- a/testing/python/module/test_repack_from_gptq.py +++ b/testing/python/module/test_repack_from_gptq.py @@ -4,19 +4,14 @@ import bitblas.testing import torch -try: - import auto_gptq # noqa: F401 -except ImportError as e: - raise ImportError("Please install auto-gptq by running `pip install auto-gptq`") from e - -from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( - QuantLinear as CudaOldQuantLinear,) - torch.manual_seed(0) bitblas.set_log_level("DEBUG") def assert_output_with_gptq(m, in_features, out_features, group_size): + from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( + QuantLinear as CudaOldQuantLinear,) + if group_size == -1: group_size = in_features _, linear, s, _ = bitblas.quantization.gen_quant4(in_features, out_features, group_size) @@ -67,6 +62,7 @@ def assert_output_with_gptq(m, in_features, out_features, group_size): torch.testing.assert_close(res_bitblas, res_cuda_old, rtol=1e-0, atol=1e-1) +@bitblas.testing.requires_package("auto_gptq") def test_assert_output_with_gptq(): assert_output_with_gptq(1, 256, 256, 64) assert_output_with_gptq(1, 256, 256, -1) diff --git a/testing/python/operators/test_general_matmul_bf16.py b/testing/python/operators/test_general_matmul_bf16.py index e083d5c43..508130d09 100644 --- a/testing/python/operators/test_general_matmul_bf16.py +++ b/testing/python/operators/test_general_matmul_bf16.py @@ -95,7 +95,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp else: raise NotImplementedError - inputs.append(torch.rand(output_shape, dtype=getattr(torch, out_dtype)).cuda()) + inputs.append(torch.zeros(output_shape, dtype=getattr(torch, out_dtype)).cuda()) intweight = inputs[1] intweight = intweight.cpu().to(torch.int8) diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops_backend.py similarity index 100% rename from testing/python/operators/test_general_matmul_ops.py rename to testing/python/operators/test_general_matmul_ops_backend.py diff --git a/testing/python/operators/test_general_matmul_ops_backend_tir.py b/testing/python/operators/test_general_matmul_ops_backend_tir.py new file mode 100644 index 000000000..5ba91497a --- /dev/null +++ b/testing/python/operators/test_general_matmul_ops_backend_tir.py @@ -0,0 +1,287 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +import bitblas.testing +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +# fmt: off +def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + assert get_codegen_result(matmul) + + +def test_matmul_codegen_default(): + matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, + -1, False, False, None), + matmul_codegen_default(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, + False, None), + matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, + False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, + False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, True, "original"), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, + False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, True, "original"), + + +def matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + matmul.hardware_aware_finetune(topk=10) + assert get_codegen_result(matmul) + + +def test_matmul_finetune(): + matmul_finetune(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), + matmul_finetune(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_finetune(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, + None), + matmul_finetune(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, + None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, + False, None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, + False, None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, + False, None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, + True, "original"), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, + False, None), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, + False, None), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, + False, None), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, + True, "original"), + + +def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + import torch + torch.random.manual_seed(0) + import numpy as np + from bitblas.quantization import general_compress + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + source_format, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] + maxq = 2**(bit - 1) + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(inputs[0], + (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) + if with_bias: + ref_result = ref_result + bias + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_inputs.append(matmul.weight_transform(intweight.cpu()).cuda()) + else: + permuted_inputs.append(intweight) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append( + torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + elif zeros_mode == "rescale": + original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + if with_bias: + permuted_inputs.append(bias) + permuted_inputs.append(inputs[2]) + matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) + print(permuted_inputs[-1]) + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + + +def test_matmul_torch_forward(): + matmul_torch_forward(1, 1024, 1024, "float16", "int4", "float16", "float16", "nt", None, None, + None, None, None) + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + False, False, None) + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, + False, False, None), + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, False, None), + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, True, "original") + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + False, False, None) + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, + False, False, None) + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, False, None) + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, + True, True, "original") + + +def matmul_transform_weight( + M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_bias, +): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + with_bias=with_bias, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + + _, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] + maxq = 2**(bit - 1) + + input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() + intweight_tensor = torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda() + output_tensor = torch.rand(output_shape, dtype=torch.float16).cuda() + + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(input_tensor, intweight_tensor.t().to(torch.float16)) + if with_bias: + ref_result = ref_result + bias + + bitblas_inputs = [input_tensor] + intweight_tensor = matmul.transform_weight(intweight_tensor) + bitblas_inputs.append(intweight_tensor) + if with_bias: + bitblas_inputs.append(bias) + output_tensor = matmul(*bitblas_inputs) + torch.testing.assert_close(output_tensor, ref_result, rtol=1e2, atol=1e0) + + +def test_matmul_transform_weight(): + matmul_transform_weight(1, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_transform_weight(1, 768, 768, "float16", "int4", "float16", "float16", False) + matmul_transform_weight(768, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_transform_weight(768, 768, 768, "float16", "int4", "float16", "float16", False) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_ops_backend_tl.py b/testing/python/operators/test_general_matmul_ops_backend_tl.py index 9ab60c2bd..62c11d470 100644 --- a/testing/python/operators/test_general_matmul_ops_backend_tl.py +++ b/testing/python/operators/test_general_matmul_ops_backend_tl.py @@ -163,7 +163,6 @@ def matmul_torch_forward_dequant(M, propagate_b=propagate_b, ) matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") - input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) output_shape = (M, N) diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index f5d82e2b1..077d2ec48 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -76,7 +76,7 @@ def matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dty output_bitblas = matmul.forward(*inputs) output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) - torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-2, atol=1e-1, max_mismatched_ratio=1e-2) def test_matmul_torch_forward_consistent():