In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
__DEVICE = torch.device("cuda")
# __DEVICE = torch.device("cpu")
__INPUT_DIM = 762
__OUTPUT_DIM = 4096
__BATCH_SIZE = 128
__SPARSITY = 0.96

In [4]:

@torch.no_grad()
def get_ffi_structure(mod: nn.Linear, sparsity: float) -> nn.Linear:
    n_zeros = int(mod.weight.numel() * (sparsity))
    n_zeros_per_neuron = n_zeros // mod.weight.shape[0]
    for idx, neuron in enumerate(mod.weight):
        rand_idx = torch.randperm(n=len(neuron))
        mod.weight[idx, rand_idx[:n_zeros_per_neuron-1]] = 0
    assert_ffi(mod)
    print_sparsity(mod)
    return mod

def assert_ffi(mod: nn.Linear):
    ffi = (mod.weight[0]!=0).sum()
    for n in mod.weight:
        assert (n!=0).sum()==ffi

def print_sparsity(mod: nn.Linear):
    print(f"Mod sparsity: {1-((mod.weight!=0).sum()/mod.weight.numel()).item():.4f}")

In [5]:
linear = nn.Linear(in_features=__INPUT_DIM, out_features=__OUTPUT_DIM, device=__DEVICE)
sparse_linear = get_ffi_structure(linear, __SPARSITY)

Mod sparsity: 0.9580


In [6]:
(sparse_linear.weight[0]!=0).sum()

tensor(32, device='cuda:0')

In [7]:
x = torch.rand(size=(__BATCH_SIZE, __INPUT_DIM), device=__DEVICE)
x.shape

torch.Size([128, 762])

In [8]:
class FFILinearNaive(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        dtype: torch.typename = torch.float32,
        transpose: bool = True,
        vectorize: bool = False,
        index_dtype: torch.typename = torch.int32,
    ):
        super().__init__()
        if dtype is None:
            dtype = module.weight.dtype

        self.transpose = transpose
        with torch.no_grad():
            fine_grained_idx = (module.weight != 0).to(
                torch.bool
            )
            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)
            self.input_mask = self.input_mask.reshape(
                shape=(module.weight.shape[0], -1)
            ).to(index_dtype)
            weight = module.weight.detach().type(dtype)
            weight = torch.clone(
                weight[fine_grained_idx]
                .reshape(shape=(weight.shape[0], -1))
                .detach()
                .type(dtype)
            )
            # padding to multiple of 4
            if vectorize:
                pad = (
                    self.input_mask.shape[1] + 3
                ) // 4 * 4 - self.input_mask.shape[1]
                self.input_mask = F.pad(self.input_mask, [0, pad])
                weight = F.pad(weight, [0, pad])

            self.condensed_weight = nn.Parameter(
                weight,
                requires_grad=False,
            )

            if hasattr(module, "bias"):
                self.bias = nn.Parameter(
                    torch.clone(
                        module.bias.detach().type(dtype)
                    ),
                    requires_grad=False,
                )
            else:
                self.register_parameter("bias", None)
            self.to(module.weight.device)
            self.input_mask.to(module.weight.device)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = torch.empty(size=(input.shape[0], self.condensed_weight.shape[0]), device=x.device)
        output_size, nnz_el_per_neuron = self.input_mask.shape
        for batch in range(input.shape[0]):
            for out in range(output_size):
                output[batch, out] = self.bias[out]
                for index in range(nnz_el_per_neuron):
                    output[batch, out] += input[batch, self.input_mask[out, index]] * self.condensed_weight[out, index]
        return output

In [9]:
ffi_naive = FFILinearNaive(sparse_linear)
# sparse_linear(x).allclose(ffi_naive(x), atol=1e-07)

In [10]:
class FFILinearVmap(FFILinearNaive):
    def __init__(
        self,
        module: nn.Module,
        dtype: torch.typename = torch.float32,
        transpose: bool = True,
        vectorize: bool = False,
        index_dtype: torch.typename = torch.int32,
    ):
        super().__init__(module, dtype, transpose, vectorize, index_dtype)
    
    def batch_kernel(self, input, input_masks, weights, biases):
        return torch.vmap(self.output_kernel, in_dims=(None, 0, 0, 0))(input, input_masks, weights, biases)

    def output_kernel(self, input, input_mask, weight, bias):
        return bias + torch.sum(input[input_mask] * weight)

    # @override
    def forward(self, input: torch.Tensor):
        return torch.vmap(self.batch_kernel, in_dims=(0, None, None, None), out_dims=0)(input, self.input_mask, self.condensed_weight, self.bias)

ffi_vmap = FFILinearVmap(sparse_linear, vectorize=True)
ffi_vmap(x).allclose(sparse_linear(x), atol=1e-06)

True

In [11]:
# ffi_compiled = torch.compile(ffi_vmap, mode="reduce-overhead")
ffi_compiled = torch.compile(ffi_vmap, mode="max-autotune")
for _ in range(10):
    _ = ffi_compiled(x)

In [12]:
N_ITERS = 10

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    for _ in range(10):
        _ = fn()
    start.record()
    for _ in range(N_ITERS):
        result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

In [13]:
with torch.no_grad():
    print("Dense:", timed(lambda: linear(x))[1])
    # print("FFI_Naive:", timed(lambda: ffi_naive(x))[1])
    print("FFI_VMAP:", timed(lambda: ffi_vmap(x))[1])
    print("FFI_Compiled:", timed(lambda: ffi_compiled(x))[1])

Dense: 0.0010383360385894775
FFI_VMAP: 0.006908927917480469
FFI_Compiled: 0.006913023948669434


In [14]:
ffi_vmap.condensed_weight.shape

torch.Size([4096, 32])

In [15]:
ffi_vmap.input_mask[0]

tensor([ 22,  36,  58,  71,  77, 112, 156, 175, 219, 233, 238, 289, 291, 293,
        321, 328, 351, 427, 435, 446, 455, 488, 503, 505, 506, 529, 537, 614,
        634, 702, 716, 753], device='cuda:0', dtype=torch.int32)

In [31]:
ffi_compiled.input_mask.shape

torch.Size([4096, 32])

In [30]:
x[:32, ffi_compiled.input_mask[]

torch.Size([32, 762])

In [None]:
## Triton
os.environ["TRITON_INTERPRET"] = "1"

import triton
import triton.language as tl
import os
import pdb

torch.backends.cuda.matmul.allow_tf32 = False
@triton.jit
def ffi_triton(
    # Output / Input pointers
    output_p,
    input_p,
    # Pointers to weights, bias, and mask
    weight_p,
    bias_p,
    mask_p,
    # number of outputs channels, input channels, and batch_size
    n_out: tl.constexpr,
    n_in: tl.constexpr,
    n_batch: tl.constexpr,
    n_weights: tl.constexpr,
    # Number of nnz_el per neuron and grid block size
    FFI_PER_NEURON: tl.constexpr,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
):
    block_idx_x, block_idx_y = tl.program_id(axis=0), tl.program_id(axis=1)
    batch_idx = block_idx_x * BLOCK_SIZE_X
    unit_idx = block_idx_y * BLOCK_SIZE_Y
    # We want this kernel invocation to compute the first 32 batches and 32 outputs

    batch_offsets = batch_idx + tl.arange(0, BLOCK_SIZE_X)
    batch_mask = batch_offsets < n_batch
    # Load up input from batches up to batch_idx * block size

    unit_offsets = unit_idx + tl.arange(0, BLOCK_SIZE_Y)
    unit_mask = unit_offsets < n_out * n_batch
    # output units up ot block_size_y * current idx. 
    
    mask_offsets = tl.expand_dims(
        unit_offsets, 1
    ) * FFI_PER_NEURON + tl.expand_dims(tl.arange(0, FFI_PER_NEURON), 0)
    # We need to load masks for the current output units and all FFI weights
    mask_mask = mask_offsets < n_weights
    weight_offsets = (
        tl.expand_dims(tl.arange(0, FFI_PER_NEURON), 1)
        + tl.expand_dims(unit_offsets, 0) * FFI_PER_NEURON
    )
    # Similarly, we load weights for current output units and all FFI weights, but weights have been transposed
    weight_mask = weight_offsets < n_weights

    output_offsets = tl.expand_dims(
        batch_offsets * BLOCK_SIZE_X, 1
    ) + tl.expand_dims(unit_offsets, 0)
    output_mask = output_offsets < n_batch * n_out
    
    # Now, we need to broadcast input by mask and iteratively write out to output???
    for _ in range(0, BLOCK_SIZE_Y):
        

    bias = tl.load(bias_p + unit_offsets, mask=unit_mask)
    output = tl.load(output_p + output_offsets, output_mask)
    output += bias
    weights = tl.load(weight_p + weight_offsets, mask=weight_mask)
    mask = tl.load(mask_p + mask_offsets, mask=mask_mask)

    input_offset = tl.expand_dims(batch_offsets, 1) * n_in + mask
    input_mask = input_offset < n_in * n_batch
    inputs = tl.load(input_p + input_offset, input_mask)
    print(inputs)
    print(weights)
    print(bias)
    # output+=tl.dot(inputs, weights, allow_tf32=False)
    # tl.store(output_p + output_offsets, output, output_mask)


# output = torch.zeros(size=(__BATCH_SIZE, __OUTPUT_DIM), device=__DEVICE)
output = torch.zeros(size=(32, 32), device=__DEVICE)
input = x[:32]
weight = ffi_vmap.condensed_weight.T[:,:32] # shape is now ffi, output
bias = ffi_vmap.bias[:32]
mask = ffi_vmap.input_mask[:, :32]
n_elements = input.numel()
n_weights = weight.numel()
FFI_PER_NEURON = weight.shape[0]  # first dim is now ffi
BLOCK_SIZE_X=32
BLOCK_SIZE_Y=32
n_out = __OUTPUT_DIM
n_in = __INPUT_DIM
n_batch = __BATCH_SIZE
grid = triton.cdiv(input.shape[0], BLOCK_SIZE_X), triton.cdiv(weight.shape[1], BLOCK_SIZE_Y)  # 32 threads per grid block 
print(grid)
ffi_triton[grid](output, input, weight, bias, mask, n_out, n_in, n_batch, n_weights, FFI_PER_NEURON, BLOCK_SIZE_X, BLOCK_SIZE_Y)

In [33]:
## Triton
os.environ["TRITON_INTERPRET"] = "1"

import triton
import triton.language as tl
import os
import pdb

torch.backends.cuda.matmul.allow_tf32 = False
@triton.jit
def ffi_triton(
    # Output / Input pointers
    output_p,
    input_p,
    # Pointers to weights, bias, and mask
    weight_p,
    bias_p,
    mask_p,
    # number of outputs channels, input channels, and batch_size
    n_out: tl.constexpr,
    n_in: tl.constexpr,
    n_batch: tl.constexpr,
    n_weights: tl.constexpr,
    # Number of nnz_el per neuron and grid block size
    FFI_PER_NEURON: tl.constexpr,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
):
    block_idx_x, block_idx_y = tl.program_id(axis=0), tl.program_id(axis=1)
    batch_idx = block_idx_x * BLOCK_SIZE_X
    unit_idx = block_idx_y * BLOCK_SIZE_Y

    batch_offsets = batch_idx + tl.arange(0, BLOCK_SIZE_X)

    unit_offsets = unit_idx + tl.arange(0, BLOCK_SIZE_Y)
    unit_mask = unit_offsets < n_out * n_batch

    mask_offsets = tl.expand_dims(
        unit_offsets, 1
    ) * FFI_PER_NEURON + tl.expand_dims(tl.arange(0, FFI_PER_NEURON), 0)
    mask_mask = mask_offsets < n_weights
    weight_offsets = (
        tl.expand_dims(tl.arange(0, FFI_PER_NEURON), 1)
        + tl.expand_dims(unit_offsets, 0) * FFI_PER_NEURON
    )
    weight_mask = weight_offsets < n_weights

    output_offsets = tl.expand_dims(
        batch_offsets * BLOCK_SIZE_X, 1
    ) + tl.expand_dims(unit_offsets, 0)
    output_mask = output_offsets < n_batch * n_out

    bias = tl.load(bias_p + unit_offsets, mask=unit_mask)
    output = tl.load(output_p + output_offsets, output_mask)
    output += bias
    weights = tl.load(weight_p + weight_offsets, mask=weight_mask)
    mask = tl.load(mask_p + mask_offsets, mask=mask_mask)

    input_offset = tl.expand_dims(batch_offsets, 1) * n_in + mask
    input_mask = input_offset < n_in * n_batch
    inputs = tl.load(input_p + input_offset, input_mask)
    print(inputs)
    print(weights)
    print(bias)
    # output+=tl.dot(inputs, weights, allow_tf32=False)
    # tl.store(output_p + output_offsets, output, output_mask)


# output = torch.zeros(size=(__BATCH_SIZE, __OUTPUT_DIM), device=__DEVICE)
output = torch.zeros(size=(32, 32), device=__DEVICE)
input = x[:32]
weight = ffi_vmap.condensed_weight.T[:,:32] # shape is now ffi, output
bias = ffi_vmap.bias[:32]
mask = ffi_vmap.input_mask[:, :32]
n_elements = input.numel()
n_weights = weight.numel()
FFI_PER_NEURON = weight.shape[0]  # first dim is now ffi
BLOCK_SIZE_X=32
BLOCK_SIZE_Y=32
n_out = __OUTPUT_DIM
n_in = __INPUT_DIM
n_batch = __BATCH_SIZE
grid = triton.cdiv(input.shape[0], BLOCK_SIZE_X), triton.cdiv(weight.shape[1], BLOCK_SIZE_Y)  # 32 threads per grid block 
print(grid)
ffi_triton[grid](output, input, weight, bias, mask, n_out, n_in, n_batch, n_weights, FFI_PER_NEURON, BLOCK_SIZE_X, BLOCK_SIZE_Y)

(1, 1)
[[0.28271836 0.20015647 0.01879376 ... 0.22647431 0.6467432  0.1392649 ]
 [0.94857585 0.8641366  0.8750439  ... 0.739289   0.43764648 0.61978006]
 [0.53988975 0.5642546  0.39189672 ... 0.8288116  0.8856129  0.17456578]
 ...
 [0.31094283 0.85903955 0.26372623 ... 0.83719045 0.82869154 0.7847737 ]
 [0.7128974  0.5074023  0.6337079  ... 0.52218926 0.1635361  0.33997655]
 [0.12492803 0.5038944  0.60762745 ... 0.33695993 0.93491524 0.4055208 ]]
[[ 0.0261482  -0.00396388 -0.0059791  ...  0.02332326 -0.00391128
  -0.03106531]
 [ 0.02015613 -0.0085853   0.01332369 ...  0.02736109  0.03410102
   0.0103325 ]
 [ 0.00322715 -0.03347414 -0.00471053 ... -0.00461934  0.03265392
   0.00083682]
 ...
 [ 0.00548304 -0.00404497  0.00299763 ... -0.03259914 -0.00229209
   0.03327067]
 [ 0.00554703 -0.03012603  0.0340563  ...  0.00013082 -0.03185382
  -0.03150038]
 [ 0.02553615 -0.00307607  0.03510918 ...  0.02092469 -0.00323277
  -0.03095286]]


[[ 22  36  58 ... 702 716 753]
 [ 35  84 127 ... 630 707 709]
 [ 48  63  74 ... 651 667 727]
 ...
 [  4  10  36 ... 741 745 752]
 [  1  15  61 ... 689 703 746]
 [ 11  83  89 ... 679 701 758]]
[constexpr[32], constexpr[32]]
[[   22    36    58 ...   702   716   753]
 [  797   846   889 ...  1392  1469  1471]
 [ 1572  1587  1598 ...  2175  2191  2251]
 ...
 [22102 22108 22134 ... 22839 22843 22850]
 [22861 22875 22921 ... 23549 23563 23606]
 [23633 23705 23711 ... 24301 24323 24380]]
[constexpr[32], constexpr[32]]
[[0.28271836 0.20015647 0.01879376 ... 0.22647431 0.6467432  0.1392649 ]
 [0.94857585 0.8641366  0.8750439  ... 0.739289   0.43764648 0.61978006]
 [0.53988975 0.5642546  0.39189672 ... 0.8288116  0.8856129  0.17456578]
 ...
 [0.31094283 0.85903955 0.26372623 ... 0.83719045 0.82869154 0.7847737 ]
 [0.7128974  0.5074023  0.6337079  ... 0.52218926 0.1635361  0.33997655]
 [0.12492803 0.5038944  0.60762745 ... 0.33695993 0.93491524 0.4055208 ]]
[[0.28271836 0.20015647 0.01879376 ...

Traceback (most recent call last):
  File "/home/mike/accel-snn/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2200351/2953876125.py", line 85, in <module>
    ffi_triton[grid](output, input, weight, bias, mask, n_out, n_in, n_batch, n_weights, FFI_PER_NEURON, BLOCK_SIZE_X, BLOCK_SIZE_Y)
  File "/home/mike/accel-snn/.venv/lib/python3.10/site-packages/triton/runtime/interpreter.py", line 511, in __call__
    self.fn(**args)
  File "/tmp/ipykernel_2200351/2953876125.py", line 64, in ffi_triton
    print(bias)
  File "/tmp/ipykernel_2200351/2953876125.py", line 64, in ffi_triton
    print(bias)
  File "_pydevd_bundle/pydevd_cython.pyx", line 1457, in _pydevd_bundle.pydevd_cython.SafeCallWrapper.__call__
  File "_pydevd_bundle/pydevd_cython.pyx", line 701, in _pydevd_bundle.pydevd_cython.PyDBFrame.trace_dispatch
  File "_pydevd_bundle/pydevd_cython.pyx", line 1152, in 

In [None]:
# 

In [20]:
output

tensor([[ 0.0471,  0.0032,  0.1054,  ..., -0.0047,  0.0056, -0.0140],
        [ 0.1094, -0.0223,  0.0566,  ...,  0.0500,  0.0528, -0.0418],
        [ 0.1216, -0.0431,  0.0571,  ..., -0.0619,  0.0593, -0.0010],
        ...,
        [ 0.1315, -0.0106,  0.1529,  ...,  0.0126,  0.0128, -0.0392],
        [ 0.0667, -0.0203,  0.0803,  ..., -0.0217,  0.0624, -0.0229],
        [ 0.1379, -0.0187,  0.0572,  ...,  0.0335,  0.0508, -0.0147]],
       device='cuda:0')

In [18]:
ffi_compiled(x)

tensor([[ 0.0471, -0.0391,  0.0182,  ...,  0.0366,  0.0762,  0.0381],
        [ 0.0930, -0.0223,  0.0981,  ...,  0.0280,  0.0720,  0.0642],
        [ 0.1260, -0.0361,  0.0571,  ...,  0.0188,  0.0824, -0.0046],
        ...,
        [ 0.1381,  0.0008,  0.0923,  ...,  0.0060,  0.1120,  0.0091],
        [ 0.1127, -0.0196,  0.1390,  ...,  0.0488,  0.0676,  0.0181],
        [ 0.1553,  0.0510,  0.1138,  ...,  0.0455,  0.1527,  0.0146]],
       device='cuda:0')

In [21]:
output.shape

torch.Size([128, 4096])

In [26]:
8160/4

2040.0

In [22]:
(output!=0).sum()

tensor(8160, device='cuda:0')

In [None]:
class FixedFanInCuda(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        dtype: torch.typename = torch.float32,
        transpose: bool = True,
        vectorize: bool = False,
        index_dtype: torch.typename = torch.int32,
    ):
        super().__init__()
        if dtype is None:
            dtype = module.weight.dtype

        self.transpose = transpose
        with torch.no_grad():
            active_neuron_idx = module.weight.sum(dim=1) != 0
            fine_grained_idx = (module.weight[active_neuron_idx] != 0).to(
                torch.bool
            )
            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)
            self.input_mask = self.input_mask.reshape(
                shape=(module.weight[active_neuron_idx].shape[0], -1)
            ).to(index_dtype)
            weight = module.weight[active_neuron_idx].detach().type(dtype)
            weight = torch.clone(
                weight[fine_grained_idx]
                .reshape(shape=(weight.shape[0], -1))
                .detach()
                .type(dtype)
            )
            # padding to multiple of 4
            if vectorize:
                pad = (
                    self.input_mask.shape[1] + 3
                ) // 4 * 4 - self.input_mask.shape[1]
                self.input_mask = F.pad(self.input_mask, [0, pad])
                weight = F.pad(weight, [0, pad])

            self.condensed_weight = nn.Parameter(
                weight,
                requires_grad=False,
            )

            if hasattr(module, "bias"):
                self.bias = nn.Parameter(
                    torch.clone(
                        module.bias[active_neuron_idx].detach().type(dtype)
                    ),
                    requires_grad=False,
                )
            else:
                self.register_parameter("bias", None)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return ffi_mul(
            input,
            self.condensed_weight,
            self.input_mask,
            self.bias,
            transpose=self.transpose,
        )

In [None]:
class CondensedLinearFineGrained(nn.Module):
    def __init__(
        self, module: nn.Module, dtype: torch.typename = torch.float32
    ):
        super().__init__()
        if dtype is None:
            dtype = module.weight.dtype
        with torch.no_grad():
            active_neuron_idx = module.weight.sum(dim=1) != 0
            fine_grained_idx = (module.weight[active_neuron_idx] != 0).to(
                torch.bool
            )
            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)
            self.input_mask = self.input_mask.reshape(
                shape=(module.weight[active_neuron_idx].shape[0], -1)
            )
            self.input_mask = self.input_mask.to(torch.int32)
            weight = module.weight[active_neuron_idx].detach().type(dtype)
            self.condensed_weight = nn.Parameter(
                torch.clone(
                    weight[fine_grained_idx]
                    .reshape(shape=(weight.shape[0], -1))
                    .detach()
                    .type(dtype)
                ),
                requires_grad=False,
            )
            if hasattr(module, "bias"):
                self.bias = nn.Parameter(
                    torch.clone(
                        module.bias[active_neuron_idx].detach().type(dtype)
                    ),
                    requires_grad=False,
                )
            else:
                self.register_parameter("bias", None)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (
            torch.sum(
                self.condensed_weight * input[..., self.input_mask],
                dim=input.dim(),
            )
            + self.bias
        )