Skip to content

Commit

Permalink
fix imbalance autotp issue (microsoft#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Nov 16, 2023
1 parent 57ff508 commit 09a348c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
15 changes: 9 additions & 6 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ class Loading():

def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm"]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

def load_buffer(module, state_dict, prefix):
Expand Down Expand Up @@ -314,7 +316,7 @@ def _replace(self, child, name, conv_linear_layer):
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data
Expand Down Expand Up @@ -344,14 +346,14 @@ def _replace(self, child, name, conv_linear_layer):
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
Expand All @@ -369,9 +371,10 @@ def _slice_embedding(self, child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name),
dim=1)
else:
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
self.world_size = world_size

def forward(self, input):
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])
input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head")
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
Expand Down
13 changes: 8 additions & 5 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@ def get_num_kv_heads():
return num_kv_heads


def get_shard_size(total_size, mp_size, rank=None):
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0:
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
name) not in last_linear:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)
grain_size = total_size // 64
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64


def get_shard_size_list(total_size, mp_size):
def get_shard_size_list(total_size, mp_size, name=None):
shard_sizes = []
for i in range(mp_size):
shard_sizes.append(get_shard_size(total_size, mp_size, i))
shard_sizes.append(get_shard_size(total_size, mp_size, name, i))
return shard_sizes

0 comments on commit 09a348c

Please sign in to comment.