<a href="https://colab.research.google.com/github/AviSoori1x/makeMoE/blob/main/makeMoE_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 从头开始的稀疏专家混合语言模型，灵感来自（并在很大程度上基于）Andrej Karpathy的makemore（[https://github.com/karpathy/makemore）](https://github.com/karpathy/makemore%EF%BC%89) :)

这是一个从头开始的稀疏专家混合语言模型的实现。这个项目的灵感来自于Andrej Karpathy的项目'makemore'，并且从该实现中借用了大部分可重用的组件。和makemore一样，makeMoE也是一个自回归的字符级语言模型，但是使用了前面提到的稀疏专家混合架构。

和makemore一样，pytorch是唯一的要求（所以希望这个从头开始的声明是合理的）。

与makemore架构相比的重要变化

- 使用稀疏专家混合代替孤立的前馈神经网络。
- 实施Top-k门控和嘈杂的Top-k门控。
- 初始化 - 这里使用了Kaiming He初始化，但这个笔记本的重点是可操作性，所以您可以替换为Xavier Glorot等，并进行尝试。

与makemore相比未改变的部分

- 数据集、预处理（标记化）以及Andrej最初选择的语言建模任务-生成类似莎士比亚文本
- 因果自注意力实施
- 训练循环
- 推理逻辑

在此实现中引用的出版物：

- 专家混合：https://arxiv.org/pdf/2401.04088.pdf
- 极大规模神经网络：稀疏门控专家混合层：https://arxiv.org/pdf/1701.06538.pdf

本笔记本详细介绍了整个模型架构的直观理解以及所有内容的整合方式。

代码是在Databricks上完全开发的，使用了一台A100进行计算。如果您在Databricks上运行此代码，可以在您选择的云提供商中轻松扩展到任意大的GPU集群中，没有任何问题。

我选择使用mlflow（在Databricks上预安装。您可以在其他地方轻松安装pip）因为我发现它有助于跟踪和记录所有必要的指标。这完全是可选的。

请注意，该实现强调可读性和可操作性而不是性能，因此有许多方法可以改进。请尝试并告诉我。

![mixture of experts overview](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/moe.png)

In [None]:
#Using mlflow is entirely optional. I personally like to use MLFlow to track and log everything. If you're using Databricks, it comes pre-installed.
%pip install mlflow

In [1]:
#Import the necessary packages and set seed for reproducibility. For this notebook, pytorch is all you need
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)
#Optional
import mlflow

Next few sections, downloading the data, preprocessing it and self attention are directly from makemore. I have elaborated a little on self attention and added visual aids to understand the process a bit better.

In [2]:
# Downloading the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

--2024-01-29 16:50:31--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 0.0.0.0, ::
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|0.0.0.0|:443... failed: Connection refused.
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|::|:443... failed: Connection refused.


In [3]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [5]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



Tokenization

In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [8]:
# let's now encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [9]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [10]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [12]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

In [13]:
ix = torch.randint(len(data) - block_size, (batch_size,))
ix

tensor([250930, 237205, 974116, 383898])

In [14]:
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x

tensor([[42,  1, 58, 46, 59, 57,  1, 21],
        [54, 56, 47, 43, 57, 58, 11,  0],
        [49, 47, 52, 45, 12,  1, 58, 46],
        [58, 46, 53, 59, 58,  1, 56, 43]])

In [15]:
y

tensor([[ 1, 58, 46, 59, 57,  1, 21,  1],
        [56, 47, 43, 57, 58, 11,  0, 37],
        [47, 52, 45, 12,  1, 58, 46, 53],
        [46, 53, 59, 58,  1, 56, 43, 42]])

下面的代码块清楚地展示了预测的自回归性质，以及上下文是在一个一维的令牌排列（在这种情况下是字符）上滚动的窗口。

In [16]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[ 6,  0, 14, 43, 44, 53, 56, 43],
        [39,  1, 42, 59, 43,  1, 39, 52],
        [47, 41, 43,  1, 39, 52, 42,  1],
        [53, 44,  1, 50, 43, 58,  1, 58]])
targets:
torch.Size([4, 8])
tensor([[ 0, 14, 43, 44, 53, 56, 43,  1],
        [ 1, 42, 59, 43,  1, 39, 52, 42],
        [41, 43,  1, 39, 52, 42,  1, 42],
        [44,  1, 50, 43, 58,  1, 58, 46]])
----
when input is [6] the target: 0
when input is [6, 0] the target: 14
when input is [6, 0, 14] the target: 43
when input is [6, 0, 14, 43] the target: 44
when input is [6, 0, 14, 43, 44] the target: 53
when input is [6, 0, 14, 43, 44, 53] the target: 56
when input is [6, 0, 14, 43, 44, 53, 56] the target: 43
when input is [6, 0, 14, 43, 44, 53, 56, 43] the target: 1
when input is [39] the target: 1
when input is [39, 1] the target: 42
when input is [39, 1, 42] the target: 59
when input is [39, 1, 42, 59] the target: 43
when input is [39, 1, 42, 59, 43] the target: 1
when input is [39, 1, 42, 59, 4

### 理解因果缩放点积自注意力的直觉

此代码来自于Andrej Karpathy的优秀makemore存储库，链接在仓库中

![scaled dot product self attention](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/self_attention.png)

提供的代码演示了自注意力机制和基本概念，特别关注经典的缩放点积自注意力。在这个变体中，查询、键和值矩阵都来自同一个输入序列。为了确保自回归语言生成过程的完整性，特别是在仅有解码器的模型中，代码实现了掩码。这种掩码技术至关重要，因为它隐藏了当前标记位置之后的任何信息，从而将模型的注意力集中在序列之前的部分。这样的注意机制被称为因果自注意力。值得注意的是，稀疏专家混合模型不仅限于仅有解码器的Transformer架构。实际上，这个领域的许多重要工作，特别是Shazeer等人的工作，都围绕着T5架构展开，该架构包含了Transformer模型中的编码器和解码器组件。

In [17]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1) #B,T,T

v = value(x) #B,T,H
out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)
#The output from this final matrix product is subsequently passsed through a linear layer as shown in the diagram above

out.shape

torch.Size([4, 8, 16])

将代码进行泛化和模块化，用于因果自注意力和多头因果自注意力。多头自注意力并行应用多个注意力头，每个头都专注于通道的不同部分（嵌入维度）。

In [18]:
#Causal scaled dot product self-Attention Head

n_embd = 64
n_head = 4
n_layer = 4
head_size = 16
dropout = 0.1

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [19]:
#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [20]:
#Confirming that what's output from multi head attention is the original embedding size
B,T,C = 4,8,64 # batch, time, channels
x = torch.randn(B,T,C)
mha = MultiHeadAttention(4,16)
mha(x).shape

torch.Size([4, 8, 64])

### 创建专家模块 比如一个简单的MLP

在稀疏的专家混合（MoE）架构中，每个Transformer块中的自注意机制保持不变。然而，在每个块的结构中发生了一个显著的改变：标准的前馈神经网络被几个稀疏激活的前馈网络（称为专家）所取代。所谓"稀疏激活"是指将序列中的每个标记仅路由到有限数量的专家中，通常是可用总数的一到两个。这种修改允许对输入数据的不同部分进行专门处理，使模型能够高效地处理更广泛的复杂性。

![experts](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/experts.png)

In [21]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

通过示例理解 Top-k 门控机制直觉

![top k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/topk.png)

门控网络，也称为路由器，确定每个令牌从多头注意力中发送到哪个专家网络。让我们考虑一个简单的例子：假设有4个专家，并且将令牌路由到前2个专家。首先，我们通过线性层将令牌输入到门控网络中。该层将输入张量从形状（2，4，32）——表示（批量大小，令牌，n_embed，其中n_embed是输入的通道维度）——投射到新形状（2，4，4），对应于（批量大小，令牌，num_experts），其中num_experts是专家网络的数量。在此之后，我们确定最高的k=2个值及其在最后一个维度上的索引。

In [22]:
#Understanding how gating works
num_experts = 4
top_k=2
n_embed=32


#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2
mh_output = torch.randn(2, 4, n_embed)

topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)

logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts
top_k_logits, top_k_indices

(tensor([[[ 0.9558,  0.1610],
          [ 0.8659, -0.1494],
          [ 0.8765,  0.7202],
          [ 0.9496, -0.6609]],
 
         [[ 0.4419, -0.2500],
          [ 1.2602,  0.8430],
          [ 0.8570,  0.7822],
          [ 0.7376,  0.2561]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 0],
          [3, 2],
          [3, 0],
          [1, 2]],
 
         [[1, 3],
          [1, 2],
          [1, 2],
          [0, 1]]]))

获取稀疏门控输出，只保留各自索引上最大的 k 个值。用“-inf”填充其余部分，并通过softmax激活函数。这样可以将“-inf”值推向零，使得前两个值更加突出且总和为1。这种总和为1的效果有助于对专家输出进行加权。

In [23]:
zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

tensor([[[ 0.1610,    -inf,  0.9558,    -inf],
         [   -inf,    -inf, -0.1494,  0.8659],
         [ 0.7202,    -inf,    -inf,  0.8765],
         [   -inf,  0.9496, -0.6609,    -inf]],

        [[   -inf,  0.4419,    -inf, -0.2500],
         [   -inf,  1.2602,  0.8430,    -inf],
         [   -inf,  0.8570,  0.7822,    -inf],
         [ 0.7376,  0.2561,    -inf,    -inf]]], grad_fn=<ScatterBackward0>)

In [24]:
gating_output= F.softmax(sparse_logits, dim=-1)
gating_output

tensor([[[0.3111, 0.0000, 0.6889, 0.0000],
         [0.0000, 0.0000, 0.2660, 0.7340],
         [0.4610, 0.0000, 0.0000, 0.5390],
         [0.0000, 0.8335, 0.1665, 0.0000]],

        [[0.0000, 0.6664, 0.0000, 0.3336],
         [0.0000, 0.6028, 0.3972, 0.0000],
         [0.0000, 0.5187, 0.4813, 0.0000],
         [0.6181, 0.3819, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)

### 将上述代码泛化和模块化，并添加嘈杂的前k个门控以实现负载均衡

In [25]:
# First define the top k router module
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear =nn.Linear(n_embed, num_experts)

    def forward(self, mh_ouput):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.linear(mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices



In [26]:
#Testing this out:
num_experts = 4
top_k = 2
n_embd = 32

mh_output = torch.randn(2, 4, n_embd)  # Example input
top_k_gate = TopkRouter(n_embd, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#And it works!!

(torch.Size([2, 4, 4]),
 tensor([[[0.0000, 0.4249, 0.5751, 0.0000],
          [0.3467, 0.6533, 0.0000, 0.0000],
          [0.3970, 0.0000, 0.6030, 0.0000],
          [0.0000, 0.0000, 0.7713, 0.2287]],
 
         [[0.4043, 0.5957, 0.0000, 0.0000],
          [0.0000, 0.5281, 0.0000, 0.4719],
          [0.7053, 0.0000, 0.2947, 0.0000],
          [0.0000, 0.4602, 0.5398, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[2, 1],
          [1, 0],
          [2, 0],
          [2, 3]],
 
         [[1, 0],
          [1, 3],
          [0, 2],
          [2, 1]]]))

虽然最近发布的mixture paper没有提到，我认为嘈杂的前k个门是训练MoE模型的重要工具。基本上，您不希望所有令牌都发送到同一组“喜爱的”专家。您希望在开发和探索之间保持良好的平衡。为了负载平衡，有助于将标准正态噪声添加到门控线性层的对数中。这样可以使训练更加高效。

![noisy top-k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/noisytopkgating.png)

In [27]:
#Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [28]:
#Testing this out, again:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#It works!!

(torch.Size([2, 4, 8]),
 tensor([[[0.0000, 0.4304, 0.0000, 0.5696, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.4109, 0.0000, 0.0000, 0.0000, 0.5891],
          [0.0000, 0.3708, 0.0000, 0.0000, 0.0000, 0.0000, 0.6292, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.4542, 0.0000, 0.5458, 0.0000]],
 
         [[0.0000, 0.0000, 0.1747, 0.0000, 0.0000, 0.0000, 0.8253, 0.0000],
          [0.0000, 0.0000, 0.5008, 0.4992, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.6160, 0.0000, 0.3840, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1952, 0.0000, 0.8048]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[3, 1],
          [7, 3],
          [6, 1],
          [6, 4]],
 
         [[6, 2],
          [2, 3],
          [4, 6],
          [7, 5]]]))


### 创建一个稀疏的 MoE 模块


这个过程的主要方面涉及到门控网络的输出。在获取这些结果后，将前k个值选择性地与相应的前k个专家的输出相乘，形成加权和，这构成了SparseMoe块的输出。这个过程的关键和具有挑战性的部分是避免不必要的乘法。只对top_k的专家进行前向传递，并计算这个加权和是非常重要的。对每个专家都进行前向传递将违背使用稀疏MoE的目的，因为它将不再是稀疏的。

In [29]:
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output
                # We need to scatter_add the weighted outputs to their original positions in the batch
                final_output.masked_scatter_(expert_mask.unsqueeze(-1), weighted_output)

        return final_output.view_as(x)




In [30]:
import torch
import torch.nn as nn

#Let's test this out
num_experts = 8
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print(final_output)

Shape of the final output: torch.Size([4, 8, 16])
tensor([[[ 6.1727e-02, -1.1298e-01, -4.1579e-02, -2.8002e-03,  0.0000e+00,
           7.7357e-02,  3.6960e-02, -1.0697e-01,  1.2526e-02,  1.0065e-02,
           2.5494e-02,  5.7258e-02,  3.1733e-02, -1.2904e-01,  3.1911e-02,
          -9.2942e-02],
         [-2.6267e-02, -2.1804e-01,  0.0000e+00,  0.0000e+00,  2.9232e-01,
           2.2580e-01,  0.0000e+00, -2.6606e-01, -6.0905e-02, -6.9949e-03,
           1.9965e-01,  2.1709e-01, -1.2841e-01, -2.2774e-01,  9.0908e-02,
          -0.0000e+00],
         [-9.4502e-02, -3.3213e-01,  4.8503e-02,  2.1059e-02,  2.4934e-01,
           6.4678e-02, -9.2514e-02, -7.5233e-02,  1.4475e-01,  0.0000e+00,
          -5.1047e-02, -7.9036e-02,  4.8226e-02, -0.0000e+00, -1.0702e-01,
           2.8969e-03],
         [ 1.3531e-01, -2.6609e-02,  0.0000e+00,  2.2722e-01, -1.2779e-01,
           3.2449e-02,  2.2538e-03, -6.9114e-02,  9.8225e-02, -8.9362e-02,
           1.6459e-01,  1.1035e-01,  7.0970e-02, -2.5

强调一下，重要的是要认识到从路由器/门控网络输出的前k个专家的幅度，如上面的代码所示，也是很重要的。这些前k个索引标识了被激活的专家，并且这些前k个维度中的值的大小确定了它们各自的权重。这种加权求和的概念在下面的图表中进一步强调。

![sparse MoE](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/sparseMoEfinal.png)

### Putting it all together

In [31]:
#First defining hyperparameters and boiler plate code. Imports and data preparation code is repeated for convenience
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2
# ------------

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [32]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [33]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

#Now create the sparse mixture of experts module
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output
                # We need to scatter_add the weighted outputs to their original positions in the batch
                final_output.masked_scatter_(expert_mask.unsqueeze(-1), weighted_output)

        return final_output.view_as(x)

In [34]:
#First create a self attention + mixture of experts block, that may be repeated several number of times
#Copy pasting key architecture variables for clarity

class Block(nn.Module):
    """ Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """

    def __init__(self, n_embed, n_head, num_experts, top_k):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x

In [35]:
#Finally putting it all together to crease a sparse mixture of experts language model
class SparseMoELanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

这里使用了Kaiming He初始化，因为专家中存在ReLU激活函数。如果你愿意，可以尝试使用更常用于transformers的Glorot初始化。Jeremy Howard的Fastai第二部分有一节讲座，非常出色地从头开始实现了这些内容：https://course.fast.ai/Lessons/lesson17.html

In [36]:

def kaiming_init_weights(m):
    if isinstance (m, (nn.Linear)):
        init.kaiming_normal_(m.weight)

In [37]:
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)

SparseMoELanguageModel(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(32, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
             

我使用mlflow来追踪和记录我关心的指标和训练超参数。下一个单元格中的训练循环包括了这个mlflow代码。如果你希望在不使用mlflow的情况下进行训练，后续单元格中的代码没有mlflow代码。然而，我发现在进行实验时，跟踪参数和指标非常方便。

In [40]:
#Using MLFlow
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#mlflow.set_experiment("makeMoE")
with mlflow.start_run():
    #If you use mlflow.autolog() this will be automatically logged. I chose to explicitly log here for completeness
    params = {"batch_size": batch_size , "block_size" : block_size, "max_iters": max_iters, "eval_interval": eval_interval,
              "learning_rate": learning_rate, "device": device, "eval_iters": eval_iters, "dropout" : dropout, "num_experts": num_experts, "top_k": top_k }
    mlflow.log_params(params)
    for iter in range(max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

            try:
                train_loss = float(losses['train'])
                val_loss = float(losses['val'])
                mlflow.log_metric("train_loss", train_loss, step=int(iter))
                mlflow.log_metric("val_loss", val_loss, step=int(iter))
            except Exception as e:
                print(f"Error logging metrics at step {iter}: {e}")
                print(f"Train loss: {losses['train']}, Validation loss: {losses['val']}")

        # sample a batch of data
        xb, yb = get_batch('train')

        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

8.996545 M parameters
step 0: train loss 5.3201, val loss 5.3154
step 100: train loss 2.7373, val loss 2.7475
step 200: train loss 2.5406, val loss 2.5388
step 300: train loss 2.4292, val loss 2.4455
step 400: train loss 2.3461, val loss 2.3537
step 500: train loss 2.2797, val loss 2.2935
step 600: train loss 2.2166, val loss 2.2304
step 700: train loss 2.1614, val loss 2.1915
step 800: train loss 2.1121, val loss 2.1541
step 900: train loss 2.0498, val loss 2.1045
step 1000: train loss 2.0205, val loss 2.0841
step 1100: train loss 1.9952, val loss 2.0577
step 1200: train loss 1.9508, val loss 2.0309
step 1300: train loss 1.9423, val loss 2.0409
step 1400: train loss 1.9017, val loss 2.0014
step 1500: train loss 1.8696, val loss 1.9871
step 1600: train loss 1.8472, val loss 1.9692
step 1700: train loss 1.8281, val loss 1.9634
step 1800: train loss 1.8143, val loss 1.9350
step 1900: train loss 1.7997, val loss 1.9247
step 2000: train loss 1.7784, val loss 1.9185
step 2100: train loss 1.

记录训练和验证损失可以很好地指示训练进展情况。根据图表显示，我可能应该在大约4500步时停止（当验证损失略微上升时）。

![mlflow_dash](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/mlflow_dash.png)

## 可选：不使用MLflow进行训练

In [None]:
#Not using MLflow
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [41]:
# generate from the model. Not great. Not too bad either
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


DesparencAtibles them, must
To let an ride not them, whole cast be our victery for a mode how;
And brother's of mercy atting Her:
O, pirate not were me,'r bught, thus.
A longergive:n they one lags to both-and slire,
The worn brought the came to the har
book couls the wour of their
or shallinaty house eyes,
Which hath state my arch will to give toicit by birmed the burwer.

FaLz:
This lest in the than mought dostleful dower soul with a recise,
If thee your boding made and heer come
And his in the four the com shors on flosse.

DUCHESS OF 'OF VINESET:
Prhate the werbous the world mad!

VIRGARY:
I come than,
We slet's my wifing hour mind eat
wick stoment, busting your my of will
The did you be?

Where IOLA:
Ay, by shere and and of satiment know,
Ter baniody that jein, chancied this stimpets heart cievervy.
Can bedwizader king him commillantate-perch
Of their is spink friend our rovening mfore his dame in to tower,
And by yether's we; when beind alough's worst it place, thou head, but you