In [1]:
import triton
import triton.language as tl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
from collections import namedtuple
from dataclasses import dataclass

@dataclass
class MaskedWeight():
    mask: torch.Tensor
    weight: torch.Tensor

def get_const_fan_w_and_M(num_neurons: int = 10, dense_fan_in: int = 100, sparsity: float = 0.1) -> MaskedWeight:
    m = _generate_mask(num_neurons, dense_fan_in, sparsity)
    w = torch.rand(size=(num_neurons, dense_fan_in), dtype=torch.float32)
    w = w * m
    return MaskedWeight(m, w)

def _generate_mask(num_neurons, dense_fan_in, sparsity):
    m = torch.zeros(size=(num_neurons, dense_fan_in), dtype=torch.bool)
    num_ones = (m.numel() * sparsity).__floor__()
    sparse_fan_in = num_ones // num_neurons
    for neuron_idx, neuron in enumerate(m):
        ones_idx = torch.randperm(len(neuron))
        m[neuron_idx][ones_idx[:sparse_fan_in]] = True
    return m

def get_input(num_features: int = 10, num_batches: int = 1) -> torch.Tensor:
    return torch.rand(size=(num_batches, num_features))
    
    
mw = get_const_fan_w_and_M(3,10,0.33)
input = get_input(10,1)
(input@mw.weight.T).shape

torch.Size([1, 3])

In [6]:
input @ nn.Linear(in_features=10, out_features=3).weight.data.T

tensor([[-0.3149, -0.2630, -0.0904]])

In [13]:
import torch

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
                 # NOTE: `constexpr` so it can be used as a shape value.
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # tl.device_print(f"pid", pid)
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)
    
def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    # NOTE:
    #  - Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
    #  - Don't forget to pass meta-parameters as keywords arguments.
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return output


torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(
    f'The maximum difference between torch and triton is '
    f'{torch.max(torch.abs(output_torch - output_triton))}'
)

tensor([1.3713, 1.3076, 0.4940,  ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0


In [18]:
## Row major ordering
def main(
    num_features = 10, num_neurons = 10, sparsity = 0.1, num_batches=1
):
    mw = get_const_fan_w_and_M(num_neurons, num_features, sparsity)
    x = get_input(num_features, num_batches)
    mw.weight = mw.weight.to("cuda:0")
    x = x.to("cuda:0")
    mw.mask = mw.mask.to("cuda:0")
    try:
        x@mw.weight.T
    except Exception as e:
        raise e 
    _META = dict(BLOCK_SIZE=mw.weight.shape[1])
    nm_matmul(_META, W_T=mw.weight.T, x=x, mask_T=mw.mask.T)
    

def nm_matmul(_META: dict[str, any], W_T: torch.Tensor, x: torch.Tensor, mask_T :torch.Tensor) -> torch.Tensor:
    assert W_T.is_cuda and x.is_cuda and mask_T.is_cuda
    M,K = x.shape
    K,N = W_T.shape
    y: torch.Tensor = torch.empty((M,N), dtype=x.dtype, device=x.device)
    debug_m = torch.empty_like(mask_T, device=mask_T.device)
    grid = (M,) # 1D grid
    grid = lambda meta: (M,)
    abc = _nm_matmul[grid](  # We index each kernel by number of neurons in W 
        W_T, x, y, mask_T, debug_m,  # pointers to tensors
        M, N, K,  # Shape of tensors ( (MxK * KxN = MxN) )
        stride_wk=W_T.stride(0), stride_wn=W_T.stride(1),
        stride_xm=x.stride(0), stride_xk=x.stride(1),
        stride_ym=y.stride(0), stride_yn=y.stride(1),
        # BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_K,
        # BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=None,
    )
    print(mask_T)
    print(debug_m)    


@triton.jit
def _nm_matmul(
    w_ptr,  # Shape (in_features * num_neurons)
    x_ptr,  # Shape (num_batches, in_features)
    y_ptr,  # Shape (num_batches * num_neurons)
    mask_ptr,  # Shape (in_features * num_neurons)
    debug_ptr,
    M,
    N,
    K,
    stride_wk, stride_wn,
    stride_xm, stride_xk,
    stride_ym, stride_yn,
    # BLOCK_SIZE: tl.constexpr
    # BLOCK_SIZE_M,
    # BLOCK_SIZE_N,
    # BLOCK_SIZE_K,
    # GROUP_SIZE_M,
):
    # Multiplication in form X@W.T
    pid = tl.program_id(axis=0)  # 1D launch grid
    tl.device_print("pid:", pid)
    row_start = mask_ptr * tl.constexpr(10)
    mask = tl.load(mask_ptr)
    tl.store(debug_ptr, mask)
    # tl.static_print(mask)
    # non_zero_w = tl.load(w_ptr, mask=mask[row_idx])
    # tl.device_print(non_zero_w)
    return
    # row_start_ptr = w_ptr+row_idx*w_n_cols,

    # block_start = pid * 
    # mask = tl.load(m_ptr)
    # w_sp = tl.load(w_ptr, mask=m_ptr)
    # tl.store(w_ptr @ x_ptr)
    # print(type(w_ptr))
    
main()

pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
pid:0
tensor([[False, False, False, False, False,  True,  True, False, False, False],
        [False,  True, False, False,  True, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, 

In [None]:
## Group Ordering

def main(
    num_features = 10, num_neurons = 10, sparsity = 0.1, num_batches=1
):
    mw = get_const_fan_w_and_M(num_neurons, num_features, sparsity)
    x = get_input(num_features, num_batches)
    mw.weight = mw.weight.to("cuda:0")
    x = x.to("cuda:0")
    mw.mask = mw.mask.to("cuda:0")
    try:
        x@mw.weight.T
    except Exception as e:
        raise e 
    _META = dict(BLOCK_SIZE=mw.weight.shape[1])
    nm_matmul(_META, W=mw.weight, x=x, m=m)
    

def nm_matmul(_META: dict[str, any], W: torch.Tensor, x: torch.Tensor, m:torch.Tensor) -> torch.Tensor:
    assert W.is_cuda and x.is_cuda and m.is_cuda
    # num_elements = y.numel()
    # grid = lambda meta: (triton.cdiv(num_elements, _META["BLOCK_SIZE"]))  # ceiling division yields number of "programs" to be launched
    # BLOCK_SIZE = triton.next_power_of_2(W.shape[1])
    # num_warps = 4
    M,K = x.shape
    K,N = W.T.shape
    y: torch.Tensor = torch.empty_like((M,N), dtype=x.dtype, device=x.device)
    BLOCK_SIZE_M = 8
    BLOCK_SIZE_N = 8
    BLOCK_SIZE_K = 8
    grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]))  # 2D grid
    _nm_matmul[grid](  # We index each kernel by number of neurons in W 
        W.T, x, y,  # pointers to tensors
        M, N, K,  # Shape of tensors ( (MxK * KxN = MxN) )
        stride_wk=W.T.stride(0), stride_wn=W.T.stride(1),
        stride_xm=x.stride(0), stride_xk=x.stride(1),
        stride_ym=y.stride(0), stride_yn=y.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_K,
        BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=None,
    )
    


@triton.jit
def _nm_matmul(
    w_ptr,  # Shape (in_features * num_neurons)
    x_ptr,  # Shape (num_batches, in_features)
    y_ptr,  # Shape (num_batches * num_neurons)
    M,
    N,
    K,
    stride_wk, stride_wn,
    stride_xm, stride_xk,
    stride_ym, stride_yn,
    BLOCK_SIZE_M,
    BLOCK_SIZE_N,
    BLOCK_SIZE_K,
    GROUP_SIZE_M,
):
    # Multiplication in form X@W.T
    row_idx = tl.program_id(axis=0)  # 1D launch grid
    row_start_ptr = w_ptr+row_idx*w_n_cols,

    block_start = pid * 
    mask = tl.load(m_ptr)
    w_sp = tl.load(w_ptr, mask=m_ptr)
    tl.store(w_ptr @ x_ptr)
    print(type(w_ptr))
    
main()

SyntaxError: invalid syntax (3339701006.py, line 64)