diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 80727de16c7d..adf0dd544058 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -54,6 +54,7 @@ accelerate_disk_offload, accelerate_dispatch, check_and_set_device_map, + expand_device_map, find_tied_parameters, init_empty_weights, ) @@ -5298,18 +5299,6 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module: return model -def expand_device_map(device_map, param_names): - """ - Expand a device map to return the correspondence parameter name to device. - """ - new_device_map = {} - for module, device in device_map.items(): - new_device_map.update( - {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} - ) - return new_device_map - - def is_accelerator_device(device: Union[str, int, torch.device]) -> bool: """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not a proper `torch.device`.