In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
__DEVICE = torch.device("cuda")
# __DEVICE = torch.device("cpu")
__INPUT_DIM = 762
__OUTPUT_DIM = 4096
__BATCH_SIZE = 128
__SPARSITY = 0.96

In [3]:

@torch.no_grad()
def get_ffi_structure(mod: nn.Linear, sparsity: float) -> nn.Linear:
    n_zeros = int(mod.weight.numel() * (sparsity))
    n_zeros_per_neuron = n_zeros // mod.weight.shape[0]
    for idx, neuron in enumerate(mod.weight):
        rand_idx = torch.randperm(n=len(neuron))
        mod.weight[idx, rand_idx[:n_zeros_per_neuron-1]] = 0
    assert_ffi(mod)
    print_sparsity(mod)
    return mod

def assert_ffi(mod: nn.Linear):
    ffi = (mod.weight[0]!=0).sum()
    for n in mod.weight:
        assert (n!=0).sum()==ffi

def print_sparsity(mod: nn.Linear):
    print(f"Mod sparsity: {1-((mod.weight!=0).sum()/mod.weight.numel()).item():.4f}")

In [4]:
linear = nn.Linear(in_features=__INPUT_DIM, out_features=__OUTPUT_DIM, device=__DEVICE)
sparse_linear = get_ffi_structure(linear, __SPARSITY)

Mod sparsity: 0.9580


In [5]:
(sparse_linear.weight[0]!=0).sum()

tensor(32, device='cuda:0')

In [6]:
x = torch.rand(size=(__BATCH_SIZE, __INPUT_DIM), device=__DEVICE)
x.shape

torch.Size([128, 762])

In [7]:
class FFILinearNaive(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        dtype: torch.typename = torch.float32,
        transpose: bool = True,
        vectorize: bool = False,
        index_dtype: torch.typename = torch.int32,
    ):
        super().__init__()
        if dtype is None:
            dtype = module.weight.dtype

        self.transpose = transpose
        with torch.no_grad():
            fine_grained_idx = (module.weight != 0).to(
                torch.bool
            )
            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)
            self.input_mask = self.input_mask.reshape(
                shape=(module.weight.shape[0], -1)
            ).to(index_dtype)
            weight = module.weight.detach().type(dtype)
            weight = torch.clone(
                weight[fine_grained_idx]
                .reshape(shape=(weight.shape[0], -1))
                .detach()
                .type(dtype)
            )
            # padding to multiple of 4
            if vectorize:
                pad = (
                    self.input_mask.shape[1] + 3
                ) // 4 * 4 - self.input_mask.shape[1]
                self.input_mask = F.pad(self.input_mask, [0, pad])
                weight = F.pad(weight, [0, pad])

            self.condensed_weight = nn.Parameter(
                weight,
                requires_grad=False,
            )

            if hasattr(module, "bias"):
                self.bias = nn.Parameter(
                    torch.clone(
                        module.bias.detach().type(dtype)
                    ),
                    requires_grad=False,
                )
            else:
                self.register_parameter("bias", None)
            self.to(module.weight.device)
            self.input_mask.to(module.weight.device)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = torch.empty(size=(input.shape[0], self.condensed_weight.shape[0]), device=x.device)
        output_size, nnz_el_per_neuron = self.input_mask.shape
        for batch in range(input.shape[0]):
            for out in range(output_size):
                output[batch, out] = self.bias[out]
                for index in range(nnz_el_per_neuron):
                    output[batch, out] += input[batch, self.input_mask[out, index]] * self.condensed_weight[out, index]
        return output

In [8]:
ffi_naive = FFILinearNaive(sparse_linear)
# sparse_linear(x).allclose(ffi_naive(x), atol=1e-07)

In [9]:
class FFILinearVmap(FFILinearNaive):
    def __init__(
        self,
        module: nn.Module,
        dtype: torch.typename = torch.float32,
        transpose: bool = True,
        vectorize: bool = False,
        index_dtype: torch.typename = torch.int32,
    ):
        super().__init__(module, dtype, transpose, vectorize, index_dtype)
    
    def batch_kernel(self, input, input_masks, weights, biases):
        return torch.vmap(self.output_kernel, in_dims=(None, 0, 0, 0))(input, input_masks, weights, biases)

    def output_kernel(self, input, input_mask, weight, bias):
        return bias + torch.sum(input[input_mask] * weight)

    # @override
    def forward(self, input: torch.Tensor):
        return torch.vmap(self.batch_kernel, in_dims=(0, None, None, None), out_dims=0)(x, self.input_mask, self.condensed_weight, self.bias)

ffi_vmap = FFILinearVmap(sparse_linear, vectorize=True)
ffi_vmap(x).allclose(sparse_linear(x), atol=1e-06)

True

In [10]:
# ffi_compiled = torch.compile(ffi_vmap, mode="reduce-overhead")
ffi_compiled = torch.compile(ffi_vmap, mode="max-autotune")
for _ in range(10):
    _ = ffi_compiled(x)

In [11]:
N_ITERS = 10

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    for _ in range(10):
        _ = fn()
    start.record()
    for _ in range(N_ITERS):
        result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

In [12]:
with torch.no_grad():
    print("Dense:", timed(lambda: linear(x))[1])
    # print("FFI_Naive:", timed(lambda: ffi_naive(x))[1])
    print("FFI_VMAP:", timed(lambda: ffi_vmap(x))[1])
    print("FFI_Compiled:", timed(lambda: ffi_compiled(x))[1])

Dense: 0.0010373120307922363
FFI_VMAP: 0.006914048194885254
FFI_Compiled: 0.006918144226074219


In [13]:
ffi_vmap.condensed_weight.shape

torch.Size([4096, 32])

In [14]:
ffi_vmap.input_mask[0]

tensor([ 19,  26,  27,  30,  79,  84, 137, 157, 164, 190, 285, 290, 330, 418,
        431, 434, 447, 487, 555, 564, 583, 603, 607, 630, 637, 657, 662, 697,
        702, 708, 712, 737], device='cuda:0', dtype=torch.int32)

In [24]:
## Triton
import triton
import triton.language as tl
import os
import pdb

os.environ["TRITON_INTERPRET"] = "0"
torch.backends.cuda.matmul.allow_tf32 = False
@triton.jit
def ffi_triton(
    # Output / Input pointers
    output_p,
    input_p,
    # Pointers to weights, bias, and mask
    weight_p,
    bias_p,
    mask_p,
    # number of outputs channels, input channels, and batch_size
    n_out: tl.constexpr,
    n_in: tl.constexpr,
    n_batch: tl.constexpr,
    n_weights: tl.constexpr,
    # Number of nnz_el per neuron and grid block size
    FFI_PER_NEURON: tl.constexpr,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
):
    block_idx_x, block_idx_y = tl.program_id(axis=0), tl.program_id(axis=1)
    batch_idx = block_idx_x * BLOCK_SIZE_X
    unit_idx = block_idx_y * BLOCK_SIZE_Y

    batch_offsets = batch_idx + tl.arange(0, BLOCK_SIZE_X)

    unit_offsets = unit_idx + tl.arange(0, BLOCK_SIZE_Y)
    unit_mask = unit_offsets < n_out * n_batch

    mask_offsets = tl.expand_dims(
        unit_offsets, 1
    ) * FFI_PER_NEURON + tl.expand_dims(tl.arange(0, FFI_PER_NEURON), 0)
    mask_mask = mask_offsets < n_weights
    weight_offsets = (
        tl.expand_dims(tl.arange(0, FFI_PER_NEURON), 1)
        + tl.expand_dims(unit_offsets, 0) * FFI_PER_NEURON
    )
    weight_mask = weight_offsets < n_weights

    output_offsets = tl.expand_dims(
        batch_offsets * BLOCK_SIZE_X, 1
    ) + tl.expand_dims(unit_offsets, 0)
    output_mask = output_offsets < n_batch * n_out

    bias = tl.load(bias_p + unit_offsets, mask=unit_mask)
    output = tl.load(output_p + output_offsets, output_mask)
    output += bias
    weights = tl.load(weight_p + weight_offsets, mask=weight_mask)
    mask = tl.load(mask_p + mask_offsets, mask=mask_mask)

    input_offset = tl.expand_dims(batch_offsets, 1) * n_in + mask
    input_mask = input_offset < n_in * n_batch
    inputs = tl.load(input_p + input_offset, input_mask)
    output+=tl.dot(inputs, weights, allow_tf32=False)
    tl.store(output_p + output_offsets, output, output_mask)


# output = torch.zeros(size=(__BATCH_SIZE, __OUTPUT_DIM), device=__DEVICE)
output = torch.zeros(size=(32, 32), device=__DEVICE)
input = x[:32]
weight = ffi_vmap.condensed_weight.T[:,:32] # shape is now ffi, output
bias = ffi_vmap.bias[:32]
mask = ffi_vmap.input_mask[:, :32]
n_elements = input.numel()
n_weights = weight.numel()
FFI_PER_NEURON = weight.shape[0]  # first dim is now ffi
BLOCK_SIZE_X=32
BLOCK_SIZE_Y=32
n_out = __OUTPUT_DIM
n_in = __INPUT_DIM
n_batch = __BATCH_SIZE
grid = triton.cdiv(input.shape[0], BLOCK_SIZE_X), triton.cdiv(weight.shape[1], BLOCK_SIZE_Y)  # 32 threads per grid block 
print(grid)
ffi_triton[grid](output, input, weight, bias, mask, n_out, n_in, n_batch, n_weights, FFI_PER_NEURON, BLOCK_SIZE_X, BLOCK_SIZE_Y)

(1, 1)


<triton.compiler.compiler.CompiledKernel at 0x7f13e6c44370>

In [29]:
output

tensor([[ 0.1232, -0.0895,  0.0676,  ..., -0.0246, -0.0354,  0.1017],
        [ 0.0760,  0.0181,  0.0712,  ...,  0.0764, -0.0776,  0.0032],
        [ 0.0138,  0.0364,  0.0373,  ..., -0.0172, -0.0130, -0.0309],
        ...,
        [ 0.0413,  0.0363,  0.0045,  ...,  0.0211, -0.0203,  0.0484],
        [ 0.0308, -0.0146,  0.0493,  ...,  0.0342, -0.1280,  0.0263],
        [ 0.0217, -0.0331,  0.0143,  ...,  0.0113, -0.0115,  0.0393]],
       device='cuda:0')

In [27]:
ffi_compiled(x)

tensor([[ 0.1232,  0.0130,  0.0503,  ...,  0.0044,  0.1327, -0.0037],
        [ 0.0611,  0.0181,  0.0692,  ..., -0.0472, -0.0061,  0.0579],
        [ 0.0146,  0.0054,  0.0373,  ..., -0.0691,  0.0511,  0.0396],
        ...,
        [ 0.0471, -0.0160,  0.0714,  ...,  0.0009,  0.0314, -0.0287],
        [ 0.0535, -0.0362,  0.0113,  ..., -0.0411,  0.0465,  0.0367],
        [ 0.0417, -0.0522,  0.0295,  ..., -0.0239,  0.0364,  0.0108]],
       device='cuda:0')

In [21]:
output.shape

torch.Size([128, 4096])

In [26]:
8160/4

2040.0

In [22]:
(output!=0).sum()

tensor(8160, device='cuda:0')

In [None]:
class FixedFanInCuda(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        dtype: torch.typename = torch.float32,
        transpose: bool = True,
        vectorize: bool = False,
        index_dtype: torch.typename = torch.int32,
    ):
        super().__init__()
        if dtype is None:
            dtype = module.weight.dtype

        self.transpose = transpose
        with torch.no_grad():
            active_neuron_idx = module.weight.sum(dim=1) != 0
            fine_grained_idx = (module.weight[active_neuron_idx] != 0).to(
                torch.bool
            )
            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)
            self.input_mask = self.input_mask.reshape(
                shape=(module.weight[active_neuron_idx].shape[0], -1)
            ).to(index_dtype)
            weight = module.weight[active_neuron_idx].detach().type(dtype)
            weight = torch.clone(
                weight[fine_grained_idx]
                .reshape(shape=(weight.shape[0], -1))
                .detach()
                .type(dtype)
            )
            # padding to multiple of 4
            if vectorize:
                pad = (
                    self.input_mask.shape[1] + 3
                ) // 4 * 4 - self.input_mask.shape[1]
                self.input_mask = F.pad(self.input_mask, [0, pad])
                weight = F.pad(weight, [0, pad])

            self.condensed_weight = nn.Parameter(
                weight,
                requires_grad=False,
            )

            if hasattr(module, "bias"):
                self.bias = nn.Parameter(
                    torch.clone(
                        module.bias[active_neuron_idx].detach().type(dtype)
                    ),
                    requires_grad=False,
                )
            else:
                self.register_parameter("bias", None)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return ffi_mul(
            input,
            self.condensed_weight,
            self.input_mask,
            self.bias,
            transpose=self.transpose,
        )

In [None]:
class CondensedLinearFineGrained(nn.Module):
    def __init__(
        self, module: nn.Module, dtype: torch.typename = torch.float32
    ):
        super().__init__()
        if dtype is None:
            dtype = module.weight.dtype
        with torch.no_grad():
            active_neuron_idx = module.weight.sum(dim=1) != 0
            fine_grained_idx = (module.weight[active_neuron_idx] != 0).to(
                torch.bool
            )
            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)
            self.input_mask = self.input_mask.reshape(
                shape=(module.weight[active_neuron_idx].shape[0], -1)
            )
            self.input_mask = self.input_mask.to(torch.int32)
            weight = module.weight[active_neuron_idx].detach().type(dtype)
            self.condensed_weight = nn.Parameter(
                torch.clone(
                    weight[fine_grained_idx]
                    .reshape(shape=(weight.shape[0], -1))
                    .detach()
                    .type(dtype)
                ),
                requires_grad=False,
            )
            if hasattr(module, "bias"):
                self.bias = nn.Parameter(
                    torch.clone(
                        module.bias[active_neuron_idx].detach().type(dtype)
                    ),
                    requires_grad=False,
                )
            else:
                self.register_parameter("bias", None)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return (
            torch.sum(
                self.condensed_weight * input[..., self.input_mask],
                dim=input.dim(),
            )
            + self.bias
        )