Skip to content

Commit

Permalink
[AutoTP] Make AutoTP work when num_heads not divisible by number of w…
Browse files Browse the repository at this point in the history
…orkers (microsoft#4011)

* allow number of heads not divisible by number of ranks

* get num_heads from model config, more robust

* simplify logic where num_head itself is sharded

* name tweaks

* make code more robust where num_attention_heads may not be defined in model_config

* support num_key_value_heads < num_attention_heads which is used by llama2

* add test for 5 ranks

* change odd rank # to 3 to avoid test skip

* add get_shard_size function

* modify sharding mechanism according to latest auto TP

* fix accuracy issue

* fix format

* skip tests with fusedqkv

* remove skip of fusedqkv tests

* skip test fusedqkv with odd number of ranks

* support model with n_heads in model_config

* fix TestInjectionPolicy::test[fp32-t5]

* fix uneven_heads on some fusedqkv types (microsoft#12)

* odd support fusedqkv

* fix format and clear text

* better fix when activation size cannot be divided by number of heads

* move tp_shard.py under module_inject

* Add get_num_kv_heads in tp_shard.py

* Refine according to comments

* remove old comment

* fix bug in getting num_kv_heads

* support uneven sharding of lm_head tensor parallel

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: mzl <mingzhi.liu@intel.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
7 people authored and amaurya committed Feb 17, 2024
1 parent 60d980d commit 1d1cd72
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 30 deletions.
32 changes: 22 additions & 10 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -312,8 +313,9 @@ def _replace(self, child, name, conv_linear_layer):

if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(
(weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
dim=1)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

Expand Down Expand Up @@ -342,14 +344,15 @@ def _replace(self, child, name, conv_linear_layer):
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
else:
data = child.weight.data.split((weight_shape[0]) // self.mp_size,
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
dim=1 if self.conv_linear_layer else 0)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(
(weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
Expand All @@ -366,13 +369,13 @@ def _slice_embedding(self, child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
else:
data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
Expand All @@ -386,8 +389,7 @@ def update_mp_params(self, child):
]:
if hasattr(child, param):
param_val = getattr(child, param)
assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
setattr(child, param, param_val // self.mp_size)
setattr(child, param, get_shard_size(param_val, self.mp_size))
setattr(child, "replaced", True)

def update_linear_policies(self):
Expand Down Expand Up @@ -442,6 +444,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
self._replace_module(child, name, class_name)
return r_module

def get_model_num_kv_heads(self, config):
num_kv_heads = None
kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads']
for name in kv_head_names:
if hasattr(config, name):
num_kv_heads = getattr(config, name)
if num_kv_heads != None:
break
return num_kv_heads

def _replace_last_linear_module(self, r_module):
if hasattr(r_module, "lm_head"):
name = "lm_head"
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/module_inject/auto_tp_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepspeed import comm as dist
import torch
from typing import Optional
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -51,8 +52,8 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size())
offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()])
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
Expand All @@ -72,8 +73,8 @@ def build_mpt_atten_bias_tensor(self,
prefix_mask=prefix_mask,
sequence_id=sequence_id)
if dist.is_initialized():
num_heads_per_rank = int(self.config.n_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size())
offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()])
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask

Expand Down
20 changes: 10 additions & 10 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads
import re


Expand Down Expand Up @@ -39,18 +40,19 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
#TODO: assert num_heads % (mp_size*codegen_mp_num) == 0
assert get_num_kv_heads() % (
mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0"
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
dst_shape = get_shard_size(shape[0], mp_size)
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])

#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1)
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
Expand All @@ -59,18 +61,16 @@ def _glm_type_transpose(input, mp_size):
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
src_split = torch.split(input, shape[0] // 3, dim=0)

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size))
return split_fusedqkv[gpu_index]

def _bloom_type_transpose(input, mp_size):
shape = input.shape
dst_shape = shape[0] // mp_size
return input[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]

split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
return split_fusedqkv[gpu_index]

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):

Expand Down
8 changes: 4 additions & 4 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from torch.nn.parameter import Parameter
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class LinearAllreduce(nn.Module):
Expand Down Expand Up @@ -47,10 +48,9 @@ def __init__(
self.world_size = world_size

def forward(self, input):
assert input.shape[
-1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
input_shard = input.shape[-1] // self.world_size
output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
input_shard_size = get_shard_size(input.shape[-1], self.world_size)
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank])
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
Expand Down
11 changes: 9 additions & 2 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -271,10 +272,16 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

# 3. Set linear policies
# 3. Try to get num_key_heads from model_config.num_key_value_heads
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)

# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
set_num_kv_heads(num_kv_heads)

# 5. Set linear policies
_autotp.update_linear_policies()

# 4. Replace modules
# 6. Replace modules
if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
return _autotp._replace_last_linear_module(module)
return _autotp._replace_module(module)
Expand Down
39 changes: 39 additions & 0 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed import comm as dist
global num_kv_heads


def set_num_kv_heads(num):
global num_kv_heads
num_kv_heads = num


def get_num_kv_heads():
global num_kv_heads
return num_kv_heads


def get_shard_size(total_size, mp_size, rank=None):
global num_kv_heads
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
if num_kv_heads != None:
if (rank == None):
rank = dist.get_rank()
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size % mp_size == 0:
return total_size // mp_size
else:
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"


def get_shard_size_list(total_size, mp_size):
shard_sizes = []
for i in range(mp_size):
shard_sizes.append(get_shard_size(total_size, mp_size, i))
return shard_sizes
32 changes: 32 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,38 @@ def test(
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

@pytest.mark.world_size(3)
def test_odd_world_size(
self,
model_w_task,
query,
inf_kwargs,
assert_fn,
dtype,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
pytest.skip(invalid_test_msg)

model, task = model_w_task
if model == "Salesforce/codegen-350M-mono":
pytest.skip("codegen does not supported by odd world_size")
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "3"))

pipe = pipeline(task,
model=model,
device=torch.device(get_accelerator().device_name(local_rank)),
framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)


@pytest.mark.nightly
@pytest.mark.parametrize(
Expand Down

0 comments on commit 1d1cd72

Please sign in to comment.