Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix imbalance autotp issue #31

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ class Loading():

def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm"]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "RMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

def load_buffer(module, state_dict, prefix):
Expand Down Expand Up @@ -314,7 +316,7 @@ def _replace(self, child, name, conv_linear_layer):
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data
Expand Down Expand Up @@ -344,14 +346,14 @@ def _replace(self, child, name, conv_linear_layer):
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
Expand All @@ -369,9 +371,10 @@ def _slice_embedding(self, child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name),
dim=1)
else:
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
self.world_size = world_size

def forward(self, input):
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])
input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head")
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
Expand Down
13 changes: 8 additions & 5 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@ def get_num_kv_heads():
return num_kv_heads


def get_shard_size(total_size, mp_size, rank=None):
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0:
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the evaluation of this line if name == None?

name) not in last_linear:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)
grain_size = total_size // 64
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64


def get_shard_size_list(total_size, mp_size):
def get_shard_size_list(total_size, mp_size, name=None):
shard_sizes = []
for i in range(mp_size):
shard_sizes.append(get_shard_size(total_size, mp_size, i))
shard_sizes.append(get_shard_size(total_size, mp_size, name, i))
return shard_sizes
Loading