Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] allow Accelerator to prepare models in eval mode for XPU&CPU #2426

Merged
merged 1 commit into from
Feb 9, 2024

Conversation

faaany
Copy link
Contributor

@faaany faaany commented Feb 8, 2024

Problem

When trying to run the nlp_example.py on Intel GPUs and CPUs, the prepare function in the following code will complain with the following:

Traceback (most recent call last):
  File "/soft/fanli/accelerate/examples/nlp_example.py", line 209, in <module>
    main()
  File "/soft/fanli/accelerate/examples/nlp_example.py", line 205, in main
    training_function(config, args)
  File "/soft/fanli/accelerate/examples/nlp_example.py", line 154, in training_function
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
  File "/soft/fanli/accelerate/src/accelerate/accelerator.py", line 1217, in prepare
    args = self._prepare_ipex(*args)
  File "/soft/fanli/accelerate/src/accelerate/accelerator.py", line 1762, in _prepare_ipex
    model, optimizer = torch.xpu.optimize(
  File "/home/fan/anaconda3/envs/study2/lib/python3.10/site-packages/intel_extension_for_pytorch/xpu/utils.py", line 237, in optimize
    return frontend.optimize(
  File "/home/fan/anaconda3/envs/study2/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py", line 339, in optimize
    assert optimizer is None, "The optimizer should not be given for inference mode"
AssertionError: The optimizer should not be given for inference mode

This is a bug because the ipex.optimize function expects the model to be in training mode, otherwise, it will assume that the user is doing inference as shown in this line.

Another thing I noticed is that the dtype passed to ipex.optimize() is fp32, but ipex.optimize() expects the dtype to be either bf16 or fp16 and the default value is None as stated here. So if no mixed_precision is used for training (currently only bf16 is supported), the dtype should keep the same with the default None value.

What does this PR do?

Fix the bug in _prepare_ipex and improve the dtype passed to ipex.optimize() or torch.xpu.optimize(). With this fix, the example code can now run on CPU and XPU both in single and distributed modes.

Who can review?

@muellerzr or @sywangyi

@faaany
Copy link
Contributor Author

faaany commented Feb 8, 2024

@yao-matrix

@faaany faaany changed the title [FIX] enable nlp_example.py to run on XPU for both single and distributed modes [FIX] allow accelerator.prepare to work for models in eval mode on XPU (single&distributed) Feb 8, 2024
@faaany faaany changed the title [FIX] allow accelerator.prepare to work for models in eval mode on XPU (single&distributed) [FIX] allow Accelerator to prepare models in eval mode for XPU&CPU Feb 8, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@muellerzr muellerzr merged commit 9c1d5ba into huggingface:main Feb 9, 2024
21 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants