Skip to content

Commit

Permalink
sequence parallel: inside text split (hpcaitech#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
linsj20 committed Mar 20, 2024
1 parent ad9e332 commit 4b7b83a
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 57 deletions.
6 changes: 1 addition & 5 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,10 +660,6 @@ def __init__(
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.use_all_to_all_sequence_parallel = (
self.model.shard_config.enable_sequence_parallelism
and self.model.shard_config.sequence_parallelism_mode == "all_to_all"
)
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
Expand All @@ -684,7 +680,6 @@ def __init__(
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
enable_sequence_parallel=self.use_all_to_all_sequence_parallel,
)

def sync_dp_grads(self):
Expand Down Expand Up @@ -1098,6 +1093,7 @@ def __init__(
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
zero_stage=zero_stage,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
12 changes: 6 additions & 6 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def backward(ctx, grad_output):
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None


class _ReduceForward(torch.autograd.Function):
Expand Down Expand Up @@ -819,7 +819,7 @@ def backward(ctx, grad_output):
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _split(grad_output, ctx.dim, ctx.process_group), None, None
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None


class _AllToAll(torch.autograd.Function):
Expand Down Expand Up @@ -1020,12 +1020,12 @@ def matmul_gather_forward_reducescatter_backward(
)


def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)


def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)


def reduce_forward(input_, process_group):
Expand Down
19 changes: 6 additions & 13 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,6 @@ def forward(

# use variable seq_len to replace input_shape[-1]
seq_len = input_shape[-1]
if sp_mode in ["ring", "all_to_all"]:
seq_len *= sp_size

if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_len)
Expand All @@ -866,8 +864,6 @@ def forward(
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if sp_mode in ["ring", "all_to_all"]:
past_length *= sp_size
if position_ids is None:
position_ids = torch.arange(past_length, seq_len + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
Expand All @@ -876,9 +872,6 @@ def forward(
if sp_mode in ["ring", "all_to_all"]:
position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)]

if sp_mode in ["ring", "all_to_all"]:
attention_mask = _gather(attention_mask, 1, sp_group)

# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
Expand Down Expand Up @@ -917,12 +910,12 @@ def forward(
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

if inputs_embeds is None:
if sp_mode in ["ring"]:
input_ids = _gather(input_ids, 1, sp_group)
inputs_embeds = self.wte(input_ids)
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
else:
inputs_embeds = self.wte(input_ids)
inputs_embeds = self.wte(input_ids)
if sp_mode == "ring":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down')

position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

Expand Down
26 changes: 11 additions & 15 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def forward(
return forward


def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0):
logger = logging.get_logger(__name__)

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
Expand Down Expand Up @@ -804,10 +804,6 @@ def forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# sp: modify seq_length when using sequence parallel
if sp_mode in ["ring", "all_to_all"]:
seq_length *= sp_size

seq_length_with_past = seq_length
past_key_values_length = 0

Expand All @@ -827,13 +823,12 @@ def forward(
position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
if sp_mode == "ring":
input_ids = _gather(input_ids, 1, sp_group)
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)]
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
else:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids)

if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down')

# TODO use_distributed_mask
use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False
Expand Down Expand Up @@ -864,8 +859,6 @@ def forward(
attention_mask = _gather(attention_mask, 1, sp_group)

hidden_states = inputs_embeds
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)

if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
Expand Down Expand Up @@ -922,7 +915,10 @@ def custom_forward(*inputs):
hidden_states = self.norm(hidden_states)

# Todo: Maybe this line can be optimized
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up")
if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0):
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all" and zero_stage in [1, 2]:
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up")

# add hidden states from the last decoder layer
if output_hidden_states:
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel_mode": sp_mode,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group),
"forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, self.shard_config.zero_stage),
},
policy=policy,
target_key=LlamaModel,
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ShardConfig:
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
zero_stage: int = 0
enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
Expand Down
5 changes: 1 addition & 4 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def __init__(
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
master_weights: bool = True, # master weights
enable_sequence_parallel: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._enable_sequence_parallel = enable_sequence_parallel

self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._logger = get_dist_logger()
Expand Down Expand Up @@ -300,8 +298,7 @@ def _run_reduction(self):

if self.moe_extra_dp_pg is None:
flat_grads = self._bucket_store.get_flatten_grad()
if not self._enable_sequence_parallel:
flat_grads /= self._world_size
flat_grads /= self._world_size
else:
# record moe and non moe param
moe_list = []
Expand Down
13 changes: 1 addition & 12 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,7 @@ def _criterion(outputs, inputs):

shard_test_data = {}
for k, v in data.items():
if k not in ["input_ids", "attention_mask"]:
shard_test_data[k] = data[k].clone()
else:
# todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads()
shard_test_data[k] = (
torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[
dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group)
]
if booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode in ["ring", "all_to_all"]
else data[k].clone()
)
shard_test_data[k] = data[k].clone()
unshard_test_data = {}
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
Expand Down

0 comments on commit 4b7b83a

Please sign in to comment.