From 119f6962db6c0f1b3bd329ef85e022cbaf7725e3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 2 Dec 2024 16:32:20 +0000 Subject: [PATCH] fix cpu nf4 Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 3 ++- bitsandbytes/nn/modules.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 9765def05..e188479f6 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -579,7 +579,8 @@ def matmul_4bit( assert quant_state is not None if A.device.type in ("cpu", "xpu") and A.requires_grad == False: if getattr(quant_state, "ipex", False): - out = F.gemv_4bit(A, B.t(), out, state=quant_state) + B = B.t() if len(B.shape) == 2 else B + out = F.gemv_4bit(A, B, out, state=quant_state) if bias is not None: out += bias return out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2159c21e4..66f14edf7 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -508,7 +508,8 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + weight = self.weight.t() if len(self.weight.shape) == 2 else self.weight + out = bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state) out = out.to(inp_dtype)