In [None]:
import torch
import stk.random
import stk.ops

# Use standard block size of 128
block_size = 128

# Ensure dimensions are multiples of block_size
m = 1024  # 1024 = 8 * 128 ✓
n = 2048  # 2048 = 16 * 128 ✓ 
hidden_size = 512  # 512 = 4 * 128 ✓

# Create the topology
sparsity = 0.5
topo = stk.random.mask(m, n, sparsity, block_size).to('cuda')

# First operation: sdd (dense × dense → sparse)
a = torch.randn(m, hidden_size, device='cuda')
w1 = torch.randn(hidden_size, n, device='cuda')
block_sparse = stk.ops.sdd(a, w1, topo)

# Second operation: dsd (sparse × dense → dense)
w2 = torch.randn(n, hidden_size, device='cuda')
output = stk.ops.dsd(block_sparse, w2) #dsd outputs a tensor! very nice. 

print(f"Output shape: {output.shape}")  # Should be (m, hidden_size)

In [2]:
import torch
import stk.random
import stk.ops

# Setup
block_size = 128
m, n = 1024, 2048
hidden_size = 512
sparsity = 0.5

# Create topology (this doesn't need gradients)
topo = stk.random.mask(m, n, sparsity, block_size).to('cuda')

# Create input and weights WITH gradient tracking
x = torch.randn(m, hidden_size, device='cuda', requires_grad=True)
w1 = torch.randn(hidden_size, n, device='cuda', requires_grad=True)
w2 = torch.randn(n, hidden_size, device='cuda', requires_grad=True)

# Forward pass
sparse_hidden = stk.ops.sdd(x, w1, topo)  # x @ w1 with sparse output
output = stk.ops.dsd(sparse_hidden, w2)   # sparse @ w2 with dense output

# Create a simple loss
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)

print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()

# Check gradients
print(f"\nGradients computed:")
print(f"x.grad shape: {x.grad.shape}, norm: {x.grad.norm().item():.4f}")
print(f"w1.grad shape: {w1.grad.shape}, norm: {w1.grad.norm().item():.4f}")
print(f"w2.grad shape: {w2.grad.shape}, norm: {w2.grad.norm().item():.4f}")

Loss: 518487.7812

Gradients computed:
x.grad shape: torch.Size([1024, 512]), norm: 2381.9978
w1.grad shape: torch.Size([512, 2048]), norm: 2482.5469
w2.grad shape: torch.Size([2048, 512]), norm: 2263.8528
