Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auto TP for duplicate modules with different gems #2784

Merged
merged 13 commits into from
Feb 15, 2023
Merged
22 changes: 19 additions & 3 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,24 @@ def get_layers(parent, module):
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
elif isinstance(submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list

def update_policy_list(policy_list, new_module, new_gems):
if len(policy_list):
for i, policy in enumerate(policy_list):
# if module already exists in policy, combine gems and remove duplicates
if policy[0] == type(new_module):
new_gems = set(new_gems + policy[1])
policy_list[i] = tuple([type(new_module), new_gems])
return policy_list
policy_list.append(tuple([type(new_module), new_gems]))
return policy_list

def tp_parser(model):
policy_list = []
module_list = []
Expand All @@ -60,7 +72,9 @@ def tp_parser(model):
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
elif isinstance(
submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
Expand All @@ -70,7 +84,9 @@ def tp_parser(model):
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
layer_list = []
if gem_list != []:
policy_list.append(tuple([type(module), gem_list]))
gem_list = list(set(gem_list))
policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
gem_list = []
return policy_list
2 changes: 2 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def update_mp_params(child):
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'):
child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'):
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
Expand Down