Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for ram efficient loading of model with FSDP #1777

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif device_placement and not has_hf_device_map:
model = model.to(self.device)

if self.native_amp:
if self.native_amp and self.distributed_type != DistributedType.FSDP:
model._original_forward = model.forward
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,9 @@ def get_cluster_input():
error_message="Please enter yes or no.",
)
fsdp_config["fsdp_sync_module_states"] = _ask_field(
"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [yes/NO]: ",
"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
_convert_yes_no_to_bool,
default=False,
default=True,
error_message="Please enter yes or no.",
)

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def launch_command_parser(subparsers=None):
)
fsdp_args.add_argument(
"--fsdp_sync_module_states",
default="false",
default="true",
type=str,
help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
" (useful only when `use_fsdp` flag is passed).",
Expand Down
40 changes: 22 additions & 18 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ class FullyShardedDataParallelPlugin:
},
)
sync_module_states: bool = field(
default=False,
default=True,
metadata={
"help": "If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0 "
"to ensure they are the same across all ranks after initialization"
Expand All @@ -874,14 +874,7 @@ class FullyShardedDataParallelPlugin:
)

def __post_init__(self):
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy

prefix = "FSDP_"
if self.sharding_strategy is None:
Expand All @@ -900,18 +893,14 @@ def __post_init__(self):

if self.state_dict_type is None:
state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT")
self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1)

if self.state_dict_type == StateDictType.FULL_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)

self.set_state_dict_type(state_dict_type_policy)
self.use_orig_params = strtobool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1
self.sync_module_states = strtobool(os.environ.get(prefix + "SYNC_MODULE_STATES", "False")) == 1
self.sync_module_states = strtobool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1
self.forward_prefetch = strtobool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1

if self.sync_module_states:
self.param_init_fn = lambda x: x.to_empty(device=torch.cuda.current_device(), recurse=False)

@staticmethod
def get_module_class_from_name(module, name):
"""
Expand Down Expand Up @@ -976,6 +965,21 @@ def set_mixed_precision(self, mixed_precision):
if self.mixed_precision_policy is None:
self.mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)

def set_state_dict_type(self, state_dict_type_policy):
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
StateDictType,
)

self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1)

if self.state_dict_type == StateDictType.FULL_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)


@dataclass
class MegatronLMPlugin:
Expand Down
Loading