In [1]:
import torch

In [4]:
sliding_window = 3
input_ids = torch.randint(0, 10, (1, 5))
dtype = torch.float
bsz, tgt_len = 1, 5
mask_cond = torch.arange(tgt_len)
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
print(mask)
diagonal = 0 - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
print(context_mask)
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
print(mask)


tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]])
tensor([[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0]], dtype=torch.int32)
tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        [-3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        [-3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00]])


In [None]:
MixtralConfig = None
import torch.nn as nn
from transformers.activations import ACT2FN
class MixtralBLockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states

In [None]:
import torch.nn.functional as F
class MixtralSparseMoeBlock(nn.Module):
    """
    This implementation is
    strictly equivalent to standard MoE with full capacity (no
    dropped tokens). It's faster since it formulates MoE operations
    in terms of block-sparse operations to accomodate imbalanced
    assignments of tokens to experts, whereas standard MoE either
    (1) drop tokens at the cost of reduced performance or (2) set
    capacity factor to number of experts and thus waste computation
    and memory on padding.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        # gating
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

        self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        # 获取到每个token的mlp层输入特征 
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        # 得到每个专家的打分，维度是batch * sequence, num_experts，取topk个专家
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        # torch.topk是可导的
        # (bs * sl, topk)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        # 取到topk个专家的打分，需要归一化，用于对后面的expert计算出来的结果进行加权
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        # (bs * sl, topk, n_experts) -> (n_experts, topk, bs * sl);
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            # 这样取到expert_mask[expert_idx]，从上面的注释可以知道维度是[topk, bs * sl]
            # torch.where的结果，第一个结果代表选到了哪一行，第二个代表选择了哪一列
            # 对应到实际意义，top_x表示取的列，也就是取哪些token
            # 而行表示，取到的这些token，根据路由gate计算，当前expert是排行第几
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
            # 没有token需要当前的expert计算
                continue

            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            # 前面hidden states已经转成了 [bs * sl, hs]，根据top_x可以找到需要计算的token
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            # 找到这个expert对应的权重 乘进去
            # 上面计算的权重是routing_weights，维度是bs * sl, topk
            # 根据top_x_list 对应的token，idx_list表示topk中第几个可以直接取到相应的权重
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

In [6]:
import torch
selected_experts = torch.randint(0, 8, (5, 2))
print(selected_experts)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=8).permute(2, 1, 0)
print(expert_mask)
print(expert_mask.shape)
print(torch.nn.functional.one_hot(selected_experts, num_classes=8))
print(torch.where(expert_mask[1]))

tensor([[5, 2],
        [5, 0],
        [5, 4],
        [6, 4],
        [6, 5]])
tensor([[[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 1, 1, 0]],

        [[1, 1, 1, 0, 0],
         [0, 0, 0, 0, 1]],

        [[0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]])
torch.Size([8, 2, 5])
tensor([[[0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 1, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 1, 0, 0]]])
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
