Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase committed Jun 19, 2024
2 parents 3cca313 + 4000cee commit e25fb52
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 10 deletions.
6 changes: 5 additions & 1 deletion deepspeed/autotuning/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import copy

from numpy import BUFSIZE
import numpy
if numpy.__version__ < '2.0.0':
from numpy import BUFSIZE
else:
from numpy._core.umath import BUFSIZE
import json
import subprocess
import sys
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/elasticity/elastic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from torch.distributed.elastic.agent.server.api import _get_socket_with_port
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.distributed.elastic.metrics import put_metric
from torch.distributed.elastic.agent.server.api import (
RunResult,
Expand Down Expand Up @@ -48,7 +48,7 @@ def _set_master_addr_port(store: Store,
master_port: Optional[int],
local_addr: Optional[str] = None):
if master_port is None:
sock = _get_socket_with_port()
sock = get_free_port()
with closing(sock):
master_port = sock.getsockname()[1]

Expand Down
26 changes: 19 additions & 7 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


Expand Down Expand Up @@ -134,7 +134,7 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -331,6 +331,16 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# for phi3.
if 'gate_up_proj' in name:
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
Expand Down Expand Up @@ -412,11 +422,13 @@ def _slice_embedding(self, child, name, conv_linear_layer):
def update_mp_params(self, child):
if getattr(child, "replaced", False) == True:
return
for param in [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
"all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads",
"d_model"
]:
param_list = [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size",
"embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model"
]
for param in param_list:
if "Yuan" in str(child) and 'embed_dim' in param_list:
param_list.remove('embed_dim')
if hasattr(child, param):
param_val = getattr(child, param)
setattr(child, param, get_shard_size(param_val, self.mp_size))
Expand Down
55 changes: 55 additions & 0 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,61 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _bloom_type_transpose(src, mp_size)


# For share qk type:
# q = [q1,...,q_{n/4}, q_{n/2+1},...,q_{3n/4}, k1,...,k_{n/4}, k_{n/2+1},...,k_{3n/4}]
# k = [q_{n/4+1},...,q_{n/2}, q_{3n/4+1},...,qn, k_{n/4+1},...,k_{n/2}, k{3n/4+1},...,kn]
# Avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.
def shard_value_with_share_qk(
weight,
bias,
rank,
world_size,
shard_value=True # True -> shard_value; False -> shard_oproj
):
if shard_value:
total_size = weight.shape[0]
weight_cat_dim = 0
else:
total_size = weight.shape[1]
weight_cat_dim = 1
num_heads = get_num_kv_heads()
head_dim = total_size // num_heads
assert (num_heads % world_size == 0)
if world_size > num_heads // 2:
RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}")
head_per_rank = num_heads // world_size
q_head_start = rank * head_per_rank
# mapping q_head to v_head
v_head_ids = []
i = 0
# mapping neighbor q_head to v_head
while i < head_per_rank:
v_head_ids.append(q_head_start // 2)
q_head_start += 2
i = i + 2

# mapping neighbor k_head to v_head
v_head_ids.extend([i + num_heads // 2 for i in v_head_ids])
sharded_weight = []
sharded_bias = []
for head_id in v_head_ids:
if shard_value:
sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim])
if bias is not None:
sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim])
else:
sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim])
sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim)
if bias is not None:
if shard_value:
sharded_bias = torch.cat(sharded_bias, dim=0)
else:
bias = bias / float(world_size)
return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias)
else:
return torch.nn.Parameter(sharded_weight), None


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
Expand Down
62 changes: 62 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,68 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class TensorParallelConv2d(nn.Module):

def __init__(self, conv, rank, world_size, shard_by_oc):
super().__init__()
self.rank = rank
self.world_size = world_size
self.shard_by_oc = shard_by_oc
self.shard_weights(conv)

# Split along the input/output channel depending on whether it is the last conv layer.
def shard_weights(self, conv):
if self.shard_by_oc:
total_size = conv.weight.shape[0]
else:
total_size = conv.weight.shape[1]
bias_data = None
cols_per_rank = [0]
for i in range(self.world_size - 1, -1, -1):
cols = total_size // self.world_size
if i < total_size % self.world_size:
cols += 1
cols_per_rank.append(cols_per_rank[-1] + cols)
weight_data = conv.weight.data
if self.shard_by_oc:
# not last conv layer, split output channel
weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
else:
# last conv layer, split input channel
weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data / float(self.world_size)
self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding,
conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode)
self.conv.weight = torch.nn.Parameter(weight_data)
if conv.bias is not None:
self.conv.bias = torch.nn.Parameter(bias_data)
del conv

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.conv(input)


class TensorParallelOcShardConv2d(TensorParallelConv2d):

def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, True)


class TensorParallelIcShardConv2d(TensorParallelConv2d):

def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, False)

def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.conv(input)
if self.world_size > 1:
dist.inference_all_reduce(out)
return out


class LinearAllreduce(nn.Module):

def __init__(self, weight, bias=None, mp_group=None):
Expand Down
27 changes: 27 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from deepspeed.accelerator import get_accelerator
from .replace_policy import replace_policies, generic_policies
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
Expand Down Expand Up @@ -340,6 +341,28 @@ def set_lm_head(module):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

def conv2d_parallel_shard_weights(model, rank, world_size):
# add conv policy
shard_oc_name = ["conv1"]
shard_ic_name = ["conv2"]
for name, sub_m in model.named_children():
for l_name, l_sub_m in sub_m.named_children():
if l_name in shard_oc_name:
TPConv2d = TensorParallelOcShardConv2d(
l_sub_m,
rank,
world_size,
)
setattr(sub_m, l_name, TPConv2d)
if l_name in shard_ic_name:
TPConv2d = TensorParallelIcShardConv2d(
l_sub_m,
rank,
world_size,
)
setattr(sub_m, l_name, TPConv2d)
conv2d_parallel_shard_weights(sub_m, rank, world_size)

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
checkpoint = checkpoint_dict["checkpoints"]
Expand All @@ -354,6 +377,10 @@ def set_lm_head(module):
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
conv2d_parallel_shard_weights(replaced_module, dist.get_rank(), dist.get_world_size())
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down

0 comments on commit e25fb52

Please sign in to comment.