Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
return zero_stage, world_size, fp32_flat_groups


def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
"""
Returns fp32 state_dict reconstructed from ds checkpoint

Expand All @@ -211,9 +211,11 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')

if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters)


def _zero2_merge_frozen_params(state_dict, zero_model_states):
Expand Down Expand Up @@ -326,7 +328,8 @@ def zero2_align(x):
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")


def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
state_dict = OrderedDict()

# buffers
Expand All @@ -335,7 +338,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer
if debug:
print(f"added {len(buffers)} buffers")

_zero2_merge_frozen_params(state_dict, zero_model_states)
if not exclude_frozen_parameters:
_zero2_merge_frozen_params(state_dict, zero_model_states)

_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)

Expand Down Expand Up @@ -444,7 +448,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")


def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
state_dict = OrderedDict()

# buffers
Expand All @@ -453,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
if debug:
print(f"added {len(buffers)} buffers")

_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
if not exclude_frozen_parameters:
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)

_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)

Expand All @@ -465,7 +471,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
return state_dict


def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
Expand All @@ -474,6 +480,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters

Returns:
- pytorch ``state_dict``
Expand Down Expand Up @@ -511,10 +518,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")

return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)


def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Expand All @@ -523,9 +530,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""

state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)

Expand Down Expand Up @@ -584,9 +592,13 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
type=str,
default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()

debug = args.debug

convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
args.output_file,
tag=args.tag,
exclude_frozen_parameters=args.exclude_frozen_parameters)