Skip to content

Commit

Permalink
Universal ckp fixes (#4588)
Browse files Browse the repository at this point in the history
Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
3 people committed Nov 9, 2023
1 parent 1d1a20c commit 8ad187d
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 35 deletions.
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2

# Vocabulary padding
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
VOCAB_TENSOR = 'vocab_tensor'
PADDED_VOCAB_SIZE = 'padded_vocab_size'
ORIGINAL_VOCAB_SIZE = 'original_vocab_size'

Expand Down
60 changes: 41 additions & 19 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
PARAM_SHAPES,
PARAM,
CAT_DIM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
ORIGINAL_VOCAB_SIZE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
Expand Down Expand Up @@ -55,6 +54,10 @@ def parse_arguments():
parser.add_argument('--keep_temp_folder',
action='store_true',
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
parser.add_argument('--no_strict',
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
args = parser.parse_args()
print(f'args = {args}')
return args
Expand Down Expand Up @@ -149,15 +152,8 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
return slices


def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor):
original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE)
if padded_vocab_tensor.shape[0] > original_vocab_size:
return padded_vocab_tensor[-1]
else:
return torch.zeros(padded_vocab_tensor.shape[1])


def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):

name, shape = name_and_shape
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)
Expand All @@ -167,39 +163,53 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
vocabulary_parameters)

def get_matched_pattern(patterns_, name_):
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}'
if matched_:
pattern_ = matched_[0]
unmatched_patterns.discard(pattern_)
return pattern_
return None

for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")

#print(f"Expected shape: {shape}")
#print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if any(re.match(pattern, name) for pattern in replicated_parameters):
if get_matched_pattern(replicated_parameters, name):
if len(slices) > 1:
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
param = slices[0]
# print(f'replicate {name} using first slice')
elif any(re.match(pattern, name) for pattern in parameters_to_average):
elif get_matched_pattern(parameters_to_average, name):
param = sum(slices) / len(slices)
# print(f'merge {name} using average')
else:
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim

if any(re.match(pattern, name) for pattern in vocabulary_parameters):
if get_matched_pattern(vocabulary_parameters, name):
#print(f"Before {param.shape=}")
# strip padding
#param = _strip_vocab_padding(ds_checkpoint, param)
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(
universal_checkpoint_info, param)
original_vocab_size = universal_checkpoint_info['original_vocab_size']
param = param[:original_vocab_size, :]
ckpt_dict[VOCAB_TENSOR] = True
#print(f"After {param.shape=}")

#print(f"Final shape: {param.shape}")
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)

return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
Expand All @@ -208,10 +218,13 @@ def _get_chunks(l, n):

def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
pool.map(do_work, batch)
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
return results


def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
Expand All @@ -232,7 +245,16 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
_do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
sets = [set(lst) for lst in unmatched_patterns_lists]
unmatched_patterns = list(set.intersection(*sets))
if args.strict:
assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices'
elif unmatched_patterns:
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')


def _save_optimizer_state(args, ds_checkpoint):
Expand Down
15 changes: 5 additions & 10 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
Expand Down Expand Up @@ -43,21 +43,16 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None)
if vocab_divisibility_padding_tensor is not None:
is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
if is_vocab_tensor:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size = self.shape[0] * tp_world_size
assert padded_target_vocab_size >= full_hp_param.shape[0], \
f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]

full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)


def model_parallel_reconfigure_tp_seed(seed):
global mpu
tp_rank = bwc_tensor_model_parallel_rank(mpu)
model_parallel_seed = seed + 2718 + tp_rank
with _CUDA_RNG_STATE_TRACKER.fork():
get_accelerator().manual_seed(model_parallel_seed)


def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
Expand Down
2 changes: 0 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(self,
self.fp32_groups_gradient_flat_partition = []
self.fp32_groups_has_gradients = []

self.step_count = 0
self.group_paddings = []

if self.using_real_optimizer:
Expand Down Expand Up @@ -252,7 +251,6 @@ def step(self, closure=None):
self.update_lp_params()

self.clear_hp_grads()
self.step_count += 1

def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
"""Perform a backward pass and copy the low-precision gradients to the
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,8 @@ def load_checkpoint(self,

if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)

return load_path, client_states

Expand Down Expand Up @@ -2903,6 +2905,24 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True

def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step

optimizer = self.optimizer
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
Expand Down
3 changes: 0 additions & 3 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(self,

self.clip_grad = clip_grad
self.norm_type = 2
self.step_count = 0

if required_torch_version(max_version=0.4):
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
Expand Down Expand Up @@ -289,8 +288,6 @@ def step(self, closure=None):

self.timers.log(STEP_TIMERS)

self.step_count += 1

return self.overflow

def _get_norm_with_moe_layers(self, all_groups_norm):
Expand Down

0 comments on commit 8ad187d

Please sign in to comment.