-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Describe the bug
[train_prompts.py](https://github.com/hpcaitech/ColossalAI/blob/main/applications/Chat/examples/train_prompts.py
if rm_model_name == 'gpt2':
reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom':
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt':
reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama':
reward_model = LlamaRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'roberta':
reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
reward_model.load_state_dict(state_dict) if rm_model_name == 'gpt2':
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom':
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'opt':
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama':
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'roberta':
critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
critic.load_state_dict(state_dict)
del state_dictcritic and reward_model have the same function to init, but have different parameter of lora_rank. when I run the train_prompts.sh with --lora_rank and --rm_path, I meet errors about loading state_dict: Missing keys in state_dict : "model.h.0.self_attention.query_key_value.lora_A", "model.h.0.self_attention.query_key_value.lora_B", .....
Environment
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working