diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 32aa1fbb934b..4385331b112e 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -41,12 +41,24 @@ def get_layers(parent, module): for key, submodule in module._modules.items(): if isinstance(submodule, nn.Linear): layer_list = layer_list + [parent + "." + key] - elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm': + elif isinstance(submodule, + nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm': layer_list = layer_list + ["ln"] else: layer_list = layer_list + AutoTP.get_layers(key, submodule) return layer_list + def update_policy_list(policy_list, new_module, new_gems): + if len(policy_list): + for i, policy in enumerate(policy_list): + # if module already exists in policy, combine gems and remove duplicates + if policy[0] == type(new_module): + new_gems = set(new_gems + policy[1]) + policy_list[i] = tuple([type(new_module), new_gems]) + return policy_list + policy_list.append(tuple([type(new_module), new_gems])) + return policy_list + def tp_parser(model): policy_list = [] module_list = [] @@ -60,7 +72,9 @@ def tp_parser(model): for key, submodule in module._modules.items(): if isinstance(submodule, nn.Linear): layer_list = layer_list + ["." + key] - elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm': + elif isinstance( + submodule, + nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm': layer_list = layer_list + ["ln"] else: layer_list = layer_list + AutoTP.get_layers(key, submodule) @@ -70,7 +84,9 @@ def tp_parser(model): gem_list = gem_list + [layer_list[i - 1]] elif 'out_proj' in layer: gem_list = gem_list + [layer] + layer_list = [] if gem_list != []: - policy_list.append(tuple([type(module), gem_list])) + gem_list = list(set(gem_list)) + policy_list = AutoTP.update_policy_list(policy_list, module, gem_list) gem_list = [] return policy_list diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 028cbbb38014..4096a5ece64a 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -463,6 +463,8 @@ def update_mp_params(child): child.num_heads = child.num_heads // mp_size if hasattr(child, 'num_attention_heads'): child.num_attention_heads = child.num_attention_heads // mp_size + if hasattr(child, 'num_attn_heads'): + child.num_attn_heads = child.num_attn_heads // mp_size if hasattr(child, 'all_head_size'): child.all_head_size = child.all_head_size // mp_size if hasattr(child, 'embed_dim'):