diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c859bcc8cb2..e5de09794ed 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1354,7 +1354,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif device_placement and not has_hf_device_map: model = model.to(self.device) - if self.native_amp: + if self.native_amp and self.distributed_type != DistributedType.FSDP: model._original_forward = model.forward model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 3634a7ed320..ba2f0535dfc 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -386,9 +386,9 @@ def get_cluster_input(): error_message="Please enter yes or no.", ) fsdp_config["fsdp_sync_module_states"] = _ask_field( - "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [yes/NO]: ", + "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ", _convert_yes_no_to_bool, - default=False, + default=True, error_message="Please enter yes or no.", ) diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 0e51cdfa9e4..fcabd13a78b 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -526,7 +526,7 @@ def launch_command_parser(subparsers=None): ) fsdp_args.add_argument( "--fsdp_sync_module_states", - default="false", + default="true", type=str, help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0." " (useful only when `use_fsdp` flag is passed).", diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index e51f069f309..57294885546 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -859,7 +859,7 @@ class FullyShardedDataParallelPlugin: }, ) sync_module_states: bool = field( - default=False, + default=True, metadata={ "help": "If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0 " "to ensure they are the same across all ranks after initialization" @@ -874,14 +874,7 @@ class FullyShardedDataParallelPlugin: ) def __post_init__(self): - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, - CPUOffload, - FullOptimStateDictConfig, - FullStateDictConfig, - ShardingStrategy, - StateDictType, - ) + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy prefix = "FSDP_" if self.sharding_strategy is None: @@ -900,18 +893,14 @@ def __post_init__(self): if self.state_dict_type is None: state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT") - self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1) - - if self.state_dict_type == StateDictType.FULL_STATE_DICT: - if self.state_dict_config is None: - self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - if self.optim_state_dict_config is None: - self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) - + self.set_state_dict_type(state_dict_type_policy) self.use_orig_params = strtobool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1 - self.sync_module_states = strtobool(os.environ.get(prefix + "SYNC_MODULE_STATES", "False")) == 1 + self.sync_module_states = strtobool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1 self.forward_prefetch = strtobool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1 + if self.sync_module_states: + self.param_init_fn = lambda x: x.to_empty(device=torch.cuda.current_device(), recurse=False) + @staticmethod def get_module_class_from_name(module, name): """ @@ -976,6 +965,21 @@ def set_mixed_precision(self, mixed_precision): if self.mixed_precision_policy is None: self.mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) + def set_state_dict_type(self, state_dict_type_policy): + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, + FullStateDictConfig, + StateDictType, + ) + + self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1) + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + if self.state_dict_config is None: + self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + if self.optim_state_dict_config is None: + self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) + @dataclass class MegatronLMPlugin: