From 64b2d1698e70520cf4fbee74438f39b67aa57938 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Thu, 23 Jan 2025 22:03:13 -0800 Subject: [PATCH] add device index --- bitsandbytes/nn/modules.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2320ffd39..81404179d 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -660,9 +660,9 @@ def cpu(self): self.SCB = SCB return self - def xpu(self): + def xpu(self, device): # we store the 8-bit rows-major weight - B = self.data.contiguous().to(torch.float16).xpu() + B = self.data.contiguous().to(torch.float16).xpu(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) if CBt is not None: del CBt @@ -700,11 +700,11 @@ def to(self, *args, **kwargs): return self.cpu() elif device.type == "xpu": if self.data.dtype == torch.int8: - self.data = self.data.contiguous().xpu() + self.data = self.data.contiguous().xpu(device) self.CB = self.data return self else: - return self.xpu() + return self.xpu(device) else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking),