Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix qlora deploy #1224

Merged
merged 11 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ The complete list of supported models and datasets can be found at [Supported Mo
|------------------------------------------------|------------------------------------------------------------------------|--------------------|----------------------------------------|------------------------------------------- |
| Qwen<br>Qwen1.5<br>Qwen2 | [Tongyi Qwen 1.0 and 1.5 series models](https://github.com/QwenLM) | Chinese<br>English | 0.5B-110B<br>including quantized versions | base model<br>chat model<br>MoE model<br>code model |
| ChatGLM2<br>ChatGLM3<br>Codegeex2<br>GLM4 | [Zhipu ChatGLM series models](https://github.com/THUDM) | Chinese<br>English | 6B-9B | base model<br>chat model<br>code model<br>long text model |
| Baichuan/Baichuan2 | [Baichuan 1 and Baichuan 2](https://github.com/baichuan-inc) | Chinese<br>English | 7B-13B<br>including quantized versions | base model<br>chat model |
| Baichuan<br>Baichuan2 | [Baichuan 1 and Baichuan 2](https://github.com/baichuan-inc) | Chinese<br>English | 7B-13B<br>including quantized versions | base model<br>chat model |
| Yuan2 | [Langchao Yuan series models](https://github.com/IEIT-Yuan) | Chinese<br>English | 2B-102B | instruct model |
| XVerse | [XVerse series models](https://github.com/xverse-ai) | Chinese<br>English | 7B-65B | base model<br>chat model<br>long text model<br>MoE model |
| LLaMA2 | [LLaMA2 series models](https://github.com/facebookresearch/llama) | English | 7B-70B<br>including quantized versions | base model<br>chat model |
Expand Down
23 changes: 12 additions & 11 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
kwargs[key] = new_value

generation_config = VllmGenerationConfig(**kwargs)
if generation_config.use_beam_search is True and request.stream is True:
if generation_config.use_beam_search and request.stream:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)
tokenizer = template.tokenizer
Expand Down Expand Up @@ -391,16 +391,17 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq

created_time = int(time.time())
adapter_kwargs = {}
if request.model != _args.model_type:
adapter_names = None
for lora_req in _args.lora_request_list:
if lora_req.lora_name == request.model:
adapter_names = request.model
break
assert adapter_names is not None
adapter_kwargs['adapter_names'] = [adapter_names]
elif isinstance(model, PeftModel):
adapter_kwargs['adapter_names'] = ['-']
if _args.lora_request_list is not None:
if request.model != _args.model_type:
adapter_names = None
for lora_req in _args.lora_request_list:
if lora_req.lora_name == request.model:
adapter_names = request.model
break
assert adapter_names is not None
adapter_kwargs['adapter_names'] = [adapter_names]
elif isinstance(model, PeftModel):
adapter_kwargs['adapter_names'] = ['-'] # use base model

async def _generate_full():
generation_info = {}
Expand Down
10 changes: 9 additions & 1 deletion swift/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,15 @@ def llm_export(args: ExportArguments) -> None:
logger.info('Saving quantized weights...')
model_cache_dir = model.model_dir
save_checkpoint(
None, template.tokenizer, model_cache_dir, args.ckpt_dir, args.quant_output_dir, dtype=args.dtype)
None,
template.tokenizer,
model_cache_dir,
args.ckpt_dir,
args.quant_output_dir,
sft_args_kwargs={
'dtype': args.dtype,
'quant_method': args.quant_method
})
logger.info(f'Successfully quantized the model and saved in {args.quant_output_dir}.')
args.ckpt_dir = args.quant_output_dir

Expand Down
27 changes: 18 additions & 9 deletions swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from swift.utils import (append_to_jsonl, get_logger, get_main, get_model_info, read_multi_line, seed_everything,
show_layers)
from .utils import (DeployArguments, InferArguments, Template, get_additional_saved_files, get_dataset,
get_model_tokenizer, get_template, inference, inference_stream, is_adapter, sample_dataset,
set_generation_config)
get_model_tokenizer, get_template, inference, inference_stream, is_adapter, is_quant_model,
sample_dataset, set_generation_config)

logger = get_logger()

Expand All @@ -29,6 +29,7 @@ def save_checkpoint(model: Optional[PreTrainedModel],
target_dir: str,
*,
save_safetensors: bool = True,
sft_args_kwargs: Dict[str, Any],
**kwargs) -> None:
if model is not None:
model.save_pretrained(target_dir, safe_serialization=save_safetensors)
Expand Down Expand Up @@ -75,9 +76,10 @@ def save_checkpoint(model: Optional[PreTrainedModel],
with open(old_sft_args_path, 'r', encoding='utf-8') as f:
res = json.load(f)
res['sft_type'] = 'full'
dtype = kwargs.get('dtype')
if dtype is not None:
res['dtype'] = dtype
for k in ['dtype', 'quant_method']:
v = sft_args_kwargs.get(k)
if v is not None:
res[k] = v
with open(new_sft_args_path, 'w', encoding='utf-8') as f:
json.dump(res, f, ensure_ascii=False, indent=2)

Expand All @@ -89,8 +91,8 @@ def merge_lora(args: InferArguments,
logger.info(f'replace_if_exists: {replace_if_exists}')
assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.'
assert args.sft_type in ('lora', 'adalora', 'longlora'), 'Only supports lora series models'
for s in ['int4', 'int8', 'awq']:
assert s not in args.model_type, f'{s} model is not supported'
assert not is_quant_model(
args.model_type), f'{args.model_type} is a quantized model and does not support merge-lora.'
if args.quantization_bit != 0:
logger.warning('It is not recommended to merge quantized models, '
'as this can result in performance degradation')
Expand All @@ -117,7 +119,7 @@ def merge_lora(args: InferArguments,
args.ckpt_dir,
merged_lora_path,
save_safetensors=args.save_safetensors,
dtype=args.dtype)
sft_args_kwargs={'dtype': args.dtype})
logger.info(f'Successfully merged LoRA and saved in {merged_lora_path}.')
logger.info("Setting args.sft_type: 'full'")
logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
Expand Down Expand Up @@ -180,6 +182,7 @@ def prepare_model_template(args: InferArguments,
model_kwargs,
model_id_or_path=model_id_or_path,
revision=args.model_revision,
quant_method=args.quant_method,
**kwargs)
if verbose:
logger.info(f'model_config: {model.config}')
Expand Down Expand Up @@ -207,7 +210,13 @@ def prepare_model_template(args: InferArguments,
f'args.max_model_len: {args.max_model_len}, model.max_model_len: {model.max_model_len}')
# Preparing LoRA
if is_adapter(args.sft_type) and args.ckpt_dir is not None:
if is_quant_model(args.model_type, model):
# gptq awq does not support lora switching
args.lora_request_list = None
logger.warning('The current model does not support LoRA switching. '
f'Setting args.lora_request_list: {args.lora_request_list}')
if isinstance(args, DeployArguments) and args.lora_request_list is not None:
logger.info(f'args.lora_request_list: {args.lora_request_list}')
for lora_request in args.lora_request_list:
model = Swift.from_pretrained(
model, lora_request.lora_local_path, lora_request.lora_name, inference_mode=True)
Expand Down Expand Up @@ -499,7 +508,7 @@ def llm_infer(args: InferArguments) -> Dict[str, List[Dict[str, Any]]]:
kwargs['tools'] = tools
kwargs['truncation_strategy'] = args.truncation_strategy
if args.infer_backend == 'vllm':
assert args.stream is True
assert args.stream
if args.verbose:
print(f"[QUERY]{data['query']}\n[RESPONSE]", end='')
gen = inference_stream_vllm(llm_engine, template, [kwargs], lora_request=lora_request)
Expand Down
8 changes: 2 additions & 6 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,6 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
kwargs['use_flash_attn'] = args.use_flash_attn
if args.local_repo_path:
kwargs['local_repo_path'] = args.local_repo_path
if args.quant_method == 'awq':
kwargs['is_awq'] = True
elif args.quant_method == 'aqlm':
kwargs['is_aqlm'] = True
elif args.quant_method == 'gptq':
kwargs['is_gptq'] = True

if args.rope_scaling:
kwargs['rope_scaling'] = args.rope_scaling
Expand All @@ -111,6 +105,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
model_kwargs,
model_id_or_path=args.model_id_or_path,
revision=args.model_revision,
quant_method=args.quant_method,
is_training=True,
**kwargs)
logger.info(f'model_config: {model.config}')
Expand Down Expand Up @@ -155,6 +150,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
model_kwargs,
model_id_or_path=args.ref_model_id_or_path,
revision=args.model_revision,
quant_method=args.quant_method,
**kwargs)
else:
ref_model = None
Expand Down
13 changes: 6 additions & 7 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,24 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
kwargs['use_flash_attn'] = args.use_flash_attn
if args.local_repo_path:
kwargs['local_repo_path'] = args.local_repo_path
if args.quant_method == 'awq':
kwargs['is_awq'] = True
elif args.quant_method == 'aqlm':
kwargs['is_aqlm'] = True
elif args.quant_method == 'gptq':
kwargs['is_gptq'] = True

if args.rope_scaling:
kwargs['rope_scaling'] = args.rope_scaling
kwargs['max_length'] = args.max_length

model, tokenizer = get_model_tokenizer(
args.model_type,
args.torch_dtype,
model_kwargs,
model_id_or_path=args.model_id_or_path,
revision=args.model_revision,
quant_method=args.quant_method,
is_training=True,
**kwargs)
for k in ['gptq', 'awq', 'aqlm']:
if getattr(model, f'is_{k}', None):
args.quant_method = k
logger.info(f'Setting args.quant_method: {args.quant_method}')
break
logger.info(f'model_config: {model.config}')
generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
get_template, register_template)
from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, find_all_linears, find_embedding,
find_ln, get_max_model_len, get_time_info, history_to_messages, inference, inference_stream,
is_vllm_available, limit_history_length, messages_join_observation, messages_to_history,
print_example, safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset,
to_device)
is_quant_model, is_vllm_available, limit_history_length, messages_join_observation,
messages_to_history, print_example, safe_tokenizer_decode, set_generation_config,
sort_by_max_length, stat_dataset, to_device)

try:
if is_vllm_available():
Expand Down
27 changes: 14 additions & 13 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .model import (MODEL_MAPPING, dtype_mapping, get_additional_saved_files, get_default_lora_target_modules,
get_default_template_type)
from .template import TEMPLATE_MAPPING
from .utils import is_vllm_available
from .utils import is_quant_model, is_vllm_available

logger = get_logger()

Expand Down Expand Up @@ -675,15 +675,15 @@ def load_from_checkpoint(self) -> None:
with open(sft_args_path, 'r', encoding='utf-8') as f:
sft_args = json.load(f)
imported_keys = [
'model_type', 'model_revision', 'quantization_bit', 'dtype', 'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type',
'bnb_4bit_use_double_quant', 'model_id_or_path'
'model_type', 'model_revision', 'quant_method', 'quantization_bit', 'dtype', 'bnb_4bit_comp_dtype',
'bnb_4bit_quant_type', 'bnb_4bit_use_double_quant', 'model_id_or_path'
]

for key in imported_keys:
value = getattr(self, key)
if key in {'dtype', 'bnb_4bit_comp_dtype'} and value != 'AUTO':
continue
if key in {'model_type', 'model_revision', 'model_id_or_path'} and value is not None:
if key in {'model_type', 'model_revision', 'model_id_or_path', 'quant_method'} and value is not None:
continue
setattr(self, key, sft_args.get(key))

Expand Down Expand Up @@ -820,8 +820,9 @@ def __post_init__(self) -> None:
'lora does not support `freeze_parameters`, please set `--sft_type full`')
assert len(self.additional_trainable_parameters) == 0, (
'lora does not support `additional_trainable_parameters`, please set `--sft_type full`')
if 'int4' in self.model_type or 'int8' in self.model_type or 'awq' in self.model_type:
assert self.quantization_bit == 0, 'int4, int8 or awq models do not need to be quantized again.'
if is_quant_model(self.model_type):
assert self.quantization_bit == 0, (
f'{self.model_type} is already a quantized model and does not need to be quantized again.')
if self.learning_rate is None:
self.learning_rate = 1e-4
if self.save_only_model is None:
Expand Down Expand Up @@ -1026,7 +1027,7 @@ def _init_training_args(self) -> None:
self.training_args = training_args

def _handle_pai_compat(self) -> None:
assert is_pai_training_job() is True
assert is_pai_training_job()
logger.info('Handle pai compat...')
pai_tensorboard_dir = get_pai_tensorboard_dir()
if self.logging_dir is None and pai_tensorboard_dir is not None:
Expand Down Expand Up @@ -1075,7 +1076,8 @@ class InferArguments(ArgumentsBase):
model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
model_author: List[str] = field(
default_factory=lambda: [None, None], metadata={'help': "e.g. ['魔搭', 'ModelScope']"})
quant_method: Literal['bnb', 'hqq', 'eetq'] = None
# 'awq', 'gptq', 'aqlm' are used for inference on pre-quantized models.
quant_method: Literal['bnb', 'hqq', 'eetq', 'awq', 'gptq', 'aqlm'] = None
quantization_bit: Literal[0, 1, 2, 3, 4, 8] = 0 # hqq: 1,2,3,4,8. bnb: 4,8
hqq_axis: Literal[0, 1] = 0
hqq_dynamic_config_path: Optional[str] = None
Expand Down Expand Up @@ -1211,14 +1213,13 @@ def handle_infer_backend(self):
if not support_vllm:
logger.warning(f'vllm not support `{self.model_type}`')
if self.sft_type == 'lora' and not self.vllm_enable_lora:
assert self.merge_lora is True, ('To use VLLM, you need to provide the complete weight parameters. '
'Please set `--merge_lora true`.')
assert self.merge_lora, ('To use VLLM, you need to provide the complete weight parameters. '
'Please set `--merge_lora true`.')
if (self.infer_backend == 'vllm' and self.vllm_enable_lora
or self.infer_backend == 'pt' and isinstance(self, DeployArguments) and self.sft_type == 'lora'):
assert self.ckpt_dir is not None
self.lora_modules.append(f'default-lora={self.ckpt_dir}')
self.lora_request_list = _parse_lora_modules(self.lora_modules, self.infer_backend == 'vllm')
logger.info(f'args.lora_request_list: {self.lora_request_list}')

template_info = TEMPLATE_MAPPING[self.template_type]
if self.num_beams != 1:
Expand All @@ -1236,7 +1237,7 @@ def load_from_ckpt_dir(self) -> None:
with open(sft_args_path, 'r', encoding='utf-8') as f:
sft_args = json.load(f)
imported_keys = [
'model_type', 'model_revision', 'sft_type', 'template_type', 'system', 'quantization_bit',
'model_type', 'model_revision', 'sft_type', 'template_type', 'system', 'quant_method', 'quantization_bit',
'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type', 'bnb_4bit_use_double_quant', 'rope_scaling'
]
if self.load_dataset_config:
Expand All @@ -1248,7 +1249,7 @@ def load_from_ckpt_dir(self) -> None:
value = getattr(self, key)
if key in {'dataset', 'val_dataset'} and len(value) > 0:
continue
if key in {'dataset_test_ratio', 'system'} and value is not None:
if key in {'dataset_test_ratio', 'system', 'quant_method'} and value is not None:
continue
setattr(self, key, sft_args.get(key))

Expand Down
Loading
Loading