Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix AMP in Ascend and support using NPUJITCompile environment #994

Merged
merged 2 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmengine/device/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional

import torch
Expand Down Expand Up @@ -39,7 +40,8 @@ def is_npu_available() -> bool:

# Enable operator support for dynamic shape and
# binary operator support on the NPU.
torch.npu.set_compile_mode(jit_compile=False)
npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
except Exception:
return False
return hasattr(torch, 'npu') and torch.npu.is_available()
Expand Down
4 changes: 4 additions & 0 deletions mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def autocast(device_type: Optional[str] = None,

elif device_type == 'mlu':
pass

elif device_type == 'npu':
pass

else:
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error
Expand Down