Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run multi_adapter_rl_v2.py with multiple gpus error #820

Closed
ASY246 opened this issue Sep 26, 2023 · 3 comments
Closed

Run multi_adapter_rl_v2.py with multiple gpus error #820

ASY246 opened this issue Sep 26, 2023 · 3 comments

Comments

@ASY246
Copy link

ASY246 commented Sep 26, 2023

In the examples folder, I run the script multi_adapter_rl_v2.py successfully with single GPU by python3 multi_adapter_rl_v2.py.

But when I run with multiple GPUs by
accelerate launch --config_file=../accelerate_configs/multi_gpu.yaml multi_adapter_rl_v2.py.
I meet the following errors:

[2023-09-26 21:14:08] Traceback (most recent call last):
[2023-09-26 21:14:08]   File "/opt/ml/code/examples/scripts/multi_adapter_rl_v2.py", line 167, in <module>
[2023-09-26 21:14:08]     raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
[2023-09-26 21:14:08]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1630, in __getattr__
[2023-09-26 21:14:08]     raise AttributeError("'{}' object has no attribute '{}'".format(
[2023-09-26 21:14:08] AttributeError: 'DistributedDataParallel' object has no attribute 'compute_reward_score'

So maybe the model warp is changed with multiple GPUs mode?

Here is the python script I used

lora_config = LoraConfig(r=16, lora_alpha=32, bias="none", task_type="CAUSAL_LM", lora_dropout=0.1, modules_to_save=["scores"])
nf4_config = None

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    script_args.model_name,
    device_map={"": current_device},
    peft_config=lora_config,
    quantization_config=nf4_config,
    reward_adapter=script_args.rm_adapter,
    use_safetensors=script_args.use_safetensors,
)

tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name, input_max_text_length=1024)
tokenizer.pad_token = tokenizer.eos_token
dataset = create_and_prepare_dataset(tokenizer)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

config = PPOConfig(
    model_name=script_args.model_name,
    log_with=script_args.log_with,
    learning_rate=1e-5,
    batch_size=2,
    mini_batch_size=1,
    gradient_accumulation_steps=2,
    optimize_cuda_cache=True,
    seed=script_args.seed,
    use_score_scaling=script_args.use_score_scaling,
    use_score_norm=script_args.use_score_norm,
    score_clip=script_args.score_clip
)
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=dataset["train"],
    data_collator=collator
)

generation_kwargs = {
    "top_k": 0.0,
    "top_p": 0.9,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    question_tensors = batch["input_ids"]
    response_tensors = ppo_trainer.generate(
        question_tensors,
        length_sampler=LengthSampler(input_min_text_length, input_max_text_length),
        return_prompt=False,
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
    
    #### Compute reward score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt").to(ppo_trainer.accelerator.device)
    raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
    rewards = [raw_rewards[i, -1, 0] for i in range(len(raw_rewards))]  # take last token, 这个和reward训练阶段的做法需要对应上

    # # Run PPO step
    stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
@younesbelkada
Copy link
Collaborator

Hi @ASY246
I think for multi-gpu we need a small check to make sure the model gets properly un-converted from DDP.
Can you replace the line:

 raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)

with:

 raw_rewards = ppo_trainer.accelerator.unwrap_model(model).compute_reward_score(**inputs)

And let me know if this works?

@ASY246
Copy link
Author

ASY246 commented Sep 27, 2023

@younesbelkada Thanks for your response, I replace the line with
raw_rewards = ppo_trainer.accelerator.unwrap_model(model).compute_reward_score(**inputs)
and I meet another error as follows:

[2023-09-27 20:33:48] Traceback (most recent call last):
[2023-09-27 20:33:48]   File "/opt/ml/code/examples/scripts/multi_adapter_rl_v2.py", line 186, in <module>
[2023-09-27 20:33:48]     stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
[2023-09-27 20:33:48]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
[2023-09-27 20:33:48]     return func(*args, **kwds)
[2023-09-27 20:33:48]   File "/opt/ml/code/trl/trainer/ppo_trainer.py", line 754, in step
[2023-09-27 20:33:48]     logprobs, logits, vpreds, _ = self.batched_forward_pass(
[2023-09-27 20:33:48]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
[2023-09-27 20:33:48]     return func(*args, **kwds)
[2023-09-27 20:33:48]   File "/opt/ml/code/trl/trainer/ppo_trainer.py", line 949, in batched_forward_pass
[2023-09-27 20:33:48]     logits, _, values = model(**input_kwargs)
[2023-09-27 20:33:48]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
[2023-09-27 20:33:48]     return self._call_impl(*args, **kwargs)
[2023-09-27 20:33:48]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
[2023-09-27 20:33:48]     return forward_call(*args, **kwargs)
[2023-09-27 20:33:48]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1513, in forward
[2023-09-27 20:33:48]     inputs, kwargs = self._pre_forward(*inputs, **kwargs)
[2023-09-27 20:33:48]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1401, in _pre_forward
[2023-09-27 20:33:48]     if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
[2023-09-27 20:33:48] RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
[2023-09-27 20:33:48] making sure all `forward` function outputs participate in calculating loss. 
[2023-09-27 20:33:48] If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
[2023-09-27 20:33:48] Parameter indices which did not receive grad for rank 3: 162
[2023-09-27 20:33:48]  In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

It seems like some modules are not included in the loss calculation, but why it works when I run the script by single GPU and only appears in the multiple GPU mode?

Copy link

github-actions bot commented Nov 1, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this as completed Nov 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants