Skip to content

Commit

Permalink
[Accelerator] We should not call to on modules that wraps `accele…
Browse files Browse the repository at this point in the history
…rate` loaded models (#1172)

* add v1

* fix docstring
  • Loading branch information
younesbelkada committed Mar 15, 2023
1 parent 41479fe commit d1aa558
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None):
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
self._models.append(model)
# We check only for models loaded with `accelerate`

# Checks if any of the child module has the attribute `hf_device_map`.
has_hf_device_map = False
for m in model.modules():
if hasattr(m, "hf_device_map"):
has_hf_device_map = True
break

if getattr(model, "is_loaded_in_8bit", False) and getattr(model, "hf_device_map", False):
model_devices = set(model.hf_device_map.values())
if len(model_devices) > 1:
Expand All @@ -1158,7 +1166,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None):
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision with CPU or disk offload."
)
elif device_placement:
elif device_placement and not has_hf_device_map:
model = model.to(self.device)

if self.distributed_type == DistributedType.MULTI_GPU:
Expand Down

0 comments on commit d1aa558

Please sign in to comment.