Skip to content

Commit

Permalink
Sync states for npu fsdp (#2113)
Browse files Browse the repository at this point in the history
  • Loading branch information
jq460494839 committed Nov 8, 2023
1 parent 5e0eb0d commit 217e1a2
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_xpu_available
from .imports import is_cuda_available, is_npu_available, is_xpu_available
from .versions import compare_versions


Expand Down Expand Up @@ -932,7 +932,16 @@ def __post_init__(self):
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
device = torch.cuda.current_device() if not is_xpu_available() else torch.xpu.current_device()
if is_npu_available():
device = torch.npu.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
device = torch.xpu.current_device()
else:
raise RuntimeError(
"There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'."
)
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)

@staticmethod
Expand Down

0 comments on commit 217e1a2

Please sign in to comment.