diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index f42c2085b..ddc40cfa6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -165,8 +165,6 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], return self def cuda(self, device): - if self.quant_state is not None: - return self w = self.data.contiguous().half().cuda(device) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_4bit