In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange
import math

# Simple config for toy MoE testing
@dataclass
class ToyMoEConfig:
    n_embd: int = 32          # small embedding dimension
    num_experts: int = 8      # total number of experts
    num_experts_per_tok: int = 2  # top-k active experts per token
    norm_topk_prob: bool = True   # normalize top-k probabilities
    bias: bool = True
    dropout: float = 0.0

config = ToyMoEConfig()
print(f"Config: {config}")

# Create toy input data
batch_size = 4
seq_len = 8  # small sequence length
n_embd = config.n_embd

# Input tensor: [batch_size, seq_len, n_embd]
x = torch.randn(batch_size, seq_len, n_embd)
print(f"Input shape: {x.shape}")

# Flatten to [num_tokens, n_embd] as the MoeMLP expects
x_flat = rearrange(x, 'batch seq hidden -> (batch seq) hidden')
num_tokens = batch_size * seq_len
print(f"\nFlattened input shape: {x_flat.shape}")
print(f"num_tokens = {num_tokens}")

# Create a simplified router to understand the routing behavior
router = nn.Linear(config.n_embd, config.num_experts, bias=False)

# Get router logits and top-k selection
router_logits = router(x_flat)
print(f"Router logits shape: {router_logits.shape}")  # [num_tokens, num_experts]

# Softmax and top-k selection
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
router_weights, selected_experts = torch.topk(router_weights, config.num_experts_per_tok, dim=-1)

if config.norm_topk_prob:
    router_weights /= router_weights.sum(dim=-1, keepdim=True)

print(f"Router weights shape: {router_weights.shape}")  # [num_tokens, top_k]
print(f"Selected experts shape: {selected_experts.shape}")  # [num_tokens, top_k]
print(f"Selected experts:\n{selected_experts}")

# Create expert mask to see which experts are active for each token
expert_mask = F.one_hot(selected_experts, num_classes=config.num_experts)
print(f"Expert mask shape: {expert_mask.shape}")  # [num_tokens, top_k, num_experts]

# Rearrange as done in the original code: n k e -> e k n
expert_mask_rearranged = rearrange(expert_mask, 'n k e -> e k n')
print(f"Rearranged expert mask shape: {expert_mask_rearranged.shape}")  # [num_experts, top_k, num_tokens]

# Create weight matrices as in the MoeMLP class
w1 = torch.randn(config.n_embd, config.num_experts * 4 * config.n_embd) 
w2 = torch.randn(config.num_experts * 4 * config.n_embd, config.n_embd)

print(f"w1 shape: {w1.shape}")  # [n_embd, num_experts * 4 * n_embd]
print(f"w2 shape: {w2.shape}")  # [num_experts * 4 * n_embd, n_embd]

# Show what the kernel call parameters should be:
print(f"\n=== Parameters for sdd_kernel call ===")
print(f"M (num_tokens): {num_tokens}")
print(f"N (ffn_hidden * num_experts): {4 * config.n_embd * config.num_experts}")  # This is wrong in the original code
print(f"K (hidden_size): {config.n_embd}")

# The issue: N should probably be related to num_experts_per_tok, not total experts
print(f"N (corrected, ffn_hidden * num_experts_per_tok): {4 * config.n_embd * config.num_experts_per_tok}")

# Count tokens per expert
tokens_per_expert = torch.zeros(config.num_experts, dtype=torch.long)
for expert_idx in range(config.num_experts):
    # Count how many tokens are assigned to this expert
    mask = (selected_experts == expert_idx).any(dim=1)
    tokens_per_expert[expert_idx] = mask.sum()

print(f"\nTokens per expert: {tokens_per_expert}")
print(f"Total active assignments: {(selected_experts >= 0).sum()}")  # should be num_tokens * num_experts_per_tok

# Create the indices needed for sparse kernel call
# We need to figure out which blocks are active and their positions

print("=== Creating sparse indices for kernel ===")

# Create row and column indices for each active block
row_indices = []
col_indices = []

# Iterate through each token and its selected experts
for token_idx in range(num_tokens):
    for k_idx in range(config.num_experts_per_tok):
        expert_idx = selected_experts[token_idx, k_idx]
        
        # Row index is the token index
        row_indices.append(token_idx)
        
        # Column index is the expert index (which expert block)
        col_indices.append(expert_idx.item())

row_indices = torch.tensor(row_indices, dtype=torch.long)
col_indices = torch.tensor(col_indices, dtype=torch.long)

print(f"Number of active blocks: {len(row_indices)}")
print(f"Row indices (which token): {row_indices}")
print(f"Col indices (which expert): {col_indices}")

# Show the mapping
print(f"\nActive blocks mapping:")
for i in range(len(row_indices)):
    token_idx = row_indices[i]
    expert_idx = col_indices[i]
    weight = router_weights.flatten()[i]  # corresponding weight
    print(f"  Block {i}: Token {token_idx} -> Expert {expert_idx} (weight: {weight:.4f})")

# This shows what tensor we need to create for the x_ptr in the kernel
print(f"\n=== Tensor reshaping for kernel ===")
print(f"Original x_flat shape: {x_flat.shape}")
print(f"Need to permute/gather tokens according to active blocks...")

# Example of how to gather the active tokens
active_tokens = x_flat[row_indices]  # Shape: [num_active_blocks, n_embd]
print(f"Active tokens shape: {active_tokens.shape}")
print(f"First few active tokens:\n{active_tokens[:5]}")
# Summary: What the kernel needs and stride calculations
print("=== Summary for sdd_kernel call ===")

# Calculate strides for the matrices
# x_ptr: active_tokens [num_active_blocks, n_embd]
stride_xm = config.n_embd  # stride between rows
stride_xk = 1              # stride between columns (contiguous)

# w1_ptr: weight matrix [n_embd, num_experts * 4 * n_embd] 
stride_wk = config.num_experts * 4 * config.n_embd  # stride between rows
stride_wn = 1                                        # stride between columns

# output_ptr: sparse output [num_tokens, num_experts * 4 * n_embd]
# But we only fill the active blocks!
stride_om = config.num_experts * 4 * config.n_embd  # stride between rows  
stride_on = 1                                        # stride between columns

print(f"Strides:")
print(f"  stride_xm={stride_xm}, stride_xk={stride_xk}")
print(f"  stride_wk={stride_wk}, stride_wn={stride_wn}")
print(f"  stride_om={stride_om}, stride_on={stride_on}")

print(f"\nKernel parameters:")
print(f"  x_ptr: active_tokens (shape: {active_tokens.shape})")
print(f"  w1_ptr: w1 (shape: {w1.shape})")
print(f"  output_ptr: sparse output tensor (need to create)")
print(f"  row_indices_ptr: {row_indices.shape} -> {row_indices}")  
print(f"  col_indices_ptr: {col_indices.shape} -> {col_indices}")
print(f"  M={num_tokens}")
print(f"  N={4 * config.n_embd * config.num_experts} (or maybe {4 * config.n_embd * config.num_experts_per_tok}?)")
print(f"  K={config.n_embd}")

print(f"\n🚧 TODO: Figure out the correct output tensor shape and how to handle sparsity!")
print(f"The current MoeMLP.forward() has incomplete tensor reshaping before the kernel call.")

# Create a placeholder sparse output tensor to show what it might look like
sparse_output = torch.zeros(num_tokens, config.num_experts * 4 * config.n_embd)
print(f"\nSparse output tensor shape: {sparse_output.shape}")
print(f"Most entries will be zero - only active expert blocks will be filled!")


Config: ToyMoEConfig(n_embd=32, num_experts=8, num_experts_per_tok=2, norm_topk_prob=True, bias=True, dropout=0.0)
Input shape: torch.Size([4, 8, 32])

Flattened input shape: torch.Size([32, 32])
num_tokens = 32
Router logits shape: torch.Size([32, 8])
Router weights shape: torch.Size([32, 2])
Selected experts shape: torch.Size([32, 2])
Selected experts:
tensor([[4, 6],
        [5, 7],
        [2, 0],
        [4, 5],
        [5, 6],
        [1, 5],
        [7, 2],
        [3, 5],
        [4, 6],
        [4, 3],
        [6, 1],
        [2, 3],
        [2, 0],
        [2, 5],
        [7, 2],
        [7, 2],
        [4, 3],
        [6, 0],
        [4, 0],
        [5, 3],
        [0, 5],
        [7, 3],
        [3, 4],
        [0, 1],
        [1, 5],
        [2, 7],
        [0, 4],
        [3, 6],
        [2, 0],
        [2, 7],
        [2, 6],
        [1, 7]])
Expert mask shape: torch.Size([32, 2, 8])
Rearranged expert mask shape: torch.Size([8, 2, 32])
w1 shape: torch.Size([32, 1024])
w2