Skip to content

Commit

Permalink
fix fp16 lora bug (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 17, 2023
1 parent ead3522 commit 5db36d2
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
9 changes: 5 additions & 4 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
show_layers)
from .utils import (LazyLLMDataset, SftArguments, Template,
add_self_cognition_dataset, data_collate_fn, dataset_map,
find_all_linear_for_lora, get_additional_saved_files,
get_dataset, get_model_tokenizer, get_template,
print_example, set_generation_config, sort_by_max_length,
stat_dataset)
find_all_linear_for_lora, fix_fp16_trainable_bug,
get_additional_saved_files, get_dataset,
get_model_tokenizer, get_template, print_example,
set_generation_config, sort_by_max_length, stat_dataset)

logger = get_logger()

Expand Down Expand Up @@ -113,6 +113,7 @@ def llm_sft(args: SftArguments) -> str:
else:
model = Swift.from_pretrained(
model, args.resume_from_checkpoint, is_trainable=True)
fix_fp16_trainable_bug(model)
elif args.sft_type == 'full':
if args.freeze_parameters > 0:
freeze_model_parameters(model, args.freeze_parameters)
Expand Down
7 changes: 4 additions & 3 deletions swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Template, TemplateType, get_template, register_template)
from .utils import (LazyLLMDataset, LLMDataset, data_collate_fn, dataset_map,
download_dataset, find_all_linear_for_lora,
history_to_messages, inference, inference_stream,
limit_history_length, messages_to_history, print_example,
set_generation_config, sort_by_max_length, stat_dataset)
fix_fp16_trainable_bug, history_to_messages, inference,
inference_stream, limit_history_length,
messages_to_history, print_example, set_generation_config,
sort_by_max_length, stat_dataset)
11 changes: 11 additions & 0 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,17 @@ def set_generation_config(model: Module,
model.generation_config = generation_config


def fix_fp16_trainable_bug(model: Module) -> None:
# fix peft==0.7 bug
is_logging = False
for p in model.parameters():
if p.requires_grad and p.dtype == torch.float16:
if not is_logging:
logger.info('Convert trainable parameters from fp16 to fp32.')
is_logging = True
p.data = p.data.to(dtype=torch.float32)


# monkey patching
MsDataset.load = _msdataset_ddp_load
if is_ddp_plus_mp():
Expand Down
1 change: 1 addition & 0 deletions tests/llm/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def test_self_cognition(self):
model_type=ModelType.qwen_7b_chat,
dataset=dataset, # no dataset
train_dataset_sample=100,
dtype='fp16',
eval_steps=5,
output_dir='output',
lora_target_modules='ALL',
Expand Down

0 comments on commit 5db36d2

Please sign in to comment.