Skip to content

Commit

Permalink
support mode 2 sp in gpt2 (hpcaitech#5)
Browse files Browse the repository at this point in the history
* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2
  • Loading branch information
linsj20 authored and KKZ20 committed Feb 14, 2024
1 parent f1d4e18 commit 8bb59a2
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 81 deletions.
242 changes: 171 additions & 71 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,47 +169,57 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None


def _AllgatherLinear(input_, weight, process_group):
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)

input_shape = input_.shape
weight_shape = weight.shape

output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
#output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]

# initialization of ring communication
input_shape[1]
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
recv_tensor = input_.clone()
send_tensor = input_.clone()

recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([send_op, recv_op])
recv_tensors = {}
send_tensors = {}
for k, v in input_to_gather.items():
recv_tensors[k] = v.clone()
send_tensors[k] = v.clone()

def communicate_step():
comm_ops = []
for k in recv_tensors:
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
return dist.batch_isend_irecv(comm_ops)

def switch_step():
for k in recv_tensors:
tmp_tensor = send_tensors[k]
send_tensors[k] = recv_tensors[k]
recv_tensors[k] = tmp_tensor

output_tensors = []

handles = communicate_step()
# first round: special case, retrive from local tensor
output_tensors[0] = F.linear(input_, weight)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
handle.wait()

tmp_tensor = send_tensor
send_tensor = recv_tensor
recv_tensor = tmp_tensor
switch_step()

recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()

# actual computation
output_tensors[i + 1] = F.linear(send_tensor, weight)
output_tensors.append(func(**send_tensors, **input_local))

# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
output_tensors[group_size - 1] = F.linear(recv_tensor, weight)
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1)
output_tensors.append(func(**recv_tensors, **input_local))

return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)


class _GatherForwardReduceScatterBackward(torch.autograd.Function):
Expand Down Expand Up @@ -249,6 +259,41 @@ def backward(ctx, grad_output):
return output, None, None


class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""

@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.process_group = process_group
ctx.dim = dim

return _gather(input_, dim, process_group)

@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group

# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)

return output, None, None


class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Expand All @@ -260,19 +305,35 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

if bias is not None:
input_parallel = _gather(input_, dim, process_group)
output = F.linear(input_parallel, weight, bias)
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather['input'] = input_
input_local['weight'] = weight

output = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)

if bias is not None:
output += bias
else:
output = _AllgatherLinear(input_, weight, process_group)
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)

return output

Expand Down Expand Up @@ -376,34 +437,43 @@ def backward(ctx, grad_output):
# wait until reduce-scatter finished
reducescatter_handle.wait()

return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None


def _ReduceScatterLinear(input_, weight, process_group):
def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)

input_shape = input_.shape

# initialization of ring communication
# communicate(e.g.): 0->1->2->3
# compute(e.g.): 3->2->1->0
input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1))
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
input_tensors = []
for _ in range(group_size):
input_tensors.append({})
for k, v in input_to_reducescatter.items():
input_shape = v.shape
assert input_shape[reducescatter_dim] % group_size == 0
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
for i in range(group_size):
input_tensors[i][k] = _input_tensors[i]
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()

# first round: special case, no reduce operation
output_tensor = F.linear(input_tensors[0], weight)
output_tensor = func(**input_tensors[0], **input_local)
recv_tensor = output_tensor.clone()
send_tensor = output_tensor.clone()
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([recv_op, send_op])

def communicate_step():
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
return dist.batch_isend_irecv([recv_op, send_op])

handles = communicate_step()
# first round: special case, retrive from local tensor
for i in range(group_size - 2):
# actual computation
output_tensor = F.linear(input_tensors[i + 1], weight)
output_tensor = func(**input_tensors[i + 1], **input_local)

for handle in handles:
handle.wait()
Expand All @@ -413,12 +483,10 @@ def _ReduceScatterLinear(input_, weight, process_group):
send_tensor = output_tensor
output_tensor = tmp_tensor

recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
handles = dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()

# final round: special case, no need to send/recv again
output_tensor = F.linear(input_tensors[group_size - 1], weight)
output_tensor = func(**input_tensors[-1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
Expand All @@ -436,27 +504,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim):
def forward(ctx, input_, weight, bias, process_group, dim, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
if bias is not None:
partial_output = F.linear(input_, weight, bias)

if ring is True:
input_to_reducescatter = {}
input_local = {}
input_to_reducescatter['input'] = input_
input_local['weight'] = weight

if bias is not None:
input_to_reducescatter['bias'] = bias

output = _ring_as_reducescatter(
F.linear,
input_to_reducescatter=input_to_reducescatter,
input_local=input_local,
process_group=process_group,
)
else:
return _ReduceScatterLinear(input_, weight, process_group)
if bias is not None:
partial_output = F.linear(input_, weight, bias)
else:
partial_output = F.linear(input_, weight)

output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)

output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)
output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)

return output

Expand Down Expand Up @@ -484,7 +569,7 @@ def backward(ctx, grad_output):
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None
return grad_input, grad_weight, grad_bias, None, None, None


class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -533,17 +618,32 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

input_parallel = _gather(input_, dim, process_group)
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather['input'] = input_
input_local['other'] = weight

output = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim
)

else:
input_parallel = _gather(input_, dim, process_group)

output = torch.matmul(input_parallel, weight)
output = torch.matmul(input_parallel, weight)

if bias is not None:
output = output + bias
Expand Down Expand Up @@ -624,7 +724,7 @@ def backward(ctx, grad_output):
# wait until reduce-scatter finished
reducescatter_handle.wait()

return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None


class _SplitForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -877,10 +977,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre


def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)


Expand All @@ -892,15 +992,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)


def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)


def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)


Expand Down
Loading

0 comments on commit 8bb59a2

Please sign in to comment.