diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2320ffd39..81404179d 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -660,9 +660,9 @@ def cpu(self): self.SCB = SCB return self - def xpu(self): + def xpu(self, device): # we store the 8-bit rows-major weight - B = self.data.contiguous().to(torch.float16).xpu() + B = self.data.contiguous().to(torch.float16).xpu(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) if CBt is not None: del CBt @@ -700,11 +700,11 @@ def to(self, *args, **kwargs): return self.cpu() elif device.type == "xpu": if self.data.dtype == torch.int8: - self.data = self.data.contiguous().xpu() + self.data = self.data.contiguous().xpu(device) self.CB = self.data return self else: - return self.xpu() + return self.xpu(device) else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking),