Skip to content

Commit

Permalink
MOE: Fix gate conversion to fp32
Browse files Browse the repository at this point in the history
Currently, during forward, topkgate gate linear layer is converted to fp32.
This is forbidden since the linear layer params are a view into deepspeed's
parameter flat buffer.
To fix it, use torch.nn.functional.linear with gate.weight.float().

Signed-off-by: Moshe Island <misland@habana.ai>
  • Loading branch information
misland-habana authored and mosheisland committed Feb 21, 2024
1 parent 4ab360a commit 8f9d75c
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions deepspeed/moe/sharded_moe.py
Expand Up @@ -400,7 +400,7 @@ def __init__(self,
# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
Expand All @@ -421,13 +421,11 @@ def forward(self,
if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).start()

if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
logits = self.wg(input_fp32)
logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None)

if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
Expand Down

0 comments on commit 8f9d75c

Please sign in to comment.