diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 91e88133c..a1bc95f39 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -39,9 +39,10 @@ from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 -from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 from .module import Linear # noqa: F401 +import warnings +import functools import logging from tqdm import tqdm @@ -89,4 +90,26 @@ def _init_logger(): _init_logger() + +def deprecated(reason): + """ + This is a decorator which can be used to mark functions as deprecated. + It will result in a warning being emitted when the function is used. + """ + + def decorator(func): + + @functools.wraps(func) + def new_func(*args, **kwargs): + warnings.warn( + f"Call to deprecated function {func.__name__} ({reason}).", + category=DeprecationWarning, + stacklevel=2) + return func(*args, **kwargs) + + return new_func + + return decorator + + __version__ = "0.0.1.dev13" diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 4575fa363..2033b8f75 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -2264,10 +2264,11 @@ def get_idx(): lop3_intrin_info["compute"], ) # Assume the grouped K is the last dim of the scaling - grouped_k = sch.get(bf).reads[1].buffer.shape[-1] - # TODO(lei): This is a hack to get the loop extent - loop_extent = 8 if out_dtype == "float16" else 16 - sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + grouped_k = sch.get(bf).reads[1].buffer.shape[-1] + # TODO(lei): This is a hack to get the loop extent + loop_extent = 8 if out_dtype == "float16" else 16 + sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) import_source.append(lop3_intrin_info["c_source"]) def tensorize_init_store_compute(): diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index a3fbba213..f148097d6 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -40,6 +40,24 @@ def unpack_qzeros(qzeros, bits): return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) +def unpack_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i)) + + # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 + # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. + return torch.bitwise_and(unpacked_weight, 2**bits - 1) + + class Linear(nn.Module): opt_M = [1, 16, 32, 64, 128, 256, 512] STORAGE_DTYPE = "int8" # assume int8 storage @@ -279,8 +297,9 @@ def load_and_transform_weight( def repack_from_gptq(self, gptq_module): # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + intweight = unpack_qweight(qweight, self.bits).contiguous() if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda() self.qweight = qweight # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) diff --git a/bitblas/ops/__init__.py b/bitblas/ops/__init__.py index 4fa456477..a132a83b2 100644 --- a/bitblas/ops/__init__.py +++ b/bitblas/ops/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .operator import Operator, OperatorConfig # noqa: F401 -from .matmul import Matmul, MatmulConfig # noqa: F401 -from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 +from .general_matmul import Matmul, MatmulConfig # noqa: F401 from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401 from .quant_compress import QuantCompress, QuantCompressConfig # noqa: F401 diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 184da0b0a..7d99d9628 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -13,6 +13,7 @@ from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass from ..ladder_permutate import LadderPermutate, LadderPermutateConfig +from ..quant_compress import QuantCompress, QuantCompressConfig from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig import logging import torch @@ -292,6 +293,7 @@ def dispatch_tir(self, # create permutate_opertors self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning) self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning) + self.weight_compress = self._assign_weight_compress(target, enable_tuning) self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning) # create cpu weight executors self.input_executors = self._create_input_executors() @@ -338,11 +340,14 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): del enable_tuning if self.propagate_b: + # weight transform should be done in the unpacked level + # otherwise the bit trick should be applied and that is + # too complex to be implemented in the ladder permutation. ladder_permutate_config = LadderPermutateConfig( M=self.N, N=self.K, datatype=self.A_dtype, - dequantize_bits=self.bit, + dequantize_bits=-1, storage_dtype=self.storage_dtype, propagate_kind="B", transpose_matrix=self.layout == "nt", @@ -354,6 +359,25 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): ) return None + def _assign_weight_compress(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning + + require_compress: bool = self.bit in [1, 2, 4] + if require_compress: + quant_compress_config = QuantCompressConfig( + M=self.N, + N=self.K, + input_dtype=self.storage_dtype, + storage_dtype=self.storage_dtype, + dequantize_bits=self.bit) + return QuantCompress( + config=quant_compress_config, + target=tvm.target.Target("llvm"), + ) + return None + def _assign_lop3_permutate(self, target: Target, enable_tuning: bool): # unused variables del target @@ -381,10 +405,12 @@ def _create_input_executors(self): def _create_weight_executors(self): weight_executors = OPExecutorCPU() - if self.fast_decoding: - weight_executors.append(self.lop3_permutate) if self.propagate_b is not TransformKind.NonTransform: weight_executors.append(self.ladder_permutate_b) + if self.weight_compress is not None: + weight_executors.append(self.weight_compress) + if self.fast_decoding: + weight_executors.append(self.lop3_permutate) return weight_executors def _select_implementation(self): @@ -452,10 +478,6 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): return self.weight_transform(weight.cpu()).cuda().contiguous() return weight - from bitblas.quantization import general_compress - import torch - import numpy as np - source_format, bit = self.source_format, self.bit # Process integer source format @@ -464,20 +486,13 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): assert not self.with_zeros, "zeros should be False for int source format" maxq = 2**(bit - 1) # Clamp weight values to be within the quantizable range and adjust - weight = torch.clamp(weight, -maxq, maxq).int() + maxq + weight = torch.clamp(weight, -maxq, maxq).char() + maxq elif source_format in ["fp_e5m2", "fp_e4m3"]: weight = weight.view(torch.int8) - weight = weight.int() else: # For non-integer formats, simply convert weights to integers - weight = weight.int() - - np_storage_dtype = getattr(np, self.storage_dtype) - - weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) - - weight = torch.from_numpy(weight).cuda().contiguous() + # And assume weight is in the range of [-128, 127] for int8 + weight = weight.char() # Apply an optional weight transformation if specified if self.weight_transform is not None: diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index 6644705cd..d09ee6dac 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -5,6 +5,7 @@ from ..operator import Operator from .ladder_permutate_impl import select_implementation from dataclasses import dataclass +import torch @dataclass(frozen=True) @@ -57,6 +58,23 @@ def _select_implementation(self): target_instruction=self.target_instruction, ) + def forward(self, inp, out=None): + if out is None: + out_shape, out_dtype = self.retrieve_output_shape() + out = torch.zeros(out_shape, dtype=out_dtype).to(inp.device) + self.torch_func(inp, out) + return out + + def retrieve_output_shape(self): + """ + Retrieve the output shape of the operator + """ + func = self.prim_func + param = func.params[-1] + assert param in func.buffer_map, f"param {param} not in buffer_map" + arg = func.buffer_map[param] + return [int(i) for i in arg.shape], getattr(torch, arg.dtype) + @property def M(self): return self.config.M diff --git a/bitblas/ops/lop3_permutate/__init__.py b/bitblas/ops/lop3_permutate/__init__.py index 10c452b3d..19c4b0eea 100644 --- a/bitblas/ops/lop3_permutate/__init__.py +++ b/bitblas/ops/lop3_permutate/__init__.py @@ -42,11 +42,23 @@ def _select_implementation(self): dequantize_bits=self.dequantize_bits, ) - def forward(self, weight, res): + def forward(self, inp, out=None): + out_shape = inp.shape + out_dtype = inp.dtype + if out is None: + # lop3 transform does not change the shape of the input tensor + out = torch.zeros(out_shape, dtype=out_dtype) # reinterpret the input tensor to int32 format - args = [arg.view(torch.int32) for arg in [weight, res]] + shape_2dim = self.retrieve_2d_weight_shape() + args = [arg.view(inp.dtype).view(shape_2dim).view(torch.int32) for arg in [inp, out]] self.torch_func(*args) - return args[-1].view(weight.dtype) + return args[-1].view(out_dtype).view(out_shape) + + def retrieve_2d_weight_shape(self): + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + elems_per_byte = storage_nbit // self.dequantize_bits + weight_shape = (self.M, self.N // elems_per_byte) + return weight_shape @property def M(self): diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py deleted file mode 100644 index e515a264c..000000000 --- a/bitblas/ops/matmul.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -import numpy as np -from tvm.target import Target -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from typing import List, Union, Optional, Any, Tuple -from .operator import Operator, TransformKind -from .impl.matmul_impl import select_implementation -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class TransformExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - with_bias: bool = False - # layout of matrix A and B - # "nn": C[i, j] = A[i, k] * B[k, j] - # "nt": C[i, j] = A[i, k] * B[j, k] - layout: str = "nt" - # weight transformation kind of matrix A - propagate_a: TransformKind = TransformKind.NonTransform - # weight transformation kind of matrix B - propagate_b: TransformKind = TransformKind.NonTransform - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class Matmul(Operator): - - def __init__( - self, - config: MatmulConfig, - name: str = "matmul", - target: Union[str, Target] = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.update_func(self.prim_func.with_attrs({"opt_shapes": self.dynamic_range})) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - assert (self.propagate_a is - TransformKind.NonTransform), "Currently only support NonTransform for input" - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="B", - transpose_matrix=(self.layout == "nt"), - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - input_executors = TransformExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - - self.input_executors = input_executors - - weight_executors = TransformExecutorCPU() - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - with_bias=self.with_bias, - layout=self.layout, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def _profile_latency_with_dynamic_range(self) -> List: - func = self.prim_func_mod["main"] - device = self.arch.device - - def var_warpper(v, m): - if isinstance(v, tvm.tir.Var): - assert "opt_shapes" in func.attrs - assert v.name in func.attrs["opt_shapes"] - return m - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - benchmark_latencies = [] - for m in self.dynamic_range["m"]: - profile_tensors = [] - for param in func.params: - if param not in func.buffer_map: - # in case of dynamic symbolic may in params - continue - arg = func.buffer_map[param] - profile_tensors.append( - tvm.nd.array( - np.random.uniform(0, 1, - [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), - device=device, - )) - latency = self.time_evaluator(*profile_tensors).mean * 1e3 - benchmark_latencies.append({"m": m, "latency": latency}) - # ms - return benchmark_latencies - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def layout(self): - return self.config.layout - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None - - -__all__ = ["Matmul", "MatmulConfig"] diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py deleted file mode 100644 index 6971547b0..000000000 --- a/bitblas/ops/matmul_dequantize.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from tvm.target import Target -from bitblas.base.arch.cuda import CUDA -from typing import Any, List, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind -from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class OPExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulWeightOnlyDequantizeConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - bit: int = 4 - storage_dtype: str = "int8" - # documents for source_format: - # the format of the source data, which can be "int", "uint", "fp", "nf" - # "int": dequantize_weight = (target)((int)(quantize_weight - fixed_zero_point)) * scale - # where the fixed_zero_point is 2^(bit - 1) - 1 - # "uint": dequantize_weight = (target)((uint)(quantize_weight - zero_point)) * scale - # where the zero_point is manually set by zeros tensor - # "fp": dequantize_weight = (quantize_weight - zero_point) * scale - # "nf": dequantize_weight = (lut[quantize_weight] - zero_point) * scale - source_format: Literal["int", "uint", "fp", "nf"] = "int" - with_scaling: bool = False - with_zeros: bool = False - group_size: int = -1 - fast_decoding: bool = False - with_bias: bool = False - propagate_a: TransformKind = TransformKind.NonTransform - propagate_b: TransformKind = TransformKind.NonTransform - layout: str = "nt" - # documents for zeros_mode: - # original: target = (dequantize_weight - zero_point) * scale - # rescale: target = dequantize_weight * scale - zero_point - # quantized: target = (dequantize_weight - dequantize_zeros) * scale - # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. - zeros_mode: Literal["original", "rescale", "quantized"] = "original" - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class MatmulWeightOnlyDequantize(Operator): - - def __init__( - self, - config: MatmulWeightOnlyDequantizeConfig, - name: str = "matmul_weight_only_dequantize", - target: Target = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - self.arch = CUDA(target) - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - propagate_kind="B", - transpose_matrix=self.layout == "nt", - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - if self.fast_decoding: - lop3_permutate_config = LOP3PermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - ) - self.lop3_permutate = LOP3Permutate( - config=lop3_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.lop3_permutate = None - - input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors - - weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: - weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - bit=self.bit, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - layout=self.layout, - zeros_mode=self.zeros_mode, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def bit(self): - return self.config.bit - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def source_format(self): - return self.config.source_format - - @property - def with_scaling(self): - return self.config.with_scaling - - @property - def with_zeros(self): - return self.config.with_zeros - - @property - def group_size(self): - return self.config.group_size - - @property - def fast_decoding(self): - return self.config.fast_decoding - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def layout(self): - return self.config.layout - - @property - def zeros_mode(self): - return self.config.zeros_mode - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 29d384302..f6fa4cca0 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -13,7 +13,6 @@ from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.arch import get_arch -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from bitblas.builder.wrapper import TIRWrapper from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass @@ -371,7 +370,6 @@ def is_none(self): def forward(self, weight): inputs = [weight] for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) inputs = [op.forward(*inputs)] return inputs[-1] diff --git a/integration/pytorch/bitblas_quant_linear.py b/integration/pytorch/bitblas_quant_linear.py index c0cdac611..6e6610c1f 100644 --- a/integration/pytorch/bitblas_quant_linear.py +++ b/integration/pytorch/bitblas_quant_linear.py @@ -182,7 +182,7 @@ def pack(self, linear, scales, zeros=None): (w[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.contiguous() - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) # quantize to 4bit qw_np = general_compress(intweight, source_bits=self.bits, storage_dtype=np.int8) # do interleave for fast type conversion diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index f329a146e..3adacaa8c 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -98,7 +98,7 @@ def correctness_weight_only_dequantize( inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: @@ -109,15 +109,13 @@ def correctness_weight_only_dequantize( ref_result = ref_result + bias_tensor with torch.no_grad(): - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] permuted_inputs.append(inputs[0]) if linear_bitblas.bitblas_matmul.weight_transform is not None: permuted_inputs.append( - linear_bitblas.bitblas_matmul.weight_transform(qw_torch.cpu()).cuda()) + linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda()) else: - permuted_inputs.append(qw_torch) + permuted_inputs.append(inputs[1]) linear_bitblas.qweight.data = permuted_inputs[-1].clone() if with_scaling: if group_size == -1: diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 62808e2a7..354914d22 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -155,7 +155,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: @@ -165,14 +165,12 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) if with_bias: ref_result = ref_result + bias - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] permuted_inputs.append(inputs[0]) if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) + permuted_inputs.append(matmul.weight_transform(intweight.cpu()).cuda()) else: - permuted_inputs.append(qw_torch) + permuted_inputs.append(intweight) if with_scaling: if group_size == -1: group_size = K