In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange
from moe import sdd_kernel
from model import MoeMLP
@dataclass
class ToyMoEConfig:
    n_embd: int = 128
    num_experts: int = 4
    num_experts_per_tok: int = 2
    norm_topk_prob: bool = True
    bias: bool = True
    dropout: float = 0.0
config = ToyMoEConfig()
# print(f"Config: {config}")

# Create toy input
batch_size = 4
seq_len = 8
x = torch.randn(batch_size, seq_len, config.n_embd,device='cuda')
# print(f"\nInput shape: {x.shape}")

# Create MoeMLP instance
moe_mlp = MoeMLP(config).cuda()
# print(f"\nMoeMLP d_ffn: {moe_mlp.d_ffn}")
# print(f"w1 shape: {moe_mlp.w1.shape}")
# print(f"w2 shape: {moe_mlp.w2.shape}")

# Test forward pass
print(f"\n=== Forward Pass ===")
block_sparse, router_logits, debug_info = moe_mlp(x)
print(f"\nForward pass completed!")
print(f"Router logits shape: {router_logits.shape}")
print(f"Block sparse shape: {block_sparse.shape}")

# Extract debug variables
col_indices_ptr = debug_info['col_indices_ptr']
row_indices_ptr = debug_info['row_indices_ptr']
selected_experts = debug_info['selected_experts']
selected_experts_sorted = debug_info['selected_experts_sorted']

# Show some debug info
print(f"\nRouter logits sample:\n{router_logits[:3, :4]}")
print(f"Block sparse sample:\n{block_sparse[:5, :5]}")

In [None]:
# Debug: Check which experts are being used
print("=== EXPERT USAGE ANALYSIS ===")
print("col_indices_ptr:", col_indices_ptr)

# Use the debug info from the forward pass
selected_experts_flat = selected_experts.reshape(-1)

print("selected_experts_sorted (first 20):", selected_experts_sorted[:20])
print("selected_experts_flat (first 20):", selected_experts_flat[:20])

print("\nExpert distribution:")
for i in range(config.num_experts):
    count = (selected_experts_flat == i).sum()
    print(f"  Expert {i}: {count} blocks ({count/len(selected_experts_flat)*100:.1f}%)")

print("\nRouter logits sample (first 5 tokens):")
print(router_logits[:5])

print("\nSelected experts (first 10 tokens):")
print(selected_experts[:10])

# Get router weights for display
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
router_weights_topk, _ = torch.topk(router_weights, config.num_experts_per_tok, dim=-1)
if config.norm_topk_prob:
    router_weights_topk /= router_weights_topk.sum(dim=-1, keepdim=True)

print("\nRouter weights (first 5 tokens):")
print(router_weights_topk[:5])


In [None]:
# Debug: Analyze block_sparse pattern
print("=== BLOCK_SPARSE PATTERN ANALYSIS ===")
print(f"block_sparse shape: {block_sparse.shape}")

# Check if certain columns are always zero
nonzero_cols = torch.any(block_sparse != 0, dim=0)
print(f"Non-zero columns: {torch.sum(nonzero_cols)} out of {block_sparse.shape[1]}")

if torch.sum(nonzero_cols) > 0:
    first_nonzero = torch.argmax(nonzero_cols.float())
    last_nonzero = len(nonzero_cols) - 1 - torch.argmax(torch.flip(nonzero_cols.float(), [0]))
    print(f"First non-zero col: {first_nonzero}")
    print(f"Last non-zero col: {last_nonzero}")
    print(f"Non-zero range: columns {first_nonzero} to {last_nonzero}")

# Check weight matrix structure
print(f"\nWeight matrix w1 shape: {moe_mlp.w1.shape}")
print(f"d_ffn per expert: {moe_mlp.d_ffn}")
print(f"Total experts: {config.num_experts}")
print(f"Expected w1 width: {moe_mlp.d_ffn * config.num_experts}")

# Show which parts of w1 should correspond to each expert
print(f"\nExpert weight matrix ranges:")
for i in range(config.num_experts):
    start_col = i * moe_mlp.d_ffn
    end_col = (i + 1) * moe_mlp.d_ffn
    print(f"  Expert {i}: columns {start_col} to {end_col-1}")

# Check basic stats
print(f"\nblock_sparse stats:")
print(f"  Non-zero elements: {torch.count_nonzero(block_sparse)}")
print(f"  Min value: {block_sparse.min():.6f}")
print(f"  Max value: {block_sparse.max():.6f}")
print(f"  Mean: {block_sparse.mean():.6f}")
print(f"  Std: {block_sparse.std():.6f}")


In [None]:
# Debug: Check kernel indexing calculation
print("=== KERNEL INDEXING DEBUG ===")

# Show the relationship between col_indices_ptr and the actual expert columns in w1
print(f"BLOCK_N: {moe_mlp.BLOCK_N}")
print(f"col_indices_ptr: {col_indices_ptr}")

for i, expert_id in enumerate(col_indices_ptr):
    expected_col_start = expert_id * moe_mlp.BLOCK_N  # This is what the kernel calculates
    actual_expert_start = expert_id * moe_mlp.d_ffn   # This is where expert weights actually are
    print(f"Block {i}: expert_id={expert_id}")
    print(f"  Kernel will use columns {expected_col_start} to {expected_col_start + moe_mlp.BLOCK_N - 1}")
    print(f"  Expert {expert_id} weights are at columns {actual_expert_start} to {actual_expert_start + moe_mlp.d_ffn - 1}")
    print(f"  Match: {expected_col_start == actual_expert_start and moe_mlp.BLOCK_N == moe_mlp.d_ffn}")

# Test with a simple manual calculation to verify the kernel is working
print(f"\n=== MANUAL VERIFICATION ===")
print(f"If all tokens use expert 0, we should see non-zeros in columns 0 to {moe_mlp.d_ffn-1}")
print(f"If all tokens use expert 1, we should see non-zeros in columns {moe_mlp.d_ffn} to {2*moe_mlp.d_ffn-1}")

# Check the actual pattern
active_regions = []
for i in range(config.num_experts):
    start_col = i * moe_mlp.d_ffn
    end_col = (i + 1) * moe_mlp.d_ffn
    region_has_values = torch.any(block_sparse[:, start_col:end_col] != 0)
    active_regions.append(region_has_values.item())
    print(f"Expert {i} region (cols {start_col}-{end_col-1}): {'ACTIVE' if region_has_values else 'INACTIVE'}")

print(f"\nActive expert regions: {[i for i, active in enumerate(active_regions) if active]}")


In [None]:
# Print the FULL block_sparse tensor (as requested)
print("=== FULL BLOCK_SPARSE TENSOR ===")
torch.set_printoptions(threshold=float('inf'), linewidth=200)
print("block_sparse (full tensor):")
print(block_sparse)

# Reset print options to default
torch.set_printoptions(profile="default")
