diff --git a/3rdparty/tvm b/3rdparty/tvm index c441882e2..6daecacc7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c441882e2372deeb33d0eaefd62a133d482ac669 +Subproject commit 6daecacc73c8c8fdea1b9732891e1d4a5ebbf818 diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 91e88133c..a1bc95f39 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -39,9 +39,10 @@ from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 -from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 from .module import Linear # noqa: F401 +import warnings +import functools import logging from tqdm import tqdm @@ -89,4 +90,26 @@ def _init_logger(): _init_logger() + +def deprecated(reason): + """ + This is a decorator which can be used to mark functions as deprecated. + It will result in a warning being emitted when the function is used. + """ + + def decorator(func): + + @functools.wraps(func) + def new_func(*args, **kwargs): + warnings.warn( + f"Call to deprecated function {func.__name__} ({reason}).", + category=DeprecationWarning, + stacklevel=2) + return func(*args, **kwargs) + + return new_func + + return decorator + + __version__ = "0.0.1.dev13" diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 9e6fff9ee..468498fbd 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -117,83 +117,92 @@ def _check_small_tile(td: TileDict): return True return False - if not _check_small_tile(td): - return None + if _check_small_tile(td): + + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + if any([v == [] for v in all_steps.values()]): + return rstep + + def _shared_memory_usage(td: TileDict): + return node.footprint(td.output_tile, new_rstep_map, + td.tensor_strides_map[node]) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis + } + score = 0 + shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, input_buffer in enumerate(input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] + for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = _shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis + } + return rstep - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) - rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep - def _optimize(node, rstep): - all_steps = self.get_node_reduce_step_candidates(node) - # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] - for k in all_steps: - all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) - if any([v == [] for v in all_steps.values()]): - return rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) - def _shared_memory_usage(td: TileDict): - return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) + if self.block_reduction_depth is not None: - def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } - score = 0 - shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) - input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) - for i, input_buffer in enumerate(input_buffers): - score += coalesced_factor(shape[i], input_buffer.shape) - return score - - def _enlarge(rstep_id): - candidates = [] - for ax in rstep_id: - if rstep_id[ax] + 1 == len(all_steps[ax]): - continue - r = rstep_id.copy() - r[ax] += 1 - candidates.append((r, _score(r))) - if len(candidates) == 0: - return None - return max(candidates, key=lambda x: x[1])[0] - - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } - new_rstep_map = rstep_map.copy() - while True: - new_rstep_id = _enlarge(cur_rstep_id) - if new_rstep_id is None: - break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } - old_rstep_map = td.rstep_map - td.rstep_map = new_rstep_map - smem_usage, _ = _shared_memory_usage(td) - td.rstep_map = old_rstep_map - if smem_usage > smem_limit: - break - else: - cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } - return rstep + def _expand_with_tags(rstep): + new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()} + return new_rstep + + rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _expand_with_tags(rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map - for node in self.ordered_nodes: - if len(node.raxis) > 0: - rstep = _optimize(node, rstep_map) - rstep_map = rstep - - # if is_block_reduction: - # # If block reduction, we should constrain the max value is 64 - # # Otherwise it will introduce an issue of cuda invalid args. - # MAX_REDUCE_K = 64 - # for k in rstep_map: - # rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) - td.rstep_map = rstep_map - td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) return def get_node_reduce_step_candidates(self, node): @@ -318,12 +327,15 @@ def _score(node, thread): # small is better # smem capacity # TODO: This is a dummy mul which avoid reusing some shared memory. # Should be removed in the future. - if td.smem_cost > (self.arch.smem_cap * 1.3): + if td.smem_cost > (self.arch.smem_cap): info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ " use dynamic shared memory." logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" + # Or assume we always use shared memory + # codegen_dict.shared_scope = "shared.dyn" + codegen_dict.complete_config(node) codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) codegen_dict.arch = self.arch diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 210c560a1..1d0889fa3 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -622,14 +622,16 @@ def check_last_trait(region: List[Range]): # Analysis Block Reduction Optimization # Currently, we only support block reduction depth 2 for small M # When the func is a dequantize like ops, we should consider the M + require_block_reduce = False if hasattr(func.attrs, "dequantize_info"): for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] if isinstance(M, tir.IntImm) and M <= 128: - tags["block_reduction_depth"] = 2 + require_block_reduce = True break - + if require_block_reduce and check_sm_version(target.arch) == 80: + tags["block_reduction_depth"] = 2 return tags (main_block,) = reduction_blocks diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 4575fa363..2033b8f75 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -2264,10 +2264,11 @@ def get_idx(): lop3_intrin_info["compute"], ) # Assume the grouped K is the last dim of the scaling - grouped_k = sch.get(bf).reads[1].buffer.shape[-1] - # TODO(lei): This is a hack to get the loop extent - loop_extent = 8 if out_dtype == "float16" else 16 - sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + grouped_k = sch.get(bf).reads[1].buffer.shape[-1] + # TODO(lei): This is a hack to get the loop extent + loop_extent = 8 if out_dtype == "float16" else 16 + sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) import_source.append(lop3_intrin_info["c_source"]) def tensorize_init_store_compute(): diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index a3fbba213..f148097d6 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -40,6 +40,24 @@ def unpack_qzeros(qzeros, bits): return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) +def unpack_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i)) + + # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 + # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. + return torch.bitwise_and(unpacked_weight, 2**bits - 1) + + class Linear(nn.Module): opt_M = [1, 16, 32, 64, 128, 256, 512] STORAGE_DTYPE = "int8" # assume int8 storage @@ -279,8 +297,9 @@ def load_and_transform_weight( def repack_from_gptq(self, gptq_module): # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + intweight = unpack_qweight(qweight, self.bits).contiguous() if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda() self.qweight = qweight # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) diff --git a/bitblas/ops/__init__.py b/bitblas/ops/__init__.py index 4fa456477..a132a83b2 100644 --- a/bitblas/ops/__init__.py +++ b/bitblas/ops/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .operator import Operator, OperatorConfig # noqa: F401 -from .matmul import Matmul, MatmulConfig # noqa: F401 -from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 +from .general_matmul import Matmul, MatmulConfig # noqa: F401 from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401 from .quant_compress import QuantCompress, QuantCompressConfig # noqa: F401 diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 184da0b0a..7d99d9628 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -13,6 +13,7 @@ from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass from ..ladder_permutate import LadderPermutate, LadderPermutateConfig +from ..quant_compress import QuantCompress, QuantCompressConfig from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig import logging import torch @@ -292,6 +293,7 @@ def dispatch_tir(self, # create permutate_opertors self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning) self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning) + self.weight_compress = self._assign_weight_compress(target, enable_tuning) self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning) # create cpu weight executors self.input_executors = self._create_input_executors() @@ -338,11 +340,14 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): del enable_tuning if self.propagate_b: + # weight transform should be done in the unpacked level + # otherwise the bit trick should be applied and that is + # too complex to be implemented in the ladder permutation. ladder_permutate_config = LadderPermutateConfig( M=self.N, N=self.K, datatype=self.A_dtype, - dequantize_bits=self.bit, + dequantize_bits=-1, storage_dtype=self.storage_dtype, propagate_kind="B", transpose_matrix=self.layout == "nt", @@ -354,6 +359,25 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): ) return None + def _assign_weight_compress(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning + + require_compress: bool = self.bit in [1, 2, 4] + if require_compress: + quant_compress_config = QuantCompressConfig( + M=self.N, + N=self.K, + input_dtype=self.storage_dtype, + storage_dtype=self.storage_dtype, + dequantize_bits=self.bit) + return QuantCompress( + config=quant_compress_config, + target=tvm.target.Target("llvm"), + ) + return None + def _assign_lop3_permutate(self, target: Target, enable_tuning: bool): # unused variables del target @@ -381,10 +405,12 @@ def _create_input_executors(self): def _create_weight_executors(self): weight_executors = OPExecutorCPU() - if self.fast_decoding: - weight_executors.append(self.lop3_permutate) if self.propagate_b is not TransformKind.NonTransform: weight_executors.append(self.ladder_permutate_b) + if self.weight_compress is not None: + weight_executors.append(self.weight_compress) + if self.fast_decoding: + weight_executors.append(self.lop3_permutate) return weight_executors def _select_implementation(self): @@ -452,10 +478,6 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): return self.weight_transform(weight.cpu()).cuda().contiguous() return weight - from bitblas.quantization import general_compress - import torch - import numpy as np - source_format, bit = self.source_format, self.bit # Process integer source format @@ -464,20 +486,13 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): assert not self.with_zeros, "zeros should be False for int source format" maxq = 2**(bit - 1) # Clamp weight values to be within the quantizable range and adjust - weight = torch.clamp(weight, -maxq, maxq).int() + maxq + weight = torch.clamp(weight, -maxq, maxq).char() + maxq elif source_format in ["fp_e5m2", "fp_e4m3"]: weight = weight.view(torch.int8) - weight = weight.int() else: # For non-integer formats, simply convert weights to integers - weight = weight.int() - - np_storage_dtype = getattr(np, self.storage_dtype) - - weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) - - weight = torch.from_numpy(weight).cuda().contiguous() + # And assume weight is in the range of [-128, 127] for int8 + weight = weight.char() # Apply an optional weight transformation if specified if self.weight_transform is not None: diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py index 17d22dcfe..a86f6469a 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -515,7 +515,7 @@ def matmul_nt_dequantize_b_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind: Union[int, TransformKind] = TransformKind.NonTransform, + transform_kind: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): if isinstance(transform_kind, int): transform_kind = TransformKind(transform_kind) @@ -699,8 +699,8 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind_input: Union[int, TransformKind] = TransformKind.NonTransform, - transform_kind_weight: Union[int, TransformKind] = TransformKind.NonTransform, + transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform, + transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): if isinstance(transform_kind_input, int): transform_kind_input = TransformKind(transform_kind_input) diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index 6644705cd..d09ee6dac 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -5,6 +5,7 @@ from ..operator import Operator from .ladder_permutate_impl import select_implementation from dataclasses import dataclass +import torch @dataclass(frozen=True) @@ -57,6 +58,23 @@ def _select_implementation(self): target_instruction=self.target_instruction, ) + def forward(self, inp, out=None): + if out is None: + out_shape, out_dtype = self.retrieve_output_shape() + out = torch.zeros(out_shape, dtype=out_dtype).to(inp.device) + self.torch_func(inp, out) + return out + + def retrieve_output_shape(self): + """ + Retrieve the output shape of the operator + """ + func = self.prim_func + param = func.params[-1] + assert param in func.buffer_map, f"param {param} not in buffer_map" + arg = func.buffer_map[param] + return [int(i) for i in arg.shape], getattr(torch, arg.dtype) + @property def M(self): return self.config.M diff --git a/bitblas/ops/lop3_permutate/__init__.py b/bitblas/ops/lop3_permutate/__init__.py index 10c452b3d..19c4b0eea 100644 --- a/bitblas/ops/lop3_permutate/__init__.py +++ b/bitblas/ops/lop3_permutate/__init__.py @@ -42,11 +42,23 @@ def _select_implementation(self): dequantize_bits=self.dequantize_bits, ) - def forward(self, weight, res): + def forward(self, inp, out=None): + out_shape = inp.shape + out_dtype = inp.dtype + if out is None: + # lop3 transform does not change the shape of the input tensor + out = torch.zeros(out_shape, dtype=out_dtype) # reinterpret the input tensor to int32 format - args = [arg.view(torch.int32) for arg in [weight, res]] + shape_2dim = self.retrieve_2d_weight_shape() + args = [arg.view(inp.dtype).view(shape_2dim).view(torch.int32) for arg in [inp, out]] self.torch_func(*args) - return args[-1].view(weight.dtype) + return args[-1].view(out_dtype).view(out_shape) + + def retrieve_2d_weight_shape(self): + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + elems_per_byte = storage_nbit // self.dequantize_bits + weight_shape = (self.M, self.N // elems_per_byte) + return weight_shape @property def M(self): diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py deleted file mode 100644 index e515a264c..000000000 --- a/bitblas/ops/matmul.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -import numpy as np -from tvm.target import Target -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from typing import List, Union, Optional, Any, Tuple -from .operator import Operator, TransformKind -from .impl.matmul_impl import select_implementation -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class TransformExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - with_bias: bool = False - # layout of matrix A and B - # "nn": C[i, j] = A[i, k] * B[k, j] - # "nt": C[i, j] = A[i, k] * B[j, k] - layout: str = "nt" - # weight transformation kind of matrix A - propagate_a: TransformKind = TransformKind.NonTransform - # weight transformation kind of matrix B - propagate_b: TransformKind = TransformKind.NonTransform - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class Matmul(Operator): - - def __init__( - self, - config: MatmulConfig, - name: str = "matmul", - target: Union[str, Target] = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.update_func(self.prim_func.with_attrs({"opt_shapes": self.dynamic_range})) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - assert (self.propagate_a is - TransformKind.NonTransform), "Currently only support NonTransform for input" - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="B", - transpose_matrix=(self.layout == "nt"), - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - input_executors = TransformExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - - self.input_executors = input_executors - - weight_executors = TransformExecutorCPU() - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - with_bias=self.with_bias, - layout=self.layout, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def _profile_latency_with_dynamic_range(self) -> List: - func = self.prim_func_mod["main"] - device = self.arch.device - - def var_warpper(v, m): - if isinstance(v, tvm.tir.Var): - assert "opt_shapes" in func.attrs - assert v.name in func.attrs["opt_shapes"] - return m - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - benchmark_latencies = [] - for m in self.dynamic_range["m"]: - profile_tensors = [] - for param in func.params: - if param not in func.buffer_map: - # in case of dynamic symbolic may in params - continue - arg = func.buffer_map[param] - profile_tensors.append( - tvm.nd.array( - np.random.uniform(0, 1, - [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), - device=device, - )) - latency = self.time_evaluator(*profile_tensors).mean * 1e3 - benchmark_latencies.append({"m": m, "latency": latency}) - # ms - return benchmark_latencies - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def layout(self): - return self.config.layout - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None - - -__all__ = ["Matmul", "MatmulConfig"] diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py deleted file mode 100644 index 6971547b0..000000000 --- a/bitblas/ops/matmul_dequantize.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from tvm.target import Target -from bitblas.base.arch.cuda import CUDA -from typing import Any, List, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind -from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class OPExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulWeightOnlyDequantizeConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - bit: int = 4 - storage_dtype: str = "int8" - # documents for source_format: - # the format of the source data, which can be "int", "uint", "fp", "nf" - # "int": dequantize_weight = (target)((int)(quantize_weight - fixed_zero_point)) * scale - # where the fixed_zero_point is 2^(bit - 1) - 1 - # "uint": dequantize_weight = (target)((uint)(quantize_weight - zero_point)) * scale - # where the zero_point is manually set by zeros tensor - # "fp": dequantize_weight = (quantize_weight - zero_point) * scale - # "nf": dequantize_weight = (lut[quantize_weight] - zero_point) * scale - source_format: Literal["int", "uint", "fp", "nf"] = "int" - with_scaling: bool = False - with_zeros: bool = False - group_size: int = -1 - fast_decoding: bool = False - with_bias: bool = False - propagate_a: TransformKind = TransformKind.NonTransform - propagate_b: TransformKind = TransformKind.NonTransform - layout: str = "nt" - # documents for zeros_mode: - # original: target = (dequantize_weight - zero_point) * scale - # rescale: target = dequantize_weight * scale - zero_point - # quantized: target = (dequantize_weight - dequantize_zeros) * scale - # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. - zeros_mode: Literal["original", "rescale", "quantized"] = "original" - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class MatmulWeightOnlyDequantize(Operator): - - def __init__( - self, - config: MatmulWeightOnlyDequantizeConfig, - name: str = "matmul_weight_only_dequantize", - target: Target = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - self.arch = CUDA(target) - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - propagate_kind="B", - transpose_matrix=self.layout == "nt", - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - if self.fast_decoding: - lop3_permutate_config = LOP3PermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - ) - self.lop3_permutate = LOP3Permutate( - config=lop3_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.lop3_permutate = None - - input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors - - weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: - weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - bit=self.bit, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - layout=self.layout, - zeros_mode=self.zeros_mode, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def bit(self): - return self.config.bit - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def source_format(self): - return self.config.source_format - - @property - def with_scaling(self): - return self.config.with_scaling - - @property - def with_zeros(self): - return self.config.with_zeros - - @property - def group_size(self): - return self.config.group_size - - @property - def fast_decoding(self): - return self.config.fast_decoding - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def layout(self): - return self.config.layout - - @property - def zeros_mode(self): - return self.config.zeros_mode - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 29d384302..f6fa4cca0 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -13,7 +13,6 @@ from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.arch import get_arch -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from bitblas.builder.wrapper import TIRWrapper from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass @@ -371,7 +370,6 @@ def is_none(self): def forward(self, weight): inputs = [weight] for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) inputs = [op.forward(*inputs)] return inputs[-1] diff --git a/integration/pytorch/bitblas_quant_linear.py b/integration/pytorch/bitblas_quant_linear.py index c0cdac611..6e6610c1f 100644 --- a/integration/pytorch/bitblas_quant_linear.py +++ b/integration/pytorch/bitblas_quant_linear.py @@ -182,7 +182,7 @@ def pack(self, linear, scales, zeros=None): (w[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.contiguous() - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) # quantize to 4bit qw_np = general_compress(intweight, source_bits=self.bits, storage_dtype=np.int8) # do interleave for fast type conversion diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index f329a146e..3adacaa8c 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -98,7 +98,7 @@ def correctness_weight_only_dequantize( inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: @@ -109,15 +109,13 @@ def correctness_weight_only_dequantize( ref_result = ref_result + bias_tensor with torch.no_grad(): - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] permuted_inputs.append(inputs[0]) if linear_bitblas.bitblas_matmul.weight_transform is not None: permuted_inputs.append( - linear_bitblas.bitblas_matmul.weight_transform(qw_torch.cpu()).cuda()) + linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda()) else: - permuted_inputs.append(qw_torch) + permuted_inputs.append(inputs[1]) linear_bitblas.qweight.data = permuted_inputs[-1].clone() if with_scaling: if group_size == -1: diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 62808e2a7..354914d22 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -155,7 +155,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: @@ -165,14 +165,12 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) if with_bias: ref_result = ref_result + bias - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] permuted_inputs.append(inputs[0]) if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) + permuted_inputs.append(matmul.weight_transform(intweight.cpu()).cuda()) else: - permuted_inputs.append(qw_torch) + permuted_inputs.append(intweight) if with_scaling: if group_size == -1: group_size = K