/
custom_transformer.py
64 lines (56 loc) · 2.02 KB
/
custom_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
r"""
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
"""
import torch
import torch.nn as nn
from custom_layers import FMoE
from linear import FMoELinear
class _Expert(nn.Module):
r"""
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
"""
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
r"""
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
"""
x = self.htoh4(inp, fwd_expert_count)
x = self.activation(x)
x = self.h4toh(x, fwd_expert_count)
return x
class FMoETransformerMLP(FMoE):
r"""
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
"""
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
activation=torch.nn.GELU(),
expert_dp_comm="none",
expert_rank=0, hyper_size=None,
**kwargs
):
super().__init__(num_expert=num_expert, d_model=d_model, hyper_size=hyper_size, **kwargs)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=expert_rank
)
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
r"""
This module wraps up the FMoE module with reshape, residual and layer
normalization.
"""
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
output = super().forward(inp)
return output.reshape(original_shape)