From 597749a128a5d2340071e66a9c740a2cb5302b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Storhaug?= Date: Fri, 19 Jan 2024 17:33:34 +0100 Subject: [PATCH 1/3] Add only_trainable_params argument to zero_to_fp32.py script --- deepspeed/utils/zero_to_fp32.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 49b846633d6e..7f02ac608a9d 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir): return zero_stage, world_size, fp32_flat_groups -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, only_trainable_params): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -211,9 +211,9 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params) def _zero2_merge_frozen_params(state_dict, zero_model_states): @@ -326,7 +326,7 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params): state_dict = OrderedDict() # buffers @@ -335,7 +335,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - _zero2_merge_frozen_params(state_dict, zero_model_states) + if not only_trainable_params: + _zero2_merge_frozen_params(state_dict, zero_model_states) _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -444,7 +445,7 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params): state_dict = OrderedDict() # buffers @@ -453,7 +454,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + if not only_trainable_params: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -465,7 +467,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, only_trainable_params=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -474,6 +476,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): Args: - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``only_trainable_params``: only merge trainable parameters Returns: - pytorch ``state_dict`` @@ -511,10 +514,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, only_trainable_params) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, only_trainable_params=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. @@ -523,9 +526,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag= - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``only_trainable_params``: only merge trainable parameters """ - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, only_trainable_params) print(f"Saving fp32 state dict to {output_file}") torch.save(state_dict, output_file) @@ -584,9 +588,10 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): type=str, default=None, help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--only_trainable_params", action='store_true', help="only merge trainable parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag) + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag, only_trainable_params=args.only_trainable_params) From 1add4f8b6c8ad013ee16b3cbf6bfe16d8a5ff811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Storhaug?= Date: Wed, 24 Jan 2024 15:11:00 +0100 Subject: [PATCH 2/3] Update only_trainable_params arg in zero_to_fp32.py to exclude_frozen_parameters --- deepspeed/utils/zero_to_fp32.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 7f02ac608a9d..fd85714c4b39 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir): return zero_stage, world_size, fp32_flat_groups -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, only_trainable_params): +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -211,9 +211,9 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, only_trainable_ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params) + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params) + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states): @@ -326,7 +326,7 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params): +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -335,7 +335,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - if not only_trainable_params: + if not exclude_frozen_parameters: _zero2_merge_frozen_params(state_dict, zero_model_states) _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -445,7 +445,7 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, only_trainable_params): +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -454,7 +454,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - if not only_trainable_params: + if not exclude_frozen_parameters: _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -467,7 +467,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, only_trainable_params=False): +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -476,7 +476,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, only_trai Args: - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` - - ``only_trainable_params``: only merge trainable parameters + - ``exclude_frozen_parameters``: exclude frozen parameters Returns: - pytorch ``state_dict`` @@ -514,10 +514,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, only_trai if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, only_trainable_params) + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, only_trainable_params=False): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. @@ -526,10 +526,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag= - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - - ``only_trainable_params``: only merge trainable parameters + - ``exclude_frozen_parameters``: exclude frozen parameters """ - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, only_trainable_params) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) print(f"Saving fp32 state dict to {output_file}") torch.save(state_dict, output_file) @@ -588,10 +588,10 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): type=str, default=None, help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") - parser.add_argument("--only_trainable_params", action='store_true', help="only merge trainable parameters") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag, only_trainable_params=args.only_trainable_params) + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag, exclude_frozen_parameters=args.exclude_frozen_parameters) From 2e16e7286b341cab4c7714516b69a3f1f89dfe01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Storhaug?= Date: Thu, 8 Feb 2024 17:11:16 +0100 Subject: [PATCH 3/3] Fix formatting --- deepspeed/utils/zero_to_fp32.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index fd85714c4b39..24cc342e78d1 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -211,9 +211,11 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states): @@ -326,7 +328,8 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -445,7 +448,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -455,7 +459,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer print(f"added {len(buffers)} buffers") if not exclude_frozen_parameters: - _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -594,4 +598,7 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): debug = args.debug - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag, exclude_frozen_parameters=args.exclude_frozen_parameters) + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters)