In [4]:
import torch
import torch.nn.functional as F

device = "cuda:1" if torch.cuda.is_available() else "cpu"
n_embd = 8
selected_experts_sorted = torch.tensor(
    [0,0,0,0,0, 1,1, 2], dtype=torch.long, device=device
)

def _pad_to_blocks(x_sorted: torch.Tensor,
                   selected_experts_sorted: torch.Tensor,
                   block_size: int = 4):
    """
    Standalone version. Assumes selected_experts_sorted is grouped (ascending by expert id).
    Returns: x_padded, tokens_per_expert_padded, unpad_indices
    """
    device = x_sorted.device
    n_embd = x_sorted.size(-1)
    num_experts = int(selected_experts_sorted.max().item()) + 1 if selected_experts_sorted.numel() else 0
    tokens_per_expert = torch.zeros(num_experts, dtype=torch.long, device=device)
    ones = torch.ones_like(selected_experts_sorted, dtype=torch.long)

    tokens_per_expert.scatter_add_(0, selected_experts_sorted, ones)
    tokens_per_expert_padded = ((tokens_per_expert + block_size - 1) // block_size) * block_size

    cumsum_original = F.pad(tokens_per_expert.cumsum(0), (1, 0))
    cumsum_padded   = F.pad(tokens_per_expert_padded.cumsum(0), (1, 0))
    total_padded_tokens = cumsum_padded[-1]

    x_padded = torch.zeros((total_padded_tokens, n_embd), dtype=x_sorted.dtype, device=device)

    original_positions = torch.arange(len(x_sorted),device=x_sorted.device)
    padded_positions = (original_positions - cumsum_original[selected_experts_sorted]) + cumsum_padded[selected_experts_sorted]
    x_padded[padded_positions] = x_sorted

    unpad_indices = padded_positions
    return x_padded, tokens_per_expert_padded, unpad_indices

x_sorted = torch.randn(selected_experts_sorted.numel(), n_embd, device=device)

# example call
x_padded, tokens_per_expert_padded, unpad_indices = _pad_to_blocks(
    x_sorted, selected_experts_sorted
)


In [19]:
import torch
import torch.nn.functional as F

device = "cuda:1" if torch.cuda.is_available() else "cpu"

# Setup parameters
num_experts = 3
block_size = 4
d_ffn = 12  # FFN dimension (will need 3 blocks of size 4)

# From your previous function, we'd have tokens_per_expert_padded
# Let's say expert 0 got 8 tokens, expert 1 got 8, expert 2 got 4
tokens_per_expert_padded = torch.tensor([8, 8, 4], dtype=torch.long, device=device)

def _create_sparse_indices(tokens_per_expert_padded, num_experts, block_size, d_ffn):
    """Create row and column indices for sparse blocks."""
    device = tokens_per_expert_padded.device
    num_token_blocks_per_expert = tokens_per_expert_padded // block_size
    num_ffn_blocks = (d_ffn + block_size - 1) // block_size
    
    blocks_per_expert = num_token_blocks_per_expert * num_ffn_blocks
    
    expert_ids = torch.repeat_interleave(
        torch.arange(num_experts, device=device),
        blocks_per_expert
    )
    within_expert_block_idx = torch.arange(len(expert_ids),device=device) - F.pad(blocks_per_expert.cumsum(0)[:-1], (1,0))[expert_ids]
    

    token_block_offset = F.pad(num_token_blocks_per_expert.cumsum(0)[:-1], (1, 0))
    row_indices = token_block_offset[expert_ids] + (within_expert_block_idx // num_ffn_blocks)
    weight_col_indices = expert_ids * num_ffn_blocks + (within_expert_block_idx % num_ffn_blocks)
    output_col_indices = within_expert_block_idx % num_ffn_blocks
    
    return row_indices.int(), weight_col_indices.int(), output_col_indices.int()

# Test it
row_idx, weight_col_idx, output_col_idx = _create_sparse_indices(
    tokens_per_expert_padded, num_experts, block_size, d_ffn
)

print(f"\nrow_indices: {row_idx}")
print(f"weight_col_indices: {weight_col_idx}")
print(f"output_col_indices: {output_col_idx}")


row_indices: tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], device='cuda:1',
       dtype=torch.int32)
weight_col_indices: tensor([0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8], device='cuda:1',
       dtype=torch.int32)
output_col_indices: tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], device='cuda:1',
       dtype=torch.int32)
