In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.sparse as sparse
import sparseprop

In [3]:
import math
import numpy as np

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter


def _gen_indx_seqs(
    fan_in: int, num_out: int, num_in: int, fan_out_const: bool
) -> torch.LongTensor:
    """
    Generates indices required by the condensed layer (LinearCondensed) for
    drawing recombination vectors from the input vector v.

    Args:
        fan_in: Number of recombination vectors, corresponding to the number of
            columns in the weight matrix of LinearCondensed.
        num_out: Length of recombination vectors, corresponding to the number of
            rows in the weight matrix of LinearCondensed.
        num_in: Length of the input vector(s).
        fan_out_const: If True, nearly constant fan-out will be ensured. Nearly,
            and not exactly, because in some cases the number of connections is
            not evenly divisible by the number of neurons.

    Returns:
        A 2d array of indices of the same shape as the weight matrix in
            LinearCondensed, namely (num_out, fan_in).
    """

    indx_seqs = np.zeros((num_out, fan_in))

    # indices of input vector
    v_inds = np.arange(num_in)

    # initializing an array of probabilities for every index of v
    # (initially uniform)
    probs = 1 / num_in * np.ones(num_in)

    for row_nr in range(num_out):
        chosen_inds = np.random.choice(
            v_inds, size=fan_in, replace=False, p=probs / sum(probs)
        )
        chosen_inds.sort()
        # update probabs only if want to control fan_out
        if fan_out_const:
            probs[chosen_inds] /= 100 * num_in

        indx_seqs[row_nr, :] = chosen_inds

    return torch.LongTensor(indx_seqs.astype(int))


class LinearCondensed(nn.Module):
    r"""Applies a special condensed matmul
    transformation to the incoming data.

    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

    Args:
        in_features: Length of each input vector.
        out_features: Length of layer output.
        fan_in: The number of rows in the weight matrix.
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``.

    Shape:
        - Input: :math:`(*, H_{in})` where :math:`*` means any number of
          dimensions including none and :math:`H_{in} = \text{in\_features}`.
        - Output: :math:`(*, H_{out})` where all but the last dimension
          are the same shape as the input and
          :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{fan\in})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{fan\in}}`
        bias:   the learnable bias of the module of shape
                :math:`(\text{out\_features})`.If :attr:`bias` is ``True``, the
                values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.LinearCondensed(20, 10, 5, False)
        >>> input = torch.randn(64, 784)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([64, 10])
    """
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    fan_in: int
    weight: torch.Tensor
    indx_seqs: torch.Tensor

    def __init__(
        self,
        in_features: int,
        out_features: int,
        fan_in: int,
        fan_out_const: bool,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super(LinearCondensed, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fan_in = fan_in
        self.weight = Parameter(
            torch.empty((out_features, fan_in), **factory_kwargs)
        )
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

        # ===== INDICES FOR RECOMBS =====
        self.indx_seqs = _gen_indx_seqs(
            fan_in=fan_in,
            num_out=out_features,
            num_in=in_features,
            fan_out_const=fan_out_const,
        )

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            dense_fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(dense_fan_in) if dense_fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = (
            torch.sum(self.weight * input[:, self.indx_seqs], axis=2)
            + self.bias
        )
        return output

    def extra_repr(self) -> str:
        return (
            "in_features={}, out_features={}, fan_in={}, fan_out_const={}, "
            "bias={}"
        ).format(
            self.in_features,
            self.out_features,
            self.fan_in,
            self.fan_out_const,
            self.bias is not None,
        )


In [4]:
# class CondensedLinear(nn.Module):
    
#     def __init__(self, dense_weight, mask, bias=None):
#         self.dense_weight = dense_weight
#         self.mask = mask
#         self.bias = bias
#         self.sparse_weight = torch.sparse.mo
        
#     def forward(self,  input: torch.Tensor) -> torch.Tensor:
#         return torch.sum(self.sparse_weight * input) + self.bias
            
        

# import copy
# class SparseLinear(nn.Module):
    
#     def __init__(self, linear_layer, bs):   
#         super().__init__()
#         self.weights = linear_layer.weight.T.to_sparse_coo()
#         # self.weights = copy.deepcopy(linear_layer.weight.T.to_sparse_coo())
#         self.bias = nn.Parameter(linear_layer.bias.expand(size=(bs, *linear_layer.bias.shape)))
#         # self.bias = linear_layer.bias

    
#     def forward(self, x):
#         return self.bias + torch.mm(x, self.weights)


class SparseLinear(nn.Module):
    
    def __init__(self, linear_layer, bs):   
        super().__init__()
        self.weight = linear_layer.weight.to_sparse_coo()
        # self.weights = copy.deepcopy(linear_layer.weight.T.to_sparse_coo())
        self.bias = nn.Parameter(linear_layer.bias.expand(size=(bs, *linear_layer.bias.shape)))
        # self.bias = linear_layer.bias

    def forward(self, x):
        # return self.bias + torch.mm(self.weight, x.T).T
        return self.bias + torch.mm(self.weight, x).T

In [10]:
import torch.utils.benchmark as benchmark

dense = nn.Linear(1024,1000)
condensed = LinearCondensed(1024, 1000, fan_in=100, fan_out_const=False)

def forward_pass(input, layer):
    layer(input)
    return
    
input = torch.rand(size=(128, 1024))
layer = dense
t_dense = benchmark.Timer(
    stmt="layer(input)",
    globals={"input": input, "layer": layer},
    num_threads=4,
    label="Dense",
)

input = torch.rand(size=(128, 1024))
layer = condensed

jit_lc = torch.jit.trace(condensed, input)
torch.jit.optimize_for_inference(jit_lc)
t_sparse = benchmark.Timer(
    stmt="layer(input)",
    globals={"input": input, "layer": jit_lc},
    num_threads=4,
    label="Sparse Linear",
)


with torch.no_grad():
    print(t_dense.timeit(100))
    print(t_sparse.timeit(100))
    


# input = torch.rand(size=(64, 1024))
# layer = accel_linear
# t_accel = benchmark.Timer(
#     stmt="layer(input)",
#     globals={"input": input, "layer": layer},
#     num_threads=4,
#     label="Accelerated Linear",
# )

<torch.utils.benchmark.utils.common.Measurement object at 0x7f6e2ff30100>
Dense
  935.75 us
  1 measurement, 100 runs , 4 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f6e2ff30820>
Sparse Linear
  30.26 ms
  1 measurement, 100 runs , 4 threads


In [185]:
dense_linear = nn.Linear(1024, 1000)
sparse_linear = nn.Linear(1024, 1000)
dense_linear

Linear(in_features=1024, out_features=1000, bias=True)

In [186]:
torch.count_nonzero(sparse_linear.weight)

tensor(1023999)

In [187]:
sparsity=0.99
idx = torch.randperm(n=sparse_linear.weight.numel())
non_zero_idx = idx[int(len(idx)*(1-0.99)):]
with torch.no_grad():
    
    w = sparse_linear.weight
    w  = w.flatten()
    print(sparse_linear.weight.shape)
    w[non_zero_idx]=0
    w = w.reshape(dense_linear.weight.shape)
    print(sparse_linear.weight.shape)



sparse_linear.weight.count_nonzero()

torch.Size([1000, 1024])
torch.Size([1000, 1024])


tensor(10240)

In [188]:
accel_linear = accel_linear = sparseprop.modules.SparseLinear(dense_weight=sparse_linear.weight, bias=sparse_linear.bias)

In [189]:
with torch.no_grad():
    sparse_linear = SparseLinear(sparse_linear, bs = 64)
    sparse_linear.weight.requires_grad

In [191]:
sparse_linear.weight.requires_grad

False

In [192]:
# sparse_linear = SparseLinear(sparse_linear, bs = 64)

with torch.no_grad():
    sparse_linear_module = torch.jit.trace(sparse_linear, torch.rand(size=(64, 1024)).T)

RuntimeError: Unsupported value kind: Tensor

In [108]:
import torch.utils.benchmark as benchmark

def forward_pass(input, layer):
    layer(input)
    return
    
input = torch.rand(size=(64, 1024))
layer = dense_linear
t_dense = benchmark.Timer(
    stmt="layer(input)",
    globals={"input": input, "layer": layer},
    num_threads=4,
    label="Dense",
)

input = torch.rand(size=(64, 1024)).T
layer = sparse_linear
t_sparse = benchmark.Timer(
    stmt="layer(input)",
    globals={"input": input, "layer": layer},
    num_threads=4,
    label="Sparse Linear",
)


input = torch.rand(size=(64, 1024))
layer = accel_linear
t_accel = benchmark.Timer(
    stmt="layer(input)",
    globals={"input": input, "layer": layer},
    num_threads=4,
    label="Accelerated Linear",
)

In [109]:
with torch.no_grad():
    print(t_dense.timeit(100))
    print(t_sparse.timeit(100))
    print(t_accel.timeit(100))


<torch.utils.benchmark.utils.common.Measurement object at 0x7f03526d5300>
Dense
  1.05 ms
  1 measurement, 100 runs , 4 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0350447550>
Sparse Linear
  1.66 ms
  1 measurement, 100 runs , 4 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f03526d5300>
Accelerated Linear
  449.00 us
  1 measurement, 100 runs , 4 threads


In [76]:
class SparseLinear(nn.Module):
    
    def __init__(self, linear_layer, bs):   
        super().__init__()
        self.weights = linear_layer.weight.to_sparse_coo()
        # self.weights = copy.deepcopy(linear_layer.weight.T.to_sparse_coo())
        self.bias = nn.Parameter(linear_layer.bias.expand(size=(bs, *linear_layer.bias.shape)))
        # self.bias = linear_layer.bias

    
    def forward(self, x):
        # return self.bias + torch.mm(x, self.weights)
        return self.bias + torch.mm(self.weights, x.T).T

In [71]:
(sparse_linear.weight @ x.T).T.shape

torch.Size([64, 1000])

In [70]:
(x @ sparse_linear.weight.T).shape

torch.Size([64, 1000])

In [66]:
x.t().shape

torch.Size([1024, 64])

In [64]:
sl.weights.shape


torch.Size([1024, 1000])

In [77]:
x = torch.rand(size=(64,1024))
with torch.no_grad():
    sl =SparseLinear(sparse_linear, bs=64)
    print(sl(x).shape)

torch.Size([64, 1000])


In [12]:
from datetime import datetime
input = torch.rand(size=(64, 1024))

with torch.no_grad():
    start = datetime.now()
    for _ in range(100):
        dense_linear(input)
    end = datetime.now() - start
    print(end)
    

0:00:00.076035


In [13]:
from datetime import datetime
input = torch.rand(size=(64, 1024))

with torch.no_grad():
    start = datetime.now()
    for _ in range(100):
        sparse_linear(input)
    end = datetime.now() - start
    print(end)
    

0:00:00.093300


In [15]:
from datetime import datetime
input = torch.rand(size=(64, 1024))

with torch.no_grad():
    start = datetime.now()
    for _ in range(100):
        accel_linear(input)
    end = datetime.now() - start
    print(end)
    

0:00:00.038281


In [108]:
import copy
class SparseLinear(nn.Module):
    
    def __init__(self, linear_layer, bs):   
        super().__init__()
        self.weights = linear_layer.weight.T.to_sparse_coo()
        # self.weights = copy.deepcopy(linear_layer.weight.T.to_sparse_coo())
        self.bias = nn.Parameter(linear_layer.bias.expand(size=(bs, *linear_layer.bias.shape)))
        # self.bias = linear_layer.bias

    
    def forward(self, x):
        return self.bias + torch.mm(x, self.weights)

In [109]:
sl.bias  + (input_sparse @ sl.weights)

tensor([[-0.0738, -0.0234,  0.0346,  ..., -0.0094,  0.0154, -0.0011],
        [-0.1085,  0.0009,  0.0709,  ..., -0.0308,  0.0147,  0.0152],
        [-0.0445,  0.0006,  0.0638,  ..., -0.0376,  0.0010, -0.0057],
        ...,
        [-0.1035,  0.0213,  0.0500,  ...,  0.0013,  0.0042,  0.0554],
        [-0.0812,  0.0079,  0.0342,  ..., -0.0241, -0.0153,  0.0018],
        [-0.0975,  0.0019,  0.0885,  ..., -0.0015, -0.0043,  0.0365]],
       grad_fn=<AddBackward0>)

In [110]:
input = torch.rand(size=(64, 1024))
input_sparse = input.to_sparse_coo()
with torch.no_grad():
    sl = SparseLinear(sparse_linear, bs=64)
    start = datetime.now()
    for _ in range(100):
        out = sl(input_sparse)
    end = datetime.now() - start
    print(end)
out.shape
print(type(out))

0:00:00.544879
<class 'torch.Tensor'>


In [20]:
with torch.no_grad():
    input.to_sparse_coo() @ sparse_linear.weight.T.to_sparse_coo()

In [10]:
w = torch.tensor(
        [
            [
                1,2,3
            ],
            [
                4,5,6,
            ],
        ]
    )
w.shape
x = torch.tensor([[1,2],[3,4],[5,6],])
x.shape

torch.Size([3, 2])