Skip to content
Merged

NF4 #1432

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def matmul_4bit(
assert quant_state is not None
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
B = B.t() if len(B.shape) == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ def forward(self, x: torch.Tensor):
x = x.to(self.compute_dtype)

bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
weight = self.weight.t() if len(self.weight.shape) == 2 else self.weight
out = bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state)

out = out.to(inp_dtype)

Expand Down