In [1]:
import torch
import torch.nn.functional as F

In [2]:
# 设置随机种子以便复现结果
torch.manual_seed(42)

<torch._C.Generator at 0x1b6ed086330>

In [3]:
# 定义参数
batch_size = 2
q_seq_len = 3
hidden_size = 4
num_experts = 3
num_topk = 2  # 每个token路由到的专家数量

# 计算派生参数
num_tokens = batch_size * q_seq_len
total_num_tokens = num_tokens * num_topk
average_num_tokens_per_expert = total_num_tokens // num_experts

print(f"参数设置:")
print(f"  batch_size = {batch_size}, q_seq_len = {q_seq_len}, hidden_size = {hidden_size}")
print(f"  num_experts = {num_experts}, num_topk = {num_topk}")
print(f"  num_tokens = {num_tokens}, total_num_tokens = {total_num_tokens}")
print(f"  average_num_tokens_per_expert = {average_num_tokens_per_expert}\n")

参数设置:
  batch_size = 2, q_seq_len = 3, hidden_size = 4
  num_experts = 3, num_topk = 2
  num_tokens = 6, total_num_tokens = 12
  average_num_tokens_per_expert = 4



In [4]:
# 1. 初始化输入 hidden_states
hidden_states = torch.randn(batch_size, q_seq_len, hidden_size, dtype=torch.bfloat16)
print("1. 原始输入 hidden_states:")
print(f"  形状: {hidden_states.shape}, dtype: {hidden_states.dtype}")
print(hidden_states, "\n")

1. 原始输入 hidden_states:
  形状: torch.Size([2, 3, 4]), dtype: torch.bfloat16
tensor([[[-0.8086, -1.5312,  0.4062,  0.1719],
         [-0.2471,  0.2041, -0.8789, -0.3867],
         [ 0.5664,  0.2363,  0.4863,  1.1719]],

        [[ 1.4531, -0.8906,  0.1543,  0.8242],
         [-2.1719,  1.3516,  0.2754, -0.1128],
         [-0.7969,  1.3438,  0.3750, -1.1328]]], dtype=torch.bfloat16) 



In [5]:
# 2. 重塑输入为 [num_tokens, hidden_size]
reshaped_states = hidden_states.view(num_tokens, hidden_size)
print("2. 重塑后的 hidden_states:")
print(f"  形状: {reshaped_states.shape}")
print(reshaped_states, "\n")

2. 重塑后的 hidden_states:
  形状: torch.Size([6, 4])
tensor([[-0.8086, -1.5312,  0.4062,  0.1719],
        [-0.2471,  0.2041, -0.8789, -0.3867],
        [ 0.5664,  0.2363,  0.4863,  1.1719],
        [ 1.4531, -0.8906,  0.1543,  0.8242],
        [-2.1719,  1.3516,  0.2754, -0.1128],
        [-0.7969,  1.3438,  0.3750, -1.1328]], dtype=torch.bfloat16) 



In [6]:
# 3. gating gemm
# 初始化门控权重
gating_weight = torch.randn(hidden_size, num_experts, dtype=torch.bfloat16)
print(gating_weight, "\n")
router_logits = torch.matmul(reshaped_states, gating_weight)
print("3. 门控计算后的 router_logits:")
print(f"  形状: {router_logits.shape}, dtype: {router_logits.dtype}")
print(router_logits, "\n")

tensor([[ 1.3516,  0.6875, -0.3281],
        [ 0.7969,  0.2812,  0.0562],
        [ 0.5234, -0.2383, -0.0498],
        [ 0.5273, -0.0085,  0.7305]], dtype=torch.bfloat16) 

3. 门控计算后的 router_logits:
  形状: torch.Size([6, 3]), dtype: torch.bfloat16
tensor([[-2.0156, -1.0859,  0.2852],
        [-0.8359,  0.1001, -0.1465],
        [ 1.8281,  0.3301,  0.6602],
        [ 1.7734,  0.7031,  0.0674],
        [-1.7734, -1.1797,  0.6914],
        [-0.4082, -0.2500, -0.5078]], dtype=torch.bfloat16) 



In [7]:
# 4. Softmax 计算
routing_weights = F.softmax(router_logits.float(), dim=1)
print("4. Softmax 后的 routing_weights (float32):")
print(f"  形状: {routing_weights.shape}, dtype: {routing_weights.dtype}")
print(routing_weights, "\n")

4. Softmax 后的 routing_weights (float32):
  形状: torch.Size([6, 3]), dtype: torch.float32
tensor([[0.0740, 0.1875, 0.7385],
        [0.1804, 0.4601, 0.3595],
        [0.6517, 0.1457, 0.2027],
        [0.6560, 0.2249, 0.1191],
        [0.0686, 0.1243, 0.8071],
        [0.3250, 0.3807, 0.2942]]) 



In [8]:
# 5. Top-k 选择和归一化
routing_weights, selected_experts = torch.topk(routing_weights, num_topk, dim=1)
print("未进行归一化, 选择专家后的routing_weights: ")
print(routing_weights, "\n")
routing_weights = routing_weights / routing_weights.sum(dim=1, keepdim=True)  # 归一化
routing_weights = routing_weights.to(torch.bfloat16)
print("5. Top-k 选择和归一化后的结果:")
print("  routing_weights:")
print(f"    形状: {routing_weights.shape}, dtype: {routing_weights.dtype}")
print(routing_weights)
print("\n  selected_experts:")
print(f"    形状: {selected_experts.shape}, dtype: {selected_experts.dtype}")
print(selected_experts, "\n")

未进行归一化, 选择专家后的routing_weights: 
tensor([[0.7385, 0.1875],
        [0.4601, 0.3595],
        [0.6517, 0.2027],
        [0.6560, 0.2249],
        [0.8071, 0.1243],
        [0.3807, 0.3250]]) 

5. Top-k 选择和归一化后的结果:
  routing_weights:
    形状: torch.Size([6, 2]), dtype: torch.bfloat16
tensor([[0.7969, 0.2021],
        [0.5625, 0.4395],
        [0.7617, 0.2373],
        [0.7461, 0.2559],
        [0.8672, 0.1338],
        [0.5391, 0.4609]], dtype=torch.bfloat16)

  selected_experts:
    形状: torch.Size([6, 2]), dtype: torch.int64
tensor([[2, 1],
        [1, 2],
        [0, 2],
        [0, 1],
        [2, 1],
        [1, 0]]) 



In [9]:
# 6. 直方图统计和索引计算
# 计算每个专家的 token 数量
expert_token_count = torch.zeros(num_experts, dtype=torch.int64)
# print(expert_token_count)
flat_experts = selected_experts.view(-1)
print(flat_experts)
"""
  selected_experts:
    形状: torch.Size([6, 2]), dtype: torch.int64
tensor([[2, 1],
        [1, 2],
        [0, 2],
        [0, 1],
        [2, 1],
        [1, 0]]) 

flat_experts : tensor([2, 1, 1, 2, 0, 2, 0, 1, 2, 1, 1, 0])
"""
expert_token_count = torch.bincount(flat_experts, minlength=num_experts)  # 每个专家被选到次数统计
"""
统计出每个专家路由了多少token
[3, 5, 4]
expert0处理3个token
expert1处理5个token
expert2处理4个token
"""

print("6. 直方图统计:")
print(f" 每个专家被选到次数统计 expert_token_count: {expert_token_count.tolist()}\n")

# 计算每个专家的偏移量
expert_offsets = torch.zeros(num_experts, dtype=torch.int64)
expert_offsets[1:] = torch.cumsum(expert_token_count, dim=0)[:-1]
print("expert_offsets: ", expert_offsets)
"""
expert_offsets: 专家token的偏移
[0, 3, 8]
3个专家总共处理12个token(token有重复)
每个专家处理的token的偏移(token的起始索引)
expert0处理0 ~ 3  (3个token)
expert1处理3 ~ 8  (5个token)
expert2处理8 ~ 12 (4个token)
"""

# 计算每个 token 在输出缓冲区的偏移
token_offsets = torch.zeros_like(selected_experts)

"""
selected_experts: 
tensor([[2, 1],
        [1, 2],
        [0, 2],
        [0, 1],
        [2, 1],
        [1, 0]])
"""
# print(token_offsets)
for i in range(num_tokens):
    for j in range(num_topk):
        expert_idx = selected_experts[i, j]   # 取某个token的专家
        token_offsets[i, j] = expert_offsets[expert_idx]  # 通过专家获得token位置
        expert_offsets[expert_idx] += 1

# 重置 expert_offsets 用于后续操作
expert_offsets = torch.zeros(num_experts, dtype=torch.int64)
expert_offsets[1:] = torch.cumsum(expert_token_count, dim=0)[:-1]

print("  token_offsets:")
print(f"    形状[tokens, topk]: {token_offsets.shape}, dtype: {token_offsets.dtype}")
print(token_offsets, "\n")

"""
token_offsets:

tensor([[ 8,  3],
        [ 4,  9],
        [ 0, 10],
        [ 1,  5],
        [11,  6],
        [ 7,  2]])
含义：
[8, 3]说明: 在num_tokens中token0被专家2, 1处理, 在allocated_tokens中token的位置为8, 3。


"""

tensor([2, 1, 1, 2, 0, 2, 0, 1, 2, 1, 1, 0])
6. 直方图统计:
 每个专家被选到次数统计 expert_token_count: [3, 5, 4]

expert_offsets:  tensor([0, 3, 8])
  token_offsets:
    形状[tokens, topk]: torch.Size([6, 2]), dtype: torch.int64
tensor([[ 8,  3],
        [ 4,  9],
        [ 0, 10],
        [ 1,  5],
        [11,  6],
        [ 7,  2]]) 



'\ntoken_offsets:\n\ntensor([[ 8,  3],\n        [ 4,  9],\n        [ 0, 10],\n        [ 1,  5],\n        [11,  6],\n        [ 7,  2]])\n含义：\n[8, 3]说明: 在num_tokens中token0被专家2, 1处理, 在allocated_tokens中token的位置为8, 3。\n\n\n'

In [10]:
# 7. 分散 (Scatter) 操作
# 准备缓冲区
prepared_tokens = torch.zeros(total_num_tokens, hidden_size, dtype=torch.bfloat16)

# 分散数据
for i in range(num_tokens):
    for j in range(num_topk):
        offset = token_offsets[i, j]
        prepared_tokens[offset] = reshaped_states[i]

print("7. 分散后的 prepared_tokens:")
print(f"  形状: {prepared_tokens.shape}, dtype: {prepared_tokens.dtype}")
print(prepared_tokens, "\n")

7. 分散后的 prepared_tokens:
  形状: torch.Size([12, 4]), dtype: torch.bfloat16
tensor([[ 0.5664,  0.2363,  0.4863,  1.1719],
        [ 1.4531, -0.8906,  0.1543,  0.8242],
        [-0.7969,  1.3438,  0.3750, -1.1328],
        [-0.8086, -1.5312,  0.4062,  0.1719],
        [-0.2471,  0.2041, -0.8789, -0.3867],
        [ 1.4531, -0.8906,  0.1543,  0.8242],
        [-2.1719,  1.3516,  0.2754, -0.1128],
        [-0.7969,  1.3438,  0.3750, -1.1328],
        [-0.8086, -1.5312,  0.4062,  0.1719],
        [-0.2471,  0.2041, -0.8789, -0.3867],
        [ 0.5664,  0.2363,  0.4863,  1.1719],
        [-2.1719,  1.3516,  0.2754, -0.1128]], dtype=torch.bfloat16) 



In [11]:
# 8. 分组矩阵乘法 (Grouped GEMM)
# 初始化专家权重
expert_weights1 = torch.randn(num_experts, hidden_size, hidden_size * 2, dtype=torch.bfloat16)
expert_weights2 = torch.randn(num_experts, hidden_size * 2, hidden_size, dtype=torch.bfloat16)

# 准备专家输入缓冲区
expert_inputs = []
for i in range(num_experts):
    start = expert_offsets[i]
    end = start + expert_token_count[i]
    expert_inputs.append(prepared_tokens[start:end])

print(expert_inputs, "\n")

# 专家处理
expert_outputs = []
for i in range(num_experts):
    if expert_token_count[i] > 0:
        # 第一层线性变换
        hidden = torch.matmul(expert_inputs[i], expert_weights1[i])
        # 激活函数
        hidden = F.gelu(hidden)
        # 第二层线性变换
        output = torch.matmul(hidden, expert_weights2[i])
        expert_outputs.append(output)
    else:
        expert_outputs.append(torch.zeros(0, hidden_size, dtype=torch.bfloat16))

# 收集所有专家输出
output_tokens = torch.cat(expert_outputs, dim=0)
print("8. 专家处理后的 output_tokens:")
print(f"  形状: {output_tokens.shape}, dtype: {output_tokens.dtype}")
print(output_tokens, "\n")

[tensor([[ 0.5664,  0.2363,  0.4863,  1.1719],
        [ 1.4531, -0.8906,  0.1543,  0.8242],
        [-0.7969,  1.3438,  0.3750, -1.1328]], dtype=torch.bfloat16), tensor([[-0.8086, -1.5312,  0.4062,  0.1719],
        [-0.2471,  0.2041, -0.8789, -0.3867],
        [ 1.4531, -0.8906,  0.1543,  0.8242],
        [-2.1719,  1.3516,  0.2754, -0.1128],
        [-0.7969,  1.3438,  0.3750, -1.1328]], dtype=torch.bfloat16), tensor([[-0.8086, -1.5312,  0.4062,  0.1719],
        [-0.2471,  0.2041, -0.8789, -0.3867],
        [ 0.5664,  0.2363,  0.4863,  1.1719],
        [-2.1719,  1.3516,  0.2754, -0.1128]], dtype=torch.bfloat16)] 

8. 专家处理后的 output_tokens:
  形状: torch.Size([12, 4]), dtype: torch.bfloat16
tensor([[-1.3672,  1.6328, -1.1641,  4.1562],
        [-1.8750, -0.5273,  2.5156,  2.3594],
        [ 4.0938, -0.1484, -2.5781,  1.1016],
        [ 1.5391, -0.4023, -0.4805,  0.2188],
        [ 1.5859, -0.2832,  0.9688,  1.2266],
        [ 9.6875,  9.8125,  1.0234, -2.7500],
        [-0.2676,  3.57

In [12]:
# 9. 收集(Gather) 操作
# 创建最终结果缓冲区
final_results = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16)

# 收集并加权求和
for i in range(num_tokens):
    for j in range(num_topk):
        offset = token_offsets[i, j]
        final_results[i] += routing_weights[i, j] * output_tokens[offset]

print("9. 收集后的 final_results:")
print(f"  形状[num_tokens, hidden_size]: {final_results.shape}, dtype: {final_results.dtype}")
print(final_results, "\n")

9. 收集后的 final_results:
  形状[num_tokens, hidden_size]: torch.Size([6, 4]), dtype: torch.bfloat16
tensor([[-1.2969, -0.6250,  0.4922,  1.7266],
        [ 1.4844, -0.2598,  0.3340,  0.4336],
        [ 0.1328,  0.3047, -1.1328,  3.2812],
        [ 1.0859,  2.1250,  2.1406,  1.0547],
        [-2.0469, -1.2031,  4.4062, -0.0352],
        [ 1.0469,  0.7266, -1.9375,  2.9688]], dtype=torch.bfloat16) 



In [13]:
# 10. 恢复原始形状
final_output = final_results.view(batch_size, q_seq_len, hidden_size)
print("10. 最终输出 final_output:")
print(f"  形状: {final_output.shape}")
print(final_output)

10. 最终输出 final_output:
  形状: torch.Size([2, 3, 4])
tensor([[[-1.2969, -0.6250,  0.4922,  1.7266],
         [ 1.4844, -0.2598,  0.3340,  0.4336],
         [ 0.1328,  0.3047, -1.1328,  3.2812]],

        [[ 1.0859,  2.1250,  2.1406,  1.0547],
         [-2.0469, -1.2031,  4.4062, -0.0352],
         [ 1.0469,  0.7266, -1.9375,  2.9688]]], dtype=torch.bfloat16)
