From 09a348c92a6a85db099c00b71f5a488cce9cdb93 Mon Sep 17 00:00:00 2001 From: Yejing-Lai <55339926+Yejing-Lai@users.noreply.github.com> Date: Thu, 16 Nov 2023 09:46:11 +0800 Subject: [PATCH] fix imbalance autotp issue (#31) --- deepspeed/module_inject/auto_tp.py | 15 +++++++++------ deepspeed/module_inject/layers.py | 4 ++-- deepspeed/module_inject/tp_shard.py | 13 ++++++++----- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 4e993876fbe8..d930c9f2f413 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -121,7 +121,9 @@ class Loading(): def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] - load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm"] + load_layer_names = [ + "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm" + ] return module.__class__ in load_layers or module._get_name() in load_layer_names def load_buffer(module, state_dict, prefix): @@ -314,7 +316,7 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = child.weight.data.split(get_shard_size_list( - weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), + weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), dim=1) data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach() del data @@ -344,14 +346,14 @@ def _replace(self, child, name, conv_linear_layer): module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to( get_accelerator().current_device_name()) else: - data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size), + data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), dim=1 if self.conv_linear_layer else 0) data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach() del data if child.bias is not None: bias_data = child.bias.data.split(get_shard_size_list( - weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), + weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), dim=0) bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) @@ -369,9 +371,10 @@ def _slice_embedding(self, child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) if hasattr(child.weight, 'ds_tensor'): - data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) + data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), + dim=1) else: - data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) + data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) data = torch.nn.parameter.Parameter(data, requires_grad=False) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 969826ad0289..5fb957faa80c 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -48,8 +48,8 @@ def __init__( self.world_size = world_size def forward(self, input): - input_shard_size = get_shard_size(input.shape[-1], self.world_size) - input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank]) + input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head") + input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank]) output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) if self.mp_group is not None: diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 77756e43c62f..cf5cd4aafe04 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -17,20 +17,23 @@ def get_num_kv_heads(): return num_kv_heads -def get_shard_size(total_size, mp_size, rank=None): +def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads + last_linear = ["lm_head", "embed_out"] # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division if rank == None: rank = dist.get_rank() - if num_kv_heads != None and total_size % num_kv_heads == 0: + if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str( + name) not in last_linear: my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: - return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0) + grain_size = total_size // 64 + return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64 -def get_shard_size_list(total_size, mp_size): +def get_shard_size_list(total_size, mp_size, name=None): shard_sizes = [] for i in range(mp_size): - shard_sizes.append(get_shard_size(total_size, mp_size, i)) + shard_sizes.append(get_shard_size(total_size, mp_size, name, i)) return shard_sizes