From fcc5bb9dd148634be77d83407c352cb97dd710d9 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 4 Jan 2024 15:55:49 -0600 Subject: [PATCH 1/8] Minimal changes for fp32 4bit storage from BNB commit 8278fca --- bitsandbytes/functional.py | 8 ++++---- bitsandbytes/nn/modules.py | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e17e70c4b..c6f18c250 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -607,7 +607,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState qs_dict: based on state_dict, with only relevant keys, striped of prefixes. - item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. """ # unpacking tensor with non-tensor components @@ -802,7 +802,7 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - + absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -931,7 +931,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if out is None: - out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + out = torch.zeros(((n+1)//8, 1), dtype=torch.float32, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -1626,7 +1626,7 @@ def gemv_4bit( ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - if B.dtype == torch.uint8: + if B.dtype == torch.float32: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.bfloat16: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 7cce82b91..0d56925b2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -142,7 +142,7 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit": + def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4', module: Optional["Linear4bit"] = None) -> "Params4bit": if data is None: data = torch.empty(0) @@ -152,6 +152,7 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_ self.quant_type = quant_type self.quant_state = quant_state self.data = data + self.module = module return self @classmethod @@ -169,6 +170,8 @@ def cuda(self, device): w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_4bit self.quant_state = quant_state + if self.module is not None: + self.module.quant_state = quant_state return self @@ -205,10 +208,11 @@ class Linear4bit(nn.Linear): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None): super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False + self.quant_state = None def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -243,7 +247,15 @@ def forward(self, x: torch.Tensor): self.bias.data = self.bias.data.to(x.dtype) if getattr(self.weight, 'quant_state', None) is None: - print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + if getattr(self, 'quant_state', None) is not None: + # the quant state got lost when the parameter got converted. This happens for example for fsdp + # since we registered the module, we can recover the state here + assert self.weight.shape[1] == 1 + if not isinstance(self.weight, Params4bit): + self.weight = Params4bit(self.weight) + self.weight.quant_state = self.quant_state + else: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True From c51403c4a70a56cd19413668102fd1e5515180bd Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 5 Jan 2024 14:54:31 -0600 Subject: [PATCH 2/8] Params4bit with selectable storage dtype --- bitsandbytes/functional.py | 19 ++++++++++--------- bitsandbytes/nn/modules.py | 30 +++++++++++++++++++++--------- tests/test_linear4bit.py | 34 ++++++++++++++++++++++++++++++---- 3 files changed, 61 insertions(+), 22 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c6f18c250..3c85f878f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor: """ Quantize tensor A in blocks of 4-bit values. @@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz absmax : torch.Tensor The absmax values. out : torch.Tensor - The output tensor (8-bit). + The output tensor. blocksize : int The blocksize used in quantization. quant_type : str @@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz Returns ------- torch.Tensor: - The 8-bit tensor with packed 4-bit values. + Tensor with packed 4-bit values. tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ @@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if out is None: - out = torch.zeros(((n+1)//8, 1), dtype=torch.float32, device=A.device) + mod = dtype2bytes[quant_storage] * 2 + out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = Parameters ---------- A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). + The input tensor (packed 4-bit values). quant_state : QuantState object with quantisation stats, incl. absmax values, original tensor shape and original dtype. absmax : torch.Tensor diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 0d56925b2..42bb0e377 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -142,7 +142,17 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4', module: Optional["Linear4bit"] = None) -> "Params4bit": + def __new__( + cls, + data: Optional[torch.Tensor] = None, + requires_grad=True, + quant_state: QuantState = None, + blocksize: int = 64, + compress_statistics: bool = True, + quant_type: str = 'fp4', + quant_storage: torch.dtype = torch.uint8, + module: Optional["Linear4bit"] = None + ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -151,6 +161,7 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_ self.compress_statistics = compress_statistics self.quant_type = quant_type self.quant_state = quant_state + self.quant_storage = quant_storage self.data = data self.module = module return self @@ -167,7 +178,7 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], def cuda(self, device): w = self.data.contiguous().half().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type, quant_storage=self.quant_storage) self.data = w_4bit self.quant_state = quant_state if self.module is not None: @@ -206,13 +217,14 @@ def to(self, *args, **kwargs): class Linear4bit(nn.Linear): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, module=self) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False self.quant_state = None + self.quant_storage = quant_storage def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -252,7 +264,7 @@ def forward(self, x: torch.Tensor): # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 if not isinstance(self.weight, Params4bit): - self.weight = Params4bit(self.weight) + self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight.quant_state = self.quant_state else: print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') @@ -273,8 +285,8 @@ def forward(self, x: torch.Tensor): class LinearFP4(Linear4bit): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device) + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) class LinearNF4(Linear4bit): @@ -288,8 +300,8 @@ class LinearNF4(Linear4bit): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. ''' - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device) + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) class Int8Params(torch.nn.Parameter): diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 67d299dea..b17bb531d 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -8,13 +8,19 @@ import bitsandbytes as bnb +storage = { + 'uint8': torch.uint8, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + 'float32': torch.float32 +} @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.parametrize( - "quant_type, compress_statistics, bias", - list(product(["nf4", "fp4"], [False, True], [False, True])), + "quant_type, compress_statistics, bias, quant_storage", + list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])), ) -def test_linear_serialization(quant_type, compress_statistics, bias): +def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): original_dtype = torch.float16 compute_dtype = None device = "cuda" @@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias): quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False) + new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias): # MATCHING a, b = linear_q.weight, linear_q2.weight + # Quantizing original layer with specified quant_storage type + linear_qs = bnb.nn.Linear4bit( + linear.in_features, + linear.out_features, + bias=bias, + compute_dtype=compute_dtype, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=storage[quant_storage], + device="meta", + ) + linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + if bias: + linear_qs.bias = torch.nn.Parameter(linear.bias) + linear_qs = linear_qs.to(device) + assert a.device == b.device assert a.dtype == b.dtype assert torch.equal(a, b) @@ -96,9 +118,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias): x = torch.rand(42, layer_shape[0], device=device) a = linear_q(x) b = linear_q2(x) + c = linear_qs(x) assert a.device == b.device assert a.dtype == b.dtype + assert a.device == c.device + assert a.dtype == c.dtype assert torch.equal(a, b) + assert torch.equal(a, c) # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: From 83c538a1181d2c554922da8d5c2b148b1d582230 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 8 Jan 2024 14:15:36 -0600 Subject: [PATCH 3/8] possible fix for double quantizing linear weight & quant storage dtype --- bitsandbytes/nn/modules.py | 13 +++++++++---- tests/test_linear4bit.py | 8 ++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 42bb0e377..b2508a0c8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -162,6 +162,7 @@ def __new__( self.quant_type = quant_type self.quant_state = quant_state self.quant_storage = quant_storage + self.bnb_quantized = False self.data = data self.module = module return self @@ -174,18 +175,22 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], self.blocksize = self.quant_state.blocksize self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type + self.bnb_quantized = True return self - def cuda(self, device): + def _quantize(self, device): w = self.data.contiguous().half().cuda(device) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type, quant_storage=self.quant_storage) self.data = w_4bit self.quant_state = quant_state if self.module is not None: self.module.quant_state = quant_state - + self.bnb_quantized = True return self + def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) + @overload def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: ... @@ -201,8 +206,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): - return self.cuda(device) + if (device is not None and device.type == "cuda" and not self.bnb_quantized): + return self._quantize(device) else: if self.quant_state is not None: self.quant_state.to(device) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index b17bb531d..f6be79a84 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -126,6 +126,14 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, b) assert torch.equal(a, c) + # Test moving to CPU and back to GPU + linear_q2.to('cpu') + linear_q2.to(device) + d = linear_qs(x) + assert c.dtype == d.dtype + assert c.device == d.device + assert torch.equal(c, d) + # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") From 359bfc9dfdc8d744506889f4f1f6213f7a4b3b10 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Tue, 9 Jan 2024 15:27:45 +0000 Subject: [PATCH 4/8] minor fixes in Params4bit for peft tests --- bitsandbytes/functional.py | 2 +- bitsandbytes/nn/modules.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3c85f878f..712ade22b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1627,7 +1627,7 @@ def gemv_4bit( ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - if B.dtype == torch.float32: + if B.dtype in [torch.uint8, torch.bfloat16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.bfloat16: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b2508a0c8..8f39f6d82 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -151,7 +151,8 @@ def __new__( compress_statistics: bool = True, quant_type: str = 'fp4', quant_storage: torch.dtype = torch.uint8, - module: Optional["Linear4bit"] = None + module: Optional["Linear4bit"] = None, + bnb_quantized: bool = False ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -162,7 +163,7 @@ def __new__( self.quant_type = quant_type self.quant_state = quant_state self.quant_storage = quant_storage - self.bnb_quantized = False + self.bnb_quantized = bnb_quantized self.data = data self.module = module return self @@ -224,7 +225,7 @@ class Linear4bit(nn.Linear): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self, bnb_quantized=False) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False From 112741ab4624ad8400681d9d863d2c90a9a0d9e3 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Tue, 9 Jan 2024 15:41:18 +0000 Subject: [PATCH 5/8] remove redundant --- bitsandbytes/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 8f39f6d82..6fbe4a8f8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -225,7 +225,7 @@ class Linear4bit(nn.Linear): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self, bnb_quantized=False) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False From 9e5a15b03f65bb60e6663b69d6824737fdeaa15d Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Fri, 12 Jan 2024 14:41:24 +0000 Subject: [PATCH 6/8] add float16 --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 712ade22b..4ce634384 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1627,7 +1627,7 @@ def gemv_4bit( ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - if B.dtype in [torch.uint8, torch.bfloat16, torch.float32]: + if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.bfloat16: From 4aecbdfc0b584d24ea12f78944db761cb4264c77 Mon Sep 17 00:00:00 2001 From: Kerem Turgutlu Date: Fri, 12 Jan 2024 17:42:41 +0000 Subject: [PATCH 7/8] update test --- tests/test_functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index f825c14df..9b1d0ef25 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2370,7 +2370,8 @@ def test_normal_map_tree(): @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -def test_gemv_4bit(dtype, storage_type, double_quant, kind): +@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32']) +def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: #for dim in [4*1024]: #for dim in [1*16]: @@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind): A = torch.randn(1, dim, dtype=dtype, device='cuda') B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True From 89ebdb2ba2cc26861172ba4fae02f24dd0177f3c Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 16 Jan 2024 11:06:49 -0600 Subject: [PATCH 8/8] Remove float16 quant cast as there are fp32, bf16, & fp16 quant kernels --- bitsandbytes/nn/modules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6fbe4a8f8..0b1dc5c6f 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -141,7 +141,6 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - def __new__( cls, data: Optional[torch.Tensor] = None, @@ -180,8 +179,9 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], return self def _quantize(self, device): - w = self.data.contiguous().half().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type, quant_storage=self.quant_storage) + w = self.data.contiguous().cuda(device) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type, quant_storage=self.quant_storage) self.data = w_4bit self.quant_state = quant_state if self.module is not None: