In [1]:
import torch
import torch.nn.functional as F

In [2]:
batch_size = 3
seqlen = 4
hidden_dim = 5
expert_num = 2 # 0 means skip
h = torch.randn(batch_size,seqlen,hidden_dim)
h

tensor([[[ 1.1009,  0.6192, -0.1429, -0.4305,  0.5437],
         [ 0.7905, -0.0539, -0.4622, -0.8850,  0.8118],
         [-0.1848,  0.5167, -0.1171,  0.9553,  0.3171],
         [-0.2319,  1.2983,  1.7737, -0.8799,  1.6616]],

        [[-0.8413,  0.8481, -1.2230,  0.3970,  0.0372],
         [ 0.5778, -0.4693,  1.5635,  1.6505, -0.0367],
         [-0.8245,  0.9551, -1.2522, -1.2334,  1.5235],
         [ 0.1954, -1.2951, -1.2750, -0.1464, -1.1960]],

        [[ 0.2232,  0.1394, -0.6916,  0.5031, -0.1735],
         [-0.4392, -0.5051, -2.0661,  0.7960, -1.8518],
         [ 0.8623,  0.6246, -0.0552, -1.2387,  0.1992],
         [ 0.5482,  0.1764, -0.2168,  0.8992,  0.0385]]])

In [7]:
FFN = torch.nn.Linear(hidden_dim,hidden_dim)

adapter_router = torch.nn.Linear(hidden_dim, expert_num)


In [10]:

flat_h = h.view(-1, h.size(-1))
logits = adapter_router(flat_h)
softmax_logits = F.softmax(logits, dim=-1)
top_k_logits, selected_experts = softmax_logits.topk(1, dim=-1)
# weighted_top_k_logits = top_k_logits / torch.sum(top_k_logits, dim=-1, keepdim=True, dtype=x.dtype)

for i in range(expert_num):
    batch_idx, nth_expert = torch.where(selected_experts == i)
    if len(batch_idx)>0:
        if i == 0: # skip FFN
            pass
        else: # compute FFN i times
            selected_tokens = flat_h[batch_idx]
            for k in range(i):
                selected_tokens = FFN(selected_tokens)
            flat_h[batch_idx] = selected_tokens
out = flat_h.view((*h.shape[:-1], flat_h.shape[-1]))

In [11]:
out

tensor([[[-0.6747,  0.7058, -0.7200, -1.3090,  0.3211],
         [ 1.6507,  1.1934, -0.8524, -0.2297, -0.1410],
         [ 0.0122, -1.6738,  0.4529, -0.8769,  0.6959],
         [-0.6717,  0.0429, -0.9673, -0.2408, -0.4398]],

        [[-0.9085,  1.0333, -1.2946, -1.2697,  1.0978],
         [ 2.3664,  1.1262,  0.2712,  0.4002, -0.3340],
         [-0.7131,  0.4895,  1.7313, -1.9211,  0.2163],
         [-0.8327, -0.1293, -0.4936, -0.0864, -0.8683]],

        [[-0.2964, -1.6552, -0.1826, -0.4211, -0.5058],
         [ 0.4085,  0.3638,  1.8576, -1.1710, -0.6801],
         [ 2.1512,  1.7554, -1.4695, -0.2422, -1.0826],
         [-0.0448,  0.7800, -0.8036,  0.2476,  0.7615]]],
       grad_fn=<ViewBackward0>)

In [17]:
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        batch_size = 3
        seqlen = 4
        hidden_dim = 5
        expert_num = 2 # 0 means skip
        self.ffn = torch.nn.Linear(hidden_dim,hidden_dim)
        self.adapter_router = torch.nn.Linear(hidden_dim, expert_num)    

    def forward(self, h):
        flat_h = h.view(-1, h.size(-1))
        logits = self.adapter_router(flat_h)
        softmax_logits = F.softmax(logits, dim=-1)
        top_k_logits, selected_experts = softmax_logits.topk(1, dim=-1)
        # weighted_top_k_logits = top_k_logits / torch.sum(top_k_logits, dim=-1, keepdim=True, dtype=x.dtype)

        for i in range(expert_num):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            print('expert:',i)
            print('selected tokens',batch_idx)
            if len(batch_idx) > 0:
                if i == 0: # skip FFN
                    pass
                else: # compute FFN i times
                    selected_tokens = flat_h[batch_idx]
                    for k in range(i):
                        selected_tokens = self.ffn(selected_tokens)
                    flat_h[batch_idx] = selected_tokens
        out = flat_h.view((*h.shape[:-1], flat_h.shape[-1]))
        return out

In [18]:
model = MyModel()

In [19]:
y = model(h)

expert: 0
selected tokens tensor([ 1,  2,  3,  4,  5,  6,  7, 10])
expert: 1
selected tokens tensor([ 0,  8,  9, 11])


In [16]:
model.adapter_router.weight.backward

<bound method Tensor.backward of Parameter containing:
tensor([[-0.0718, -0.2351,  0.0207,  0.0517,  0.0917],
        [ 0.1288, -0.2517, -0.3103, -0.1274,  0.0518]], requires_grad=True)>

In [15]:
model.ffn.weight.backward

<bound method Tensor.backward of Parameter containing:
tensor([[ 0.2995, -0.1930, -0.1372, -0.1588,  0.4157],
        [-0.3692, -0.1691,  0.0688,  0.1921,  0.2069],
        [-0.1188, -0.0804,  0.1148, -0.1464,  0.3371],
        [-0.0081,  0.1726,  0.2778, -0.4046,  0.2030],
        [ 0.0882, -0.4321,  0.0098, -0.2131, -0.3184]], requires_grad=True)>

In [14]:
y.backward

<bound method Tensor.backward of tensor([[[ 1.1009,  0.6192, -0.1429, -0.4305,  0.5437],
         [ 0.7905, -0.0539, -0.4622, -0.8850,  0.8118],
         [-0.1848,  0.5167, -0.1171,  0.9553,  0.3171],
         [-0.2319,  1.2983,  1.7737, -0.8799,  1.6616]],

        [[-0.8413,  0.8481, -1.2230,  0.3970,  0.0372],
         [ 0.5778, -0.4693,  1.5635,  1.6505, -0.0367],
         [-0.8245,  0.9551, -1.2522, -1.2334,  1.5235],
         [ 0.3503,  0.1098, -0.0554, -1.0284,  0.5551]],

        [[ 0.2232,  0.1394, -0.6916,  0.5031, -0.1735],
         [-0.4392, -0.5051, -2.0661,  0.7960, -1.8518],
         [ 0.8623,  0.6246, -0.0552, -1.2387,  0.1992],
         [ 0.5482,  0.1764, -0.2168,  0.8992,  0.0385]]],
       grad_fn=<ViewBackward0>)>

In [6]:
from torchviz import make_dot

In [7]:
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("model_graph", format='png', view=True)

ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH