Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,29 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
if is_generation:
logger.warning(f"Please check if args.template_type: '{args.template_type}' is correct.")

msg = {}
model, ref_model, template, callbacks = prepare_model_template_train(args)
kwargs = {}
if args.rlhf_type == 'ppo':
from copy import deepcopy
reward_model_args = deepcopy(args)
reward_model_args, value_model_args = deepcopy(args), deepcopy(args)
args_to_modified = ['model_id_or_path', 'model_type', 'model_revision']
for arg in args_to_modified:
setattr(reward_model_args, arg, getattr(args, f'reward_{arg}'))
for model_args in [reward_model_args, value_model_args]:
for arg in args_to_modified:
setattr(model_args, arg, getattr(args, f'reward_{arg}'))
reward_model_args.ref_model_free = True # avoid to create ref model
value_model_args.ref_model_free = True
reward_model, _, _, _ = prepare_model_template_train(reward_model_args)
reward_model.requires_grad_(False).eval()

reward_model = get_model_with_value_head(reward_model) # add and load value head
# hack here to customize the value model
value_model, _, _, _ = prepare_model_template_train(reward_model_args)
value_model, _, _, _ = prepare_model_template_train(value_model_args)
value_model = get_model_with_value_head(value_model)
kwargs['reward_model'] = reward_model
kwargs['value_model'] = value_model

msg = {}
model, ref_model, template, callbacks = prepare_model_template_train(args)

with TrainerFactory.patch_template(args, template):
train_dataset, val_dataset = prepare_dataset(args, template, msg)

Expand Down
Loading