In [None]:
import torch
import stk.random
import stk.ops

# Use standard block size of 128
block_size = 128

# Ensure dimensions are multiples of block_size
m = 1024  # 1024 = 8 * 128 ✓
n = 2048  # 2048 = 16 * 128 ✓ 
hidden_size = 512  # 512 = 4 * 128 ✓

# Create the topology
sparsity = 0.5
topo = stk.random.mask(m, n, sparsity, block_size).to('cuda')

# First operation: sdd (dense × dense → sparse)
a = torch.randn(m, hidden_size, device='cuda')
w1 = torch.randn(hidden_size, n, device='cuda')
block_sparse = stk.ops.sdd(a, w1, topo)

# Second operation: dsd (sparse × dense → dense)
w2 = torch.randn(n, hidden_size, device='cuda')
output = stk.ops.dsd(block_sparse, w2) #dsd outputs a tensor! very nice. 

print(f"Output shape: {output.shape}")  # Should be (m, hidden_size)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import stk.random
import stk.ops
from einops import rearrange

# Setup
block_size = 128
m, n = 1024, 2048
hidden_size = 512
num_experts = 8 
num_experts_per_tok = 2
norm_topk_prob = True
batch_size = 4

router = nn.Linear(hidden_size, num_experts, bias=False,device='cuda')

def route_tokens(x_flat):
    """Route tokens to experts and compute weights."""
    router_logits = router(x_flat)
    router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    router_weights, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1)
    
    if norm_topk_prob:
        router_weights /= router_weights.sum(dim=-1, keepdim=True)
    
    router_weights = router_weights.to(x_flat.dtype)
    return router_weights, selected_experts, router_logits

def sort_by_expert(x_flat, router_weights, selected_experts):
    """Replicate tokens for each expert and sort by expert assignment."""
    x_rep = x_flat.repeat_interleave(num_experts_per_tok, dim=0)
    selected_experts_rep = selected_experts.reshape(-1)
    router_weights_rep = router_weights.reshape(-1, 1)
    
    expert_sort_indices = torch.argsort(selected_experts_rep, stable=True)
    x_sorted = x_rep[expert_sort_indices]
    selected_experts_sorted = selected_experts_rep[expert_sort_indices]
    router_weights_sorted = router_weights_rep[expert_sort_indices]
    
    inv_expert_sort_indices = torch.empty_like(expert_sort_indices)
    
    return x_sorted, selected_experts_sorted, router_weights_sorted, inv_expert_sort_indices

def make_topology(): #???
    pass



# Create topology (this doesn't need gradients)
# topo = stk.random.mask(m, n, sparsity, block_size).to('cuda')

# Create input and weights WITH gradient tracking
x = torch.randn(batch_size, m, hidden_size, device='cuda', requires_grad=True)
w1 = torch.randn(hidden_size, n, device='cuda', requires_grad=True)
w2 = torch.randn(n, hidden_size, device='cuda', requires_grad=True)

x_flat = rearrange(x, 'batch seq hidden -> (batch seq) hidden')

router_weights, selected_experts, router_logits = route_tokens(x_flat)

x_sorted, selected_experts_sorted, router_weights_sorted, inv_indices = sort_by_expert(x_flat, router_weights, selected_experts)

print("selected_experts:\n", selected_experts)
print("selected_experts_sorted:\n", selected_experts_sorted)
# topo = make_topology()#??

# # Forward pass
# sparse_hidden = stk.ops.sdd(x, w1, topo)  # x @ w1 with sparse output
# output = stk.ops.dsd(sparse_hidden, w2)   # sparse @ w2 with dense output

# # Create a simple loss
# target = torch.randn_like(output)
# loss = torch.nn.functional.mse_loss(output, target)

# print(f"Loss: {loss.item():.4f}")

# # Backward pass
# loss.backward()

# # Check gradients
# print(f"\nGradients computed:")
# print(f"x.grad shape: {x.grad.shape}, norm: {x.grad.norm().item():.4f}")
# print(f"w1.grad shape: {w1.grad.shape}, norm: {w1.grad.norm().item():.4f}")
# print(f"w2.grad shape: {w2.grad.shape}, norm: {w2.grad.norm().item():.4f}")

selected_experts:
 tensor([[3, 5],
        [2, 5],
        [7, 2],
        ...,
        [2, 3],
        [0, 3],
        [4, 2]], device='cuda:0')
selected_experts_sorted:
 tensor([0, 0, 0,  ..., 7, 7, 7], device='cuda:0')
router_weights_sorted:
 tensor([[0.5580],
        [0.6590],
        [0.6197],
        ...,
        [0.3672],
        [0.5748],
        [0.4858]], device='cuda:0', grad_fn=<IndexBackward0>)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import stk.ops
import stk.matrix
from einops import rearrange

# Setup
block_size = 128
m, n = 1024, 2048
hidden_size = 512
num_experts = 8 
num_experts_per_tok = 2
expert_capacity = hidden_size  # Each expert has hidden_size dimension
norm_topk_prob = True
batch_size = 4

router = nn.Linear(hidden_size, num_experts, bias=False, device='cuda')

def route_tokens(x_flat):
    """Route tokens to experts and compute weights."""
    router_logits = router(x_flat)
    router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    router_weights, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1)
    
    if norm_topk_prob:
        router_weights /= router_weights.sum(dim=-1, keepdim=True)
    
    router_weights = router_weights.to(x_flat.dtype)
    return router_weights, selected_experts, router_logits

def sort_by_expert(x_flat, router_weights, selected_experts):
    """Replicate tokens for each expert and sort by expert assignment."""
    x_rep = x_flat.repeat_interleave(num_experts_per_tok, dim=0)
    selected_experts_rep = selected_experts.reshape(-1)
    router_weights_rep = router_weights.reshape(-1, 1)
    
    expert_sort_indices = torch.argsort(selected_experts_rep, stable=True)
    x_sorted = x_rep[expert_sort_indices]
    selected_experts_sorted = selected_experts_rep[expert_sort_indices]
    router_weights_sorted = router_weights_rep[expert_sort_indices]
    
    # Compute inverse indices for unsort
    inv_expert_sort_indices = torch.empty_like(expert_sort_indices)
    inv_expert_sort_indices[expert_sort_indices] = torch.arange(
        len(expert_sort_indices), device=expert_sort_indices.device
    )
    
    return x_sorted, selected_experts_sorted, router_weights_sorted, inv_expert_sort_indices


def pad_to_blocks( x_sorted, selected_experts_sorted):
    """Pad each expert's tokens to multiples of block_size and track unpadding indices."""
    device = x_sorted.device
    num_tokens = x_sorted.shape[0]
    token_dim = x_sorted.shape[-1]

    # Use self.num_experts and self.block_size directly
    min_tokens_or_experts = min(num_tokens, num_experts)
    max_blocks = min_tokens_or_experts + (num_tokens - min_tokens_or_experts) // block_size
    capacity_tokens = max_blocks * block_size

    # Per-expert counts via scatter_add (compile-safe; avoids bincount)
    tokens_per_expert = torch.zeros(num_experts, dtype=torch.long, device=device)
    ones = torch.ones_like(selected_experts_sorted, dtype=torch.long)
    tokens_per_expert.scatter_add_(0, selected_experts_sorted, ones)

    # Round each expert up to a multiple of block_size
    tokens_per_expert_padded = ((tokens_per_expert + block_size - 1) // block_size) * block_size

    # Exclusive-prefix sums (orig vs padded) for placement
    offset_original = F.pad(tokens_per_expert.cumsum(0), (1, 0))
    offset_padded  = F.pad(tokens_per_expert_padded.cumsum(0), (1, 0))

    # Allocate fixed capacity once; the tail is never indexed
    x_padded = x_sorted.new_zeros((capacity_tokens, token_dim))

    
    # Map each sorted token to its padded position
    token_idx = torch.arange(num_tokens, device=x_sorted.device)
    print(token_idx)
    idx_within_expert = token_idx - offset_original[selected_experts_sorted]
    print(tokens_per_expert_padded)
    unpad_indices = idx_within_expert + offset_padded[selected_experts_sorted]

    # Scatter the actual tokens into their padded slots
    x_padded[unpad_indices] = x_sorted

    # Return exactly what you wanted
    return x_padded, tokens_per_expert_padded, unpad_indices


def create_sparse_indices(tokens_per_expert_padded):
    device = tokens_per_expert_padded.device

    # Compute blocks per expert (vectorized)
    num_token_blocks_per_expert = tokens_per_expert_padded // block_size
    blocks_per_expert = num_token_blocks_per_expert * _num_ffn_blocks

    # Convert to Python int ONCE to avoid dynamic shapes in torch.compile
    # This single .item() call prevents thousands of shape guard checks
    total_blocks = int(blocks_per_expert.sum())

    indices = torch.arange(total_blocks, device=device, dtype=torch.long)

    # Single cumsum for expert assignment
    cumsum = blocks_per_expert.cumsum(0)
    expert_ids = torch.searchsorted(cumsum, indices, right=True).clamp(max=self.num_experts - 1)

    # Fuse within-expert index computation
    cumsum_padded = F.pad(cumsum[:-1], (1, 0))
    within_expert_idx = indices - cumsum_padded[expert_ids]

    # Compute row indices (combine operations)
    token_block_cumsum = num_token_blocks_per_expert.cumsum(0)
    token_block_offset = F.pad(token_block_cumsum[:-1], (1, 0))

    # Fast modulo and division using bit operations when possible
    # Since _num_ffn_blocks is often a power of 2, we can optimize
    within_expert_block = within_expert_idx // self._num_ffn_blocks
    within_expert_ffn = within_expert_idx % self._num_ffn_blocks

    row_indices = token_block_offset[expert_ids] + within_expert_block
    weight_col_indices = expert_ids * self._num_ffn_blocks + within_expert_ffn
    output_col_indices = within_expert_ffn

    return row_indices.int(), weight_col_indices.int(), output_col_indices.int()

# Your existing code
x = torch.randn(batch_size, m, hidden_size, device='cuda', requires_grad=True)
x_flat = rearrange(x, 'batch seq hidden -> (batch seq) hidden')

router_weights, selected_experts, router_logits = route_tokens(x_flat)
x_sorted, selected_experts_sorted, router_weights_sorted, inv_indices = sort_by_expert(
    x_flat, router_weights, selected_experts
)

x_padded, tokens_per_expert_padded, unpad_indices = pad_to_blocks(x_sorted, selected_experts_sorted)


# Create the topology
# topology, padded_tokens_per_expert = make_topology(


# print(f"Topology shape: {topology.size()}")
print(f"Tokens per expert: {torch.bincount(selected_experts_sorted, minlength=num_experts)}")
print(f"Padded blocks per expert: {tokens_per_expert_padded//128}")

# # Now you can use the topology for sparse matmul
# # Create expert weights (all experts stacked)
# all_expert_weights = torch.randn(
#     num_experts * expert_capacity, 
#     n, 
#     device='cuda',
#     requires_grad=True
# )

# # Pad x_sorted to match topology expectations
# total_padded = padded_tokens_per_expert.sum().item()
# if x_sorted.shape[0] < total_padded:
#     padding = total_padded - x_sorted.shape[0]
#     x_padded = torch.cat([
#         x_sorted,
#         torch.zeros(padding, hidden_size, device='cuda')
#     ])
# else:
#     x_padded = x_sorted
# print(x_padded.shape)
# print(all_expert_weights.shape)
# # Sparse matrix multiply with expert weights
# w1 = torch.randn(hidden_size, n, device='cuda', requires_grad=True)
# w2 = torch.randn(n, hidden_size, device='cuda', requires_grad=True)

# result_sparse = stk.ops.sdd(x_padded, w1, topology)
# result = stk.ops.to_dense(result_sparse)

# print(f"Result shape: {result.shape}")

tensor([   0,    1,    2,  ..., 8189, 8190, 8191], device='cuda:0')
tensor([1152, 1024, 1024, 1152, 1024, 1024, 1024, 1152], device='cuda:0')
Tokens per expert: tensor([1069, 1011,  995, 1034, 1021, 1021, 1000, 1041], device='cuda:0')
Padded blocks per expert: tensor([9, 8, 8, 9, 8, 8, 8, 9], device='cuda:0')
