Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Make autocast compatible with mps #587

Merged
merged 5 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 11 additions & 2 deletions mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,18 @@ def autocast(device_type: Optional[str] = None,
assert dtype == torch.bfloat16, (
'In CPU autocast, only support `torch.bfloat16` dtype')

elif device_type == 'mlu':
pass
else:
raise ValueError('User specified autocast device_type must be '
F'cuda or cpu, but got {device_type}')
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error
# will be thrown.
if enabled is False:
yield
return
else:
raise ValueError('User specified autocast device_type must be '
f'cuda or cpu, but got {device_type}')

with torch.autocast(
device_type=device_type,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_runner/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import torch.nn as nn

import mmengine
from mmengine.device import get_device
from mmengine.runner import autocast
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -56,3 +58,24 @@ def test_autocast(self):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)

# Test mps
if digit_version(TORCH_VERSION) >= digit_version('1.12.0'):
mmengine.runner.amp.get_device = lambda: 'mps'
with autocast(enabled=False):
layer = nn.Conv2d(1, 1, 1)
res = layer(torch.randn(1, 1, 1, 1))
self.assertEqual(res.dtype, torch.float32)

with self.assertRaisesRegex(ValueError,
'User specified autocast device_type'):
with autocast(enabled=True):
pass
# Native pytorch does not support mlu, here we simply test autocast
# will call `torch.autocast`, which will be overridden by mlu version
# pytorch
mmengine.runner.amp.get_device = lambda: 'mlu'
with self.assertRaises(RuntimeError):
with autocast(enabled=False):
pass
mmengine.runner.amp.get_device = get_device