From 666d2fd41734f25e517b25e980419b80e023d00f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 10 Sep 2023 07:50:42 -0700 Subject: [PATCH] in split by rank function, cache the sizes so on backwards there is not an extra call --- setup.py | 2 +- st_moe_pytorch/distributed.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index a98ad3a..44cdf40 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.23', + version = '0.0.24', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/distributed.py b/st_moe_pytorch/distributed.py index 3ccb2cc..3b55d51 100644 --- a/st_moe_pytorch/distributed.py +++ b/st_moe_pytorch/distributed.py @@ -85,16 +85,24 @@ def __init__(self, *, dim = 0): def forward(self, x, sizes = None): return AllGatherFunction.apply(x, self.dim, sizes) -class SplitByRank(Function): +class SplitByRankFunction(Function): @staticmethod def forward(ctx, x): rank = dist.get_rank() - return x[rank] + out = x[rank] + + if isinstance(x, tuple): + sizes = tuple(map(lambda t: t.shape[0], x)) + else: + sizes = (x.shape[1],) * x.shape[0] + + ctx.sizes = torch.tensor(sizes, device = out.device, dtype = torch.long) + return out @staticmethod def backward(ctx, grads): grads = rearrange(grads, '... -> 1 ...') - grads = all_gather_variable_dim(grads) + grads = all_gather_variable_dim(grads, sizes = ctx.sizes) return grads -split_by_rank = SplitByRank.apply +split_by_rank = SplitByRankFunction.apply