Skip to content

Commit

Permalink
Merge branch 'feature/shardformer' into pp_tp_zero1
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Aug 30, 2023
2 parents 74be5f0 + d367b88 commit 7b353f4
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 120 deletions.
87 changes: 61 additions & 26 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

input_parallel = _gather(input_, dim, process_group)

Expand All @@ -312,37 +313,70 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)
if not overlap:
input_parallel = _gather(input_, dim, process_group)

total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
handle.wait()

# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]

# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()

if ctx.async_grad_reduce_scatter:
handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()

return output, grad_weight, grad_bias, None, None, None
return output, grad_weight, grad_bias, None, None, None, None


class _SplitForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -510,9 +544,10 @@ def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)


def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim)
async_grad_reduce_scatter, dim, overlap)


def gather_forward_split_backward(input_, dim, process_group):
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(self,
async_communication: bool = False,
gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
Expand All @@ -190,6 +191,7 @@ def __init__(self,
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
Expand Down Expand Up @@ -308,7 +310,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
if self.seq_parallel:
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1)
self.process_group, True, 1, self.overlap)
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
Expand Down
19 changes: 0 additions & 19 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,22 +226,3 @@ def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
end_idx = num_layers_per_stage_accumulated[stage + 1]

return [start_idx, end_idx]

def append_seq_parallel_to_policy(
self,
suffix_list: List[str],
module_policy_description: ModulePolicyDescription,
):
r"""
Append the sequence parallel policy to the policy for the given key.
Args:
suffix_list (List[str]): the suffix list of the module to be parallelized
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
"""

for sub_description in module_policy_description.sub_module_replacement:
if (sub_description.suffix in suffix_list):
if sub_description.kwargs is None:
sub_description.kwargs = {}
sub_description.kwargs["seq_parallel"] = True
94 changes: 50 additions & 44 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model

policy = {}

use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
Expand All @@ -50,47 +51,54 @@ def module_policy(self):
),
])

policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[GPT2Block] = ModulePolicyDescription(
attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])

# optimization configuration
if self.shard_config.enable_fused_normalization:
Expand Down Expand Up @@ -126,8 +134,6 @@ def module_policy(self):

if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])

return policy

Expand Down
26 changes: 13 additions & 13 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,21 @@ def module_policy(self):
target_key=OPTDecoderLayer)

# use flash attention
# if self.shard_config.enable_flash_attention:
# self.append_or_create_method_replacement(description={
# 'forward': get_opt_flash_attention_forward(),
# },
# policy=policy,
# target_key=OPTAttention)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
},
policy=policy,
target_key=OPTAttention)

# use jit fused operator
# if self.shard_config.enable_jit_fused:
# self.append_or_create_method_replacement(description={
# 'forward': get_jit_fused_opt_decoder_layer_forward(),
# 'dropout_add': get_jit_fused_dropout_add_func(),
# },
# policy=policy,
# target_key=OPTDecoderLayer)
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=OPTDecoderLayer)

return policy

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor


def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
Expand All @@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
n_fused=3)
n_fused=3,
overlap=overlap)

assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
Expand Down Expand Up @@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):

@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel)
@parameterize('overlap', [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel)


Expand Down
14 changes: 1 addition & 13 deletions tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group

# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3

if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)

check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)


# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
Expand Down

0 comments on commit 7b353f4

Please sign in to comment.