From 60d980d521d7a44518d579fbe00f755a1b66e4cb Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 25 Oct 2023 16:25:23 -0400 Subject: [PATCH] Enable universal checkpoint for zero stage 1 (#4516) * Enable uni_ckpt for z1 * Remove logging fix to seperate PR. Relocate conversion script to avoid logging circular import issue * Formatting fix * PR feedback * Handle replicated params * Detect bf16_optimizer * Docs * Fix docs --- deepspeed/checkpoint/constants.py | 2 + deepspeed/checkpoint/ds_to_universal.py | 303 ++++++++++++++++++ deepspeed/runtime/engine.py | 3 + deepspeed/runtime/pipe/engine.py | 11 +- deepspeed/runtime/pipe/module.py | 25 +- deepspeed/runtime/zero/stage3.py | 6 +- deepspeed/runtime/zero/stage_1_and_2.py | 47 +-- docs/code-docs/source/model-checkpointing.rst | 11 + 8 files changed, 370 insertions(+), 38 deletions(-) create mode 100755 deepspeed/checkpoint/ds_to_universal.py diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index e7bc67a0265e..75bb8d4d6c8f 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -21,6 +21,7 @@ ZERO_STAGE = 'zero_stage' CLIP_GRAD = 'clip_grad' FP32_WEIGHT_KEY = "fp32" +LOSS_SCALER = 'loss_scaler' ######################################### # Module checkpoint keys @@ -69,3 +70,4 @@ PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns' PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns' PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns' +TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns' diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py new file mode 100755 index 000000000000..7fb96ce98e29 --- /dev/null +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from functools import partial +import argparse +import glob +import itertools +import multiprocessing +import os +import re +import shutil +import torch +import tqdm +# from pprint import pprint + +from deepspeed.checkpoint import DeepSpeedCheckpoint +from deepspeed.checkpoint import ( + OPTIMIZER_STATE_DICT, + BASE_OPTIMIZER_STATE, + SINGLE_PARTITION_OF_FP32_GROUPS, + PARAM_SLICE_MAPPINGS, + PARAM_SHAPES, + PARAM, + CAT_DIM, + VOCAB_DIVISIBILITY_PADDING_TENSOR, + ORIGINAL_VOCAB_SIZE, + UNIVERSAL_CHECKPOINT_INFO, + VOCABULARY_PARAMETER_PATTERNS, + PIPELINE_REPLICATED_PARAMETER_PATTERNS, + TP_REPLICATED_PARAMETER_PATTERNS, + PARAMETER_TO_AVERAGE_PATTERNS, + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, +) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder') + parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder') + parser.add_argument('--num_extract_workers', + default=4, + type=int, + help='How many parallel processes to extract zero shards') + parser.add_argument( + '--num_merge_workers', + default=2, + type=int, + help= + 'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))' + ) + parser.add_argument('--keep_temp_folder', + action='store_true', + help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.') + args = parser.parse_args() + print(f'args = {args}') + return args + + +def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): + path_list = [] + iter_folder = f'iter_{iteration:07d}' + for i in range(0, tp_degree): + path_list.append([]) + for j in range(0, pp_degree): + rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' + ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') + path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) + + return path_list + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + + +def extract_zero_shards(dir, ds_checkpoint, indices_3D): + pp_index, tp_index, dp_index = indices_3D + sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) + + # pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}") + + optim_sd = sd[OPTIMIZER_STATE_DICT] + param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) + # print(f'{pipeline_replicated_params=}') + + # dict + state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] + # list + fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] + param_groups_cnt = len(state_groups) + + for param_group_id in range(param_groups_cnt): + + flat_state = dict( + exp_avg=state_groups[param_group_id]["exp_avg"], + exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], + fp32=fp32_groups[param_group_id], + ) + + for name, fragment_mapping in param_slice_mappings[param_group_id].items(): + if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params): + # Skip tied weights that are replicated in first and last pp stages + continue + + # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}") + for state_key in flat_state.keys(): + dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, + fragment_mapping.start, fragment_mapping.numel) + + +cnt = 0 + + +def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): + + global cnt # temp hack + + param_base_path = os.path.join(dir, param_name, str(tp_index)) + os.makedirs(param_base_path, exist_ok=True) + + cnt += 1 + counter = f"{dp_index:0>2d}" + + path = os.path.join(param_base_path, f"{state_name}.{counter}") + + #print(f"{param_name}: {offset}: {numel} => {path}") + + t = state_flat_tensor.narrow(0, offset, numel).clone() + _save_checkpoint(path, t) + + +def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): + slices = [] + for tp_index in range(tp_degree): + prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") + paths = sorted(list(glob.glob(f"{prefix_path}.*"))) + shards = [torch.load(p) for p in paths] + slice = torch.cat(shards, dim=0).reshape(slice_shape) + slices.append(slice) + + return slices + + +def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor): + original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE) + if padded_vocab_tensor.shape[0] > original_vocab_size: + return padded_vocab_tensor[-1] + else: + return torch.zeros(padded_vocab_tensor.shape[1]) + + +def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): + name, shape = name_and_shape + slice_base_path = os.path.join(slice_dir, name) + param_base_path = os.path.join(dir, name) + + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, []) + parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, []) + parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []) + vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, []) + for state in ("fp32", "exp_avg", "exp_avg_sq"): + slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) + final_path = os.path.join(param_base_path, f"{state}.pt") + + #print(f"Expected shape: {shape}") + #print(f"Fragment sizes:", list(frag.shape for frag in slices)) + ckpt_dict = {} + if any(re.match(pattern, name) for pattern in replicated_parameters): + if len(slices) > 1: + assert all([slices[0].equal(other_slice) for other_slice in slices[1:]]) + param = slices[0] + # print(f'replicate {name} using first slice') + elif any(re.match(pattern, name) for pattern in parameters_to_average): + param = sum(slices) / len(slices) + # print(f'merge {name} using average') + else: + cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0 + # print(f"merge {name} with CAT DIM: {cat_dim}") + param = torch.cat(slices, dim=cat_dim) + ckpt_dict[CAT_DIM] = cat_dim + + if any(re.match(pattern, name) for pattern in vocabulary_parameters): + #print(f"Before {param.shape=}") + # strip padding + #param = _strip_vocab_padding(ds_checkpoint, param) + ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor( + universal_checkpoint_info, param) + #print(f"After {param.shape=}") + + #print(f"Final shape: {param.shape}") + ckpt_dict[PARAM] = param + _save_checkpoint(final_path, ckpt_dict) + + +def _get_chunks(l, n): + for i in range(0, len(l), n): + yield l[i:i + n] + + +def _do_parallel_work(do_work, work_chunks, num_workers): + pool = multiprocessing.Pool(num_workers) + for batch in tqdm.tqdm(work_chunks): + pool.map(do_work, batch) + pool.close() + pool.join() + + +def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): + _3d_range_list = list( + itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), + range(ds_checkpoint.dp_degree))) + # pprint(f'{_3d_range_list=}') + work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) + # pprint(f'{work_chunks=}') + + # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) + do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) + _do_parallel_work(do_work, work_chunks, args.num_extract_workers) + + +def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): + work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) + #pprint(work_chunks) + zero_output_folder = os.path.join(args.output_folder, "zero") + do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) + _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + + +def _save_optimizer_state(args, ds_checkpoint): + sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS] + sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0) + + optim_sd = sd[OPTIMIZER_STATE_DICT] + output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states} + zero_output_folder = os.path.join(args.output_folder, "zero") + output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt") + _save_checkpoint(output_file_path, output_sd) + + +def _check_for_required_state(ds_checkpoint): + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' + + +def main(): + print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint') + + args = parse_arguments() + print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}') + + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) + _check_for_required_state(ds_checkpoint) + + iteration = ds_checkpoint.get_iteration() + #_create_latest_file(args.output_folder, iteration) + checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, + ds_checkpoint.pp_degree) + + slice_shapes = [] + for mp_rank_file in ds_checkpoint.mp_rank_files: + mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) + slice_shapes += mp_sd[PARAM_SHAPES] + + # fix back to normal flat dict, merge duplicates for tp>1 + slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) + temp_dir = os.path.join(args.output_folder, 'tmp') + + print('*** 1. Extracting ZeRO fragments') + _extract_zero_shard_files(args, ds_checkpoint, temp_dir) + + print('*** 2. Merging slices') + _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) + + print('*** 3. Saving common optimizer states') + _save_optimizer_state(args, ds_checkpoint) + + if not args.keep_temp_folder: + shutil.rmtree(temp_dir, ignore_errors=True) + + # Copy mp* files into output folder + for f in glob.glob(os.path.join(args.input_folder, 'mp*')): + shutil.copy2(f, args.output_folder) + + # Update latest to output folder + checkpoint_root_folder, step_folder = os.path.split(args.output_folder) + latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') + with open(latest_file, "w") as f: + f.write(step_folder) + + print('*** Done!') + + +if __name__ == "__main__": + main() diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 8de7b3c4307c..faf453d2bca3 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2758,6 +2758,9 @@ def load_checkpoint(self, if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() + if self.load_universal_checkpoint(): + self.optimizer.update_lp_params() + return load_path, client_states def _load_checkpoint(self, diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index bfe01f6db2fc..ef94ca87f6bb 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -11,6 +11,7 @@ from deepspeed.utils import logger, timeit from deepspeed.utils.timer import ThroughputTimer from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \ @@ -76,6 +77,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.has_bool_tensors = has_bool_tensors self.eval_return_logits = False self.outputs = None + # BF16 Optimizer is hardcoded for fp32 gradient accumulation + self.using_bf16_optimizer = type(self.optimizer) == BF16_Optimizer # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB self.pipeline_enable_backward_allreduce = True @@ -260,14 +263,14 @@ def _exec_reduce_tied_grads(self): weight_group_list = self.module.get_tied_weights_and_groups() for weight, group in weight_group_list: - grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad + grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad dist.all_reduce(grad, group=group) # @timeit def _exec_reduce_grads(self): self._force_grad_boundary = True if self.pipeline_enable_backward_allreduce: - if self.bfloat16_enabled(): + if self.using_bf16_optimizer: # PP+BF16 work for ZeRO Stage 1 self._bf16_reduce_grads() else: @@ -759,7 +762,7 @@ def _exec_backward_pass(self, buffer_id): part_grad = None #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') - if self.bfloat16_enabled() and not self.is_last_stage(): + if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward() self.optimizer.clear_lp_grads() @@ -771,7 +774,7 @@ def _exec_backward_pass(self, buffer_id): else: torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) - if self.bfloat16_enabled() and not self.is_last_stage(): + if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward() self.optimizer.update_hp_grads(clear_lp_grads=False) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 7495834fa9e4..45d420bea291 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -76,11 +76,11 @@ def build(self, log=False): class TiedLayerSpec(LayerSpec): - def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs): + def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr=['weight'], **module_kwargs): super().__init__(typename, *module_args, **module_kwargs) self.key = key self.forward_fn = forward_fn - self.tied_weight_attr = tied_weight_attr + self.tied_weight_attr = [tied_weight_attr] if type(tied_weight_attr) == str else tied_weight_attr class PipelineModule(nn.Module): @@ -424,23 +424,26 @@ def _partition_layers(self, method='uniform'): def allreduce_tied_weight_gradients(self): '''All reduce the gradients of the tied weights between tied stages''' for key, comm in self.tied_comms.items(): - weight = getattr(self.tied_modules[key], comm['weight_attr']) - dist.all_reduce(weight.grad, group=comm['group']) + for attr_name in comm['weight_attr']: + weight = getattr(self.tied_modules[key], attr_name) + dist.all_reduce(weight.grad, group=comm['group']) def get_tied_weights_and_groups(self): weight_group_list = [] for key, comm in self.tied_comms.items(): - weight = getattr(self.tied_modules[key], comm['weight_attr']) - weight_group_list.append((weight, comm['group'])) + for attr_name in comm['weight_attr']: + weight = getattr(self.tied_modules[key], attr_name) + weight_group_list.append((weight, comm['group'])) return weight_group_list def _synchronize_tied_weights(self): for key, comm in self.tied_comms.items(): - dist.broadcast( - getattr(comm['module'], comm['weight_attr']), - src=min(comm['ranks']), - group=comm['group'], - ) + for attr_name in comm['weight_attr']: + dist.broadcast( + getattr(comm['module'], attr_name), + src=min(comm['ranks']), + group=comm['group'], + ) def _index_tied_modules(self): ''' Build communication structures for tied modules. ''' diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index bd658649e981..b116276b72e2 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -24,7 +24,7 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator import time # Toggle this to true to enable correctness test @@ -2315,7 +2315,7 @@ def _clear_fp32_optimizer_param_groups(self): def _rigid_state_dict(self): state_dict = {} state_dict[ZERO_STAGE] = ZeroStageEnum.weights - state_dict['loss_scaler'] = self.loss_scaler + state_dict[LOSS_SCALER] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict[PARTITION_COUNT] = self.partition_count @@ -2419,7 +2419,7 @@ def _restore_base_optimizer_state(self, all_state_dict): def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] + self.loss_scaler = state_dict[LOSS_SCALER] self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.overflow = state_dict['overflow'] diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0fefb4e7dd21..a0c03bcaf05e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -26,7 +26,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.utils import logger, timeit -from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, +from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) from deepspeed.utils import link_hp_params @@ -1304,7 +1304,6 @@ def reduce_ipg_grads(self): Multiple gradient reduction is currently not supported" self.params_already_reduced[param_id] = True - if self.partition_gradients: if not self.is_param_in_current_partition[param_id]: if self.overlap_comm and self.contiguous_gradients is False: @@ -2046,7 +2045,7 @@ def state_dict(self): torch.save(checkpoint, "saved.pth") """ state_dict = {} - state_dict['loss_scaler'] = self.loss_scaler + state_dict[LOSS_SCALER] = self.loss_scaler state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict[CLIP_GRAD] = self.clip_grad @@ -2192,9 +2191,14 @@ def param_groups(self): def _load_hp_checkpoint_state(self, checkpoint_dir): checkpoint_dir = os.path.join(checkpoint_dir, "zero") + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + optim_sd = torch.load(optim_state_path) + self._load_global_state(optim_sd) + tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() - for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: @@ -2202,6 +2206,24 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + def _load_global_state(self, sd): + self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) + self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) + self.overflow = sd.get('overflow', self.overflow) + self.clip_grad = sd.get(CLIP_GRAD, self.clip_grad) + + ckpt_version = sd.get(DS_VERSION, False) + assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" + ckpt_version = pkg_version.parse(ckpt_version) + + # zero stage 1 mode + if not self.partition_gradients: + required_version = pkg_version.parse("0.3.17") + error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ + "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ + "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json." + assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}" + def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): r"""Loading ZeRO checkpoint @@ -2232,22 +2254,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l # I think it should actually be ok to reload the optimizer before the model. dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] - self.loss_scaler = current_rank_sd.get('loss_scaler', self.loss_scaler) - self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', self.dynamic_loss_scale) - self.overflow = current_rank_sd.get('overflow', self.overflow) - self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) - - ckpt_version = current_rank_sd.get(DS_VERSION, False) - assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" - ckpt_version = pkg_version.parse(ckpt_version) - - # zero stage 1 mode - if not self.partition_gradients: - required_version = pkg_version.parse("0.3.17") - error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ - "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ - "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json." - assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}" + self._load_global_state(current_rank_sd) ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict) diff --git a/docs/code-docs/source/model-checkpointing.rst b/docs/code-docs/source/model-checkpointing.rst index 7491529ba06f..85f7c7947a60 100644 --- a/docs/code-docs/source/model-checkpointing.rst +++ b/docs/code-docs/source/model-checkpointing.rst @@ -43,3 +43,14 @@ The following code snippet illustrates this functionality for creating a Hugging ds_engine, _, _, _ = deepspeed.initialize(model=model, config_params=ds_config) lean_state_dict = deepspeed.checkpoint.utils.clone_tensors_for_torch_save(ds_engine.module.state_dict()) ds_engine.module.save_pretrained("lean_after", state_dict=lean_state_dict) + + + +Universal Checkpoints (under development) +------------------------------------------ +Parallelism techniques such as ZeRO data parallelism (DP), Tensor parallelism (TP), Pipeline parallelism (TP), which shard model and/or +optimizer states make it difficult to resume training with a checkpoint that was created on a different number of GPUs. DeepSpeed provides the +Universal Checkpoint mechanism to address this problem. Universal Checkpoints give users the flexibility of changing the number of GPUs when training +with 3D (TP, PP, and DP) parallelism, and enables more efficient use of elastic training hardware. The easiest way to get started with +using Universal Checkpoints is to consult the `Megatron-DeepSpeed `_ +and `BLOOM `_ examples.