In [2]:
import torch
import torch.nn.functional as F

device = "cuda:1" if torch.cuda.is_available() else "cpu"
n_embd = 8
selected_experts_sorted = torch.tensor(
    [0,0,0,0,0, 1,1, 2], dtype=torch.long, device=device
)

def _pad_to_blocks(x_sorted: torch.Tensor,
                   selected_experts_sorted: torch.Tensor,
                   block_size: int = 4):
    """
    Standalone version. Assumes selected_experts_sorted is grouped (ascending by expert id).
    Returns: x_padded, tokens_per_expert_padded, unpad_indices
    """
    device = x_sorted.device
    n_embd = x_sorted.size(-1)
    num_experts = int(selected_experts_sorted.max().item()) + 1 if selected_experts_sorted.numel() else 0
    tokens_per_expert = torch.zeros(num_experts, dtype=torch.long, device=device)
    ones = torch.ones_like(selected_experts_sorted, dtype=torch.long)

    tokens_per_expert.scatter_add_(0, selected_experts_sorted, ones)
    tokens_per_expert_padded = ((tokens_per_expert + block_size - 1) // block_size) * block_size

    cumsum_original = F.pad(tokens_per_expert.cumsum(0), (1, 0))
    cumsum_padded   = F.pad(tokens_per_expert_padded.cumsum(0), (1, 0))
    total_padded_tokens = cumsum_padded[-1]

    x_padded = torch.zeros((total_padded_tokens, n_embd), dtype=x_sorted.dtype, device=device)

    original_positions = torch.arange(len(x_sorted),device=x_sorted.device)
    padded_positions = (original_positions - cumsum_original[selected_experts_sorted]) + cumsum_padded[selected_experts_sorted]
    x_padded[padded_positions] = x_sorted

    unpad_indices = padded_positions
    return x_padded, tokens_per_expert_padded, unpad_indices

x_sorted = torch.randn(selected_experts_sorted.numel(), n_embd, device=device)

# example call
x_padded, tokens_per_expert_padded, unpad_indices = _pad_to_blocks(
    x_sorted, selected_experts_sorted
)


In [3]:
import torch
import torch.nn.functional as F

device = "cuda:1" if torch.cuda.is_available() else "cpu"

# Setup parameters
num_experts = 3
block_size = 4
d_ffn = 12  # FFN dimension (will need 3 blocks of size 4)

# From your previous function, we'd have tokens_per_expert_padded
# Let's say expert 0 got 8 tokens, expert 1 got 8, expert 2 got 4
tokens_per_expert_padded = torch.tensor([8, 8, 4], dtype=torch.long, device=device)

def _create_sparse_indices(tokens_per_expert_padded, num_experts, block_size, d_ffn):
    """Create row and column indices for sparse blocks."""
    device = tokens_per_expert_padded.device
    num_token_blocks_per_expert = tokens_per_expert_padded // block_size
    num_ffn_blocks = (d_ffn + block_size - 1) // block_size
    
    blocks_per_expert = num_token_blocks_per_expert * num_ffn_blocks
    
    expert_ids = torch.repeat_interleave(
        torch.arange(num_experts, device=device),
        blocks_per_expert
    )
    within_expert_block_idx = torch.arange(len(expert_ids),device=device) - F.pad(blocks_per_expert.cumsum(0)[:-1], (1,0))[expert_ids]
    

    token_block_offset = F.pad(num_token_blocks_per_expert.cumsum(0)[:-1], (1, 0))
    row_indices = token_block_offset[expert_ids] + (within_expert_block_idx // num_ffn_blocks)
    weight_col_indices = expert_ids * num_ffn_blocks + (within_expert_block_idx % num_ffn_blocks)
    output_col_indices = within_expert_block_idx % num_ffn_blocks
    
    return row_indices.int(), weight_col_indices.int(), output_col_indices.int()

# Test it
row_idx, weight_col_idx, output_col_idx = _create_sparse_indices(
    tokens_per_expert_padded, num_experts, block_size, d_ffn
)

print(f"\nrow_indices: {row_idx}")
print(f"weight_col_indices: {weight_col_idx}")
print(f"output_col_indices: {output_col_idx}")


row_indices: tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], device='cuda:1',
       dtype=torch.int32)
weight_col_indices: tensor([0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8], device='cuda:1',
       dtype=torch.int32)
output_col_indices: tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], device='cuda:1',
       dtype=torch.int32)


In [4]:
"""PyTorch implementation of what we're trying to do in dsd"""

import torch
import torch.nn as nn
import torch.nn.functional as F

num_tokens = 512
hidden_size = 768
d_ffn = 1536
num_experts = 8
num_active_experts = 2
block_size = 16
num_tokens_per_expert = num_tokens//num_experts #not realistic! but this is a toy scenario

x = torch.randn(num_tokens, d_ffn, device='cuda', dtype=torch.bfloat16) #this is x *after* w1

w2 = torch.randn(d_ffn * num_experts, hidden_size, device='cuda', dtype=torch.bfloat16) # w2 with num_experts experts

'''
Now, let's say that all the tokens get distributed to each expert evenly.
Then we'll get batch_size/num_experts = block_size tokens to each expert. Let's just do it in a for loop, but we know we'd be doing it in parallel in triton
'''

out_tensor = torch.zeros(num_tokens, hidden_size, device='cuda',dtype=torch.bfloat16) #accumulating to this tensor, so use zeros
for expert_idx in range(num_experts):
    x_bottom_index = expert_idx * num_tokens_per_expert
    x_top_index = (expert_idx + 1) * num_tokens_per_expert
    w2_bottom_index = expert_idx * d_ffn
    w2_top_index = (expert_idx + 1) * d_ffn
    x_expert = x[x_bottom_index:x_top_index]
    w2_expert = w2[w2_bottom_index:w2_top_index]
    output_block = torch.zeros(num_tokens_per_expert, hidden_size, device='cuda', dtype=torch.bfloat16)
    for k in range(0, d_ffn, block_size): # d_ffn is NOT divisible by block_size, so we'll have to mask! just like triton.
        k_end = min(k + block_size, d_ffn)

        x_tile = x_expert[:, k:k_end]
        w2_tile = w2_expert[k:k_end,:]

        output_block += x_tile @ w2_tile

    out_tensor[x_bottom_index:x_top_index] = output_block


In [1]:
import os
import torch
from model import MoeMLP, GPTConfig

torch.manual_seed(0)

config = GPTConfig()
config.n_embd = 768
config.num_experts = 8 
config.num_experts_per_tok = 2
config.n_ctx = 8
config.block_k,config.block_size = 16,16


model = MoeMLP(config).cuda().bfloat16()

x = torch.randn(1, 8, 768, device='cuda', dtype=torch.bfloat16)

out = model(x)

tensor([16, 16, 16, 16, 16, 16,  0, 16], device='cuda:0')


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

block_size = 16
n_embd = 768
num_experts = 8
num_experts_per_tok = 2
seq_len = n_ctx = 8 
d_ffn = (n_embd * 4) // num_experts_per_tok

tokens_per_expert_padded = torch.tensor([16, 16, 16, 16, 16, 16,  0, 16], device='cuda:0')

import time

def create_sparse_indices(tokens_per_expert_padded):
    """Create compact row/col indices for block-sparse MoE using compile-friendly ops."""
    device = tokens_per_expert_padded.device

    # Per-expert token blocks and FFN blocks
    num_token_blocks_per_expert = tokens_per_expert_padded // block_size
    num_ffn_blocks = (d_ffn + block_size - 1) // block_size
    blocks_per_expert = num_token_blocks_per_expert * num_ffn_blocks

    # Actual total blocks used this step
    total_blocks = blocks_per_expert.sum()
    print(total_blocks)

    # Static upper bound for fixed-size allocation (compile-safe)
    max_token_blocks_per_expert_static = (seq_len * num_experts_per_tok) // block_size
    max_blocks_static = num_experts * max_token_blocks_per_expert_static * num_ffn_blocks
    max_blocks = total_blocks.clamp_min(max_blocks_static)

    # Global linear indices [0, max_blocks)
    indices = torch.arange(max_blocks, device=device)

    # Map indices → experts via cumsum + searchsorted
    cumsum = blocks_per_expert.cumsum(0)
    expert_ids = torch.searchsorted(cumsum, indices, right=True)
    expert_ids = torch.clamp(expert_ids, max=num_experts - 1)

    # Within-expert linear offset
    cumsum_padded = F.pad(cumsum[:-1], (1, 0))
    within_expert_idx = indices - cumsum_padded[expert_ids]

    # Final index tensors
    token_block_offset = F.pad(num_token_blocks_per_expert.cumsum(0)[:-1], (1, 0))
    row_indices = token_block_offset[expert_ids] + (within_expert_idx // num_ffn_blocks)
    weight_col_indices = expert_ids * num_ffn_blocks + (within_expert_idx % num_ffn_blocks)
    output_col_indices = within_expert_idx % num_ffn_blocks

    # Mask out padded slots (≥ total_blocks) to a safe value
    valid_mask = indices < total_blocks
    zero_like = lambda t: torch.zeros_like(t)
    row_indices = torch.where(valid_mask, row_indices, zero_like(row_indices))
    weight_col_indices = torch.where(valid_mask, weight_col_indices, zero_like(weight_col_indices))
    output_col_indices = torch.where(valid_mask, output_col_indices, zero_like(output_col_indices))

    # Kernels typically want int32 indices
    return row_indices.int(), weight_col_indices.int(), output_col_indices.int()

out = create_sparse_indices(tokens_per_expert_padded)

tensor(672, device='cuda:0')
