From 1d1cd727ef7f9e4044d1b44521e20ca76ecb7eb1 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 26 Oct 2023 05:35:17 +0800 Subject: [PATCH] [AutoTP] Make AutoTP work when num_heads not divisible by number of workers (#4011) * allow number of heads not divisible by number of ranks * get num_heads from model config, more robust * simplify logic where num_head itself is sharded * name tweaks * make code more robust where num_attention_heads may not be defined in model_config * support num_key_value_heads < num_attention_heads which is used by llama2 * add test for 5 ranks * change odd rank # to 3 to avoid test skip * add get_shard_size function * modify sharding mechanism according to latest auto TP * fix accuracy issue * fix format * skip tests with fusedqkv * remove skip of fusedqkv tests * skip test fusedqkv with odd number of ranks * support model with n_heads in model_config * fix TestInjectionPolicy::test[fp32-t5] * fix uneven_heads on some fusedqkv types (#12) * odd support fusedqkv * fix format and clear text * better fix when activation size cannot be divided by number of heads * move tp_shard.py under module_inject * Add get_num_kv_heads in tp_shard.py * Refine according to comments * remove old comment * fix bug in getting num_kv_heads * support uneven sharding of lm_head tensor parallel --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com> Co-authored-by: mzl Co-authored-by: Michael Wyatt Co-authored-by: Michael Wyatt --- deepspeed/module_inject/auto_tp.py | 32 ++++++++++----- .../module_inject/auto_tp_model_utils.py | 9 +++-- deepspeed/module_inject/fusedqkv_utils.py | 20 +++++----- deepspeed/module_inject/layers.py | 8 ++-- deepspeed/module_inject/replace_module.py | 11 +++++- deepspeed/module_inject/tp_shard.py | 39 +++++++++++++++++++ tests/unit/inference/test_inference.py | 32 +++++++++++++++ 7 files changed, 121 insertions(+), 30 deletions(-) create mode 100644 deepspeed/module_inject/tp_shard.py diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 2e348de63454..50fab4cced37 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -14,6 +14,7 @@ from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list class ReplaceWithTensorSlicing: @@ -312,8 +313,9 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - data = child.weight.data.split( - (weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1) + data = child.weight.data.split(get_shard_size_list( + weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), + dim=1) data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach() del data @@ -342,14 +344,15 @@ def _replace(self, child, name, conv_linear_layer): module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to( get_accelerator().current_device_name()) else: - data = child.weight.data.split((weight_shape[0]) // self.mp_size, + data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size), dim=1 if self.conv_linear_layer else 0) data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach() del data if child.bias is not None: - bias_data = child.bias.data.split( - (weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0) + bias_data = child.bias.data.split(get_shard_size_list( + weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), + dim=0) bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) del bias_data @@ -366,13 +369,13 @@ def _slice_embedding(self, child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) if hasattr(child.weight, 'ds_tensor'): - data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1) + data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) else: - data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1) + data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) data = torch.nn.parameter.Parameter(data, requires_grad=False) - new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size) + new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size)) new_embedding.weight.data.copy_(data) setattr(child, "replaced", True) return new_embedding @@ -386,8 +389,7 @@ def update_mp_params(self, child): ]: if hasattr(child, param): param_val = getattr(child, param) - assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})" - setattr(child, param, param_val // self.mp_size) + setattr(child, param, get_shard_size(param_val, self.mp_size)) setattr(child, "replaced", True) def update_linear_policies(self): @@ -442,6 +444,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): self._replace_module(child, name, class_name) return r_module + def get_model_num_kv_heads(self, config): + num_kv_heads = None + kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads'] + for name in kv_head_names: + if hasattr(config, name): + num_kv_heads = getattr(config, name) + if num_kv_heads != None: + break + return num_kv_heads + def _replace_last_linear_module(self, r_module): if hasattr(r_module, "lm_head"): name = "lm_head" diff --git a/deepspeed/module_inject/auto_tp_model_utils.py b/deepspeed/module_inject/auto_tp_model_utils.py index ce52cd6a6250..51e52e3258dd 100644 --- a/deepspeed/module_inject/auto_tp_model_utils.py +++ b/deepspeed/module_inject/auto_tp_model_utils.py @@ -6,6 +6,7 @@ from deepspeed import comm as dist import torch from typing import Optional +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: @@ -51,8 +52,8 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] alibi = slopes[..., None] * arange_tensor if dist.is_initialized(): - num_heads_per_rank = int(num_heads / dist.get_world_size()) - offset = dist.get_rank() * num_heads_per_rank + num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size()) + offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()]) alibi = alibi.view(batch_size, num_heads, 1, seq_length) alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) @@ -72,8 +73,8 @@ def build_mpt_atten_bias_tensor(self, prefix_mask=prefix_mask, sequence_id=sequence_id) if dist.is_initialized(): - num_heads_per_rank = int(self.config.n_heads / dist.get_world_size()) - offset = dist.get_rank() * num_heads_per_rank + num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size()) + offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()]) attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :] return attn_bias, attention_mask diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 3ddf8f4404ac..2e8f6b5917ed 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads import re @@ -39,18 +40,19 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py - #TODO: assert num_heads % (mp_size*codegen_mp_num) == 0 + assert get_num_kv_heads() % ( + mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0" #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) shape = input.shape - dst_shape = shape[0] // mp_size + dst_shape = get_shard_size(shape[0], mp_size) num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1]) #num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :] src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1)) src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split] - split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1) + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1) tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1) return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] @@ -59,18 +61,16 @@ def _glm_type_transpose(input, mp_size): #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) shape = input.shape - dst_shape = shape[0] // mp_size src_split = torch.split(input, shape[0] // 3, dim=0) - split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size) - tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0) - - return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size)) + return split_fusedqkv[gpu_index] def _bloom_type_transpose(input, mp_size): shape = input.shape - dst_shape = shape[0] // mp_size - return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] + + split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0) + return split_fusedqkv[gpu_index] def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 7a565560dec9..969826ad0289 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list class LinearAllreduce(nn.Module): @@ -47,10 +48,9 @@ def __init__( self.world_size = world_size def forward(self, input): - assert input.shape[ - -1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]' - input_shard = input.shape[-1] // self.world_size - output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard], + input_shard_size = get_shard_size(input.shape[-1], self.world_size) + input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank]) + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) if self.mp_group is not None: dist.inference_all_reduce(output, group=self.mp_group) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 8666372fa3f4..fe32378613c9 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,6 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist +from deepspeed.module_inject.tp_shard import set_num_kv_heads from .load_checkpoint import load_model_with_checkpoint import time @@ -271,10 +272,16 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) - # 3. Set linear policies + # 3. Try to get num_key_heads from model_config.num_key_value_heads + num_kv_heads = _autotp.get_model_num_kv_heads(model_config) + + # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + set_num_kv_heads(num_kv_heads) + + # 5. Set linear policies _autotp.update_linear_policies() - # 4. Replace modules + # 6. Replace modules if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears: return _autotp._replace_last_linear_module(module) return _autotp._replace_module(module) diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py new file mode 100644 index 000000000000..8e2fa78d883f --- /dev/null +++ b/deepspeed/module_inject/tp_shard.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed import comm as dist +global num_kv_heads + + +def set_num_kv_heads(num): + global num_kv_heads + num_kv_heads = num + + +def get_num_kv_heads(): + global num_kv_heads + return num_kv_heads + + +def get_shard_size(total_size, mp_size, rank=None): + global num_kv_heads + # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + if num_kv_heads != None: + if (rank == None): + rank = dist.get_rank() + my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) + return total_size * my_slices // num_kv_heads + else: + if total_size % mp_size == 0: + return total_size // mp_size + else: + assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" + + +def get_shard_size_list(total_size, mp_size): + shard_sizes = [] + for i in range(mp_size): + shard_sizes.append(get_shard_size(total_size, mp_size, i)) + return shard_sizes diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 894f040be207..6b5588d8a1f7 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -558,6 +558,38 @@ def test( print(local_rank, "deepspeed", ds_output) assert assert_fn(bs_output, ds_output) + @pytest.mark.world_size(3) + def test_odd_world_size( + self, + model_w_task, + query, + inf_kwargs, + assert_fn, + dtype, + ): + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) + + model, task = model_w_task + if model == "Salesforce/codegen-350M-mono": + pytest.skip("codegen does not supported by odd world_size") + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "3")) + + pipe = pipeline(task, + model=model, + device=torch.device(get_accelerator().device_name(local_rank)), + framework="pt") + bs_output = pipe(query, **inf_kwargs) + + pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + ds_output = pipe(query, **inf_kwargs) + + print(local_rank, "baseline", bs_output) + print(local_rank, "deepspeed", ds_output) + assert assert_fn(bs_output, ds_output) + @pytest.mark.nightly @pytest.mark.parametrize(