Skip to content

Commit

Permalink
in split by rank function, cache the sizes so on backwards there is n…
Browse files Browse the repository at this point in the history
…ot an extra call
  • Loading branch information
lucidrains committed Sep 10, 2023
1 parent 2e272df commit 666d2fd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.23',
version = '0.0.24',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
16 changes: 12 additions & 4 deletions st_moe_pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,24 @@ def __init__(self, *, dim = 0):
def forward(self, x, sizes = None):
return AllGatherFunction.apply(x, self.dim, sizes)

class SplitByRank(Function):
class SplitByRankFunction(Function):
@staticmethod
def forward(ctx, x):
rank = dist.get_rank()
return x[rank]
out = x[rank]

if isinstance(x, tuple):
sizes = tuple(map(lambda t: t.shape[0], x))
else:
sizes = (x.shape[1],) * x.shape[0]

ctx.sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
return out

@staticmethod
def backward(ctx, grads):
grads = rearrange(grads, '... -> 1 ...')
grads = all_gather_variable_dim(grads)
grads = all_gather_variable_dim(grads, sizes = ctx.sizes)
return grads

split_by_rank = SplitByRank.apply
split_by_rank = SplitByRankFunction.apply

0 comments on commit 666d2fd

Please sign in to comment.