diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 69d39277b..c36fb68a6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -694,11 +694,10 @@ def to(self, *args, **kwargs): ) # If we had already quantized, move the statistics appropriately. - if is_quantized and device is not None: - if self.CB is not None: - new_param.CB = new_param.data + if is_quantized: + new_param.CB = new_param.data - if self.SCB is not None: + if self.SCB is not None and device is not None: new_param.SCB = self.SCB.to(device) return new_param