diff --git a/benchmark/operators/benchmark_bitblas_matmul.py b/benchmark/operators/benchmark_bitblas_matmul.py index 9a40db4ad..5743e112e 100644 --- a/benchmark/operators/benchmark_bitblas_matmul.py +++ b/benchmark/operators/benchmark_bitblas_matmul.py @@ -1,48 +1,123 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import bitblas + from bitblas.utils.target_detector import auto_detect_nvidia_target from bitblas import Matmul, MatmulConfig import argparse -import json +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.utils import apply_and_build -bitblas.set_log_level("DEBUG") # Initialize the parser parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") # Add arguments to the parser + parser.add_argument( "--target", type=str, default=auto_detect_nvidia_target(), - help="Specify the target device for benchmarking.") + help="Specify the target device for benchmarking.", +) parser.add_argument( - "--backend", - type=str, - default="tir", - choices=["tir", "tl"], # Replace with actual modes if applicable - help="Specify the mode for calculating zeros.") - -parser.add_argument("--verbose", type=bool, default=True, help="Enable verbose logging.") - -# [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode] -default_test_shapes = json.dumps([ - # ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] - [ - "MatmulConfig", "Matmul", - [ - 16384, 16384, 16384, "float16", "float16", "float16", "float16", "nt", False, None, - False, False, None - ] - ] -]) + "--M", + type=int, + default=16384, + help="Number of rows in matrix A.", +) parser.add_argument( - "--test_shapes", + "--N", + type=int, + default=16384, + help="Number of rows in matrix A.", +) + +parser.add_argument( + "--K", + type=int, + default=16384, + help="Number of rows in matrix A.", +) + +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=[ + "float16", + "float32", + "float64", + "int32", + "int8", + ], # Assuming these are the valid choices + help="Data type of activation A.", +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=[ + "float16", + "float32", + "float64", + "int32", + "int8", + "int4", + "int2", + "int1", + "nf4", + "fp4_e2m1", + ], # Assuming these are the valid choices + help="Data type of weight W.", +) +parser.add_argument( + "--accum_dtype", type=str, - default=default_test_shapes, - help="JSON string defining test shapes. Example format: '[[\"MatmulConfig\", \"Matmul\", [1,16384,16384,\"float16\",\"int4\",\"float16\",\"float16\",\"nt\",false,null,false,false,null]]]'" + default="float16", + choices=["float16", "int32"], # Assuming these are the valid choices + help="Data type for accumulation.", +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=[ + "float16", + "float32", + "int32", + "int8", + ], # Assuming these are the valid choices + help="Data type for output.", +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], # Assuming these are the valid choices + help="Matrix layout, 'nt' for non-transpose A and transpose W.", +) +parser.add_argument("--with_bias", action="store_true", help="Include bias in the benchmark.") +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization.", +) +parser.add_argument( + "--group_size", type=int, default=None, help="Group size for grouped quantization.") +parser.add_argument("--with_zeros", action="store_true", help="Include zeros in the quantization.") +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=[ + "original", + "rescale", + "quantized", + ], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros.", ) # Parse the arguments @@ -50,18 +125,41 @@ # Assign arguments to variables target = args.target -backend = args.backend -verbose = args.verbose - -parsed_test_shapes = json.loads(args.test_shapes) -name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul} - -test_shapes = [] -for item in parsed_test_shapes: - config_class_name, operator_class_name, input_args = item - config_class = name_to_class[config_class_name] - operator_class = name_to_class[operator_class_name] - test_shapes.append((config_class, operator_class, tuple(input_args))) +M, N, K = args.M, args.N, args.K +group_size = args.group_size +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +test_shapes = [ + # square test + ( + MatmulConfig, + Matmul, + ( + M, + N, + K, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), +] benchmark_sets = [] benchmark_sets.extend(test_shapes) @@ -71,59 +169,39 @@ benchmark_results = {} for config, operator, input_args in benchmark_sets: config = config(*input_args) - print(f"Running benchmark for {operator.__name__} with config: {config}") - op_inst = operator(config, target=target, enable_tuning=True, backend=backend) - kernel_latency = op_inst.profile_latency() - if op_inst.input_transform is not None: - kernel_latency += op_inst.ladder_permutate_a.profile_latency() - - print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency)) - - if verbose: - print(op_inst.scheduled_ir_module) - print(op_inst.get_source()) - - profile_config = { - f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { - "BitBLAS_top20_latency": kernel_latency, - } - } - - benchmark_results.update(profile_config) - -# Define headers for the table -headers = [ - "PrimFunc", - "Input Arguments", - "BitBLAS Top20 Latency", -] - -col_widths = [0, 0, 0] -for config, values in benchmark_results.items(): - args = config.split("-") - func_name = args[0] - input_args = "-".join(args[1:]) - col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0]) - col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1])) - col_widths[2] = max( - max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, - col_widths[2]) - break - -for i, header in enumerate(headers): - headers[i] = header.ljust(col_widths[i]) - -print("".join(headers)) - -print("-" * sum(col_widths)) - -for config, values in benchmark_results.items(): - args = config.split("-") - func_name = args[0] - input_args = "-".join(args[1:]) - row = [ - func_name, - input_args, - f"{values['BitBLAS_top20_latency']:.3f} ms", - ] - print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)])) + matmul = operator(config, target=target, enable_tuning=False) + func = matmul.prim_func + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + static_configs = [] + for config in configs: + static_config = config + static_config.shared_scope = "shared" + static_configs.append(static_config) + dynamic_configs = [] + for config in configs: + dynamic_config = config + dynamic_config.shared_scope = "shared.dyn" + dynamic_configs.append(dynamic_config) + + _, best_static = apply_and_build(func, static_configs, arch, parallel_build=True) + + _, best_dynamic = apply_and_build(func, dynamic_configs, arch, parallel_build=True) + benchmark_results[input_args] = ( + best_static.latency, + best_dynamic.latency, + best_static.latency - best_dynamic.latency, + ) + +for key, value in benchmark_results.items(): + print( + f"Input arguments: {key}, Static latency: {value[0]}, Dynamic latency: {value[1]}, Difference: {value[2]}" + ) diff --git a/benchmark/operators/benchmark_bitblas_op.py b/benchmark/operators/benchmark_bitblas_op.py new file mode 100644 index 000000000..9a40db4ad --- /dev/null +++ b/benchmark/operators/benchmark_bitblas_op.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas.utils.target_detector import auto_detect_nvidia_target +from bitblas import Matmul, MatmulConfig +import argparse +import json + +bitblas.set_log_level("DEBUG") +# Initialize the parser +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking.") + +parser.add_argument( + "--backend", + type=str, + default="tir", + choices=["tir", "tl"], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros.") + +parser.add_argument("--verbose", type=bool, default=True, help="Enable verbose logging.") + +# [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode] +default_test_shapes = json.dumps([ + # ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]] + [ + "MatmulConfig", "Matmul", + [ + 16384, 16384, 16384, "float16", "float16", "float16", "float16", "nt", False, None, + False, False, None + ] + ] +]) + +parser.add_argument( + "--test_shapes", + type=str, + default=default_test_shapes, + help="JSON string defining test shapes. Example format: '[[\"MatmulConfig\", \"Matmul\", [1,16384,16384,\"float16\",\"int4\",\"float16\",\"float16\",\"nt\",false,null,false,false,null]]]'" +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +backend = args.backend +verbose = args.verbose + +parsed_test_shapes = json.loads(args.test_shapes) +name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul} + +test_shapes = [] +for item in parsed_test_shapes: + config_class_name, operator_class_name, input_args = item + config_class = name_to_class[config_class_name] + operator_class = name_to_class[operator_class_name] + test_shapes.append((config_class, operator_class, tuple(input_args))) + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +# fmt:on + +benchmark_results = {} +for config, operator, input_args in benchmark_sets: + config = config(*input_args) + print(f"Running benchmark for {operator.__name__} with config: {config}") + op_inst = operator(config, target=target, enable_tuning=True, backend=backend) + kernel_latency = op_inst.profile_latency() + if op_inst.input_transform is not None: + kernel_latency += op_inst.ladder_permutate_a.profile_latency() + + print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency)) + + if verbose: + print(op_inst.scheduled_ir_module) + print(op_inst.get_source()) + + profile_config = { + f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_latency": kernel_latency, + } + } + + benchmark_results.update(profile_config) + +# Define headers for the table +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Latency", +] + +col_widths = [0, 0, 0] +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0]) + col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1])) + col_widths[2] = max( + max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, + col_widths[2]) + break + +for i, header in enumerate(headers): + headers[i] = header.ljust(col_widths[i]) + +print("".join(headers)) + +print("-" * sum(col_widths)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f"{values['BitBLAS_top20_latency']:.3f} ms", + ] + print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)])) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index c560afd0e..723391777 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -376,8 +376,6 @@ def __init__( # by implementing all the operators in the tl backend. if config.A_dtype in ["int4", "uint4"]: backend = "tl" - if source_format in ["nf"]: - backend = "tir" super().__init__(name, config, target, backend) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/base.py b/bitblas/ops/general_matmul/tilelang/dequantize/base.py index 3d27f0703..ce8582e4a 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/base.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/base.py @@ -73,11 +73,22 @@ def __repr__(self) -> str: return f"{cls_name}({field_str})" def __post_init__(self): - # Validate the matrix transpose settings - assert (self.trans_A is False), "Currently only support Matrix A not transposed" - assert (self.trans_B is True), "Currently only support Matrix B transposed" - # Legalize group_size if self.with_scaling and self.group_size == -1: object.__setattr__(self, "group_size", self.K) + + # Legalization Check + # 1. Validate the matrix transpose settings + assert (self.trans_A is False), "Currently only support Matrix A not transposed" + assert (self.trans_B is True), "Currently only support Matrix B transposed" + # 2. Validate the Fast Decoding settings + if self.fast_decoding: + # TODO(lei): However, I think it's also possible for us to leverage fast + # decoding for nf4 format. + assert self.source_format in {"uint", "int" + }, "Fast Decoding only support uint/int source format" + # 3. Validate the quant config + if self.source_format == "nf": + assert not self.with_scaling, "NF format does not support scaling" + assert not self.fast_decoding, "NF format does not support fast decoding" return diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 963a03b66..4f0f3b0c1 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -124,7 +124,7 @@ def apply_config( A_shape = (M, K) B_shape = (N, K // storage_nbit * num_bits) - LUT_shape = (group_size, K // storage_nbit * num_bits) + LUT_shape = (1 << num_bits,) Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) @@ -211,6 +211,7 @@ def main( zeros_local, dequant_qzeros_local, B_dequantize_local, + LUT, Scale, Zeros, Qzeros, @@ -267,6 +268,7 @@ def _normal_dequant( zeros_local: T.Buffer, dequant_qzeros_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -286,6 +288,8 @@ def _normal_dequant( in_dtype = self.in_dtype group_size = self.group_size storage_dtype = self.storage_dtype + source_format = self.source_format + is_lut = source_format == "nf" storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) (local_scale_size,) = scale_local.shape @@ -296,90 +300,102 @@ def _normal_dequant( def _normal_dequant_impl( compressed_weight_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, ): - if with_scaling: - for v in T.vectorized(0, local_scale_size): - vi = ni - vj = kr * local_size + v - scale_local[v] = scale_buffer[ - pid_n * stride_n + vi, - (k * stride_k + vj) // group_size, - ] - - if with_scaling and with_zeros: - if zeros_mode in ["original", "rescale"]: - for v in T.vectorized(0, local_zeros_size): + if is_lut: + for v in T.serial(0, local_size): + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + "int32" # default index dtype + ) + dequant_weight_local[v] = lut_buffer[index] + else: + if with_scaling: + for v in T.vectorized(0, local_scale_size): vi = ni vj = kr * local_size + v - zeros_local[v] = zeros_buffer[ + scale_local[v] = scale_buffer[ pid_n * stride_n + vi, (k * stride_k + vj) // group_size, ] - elif zeros_mode == "quantized": - for v in T.vectorized(0, local_qzeros_size): - vi = ni - vj = kr * local_size + v - dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - for v in T.serial(0, local_size): - if not with_scaling: - dequant_weight_local[v] = self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - elif not with_zeros: - dequant_weight_local[v] = ( - self._decode_func( + if with_scaling and with_zeros: + if zeros_mode in ["original", "rescale"]: + for v in T.vectorized(0, local_zeros_size): + vi = ni + vj = kr * local_size + v + zeros_local[v] = zeros_buffer[ + pid_n * stride_n + vi, + (k * stride_k + vj) // group_size, + ] + elif zeros_mode == "quantized": + for v in T.vectorized(0, local_qzeros_size): + vi = ni + vj = kr * local_size + v + dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + for v in T.serial(0, local_size): + if not with_scaling: + dequant_weight_local[v] = self._decode_func( num_bits, compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_local[v // group_size]) - elif zeros_mode == "original": - dequant_weight_local[v] = (self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - zeros_local[v // group_size]) * scale_local[v // group_size] - elif zeros_mode == "rescale": - dequant_weight_local[v] = ( - self._decode_func( + ) + elif not with_zeros: + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_local[v // group_size]) + elif zeros_mode == "original": + dequant_weight_local[v] = (self._decode_func( num_bits, compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_local[v // group_size] - zeros_local[v // group_size]) - elif zeros_mode == "quantized": - dequant_weight_local[v] = (self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - zero=dequant_qzeros_local[v // group_size], - dtype=in_dtype, - )) * scale_local[v // group_size] - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + ) - zeros_local[v // group_size]) * scale_local[v // group_size] + elif zeros_mode == "rescale": + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_local[v // group_size] - zeros_local[v // group_size]) + elif zeros_mode == "quantized": + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros_local[v // group_size], + dtype=in_dtype, + )) * scale_local[v // group_size] + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") return _normal_dequant_impl( compressed_weight_local, dequant_weight_local, + lut_buffer, scale_buffer, zeros_buffer, qzeros_buffer, @@ -512,6 +528,7 @@ def dequantize( zeros_local: T.Buffer, dequant_qzeros_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -551,6 +568,7 @@ def dequantize( zeros_local, dequant_qzeros_local, dequant_weight_local, + lut_buffer, scale_buffer, zeros_buffer, qzeros_buffer, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 4ca802608..0fda0b2ad 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -130,6 +130,7 @@ def _normal_dequant( zeros_local: T.Buffer, dequant_qzeros_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -150,6 +151,8 @@ def _normal_dequant( in_dtype = self.in_dtype group_size = self.group_size storage_dtype = self.storage_dtype + source_format = self.source_format + is_lut = source_format == "nf" storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) (local_scale_size,) = scale_local.shape @@ -160,98 +163,110 @@ def _normal_dequant( def _normal_dequant_impl( compressed_weight_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, ): - if with_scaling: - for v in T.vectorized(0, local_scale_size): - # TODO: Enhance all to index2coord - index = i * threads * local_size + tx * local_size + v - vi = index // stride_k - vj = index % stride_k - scale_local[v] = scale_buffer[ - pid_n * stride_n + vi, - (k * stride_k + vj) // group_size, - ] - - if with_scaling and with_zeros: - if zeros_mode in ["original", "rescale"]: - for v in T.vectorized(0, local_zeros_size): + if is_lut: + for v in T.serial(0, local_size): + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + "int32" # default index dtype + ) + dequant_weight_local[v] = lut_buffer[index] + else: + if with_scaling: + for v in T.vectorized(0, local_scale_size): + # TODO: Enhance all to index2coord index = i * threads * local_size + tx * local_size + v vi = index // stride_k vj = index % stride_k - zeros_local[v] = zeros_buffer[ + scale_local[v] = scale_buffer[ pid_n * stride_n + vi, (k * stride_k + vj) // group_size, ] - elif zeros_mode == "quantized": - for v in T.vectorized(0, local_qzeros_size): - index = i * threads * local_size + tx * local_size + v - vi = index // stride_k - vj = index % stride_k - dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - for v in T.serial(0, local_size): - index = i * threads * local_size + tx * local_size + v - vi = index // stride_k - vj = index % stride_k - if not with_scaling: - dequant_weight_local[v] = self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - elif not with_zeros: - # Scaling only - dequant_weight_local[v] = ( - self._decode_func( + if with_scaling and with_zeros: + if zeros_mode in ["original", "rescale"]: + for v in T.vectorized(0, local_zeros_size): + index = i * threads * local_size + tx * local_size + v + vi = index // stride_k + vj = index % stride_k + zeros_local[v] = zeros_buffer[ + pid_n * stride_n + vi, + (k * stride_k + vj) // group_size, + ] + elif zeros_mode == "quantized": + for v in T.vectorized(0, local_qzeros_size): + index = i * threads * local_size + tx * local_size + v + vi = index // stride_k + vj = index % stride_k + dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + for v in T.serial(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // stride_k + vj = index % stride_k + if not with_scaling: + dequant_weight_local[v] = self._decode_func( num_bits, compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_local[v // group_size]) - elif zeros_mode == "original": - dequant_weight_local[v] = (self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - zeros_local[v // group_size]) * scale_local[v // group_size] - elif zeros_mode == "rescale": - dequant_weight_local[v] = ( - self._decode_func( + ) + elif not with_zeros: + # Scaling only + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_local[v // group_size]) + elif zeros_mode == "original": + dequant_weight_local[v] = (self._decode_func( num_bits, compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_local[v // group_size] - zeros_local[v // group_size]) - elif zeros_mode == "quantized": - dequant_weight_local[v] = (self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - zero=dequant_qzeros_local[v // group_size], - dtype=in_dtype, - )) * scale_local[v // group_size] - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + ) - zeros_local[v // group_size]) * scale_local[v // group_size] + elif zeros_mode == "rescale": + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_local[v // group_size] - zeros_local[v // group_size]) + elif zeros_mode == "quantized": + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros_local[v // group_size], + dtype=in_dtype, + )) * scale_local[v // group_size] + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") return _normal_dequant_impl( compressed_weight_local, dequant_weight_local, + lut_buffer, scale_buffer, zeros_buffer, qzeros_buffer, @@ -386,6 +401,7 @@ def dequantize( zeros_local: T.Buffer, dequant_qzeros_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -422,9 +438,9 @@ def dequantize( ) else: return self._normal_dequant(compressed_weight_local, scale_local, zeros_local, - dequant_qzeros_local, dequant_weight_local, scale_buffer, - zeros_buffer, qzeros_buffer, local_size, pid_n, tx, k, i, - stride_n, stride_k, threads) + dequant_qzeros_local, dequant_weight_local, lut_buffer, + scale_buffer, zeros_buffer, qzeros_buffer, local_size, + pid_n, tx, k, i, stride_n, stride_k, threads) @property def num_elems_per_byte(self): @@ -564,7 +580,7 @@ def apply_config( A_shape = (M, K) B_shape = (N, K // storage_nbit * num_bits) - LUT_shape = (group_size, K // storage_nbit * num_bits) + LUT_shape = (1 << num_bits,) Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) @@ -679,6 +695,7 @@ def general_shared_dequant_matmul( zeros_local, dequant_qzeros_local, B_dequantize_local, + LUT, Scale, Zeros, Qzeros, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index e0da07aa7..847eb49fd 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -132,6 +132,7 @@ def _normal_dequant( zeros_local: T.Buffer, dequant_qzeros_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -152,6 +153,8 @@ def _normal_dequant( in_dtype = self.in_dtype group_size = self.group_size storage_dtype = self.storage_dtype + source_format = self.source_format + is_lut = source_format == "nf" storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) (local_scale_size,) = scale_local.shape @@ -162,98 +165,110 @@ def _normal_dequant( def _normal_dequant_impl( compressed_weight_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, ): - if with_scaling: - for v in T.vectorized(0, local_scale_size): - # TODO: Enhance all to index2coord - index = i * threads * local_size + tx * local_size + v - vi = index // stride_k - vj = index % stride_k - scale_local[v] = scale_buffer[ - pid_n * stride_n + vi, - (k * stride_k + vj) // group_size, - ] - - if with_scaling and with_zeros: - if zeros_mode in ["original", "rescale"]: - for v in T.vectorized(0, local_zeros_size): + if is_lut: + for v in T.serial(0, local_size): + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + "int32" # default index dtype + ) + dequant_weight_local[v] = lut_buffer[index] + else: + if with_scaling: + for v in T.vectorized(0, local_scale_size): + # TODO: Enhance all to index2coord index = i * threads * local_size + tx * local_size + v vi = index // stride_k vj = index % stride_k - zeros_local[v] = zeros_buffer[ + scale_local[v] = scale_buffer[ pid_n * stride_n + vi, (k * stride_k + vj) // group_size, ] - elif zeros_mode == "quantized": - for v in T.vectorized(0, local_qzeros_size): - index = i * threads * local_size + tx * local_size + v - vi = index // stride_k - vj = index % stride_k - dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( - num_bits, - qzeros_buffer[ - (k * stride_k + vj) // group_size, - (pid_n * stride_n + vi) // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - for v in T.serial(0, local_size): - index = i * threads * local_size + tx * local_size + v - vi = index // stride_k - vj = index % stride_k - if not with_scaling: - dequant_weight_local[v] = self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - elif not with_zeros: - # Scaling only - dequant_weight_local[v] = ( - self._decode_func( + if with_scaling and with_zeros: + if zeros_mode in ["original", "rescale"]: + for v in T.vectorized(0, local_zeros_size): + index = i * threads * local_size + tx * local_size + v + vi = index // stride_k + vj = index % stride_k + zeros_local[v] = zeros_buffer[ + pid_n * stride_n + vi, + (k * stride_k + vj) // group_size, + ] + elif zeros_mode == "quantized": + for v in T.vectorized(0, local_qzeros_size): + index = i * threads * local_size + tx * local_size + v + vi = index // stride_k + vj = index % stride_k + dequant_qzeros_local[v] = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + for v in T.serial(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // stride_k + vj = index % stride_k + if not with_scaling: + dequant_weight_local[v] = self._decode_func( num_bits, compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_local[v // group_size]) - elif zeros_mode == "original": - dequant_weight_local[v] = (self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - zeros_local[v // group_size]) * scale_local[v // group_size] - elif zeros_mode == "rescale": - dequant_weight_local[v] = ( - self._decode_func( + ) + elif not with_zeros: + # Scaling only + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_local[v // group_size]) + elif zeros_mode == "original": + dequant_weight_local[v] = (self._decode_func( num_bits, compressed_weight_local[v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_local[v // group_size] - zeros_local[v // group_size]) - elif zeros_mode == "quantized": - dequant_weight_local[v] = (self._decode_func( - num_bits, - compressed_weight_local[v // num_elems_per_byte], - v % num_elems_per_byte, - zero=dequant_qzeros_local[v // group_size], - dtype=in_dtype, - )) * scale_local[v // group_size] - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + ) - zeros_local[v // group_size]) * scale_local[v // group_size] + elif zeros_mode == "rescale": + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_local[v // group_size] - zeros_local[v // group_size]) + elif zeros_mode == "quantized": + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros_local[v // group_size], + dtype=in_dtype, + )) * scale_local[v // group_size] + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") return _normal_dequant_impl( compressed_weight_local, dequant_weight_local, + lut_buffer, scale_buffer, zeros_buffer, qzeros_buffer, @@ -388,6 +403,7 @@ def dequantize( zeros_local: T.Buffer, dequant_qzeros_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -424,9 +440,9 @@ def dequantize( ) else: return self._normal_dequant(compressed_weight_local, scale_local, zeros_local, - dequant_qzeros_local, dequant_weight_local, scale_buffer, - zeros_buffer, qzeros_buffer, local_size, pid_n, tx, k, i, - stride_n, stride_k, threads) + dequant_qzeros_local, dequant_weight_local, lut_buffer, + scale_buffer, zeros_buffer, qzeros_buffer, local_size, + pid_n, tx, k, i, stride_n, stride_k, threads) @property def num_elems_per_byte(self): @@ -576,7 +592,7 @@ def apply_config( A_shape = (M, K) B_shape = (N, K // storage_nbit * num_bits) - LUT_shape = (group_size, K // storage_nbit * num_bits) + LUT_shape = (1 << num_bits,) Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) @@ -667,6 +683,7 @@ def general_shared_dequant_matmul( zeros_local, dequant_qzeros_local, B_dequantize_local, + LUT, Scale, Zeros, Qzeros, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index 61b539ee8..98cbc297e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -370,6 +370,12 @@ def general_dequant_matmul( T.clear(C_frag) + if enable_split_k: # noqa: SIM102 + if bz == 0: + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = T.cast(0, out_dtype) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages): T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) @@ -394,6 +400,7 @@ def general_dequant_matmul( zeros_local, dequant_qzeros_local, B_dequantize_local, + LUT, Scale, Zeros, Qzeros, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index dcf235d18..56d052eda 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -113,7 +113,7 @@ def apply_config( micro_size_y, micro_size_k // num_elems_per_byte, ) - LUT_shape = (group_size, K // num_elems_per_byte) + LUT_shape = (1 << num_bits,) Scale_shape = (N, K // group_size) Zeros_shape = (N, K // group_size) Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) @@ -148,7 +148,7 @@ def apply_config( micro_size_y, ) - shared_scope = "shared" + shared_scope = "shared.dyn" import_source: Optional[str] = None func_name: str = "" @@ -200,10 +200,6 @@ def check_require_cache(): cache_write_required = check_require_cache() - vec_load_qb = 16 - if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: - vec_load_qb = block_N * block_K // num_elems_per_byte // threads - @T.prim_func def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), @@ -241,6 +237,12 @@ def general_dequant_matmul( T.clear(C_frag) + if enable_split_k: # noqa: SIM102 + if bz == 0: + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = T.cast(0, out_dtype) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages): if is_a_smooth: @@ -311,6 +313,7 @@ def general_dequant_matmul( self._normal_dequant( B_frag, B_dequantize_frag, + LUT, Scale, Zeros, Qzeros, @@ -393,6 +396,7 @@ def _normal_dequant( self, compressed_weight_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -414,6 +418,8 @@ def _normal_dequant( in_dtype = self.in_dtype group_size = self.group_size storage_dtype = self.storage_dtype + source_format = self.source_format + is_lut = source_format == "nf" storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) micro_size_k = mma_emitter.micro_size_k @@ -423,6 +429,7 @@ def _normal_dequant( def _normal_dequant_impl( compressed_weight_local: T.Buffer, dequant_weight_local: T.Buffer, + lut_buffer: T.Buffer, scale_buffer: T.Buffer, zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, @@ -444,65 +451,77 @@ def _normal_dequant_impl( matrix_name="B", group_size=group_size, ) - if not with_scaling: - dequant_weight_local[j * local_size + v] = self._decode_func( + if is_lut: + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( num_bits, compressed_weight_local[j * local_size // num_elems_per_byte + v // num_elems_per_byte], v % num_elems_per_byte, - dtype=in_dtype, + "int32" # default index dtype ) - elif not with_zeros: - dequant_weight_local[j * local_size + v] = ( - self._decode_func( + dequant_weight_local[j * local_size + v] = lut_buffer[index] + else: + if not with_scaling: + dequant_weight_local[j * local_size + v] = self._decode_func( num_bits, compressed_weight_local[j * local_size // num_elems_per_byte + v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_buffer[remaped_i, remaped_j]) - elif zeros_mode == "original": - dequant_weight_local[j * local_size + v] = (self._decode_func( - num_bits, - compressed_weight_local[j * local_size // num_elems_per_byte + - v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, - ) - zeros_buffer[remaped_i, remaped_j]) * scale_buffer[remaped_i, remaped_j] - elif zeros_mode == "rescale": - dequant_weight_local[j * local_size + v] = ( - self._decode_func( + ) + elif not with_zeros: + dequant_weight_local[j * local_size + v] = ( + self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[remaped_i, remaped_j]) + elif zeros_mode == "original": + dequant_weight_local[j * local_size + v] = (self._decode_func( num_bits, compressed_weight_local[j * local_size // num_elems_per_byte + v // num_elems_per_byte], v % num_elems_per_byte, dtype=in_dtype, - ) * scale_buffer[remaped_i, remaped_j] - - zeros_buffer[remaped_i, remaped_j]) - elif zeros_mode == "quantized": - dequant_qzeros = _tir_packed_to_unsigned_convert( - storage_type, storage_nbit)( + ) - zeros_buffer[remaped_i, remaped_j]) * scale_buffer[remaped_i, + remaped_j] + elif zeros_mode == "rescale": + dequant_weight_local[j * local_size + v] = ( + self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[remaped_i, remaped_j] - + zeros_buffer[remaped_i, remaped_j]) + elif zeros_mode == "quantized": + dequant_qzeros = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + remaped_i, + remaped_j // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + + dequant_weight_local[j * local_size + v] = (self._decode_func( num_bits, - qzeros_buffer[ - remaped_i, - remaped_j // num_elems_per_byte, - ], - (pid_n * stride_n + vi) % num_elems_per_byte, - dtype=storage_dtype, - ) - - dequant_weight_local[j * local_size + v] = (self._decode_func( - num_bits, - compressed_weight_local[j * local_size // num_elems_per_byte + - v // num_elems_per_byte], - v % num_elems_per_byte, - zero=dequant_qzeros, - dtype=in_dtype, - )) * scale_buffer[remaped_i, remaped_j] + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros, + dtype=in_dtype, + )) * scale_buffer[remaped_i, remaped_j] return _normal_dequant_impl( compressed_weight_local, dequant_weight_local, + lut_buffer, scale_buffer, zeros_buffer, qzeros_buffer, diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index 443da90eb..aad3ff955 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -120,12 +120,13 @@ def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): dtype = shared_buf.dtype shape = shared_buf.shape - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if is_smooth or not can_swizzle: + can_swizzle = shape[-1] * DataType(dtype).bits % 512 == 0 + if is_smooth or (not can_swizzle): return T.Layout(shape, lambda *args: args) - def transform_func(i, j): + def transform_func(*args): + i, j = args[-2:] new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] + return [*args[:-2], new_warp_i, new_warp_j] return T.Layout(shape, transform_func) diff --git a/testing/python/operators/test_general_matmul_ops_nf4.py b/testing/python/operators/test_general_matmul_ops_nf4.py new file mode 100644 index 000000000..4b3e2b9d7 --- /dev/null +++ b/testing/python/operators/test_general_matmul_ops_nf4.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +import bitblas.testing +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + print(matmul.scheduled_ir_module) + + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5 + + _, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] + maxq = 2**(bit - 1) + weight_tensor = torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda() + output_tensor = torch.empty(output_shape, dtype=torch.float16).cuda() + + intweight = weight_tensor + lut = matmul.lut + assert lut is not None + ref_weight = torch.zeros_like(intweight, dtype=torch.float16) + for j in range(intweight.shape[0]): + for k in range(intweight.shape[1]): + ref_weight[j, k] = lut[intweight[j, k]] + + intweight = intweight.cpu().to(torch.int8) + ref_result = torch.matmul(input_tensor, ref_weight.t().to(torch.float16)) + permuted_inputs = [] + permuted_inputs.append(input_tensor) + permuted_inputs.append(matmul.weight_transform(intweight.cpu()).cuda()) + + permuted_inputs.append(output_tensor) + matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) + print(permuted_inputs[-1]) + print(ref_result) + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + + +def test_matmul_torch_forward(): + matmul_torch_forward(1, 1024, 1024, "float16", "nf4", "float16", "float16") + matmul_torch_forward(768, 768, 768, "float16", "nf4", "float16", "float16") + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py deleted file mode 100644 index 077d2ec48..000000000 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import bitblas -import bitblas.testing -from bitblas.ops.general_matmul_splitk import MatmulWithSplitK, MatmulConfigWithSplitK - - -def get_codegen_result(ops): - code = ops.get_source() - return code - - -# fmt: off -def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, - group_size, with_scaling, with_zeros, zeros_mode): - - matmul_config = MatmulConfigWithSplitK( - M=M, - N=N, - K=K, - A_dtype=A_dtype, - W_dtype=W_dtype, - accum_dtype=accum_dtype, - out_dtype=out_dtype, - layout=layout, - with_bias=with_bias, - group_size=group_size, - with_scaling=with_scaling, - with_zeros=with_zeros, - zeros_mode=zeros_mode, - propagate_a=False, - propagate_b=False, - ) - matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) - assert get_codegen_result(matmul) - - -def test_matmul_codegen_default(): - matmul_codegen_default(1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, - -1, False, False, None) - matmul_codegen_default(16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, - -1, False, False, None) - - -def matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, - layout, with_bias, group_size, with_scaling, with_zeros, - zeros_mode): - import torch - torch.random.manual_seed(0) - matmul_config = MatmulConfigWithSplitK( - k_split=SplitK, - M=M, - N=N, - K=K, - A_dtype=A_dtype, - W_dtype=W_dtype, - accum_dtype=accum_dtype, - out_dtype=out_dtype, - layout=layout, - with_bias=with_bias, - group_size=group_size, - with_scaling=with_scaling, - with_zeros=with_zeros, - zeros_mode=zeros_mode, - propagate_a=False, - propagate_b=False, - ) - matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) - - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) - inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) - - output_bitblas = matmul.forward(*inputs) - output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) - bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-2, atol=1e-1, max_mismatched_ratio=1e-2) - - -def test_matmul_torch_forward_consistent(): - matmul_torch_forward_consistent(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", - "nt", False, -1, False, False, None) - matmul_torch_forward_consistent(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", - "nt", False, -1, False, False, None) - - -def matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, - with_bias, group_size, with_scaling, with_zeros, zeros_mode): - import torch - torch.random.manual_seed(0) - matmul_config = MatmulConfigWithSplitK( - k_split=SplitK, - M=[1, 16], - N=N, - K=K, - A_dtype=A_dtype, - W_dtype=W_dtype, - accum_dtype=accum_dtype, - out_dtype=out_dtype, - layout=layout, - with_bias=with_bias, - group_size=group_size, - with_scaling=with_scaling, - with_zeros=with_zeros, - zeros_mode=zeros_mode, - propagate_a=False, - propagate_b=False, - ) - matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) - - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - - def map_torch_type(intype): - - typemap = { - 'e4m3_float8': torch.float8_e4m3fn, - 'e5m2_float8': torch.float8_e5m2, - } - if intype in typemap: - return typemap[intype] - else: - return getattr(torch, intype) - - numpytype_a = map_torch_type(A_dtype) - numpytype_b = map_torch_type(W_dtype) - - torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() - torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() - ref_out = torch.matmul(torch_a.to(torch.float32), - torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul( - torch_a.to(torch.float32), torch_b.to(torch.float32)) - ref_out = ref_out.to(torch.float16) - bitblas_out = torch.empty_like(ref_out) - matmul.forward(torch_a, torch_b) - print("torch_ref_out", ref_out) - print("bitblas_out", bitblas_out) - - matmul.forward(torch_a, torch_b, output=bitblas_out) - print("torch_ref_out", ref_out) - print("bitblas_out", bitblas_out) - - matmul.forward(torch_a, torch_b, output=bitblas_out) - print("torch_ref_out", ref_out) - print("bitblas_out", bitblas_out) - - torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1) - - -@bitblas.testing.requires_cuda_compute_version(8, 9) -def test_matmul_torch_forward_fp8e4m3(): - matmul_torch_forward_fp8e4m3(1, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", - "float16", "nt", False, -1, False, False, None) - matmul_torch_forward_fp8e4m3(4, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", - "float16", "nt", False, -1, False, False, None) - - -# fmt: on -if __name__ == "__main__": - bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index 6fd669789..42e449056 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -413,4 +413,5 @@ def test_assert_tl_matmul_weight_only_transform(): if __name__ == "__main__": - bitblas.testing.main() + # bitblas.testing.main() + test_assert_tl_matmul_weight_only_transform() diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index d343d8078..660aaad89 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -1116,4 +1116,5 @@ def test_assert_tl_matmul_with_ladder_input_weight_transform(): if __name__ == "__main__": - bitblas.testing.main() + # bitblas.testing.main() + test_assert_tl_matmul_with_ladder_input_weight_transform()