From b02b757c91d2e834381cce7aee2edad4ddb4a4c8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 23 Jan 2025 15:08:21 +0000 Subject: [PATCH 1/2] new matmul8bit Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6440ab1b5..3a6a79fbd 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -563,6 +563,28 @@ def backward(ctx, grad_output): return grad_A, grad_B, None, grad_bias, None +class MatMul8bitFp(torch.autograd.Function): + # For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune. + # We'd like to use dequant + matmul to run finetune currently. + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + CB = B.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)).t() + output = torch.matmul(A, CB).to(A.dtype) + ctx.state = state + ctx.dtype_A = A.dtype + ctx.grad_shape = A.shape + return output + + @staticmethod + def backward(ctx, grad_output): + state = ctx.state + CB = state.CB.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + + return grad_A, None, None, None, None + + def matmul( A: torch.Tensor, B: torch.Tensor, @@ -574,6 +596,8 @@ def matmul( state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold + if A.device.type in ("cpu", "xpu") and state.is_training: + return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state) From f072403a18d2714444dde05c502492f2ca0b15c3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 23 Jan 2025 15:29:48 +0000 Subject: [PATCH 2/2] fix cxb Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 3a6a79fbd..9de5a8924 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -579,7 +579,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): @staticmethod def backward(ctx, grad_output): state = ctx.state - CB = state.CB.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + B = state.CxB if state.CxB is not None else state.CB + CB = B.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) return grad_A, None, None, None, None