Skip to content

Commit

Permalink
Fix auto TP for duplicate modules with different gems (#2784)
Browse files Browse the repository at this point in the history
* Fix auto TP for duplicate modules with different gems

* precommit and comments

* Comment

* Combine gem list of same named modules

* remove duplicates from gem_list before updating policy

* Add module attribute with name variation for ProphetNet

---------

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
molly-smith and jeffra committed Feb 15, 2023
1 parent cc1054d commit 46784cb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
22 changes: 19 additions & 3 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,24 @@ def get_layers(parent, module):
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
elif isinstance(submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list

def update_policy_list(policy_list, new_module, new_gems):
if len(policy_list):
for i, policy in enumerate(policy_list):
# if module already exists in policy, combine gems and remove duplicates
if policy[0] == type(new_module):
new_gems = set(new_gems + policy[1])
policy_list[i] = tuple([type(new_module), new_gems])
return policy_list
policy_list.append(tuple([type(new_module), new_gems]))
return policy_list

def tp_parser(model):
policy_list = []
module_list = []
Expand All @@ -60,7 +72,9 @@ def tp_parser(model):
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
elif isinstance(
submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
Expand All @@ -70,7 +84,9 @@ def tp_parser(model):
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
layer_list = []
if gem_list != []:
policy_list.append(tuple([type(module), gem_list]))
gem_list = list(set(gem_list))
policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
gem_list = []
return policy_list
2 changes: 2 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def update_mp_params(child):
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'):
child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'):
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
Expand Down

0 comments on commit 46784cb

Please sign in to comment.