Skip to content

[BUG]: Chat: about load_state_dict of critic and reward_model #4031

@HaomiaoPan

Description

@HaomiaoPan

🐛 Describe the bug

[train_prompts.py](https://github.com/hpcaitech/ColossalAI/blob/main/applications/Chat/examples/train_prompts.py

        if rm_model_name == 'gpt2':
            reward_model = GPTRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'bloom':
            reward_model = BLOOMRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'opt':
            reward_model = OPTRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'llama':
            reward_model = LlamaRM(pretrained=args.rm_pretrain)
        elif rm_model_name == 'roberta':
            reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')

        if args.rm_path is not None:
            reward_model.load_state_dict(state_dict)
        if rm_model_name == 'gpt2':
            critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'bloom':
            critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'opt':
            critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'llama':
            critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        elif rm_model_name == 'roberta':
            critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
        else:
            raise ValueError(f'Unsupported reward model "{rm_model_name}"')

        if args.rm_path is not None:
            critic.load_state_dict(state_dict)
            del state_dict

critic and reward_model have the same function to init, but have different parameter of lora_rank. when I run the train_prompts.sh with --lora_rank and --rm_path, I meet errors about loading state_dict: Missing keys in state_dict : "model.h.0.self_attention.query_key_value.lora_A", "model.h.0.self_attention.query_key_value.lora_B", .....

Environment

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions