In [1]:
import torch
from sonicmoe import MoE, KernelBackendMoE

# Create MoE layer
moe = MoE(
    num_experts=128,           # Number of experts
    num_experts_per_tok=8,     # Top-k experts per token
    hidden_size=4096,          # Hidden dimension
    intermediate_size=1536,    # Expert intermediate size
    is_glu=True,               # Whether to use GLU (e.g. SwiGLU) activation
    add_bias=False,            # Add bias to linear layers
    std=0.02,                  # Weight initialization std
).to(device="cuda", dtype=torch.bfloat16)

# Forward pass
x = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16)
output, aux_loss = moe(x, kernel_backend_moe=KernelBackendMoE.sonicmoe)


ninja: no work to do.


In [3]:
from sonicmoe.functional import TC_Softmax_Topk_Router_Function, count_cumsum, TC_topk_router_metadata
import torch.nn.functional as F

router_w = moe.router.weight
router_logits = F.linear(x, router_w)
K = moe.top_k
E = moe.num_experts

In [None]:
# topk_scores: torch.Size([32768, 8]) torch.float32 cuda:0 - [seq, topk]
# topk_indices: torch.Size([32768, 8]) torch.int32 cuda:0 - [seq, topk]
# expert_frequency: torch.Size([128]) torch.int32 cuda:0 - [NumTotalExperts]
# expert_offset: torch.Size([129]) torch.int32 cuda:0 - [NumTotalExperts + 1]
# x_gather_idx: torch.Size([262144]) torch.int32 cuda:0 - [topk * seq]
# s_scatter_idx: torch.Size([262144]) torch.int32 cuda:0 - [topk * seq]
# s_reverse_scatter_idx: torch.Size([262144]) torch.int32 cuda:0 - [topk * seq]
# topk_reverse_scatter_idx: torch.Size([32769]) torch.int32 cuda:0 - [seq + 1]


topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K)
expert_frequency, expert_offset = count_cumsum(topk_indices.view(-1), router_w.size(0), do_cumsum=True)
# Puts a leading 0 infront of expert_offset
expert_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, topk_token_offset = TC_topk_router_metadata(
    topk_indices, expert_offset, K
)

In [2]:

def forward_token_choice_rounding(
    x: torch.Tensor, router_w: torch.Tensor, E, K, Mtile, routing
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    T, D = x.shape  # # B, L, # total expert
    Mtile = 128

    device = x.device
    dtype = x.dtype

    router_logits = F.linear(x, router_w)
    router_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32).to(dtype)

    # first sorting, similar to TC
    topk_values, topk_indices = router_scores.topk(K, dim=-1)

    expert_freq = count_cumsum(topk_indices.view(-1), E, do_cumsum=True)[0]
    expert_freq_rounded_up = (torch.ceil(expert_freq / Mtile) * Mtile).type(torch.int32)
    expert_freq_rounded_down = expert_freq // Mtile * Mtile

    topk_values /= topk_values.sum(dim=-1, keepdim=True)

    router_scores.scatter_(-1, topk_indices, topk_values)

    router_TC_EC_combined_val = router_scores.detach().clone()
    router_TC_EC_combined_val -= 1  # make sure EC's score is lower than TC & EC keeps the score order
    router_TC_EC_combined_val.scatter_(1, topk_indices, topk_values)  # mask out original TC score

    # second sorting, similar to EC
    topk_indices = router_TC_EC_combined_val.argsort(dim=0, descending=True).int()  # type: ignore

    if routing == "down":
        expert_freq_rounded = expert_freq_rounded_down

    elif routing == "up":
        expert_freq_rounded = expert_freq_rounded_up

    elif routing == "nr":
        expert_freq_rounded = torch.round(expert_freq / Mtile).type(torch.int32) * Mtile

    else:
        raise NotImplementedError()

    expert_freq_mask = torch.arange(T, device=device, dtype=torch.int32)[:, None].expand(-1, E) < expert_freq_rounded[None, :]  # type: ignore

    selected_T = topk_indices[expert_freq_mask]
    selected_E = torch.arange(E, device=device, dtype=torch.int32)[None, :].expand(T, -1)[expert_freq_mask]  # type: ignore

    # implicit assumption: selected_T should be sorted in my reduction code
    selected_T_order = selected_T.argsort().int()
    selected_T = selected_T[selected_T_order]
    selected_E = selected_E[selected_T_order]

    return router_scores[selected_T, selected_E].contiguous(), selected_T, selected_E


def forward_topk(x: torch.Tensor, router_w: torch.Tensor, E, K) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    T = x.shape[0]

    router_logits = F.linear(x, router_w)

    top_logits, topk_indices = router_logits.topk(K, dim=1)
    router_scores = F.softmax(top_logits, dim=-1, dtype=torch.float32)

    # first sorting, similar to TC
    return (
        router_scores.view(-1),
        torch.arange(T, device="cuda", dtype=torch.int32).repeat_interleave(K),
        topk_indices.view(-1).int(),
    )



In [None]:
# router_scores_selected: torch.Size([262144]) torch.float32 cuda:0 - [topk * seq]
# selected_T: torch.Size([262144]) torch.int32 cuda:0 - [topk * seq] -> selected expert (i.e. [0-7] is 0 because these are token 0)
# selected_E: torch.Size([262144]) torch.int32 cuda:0 - [topk * seq]
router_scores_selected, selected_T, selected_E = forward_topk(x, router_w.detach(), E, K)
# router_scores_selected, selected_T, selected_E = forward_token_choice_rounding(
#     x, router_w.detach(), E, K, Mtile, routing
# )

In [None]:
t = selected_E
print(t.shape, t.dtype, t.device)
t

torch.Size([262144]) torch.int32 cuda:0


tensor([ 83,  90,  80,  ..., 101,  57,  13], device='cuda:0',
       dtype=torch.int32)

: 