Skip to content

Commit

Permalink
[LLM] Support QLoRA on CPU device (#442)
Browse files Browse the repository at this point in the history
* added qlora support on cpu device.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>
  • Loading branch information
XinyuYe-Intel committed Nov 21, 2023
1 parent d2bd4d9 commit adb109b
Show file tree
Hide file tree
Showing 13 changed files with 501 additions and 200 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/script/formatScan/nlp_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2408,3 +2408,5 @@ Sudhanshu
Tripathi
akarX
dalvishruti
NormalFloat
backpropagates
63 changes: 62 additions & 1 deletion docs/qloracpu.md
Original file line number Diff line number Diff line change
@@ -1 +1,62 @@
TBD
# QLoRA on CPU

1. [Introduction](#introduction)
2. [Examples](#examples)

2.1. [Python API](#python-api)

2.2. [Neural Chat Example](#neural-chat-example)

## Introduction
[QLoRA](https://arxiv.org/abs/2305.14314) is an efficient finetuning approach that reduces memory usage of Large Language Models (LLMs) finetuning, it backpropagates gradients through a frozen, quantized LLMs into Low Rank Adapters~(LoRA). Currently it only supports finetuning on CUDA devices, we have developed necessary API to support QLoRA on CPU device, where 4-bit NormalFloat (NF4), Float4 (FP4), INT4 and INT8 are supported data type for LLMs quantization.

## Examples

### Python API

```python
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf',
torch_dtype=torch.bfloat16,
load_in_4bit=True,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=True
)
model.gradient_checkpointing_enable()
peft_config = LoraConfig(
r=8,
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, peft_config)
```

### Neural Chat Example

To use QLoRA on Neural Chat with CPU device, just add `--qlora` argument to the normal [Neural Chat Fine-tuning Example](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/neural_chat/examples/finetuning/instruction), for example, as below.

```bash
python finetune_clm.py \
--model_name_or_path "meta-llama/Llama-2-7b" \
--bf16 True \
--dataset_name /path/to/alpaca_data.json \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--do_train \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 100 \
--save_total_limit 2 \
--overwrite_output_dir \
--log_level info \
--save_strategy epoch \
--output_dir ./llama_peft_finetuned_model \
--peft lora \
--use_fast_tokenizer false \
--no_cuda
--qlora
--max_train_samples 500
```
73 changes: 54 additions & 19 deletions intel_extension_for_transformers/llm/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PeftConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_kbit_training
)
from peft.tuners.adaption_prompt import AdaptionPromptConfig
from transformers import (
Expand Down Expand Up @@ -245,32 +246,39 @@ def load_tokenizer(self, model_args):
def finetune(self):
model_args, data_args, training_args, finetune_args = \
self.model_args, self.data_args, self.training_args, self.finetune_args
if not (is_bitsandbytes_available() and torch.cuda.is_available() and training_args.device.type == "cuda"):
if training_args.device.type != "cpu" and \
not (is_bitsandbytes_available() and torch.cuda.is_available() and training_args.device.type == "cuda"):
finetune_args.qlora = False
self.device_map = None
self.bitsandbytes_quant_config = None
self.load_in_4bit = False
self.load_in_8bit = False
if finetune_args.qlora:
# finetune_args.lora_all_linear = True
object.__setattr__(training_args, "gradient_checkpointing", True)
object.__setattr__(training_args, "ddp_find_unused_parameters", False)
finetune_args.peft = "lora"
compute_dtype = (
torch.float16 if training_args.fp16 else
(torch.bfloat16 if training_args.bf16 else torch.float32)
)
self.device_map = "auto"
self.bitsandbytes_quant_config = BitsAndBytesConfig(
load_in_4bit=finetune_args.bits == 4,
load_in_8bit=finetune_args.bits == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=finetune_args.double_quant,
bnb_4bit_quant_type=finetune_args.quant_type,
)
self.load_in_4bit = finetune_args.bits == 4
self.load_in_8bit = finetune_args.bits == 8
if training_args.device.type == "cuda":
self.device_map = "auto"
self.bitsandbytes_quant_config = BitsAndBytesConfig(
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=finetune_args.double_quant,
bnb_4bit_quant_type=finetune_args.quant_type,
)
if finetune_args.bits not in [4, 8]:
raise NotImplementedError(
f"Unsupported bits {finetune_args.bits}, only support 4 and 8 now."
)
else:
self.device_map = None
self.bitsandbytes_quant_config = None

config = self.load_model_config(self.model_args)
if config.architectures[0].endswith("ForCausalLM"):
Expand All @@ -288,10 +296,14 @@ def finetune(self):
def find_all_linear_names(self, model):
cls = torch.nn.Linear
if self.finetune_args.qlora:
if self.finetune_args.bits == 8:
cls = bnb.nn.Linear8bitLt
elif self.finetune_args.bits == 4:
cls = bnb.nn.Linear4bit
if self.training_args.device.type == "cuda":
if self.finetune_args.bits == 8:
cls = bnb.nn.Linear8bitLt
elif self.finetune_args.bits == 4:
cls = bnb.nn.Linear4bit
elif self.training_args.device.type == "cpu":
from intel_extension_for_transformers.llm.quantization.nn.modules import QuantizedLinearQBits
cls = QuantizedLinearQBits

lora_module_names = set()
for name, module in model.named_modules():
Expand Down Expand Up @@ -330,6 +342,12 @@ def finetune_clm(self, model_args, data_args, training_args, finetune_args, conf
torch.float16 if training_args.fp16 else
(torch.bfloat16 if training_args.bf16 else torch.float32)
)
kwargs = {}
if finetune_args.qlora and training_args.device.type == "cpu":
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
kwargs['use_llm_runtime'] = False
else:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand All @@ -342,7 +360,16 @@ def finetune_clm(self, model_args, data_args, training_args, finetune_args, conf
trust_remote_code=True if model_args.trust_remote_code else None,
torch_dtype=model_dtype,
low_cpu_mem_usage=True,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
**kwargs
)
if finetune_args.qlora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
if not (re.search("mpt", model_args.model_name_or_path, re.IGNORECASE) or
re.search("neural-chat-7b-v1", model_args.model_name_or_path, re.IGNORECASE) or
re.search("starcoder", model_args.model_name_or_path, re.IGNORECASE)):
Expand All @@ -351,7 +378,6 @@ def finetune_clm(self, model_args, data_args, training_args, finetune_args, conf
raise ValueError(
"Must provide model_name_or_path to load a pretrained CausalLM model."
)

# add special tokens
if data_args.special_tokens:
additional_special_tokens = {
Expand Down Expand Up @@ -746,6 +772,12 @@ def preprocess_logits_for_metrics(logits, labels):
torch.float16 if training_args.fp16 else
(torch.bfloat16 if training_args.bf16 else torch.float32)
)
kwargs = {}
if finetune_args.qlora and training_args.device.type == "cpu":
from intel_extension_for_transformers.transformers.modeling import AutoModelForSeq2SeqLM
kwargs['use_llm_runtime'] = False
else:
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand All @@ -756,6 +788,9 @@ def preprocess_logits_for_metrics(logits, labels):
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=model_dtype,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
**kwargs
)
model.resize_token_embeddings(len(tokenizer))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# limitations under the License.


from .functions import matmul_4bit
from .functions import matmul_kbit
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,42 @@
def prod(iterable):
return reduce(operator.mul, iterable, 1)

class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

class MatMulKBit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
def forward(ctx, A, B, out=None, bias=None, compute_dtype=None, weight_dtype=None):
# # 1. Dequantize
# B_dequant = torch.zeros(out.shape[-1], A.shape[-1], dtype=torch.float)
# torch.ops.weight_only_jblasop.qbits_dequantize(
# B, B_dequant, True, compute_dtype, weight_dtype)
# B_dequant = B_dequant.to(dtype=A.dtype)

# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.B = B # B_dequant
ctx.bias = bias
B_shape = state[1]
B_shape = (out.shape[-1], A.shape[-1]) # B_dequant.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)


# 1. Dequantize
# 2. MatmulnN
# torch.ops.weight_only_jblasop.jblas_symqdq_weight(B, False, 4, 32) # TODO: replace with dequantize
output = torch.nn.functional.linear(A, B.to(A.dtype), bias)
# 2. Matmul
# output = torch.nn.functional.linear(A, B_dequant, bias)
torch.ops.weight_only_jblasop.qbits_linear(
A, B.data, bias, out, out.shape[-1], bias is not None, compute_dtype, weight_dtype
)
output = out

# 3. Save state
ctx.state = state
ctx.compute_dtype, ctx.weight_dtype = compute_dtype, weight_dtype
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
# B_dequant.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B)
ctx.tensors = (A, B) # B_dequant
else:
ctx.tensors = (None, None)

Expand All @@ -67,26 +72,30 @@ def backward(ctx, grad_output):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
req_gradA, _, _, req_gradBias, _, _ = ctx.needs_input_grad
A, B = ctx.tensors
state = ctx.state

grad_A, grad_B, grad_bias = None, None, None

B_dequant = torch.zeros(grad_output.shape[-1], A.shape[-1], dtype=torch.float)
torch.ops.weight_only_jblasop.qbits_dequantize(
B, B_dequant, True, ctx.compute_dtype, ctx.weight_dtype)
B = B_dequant

if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
# torch.ops.weight_only_jblasop.jblas_symqdq_weight(B, False, 4, 32) # TODO: replace with dequantize
if req_gradA: grad_A = torch.matmul(grad_output, B.to(grad_output.dtype))

return grad_A, grad_B, None, grad_bias, None
return grad_A, grad_B, None, grad_bias, None, None

def matmul_4bit(A: Tensor, B: Tensor, quant_state: List = None, out: Tensor = None, bias=None, do_dequant=True):
# assert quant_state is not None
def matmul_kbit(A: Tensor, B: Tensor, bias, out, compute_dtype, weight_dtype, do_dequant=False):
if do_dequant:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
return MatMulKBit.apply(A, B, out, bias, compute_dtype, weight_dtype)
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state) # TODO: replace with 4bit matmul
torch.ops.weight_only_jblasop.qbits_linear(
A, B.data, bias, out, out.shape[-1], bias is not None, compute_dtype, weight_dtype
)
return out

0 comments on commit adb109b

Please sign in to comment.