In [16]:
import torch
import sten

In [17]:
from native_scripting import compile
import functools
import ctypes
import time
import math
from heapq import nlargest

In [18]:
try:
    cache = functools.cache
except AttributeError:
    cache = functools.lru_cache(maxsize=None)

In [19]:
@cache
def venom2dense(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    A_size = nrows*ncols
    density = n/m

    brow = 4 #this->brow = brow_;
    mbrow = 32 #this->mbrow = mbrow_;

    bm   = tileM
    # !IMPORTANT! constants because of architecture constraints
    m_fixed = 4
    bits_elem_meta=2
    mrow_m = 2
    bits_elem_cols=8
    brow_fixed = 16
    nelems=32//bits_elem_meta #(sizeof(uint)*8)=32
    nelems_col = nelems//mrow_m

    A_num_cols_sp = (ncols/m)*n
    A_num_cols_sp_pad_nm = (round_up(ncols, m)/m)*n
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    A_nnz = nrows*A_num_cols_sp_pad

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;


        extern "C" void func3({dtype}* hA_dense, {dtype}* hA_values, int *hA_columns, int *hA_metadata){{
            //this->hA_dense.resize(this->A_size, 0);

            // general variables N:M format
            int bm_m = {nrows}/{bm};
            int mbrow_m = {bm}/{mbrow};
            int mbrow_m2 = {mbrow}/{brow_fixed};
            int brow_m = {brow_fixed}/{brow};
            // metadata
            int mcol_kk = {nelems}/{mrow_m}/{n};
            int mcol_k = {A_num_cols_sp_pad}/{n}/mcol_kk;
            // indices
            int col_kk = mcol_kk;
            int col_k = {A_num_cols_sp_pad}/{n}/col_kk;

            uint indexes[{nelems}];
            uint columns[col_kk*{m_fixed}];

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                for(int mbrow_i=0; mbrow_i<mbrow_m; mbrow_i++){{
                    for(int mbrow_i2=0; mbrow_i2<mbrow_m2; mbrow_i2++){{
                        for(int brow_i=0; brow_i<brow_m; brow_i++){{
                            for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                                //read columns indexes
                                for(int col_i=0; col_i<col_kk; col_i++){{
                                    for(int col_ii=0; col_ii<{m_fixed}; col_ii++){{
                                        columns[col_i*{m_fixed} + col_ii] =
                                        hA_columns[bm_i*col_k*col_kk*{m_fixed} + mcol_i*col_kk*{m_fixed} + col_i*{m_fixed} + col_ii];
                                    }}
                                }}
                                // read metadata
                                for(int mbrow_ii=0; mbrow_ii<({brow}/{mrow_m}); mbrow_ii++){{
                                    for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                        for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                            for (int n_i=0; n_i<{n}; n_i++) {{
                                                indexes[
                                                    mbrow_iii*{n} +
                                                    mcol_ii*{mrow_m}*{n} +
                                                    n_i] =
                                                (((hA_metadata[
                                                    bm_i*mcol_k*{bm}/{mrow_m} +
                                                    mbrow_i*mcol_k*{mbrow}/{mrow_m} +
                                                    mbrow_i2*{brow_fixed}/{mrow_m} +
                                                    brow_i*{brow}/{mrow_m}  +
                                                    mcol_i*{mbrow}/{mrow_m} +
                                                    mbrow_ii]) >> (mbrow_iii*({nelems}/{mrow_m})*{bits_elem_meta}+mcol_ii*{n}*{bits_elem_meta}+n_i*{bits_elem_meta})) & 0x3);
                                            }}
                                        }}
                                    }}

                                    for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                        for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                            for(int n_i=0; n_i<{n}; n_i++){{
                                                unsigned int index = columns[mcol_ii*{m_fixed} + indexes[mcol_ii*{mrow_m}*{n}+mbrow_iii*{n}+n_i]];

                                                if((mcol_i*{m}*mcol_kk + mcol_ii*{m} + index) < {ncols}){{
                                                    hA_dense[
                                                        bm_i*{bm}*{ncols} +
                                                        mbrow_i*{mbrow}*{ncols} +
                                                        mbrow_i2*{brow_fixed}*{ncols} +
                                                        brow_i*{brow}*{ncols} +
                                                        mcol_i*{m}*mcol_kk +
                                                        mbrow_ii*{mrow_m}*{ncols} +
                                                        mcol_ii*{m} +
                                                        mbrow_iii*{ncols} +
                                                        index] =
                                                    hA_values[
                                                        bm_i*{bm}*{A_num_cols_sp_pad} +
                                                        mbrow_i*{mbrow}*{A_num_cols_sp_pad}+
                                                        mbrow_i2*{brow_fixed}*{A_num_cols_sp_pad}+
                                                        brow_i*{brow}*{nelems}/{mrow_m}+
                                                        mcol_i*{brow_fixed}*{nelems}/{mrow_m} +
                                                        mbrow_ii*{mrow_m}*{n} +
                                                        mcol_ii*{n}*{brow} +
                                                        mbrow_iii*{n} +
                                                        n_i];
                                                }}
                                            }}
                                        }}
                                    }}
                                }}
                            }}
                        }}
                    }}
                }}
            }}
        }}
        """,
    )
    lib.func3.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func3

In [20]:
@cache
def dense2venom(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    brow = 4 #this->brow = brow_;
    mbrow = 32 #this->mbrow = mbrow_;

    bm   = tileM
    # !IMPORTANT! constants because of architecture constraints
    m_fixed = 4
    bits_elem_meta=2
    mrow_m = 2
    bits_elem_cols=8
    brow_fixed = 16
    nelems=32//bits_elem_meta #(sizeof(uint)*8)=32
    nelems_col = nelems//mrow_m

    A_num_cols_sp = (ncols//m)*n
    A_num_cols_sp_pad_nm = (round_up(ncols, m)/m)*n
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    A_nnz = nrows*A_num_cols_sp_pad

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;


        extern "C" void func2({dtype}* sparse, int* masks, {dtype}* hA_values, int *hA_columns, int *hA_metadata){{

            int bm_m = {nrows}/{bm};
            int mbrow_m = {bm}/{mbrow};
            int mbrow_m2 = {mbrow}/{brow_fixed};
            int brow_m = {brow_fixed}/{brow};
            // metadata
            int mcol_kk = {nelems}/{mrow_m}/{n};
            int mcol_k = {A_num_cols_sp_pad}/{n}/mcol_kk;
            // indices
            int col_kk = mcol_kk;
            int col_k = {A_num_cols_sp_pad}/{n}/col_kk;

            {dtype} values[{nelems}];
            uint indexes[{nelems}];
            uint columns[col_kk*{m_fixed}];

            int max_idx = 0;

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                for(int mbrow_i=0; mbrow_i<mbrow_m; mbrow_i++){{
                    for(int mbrow_i2=0; mbrow_i2<mbrow_m2; mbrow_i2++){{
                        for(int brow_i=0; brow_i<brow_m; brow_i++){{
                            for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                                for(int col_i=0; col_i<col_kk; col_i++){{
                                    for(int col_ii=0; col_ii<{m_fixed}; col_ii++){{
                                        columns[col_i*{m_fixed} + col_ii] =
                                        hA_columns[bm_i*col_k*col_kk*{m_fixed} + mcol_i*col_kk*{m_fixed} + col_i*{m_fixed} + col_ii];
                                    }}
                                }}
                                for(int mbrow_ii=0; mbrow_ii<({brow}/{mrow_m}); mbrow_ii++){{
                                    for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                        for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                            int pos=0;
                                            for(int n_i=0; n_i<{m_fixed}; n_i++){{
                                                unsigned int index = columns[mcol_ii*{m_fixed} + n_i];

                                                if((mcol_i*{m}*mcol_kk + mcol_ii*{m} + index) < {ncols}){{
                                                    int nnz = masks[
                                                            bm_i*{bm}*{ncols} +
                                                            mbrow_i*{mbrow}*{ncols} +
                                                            mbrow_i2*{brow_fixed}*{ncols} +
                                                            brow_i*{brow}*{ncols} +
                                                            mcol_i*{m}*mcol_kk +
                                                            mbrow_ii*{mrow_m}*{ncols} +
                                                            mcol_ii*{m} +
                                                            mbrow_iii*{ncols} +
                                                            index];

                                                    if(nnz != 0){{
                                                        indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            pos] = n_i;

                                                        values[
                                                            mcol_ii*{mrow_m}*{n} +
                                                            mbrow_iii*{n} +
                                                            pos] =
                                                        sparse[
                                                            bm_i*{bm}*{ncols} +
                                                            mbrow_i*{mbrow}*{ncols} +
                                                            mbrow_i2*{brow_fixed}*{ncols} +
                                                            brow_i*{brow}*{ncols} +
                                                            mcol_i*{m}*mcol_kk +
                                                            mbrow_ii*{mrow_m}*{ncols} +
                                                            mcol_ii*{m} +
                                                            mbrow_iii*{ncols} +
                                                            index];

                                                        pos+=1;
                                                    }}
                                                }} else {{
                                                    if(n_i<2){{
                                                        indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            pos] = 0;

                                                        values[
                                                            mcol_ii*{mrow_m}*{n} +
                                                            mbrow_iii*{n} +
                                                            pos] = 0;

                                                        pos+=1;
                                                    }}
                                                }}
                                            }}
                                        }}
                                    }}
                                    // write metadata
                                    unsigned int meta=0;
                                    for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                        for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                            for (int n_i=0; n_i<{n}; n_i++) {{

                                                int idx = bm_i*{bm}*{A_num_cols_sp_pad} +
                                                        mbrow_i*{mbrow}*{A_num_cols_sp_pad}+
                                                        mbrow_i2*{brow_fixed}*{A_num_cols_sp_pad}+
                                                        brow_i*{brow}*{nelems}/{mrow_m}+
                                                        mcol_i*{brow_fixed}*{nelems}/{mrow_m} +
                                                        mbrow_ii*{mrow_m}*{n} +
                                                        mcol_ii*{n}*{brow} +
                                                        mbrow_iii*{n} +
                                                        n_i;

                                                max_idx = (idx>max_idx)?(idx):(max_idx);

                                                hA_values[
                                                        idx] =
                                                values[
                                                    mcol_ii*{mrow_m}*{n} +
                                                    mbrow_iii*{n} +
                                                    n_i];

                                                unsigned int tmp = indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            n_i];
                                                meta |= (tmp << (mbrow_iii*({nelems}/{mrow_m})*{bits_elem_meta}+mcol_ii*{n}*{bits_elem_meta}+n_i*{bits_elem_meta}));
                                            }}
                                        }}
                                    }}
                                    hA_metadata[bm_i*mcol_k*{bm}/{mrow_m} +
                                                mbrow_i*mcol_k*{mbrow}/{mrow_m} +
                                                mbrow_i2*{brow_fixed}/{mrow_m} +
                                                brow_i*{brow}/{mrow_m}  +
                                                mcol_i*{mbrow}/{mrow_m} +
                                                mbrow_ii] = meta;
                                }}
                            }}
                        }}
                    }}
                }}
            }}
            cout << "max_idx: " << max_idx << endl;
        }}
        """,
    )
    lib.func2.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func2

In [21]:
def round_up(x,y):
    return math.ceil(x/y)*y

In [22]:
class SparseVNMTensor:
    def __init__(self, n_, m_, v_, dense_, mask_, columns_, device_):
        self.n = n_
        self.m = m_
        self.v = v_
        self.nnz = 0
        self.nrows = None
        self.ncols = None
        
        self.dense = dense_.cpu().to(dtype=torch.float32)
        self.device=device_
        
        self.columns = columns_
        self.values = None
        self.metadata = None

        self.mask = mask_

        self.to_venom(dense_.cpu().to(dtype=torch.float32), mask_.cpu())

    def to_venom(self, dense_, mask_):
        impl_builder = (
            dense2venom
            )
        func = impl_builder(
                dense_.shape,
                dense_.dtype,
                self.n,
                self.m,
                self.v
            )

        self.nrows, self.ncols = dense_.shape
        A_num_cols_sp_pad = round_up((round_up(self.ncols, self.m)/self.m)*self.n, 16)
        self.nnz = self.nrows*A_num_cols_sp_pad
        m_fixed = 4
        mrow_m = 2
        bits_elem_meta=2

        nelems = 32//bits_elem_meta
        nelems_col = nelems//mrow_m

        self.values = torch.zeros(self.nrows * A_num_cols_sp_pad, dtype=torch.float32, device="cpu")
        self.metadata = torch.zeros(self.nrows//mrow_m * A_num_cols_sp_pad//nelems_col, dtype=torch.int32, device="cpu")
        
        func(dense_.data_ptr(), mask_.data_ptr(), self.values.data_ptr(), self.columns.data_ptr(), self.metadata.data_ptr())

    def to_dense(self):
        impl_builder = (
            venom2dense
            )
        func = impl_builder(
                (self.nrows, self.ncols),
                torch.float32, 
                self.n,
                self.m,
                self.v
            )
        # initialize with ones
        dense = torch.zeros((self.nrows, self.ncols), dtype=self.values.dtype, device="cpu", requires_grad=True)
        
        func(dense.data_ptr(), self.values.data_ptr(), self.columns.data_ptr(), self.metadata.data_ptr())

        return dense.to(device="cuda:0").half()

In [23]:
class NMVectorSparsifier:
    def __init__(self, n, m, v):
        self.n = n
        self.m = m
        self.v = v

    @staticmethod
    def get_random_mask(tensor, m, v):
        mask = torch.zeros(tensor.shape, dtype=tensor.dtype)
        m_tmp = torch.cat( (torch.tensor([1,0,1,0]), torch.zeros(m-4)), 0 )
        mask = mask.reshape(-1, v, m) + m_tmp
        mask = mask.reshape(tensor.shape)

        return mask

    def __call__(self, tensor, grad_fmt=None):
        nrows, ncols = tensor.shape
        columns = torch.zeros(nrows//self.v, ncols//self.m*4, dtype=torch.int32)
        columns = columns.reshape((-1,4)) + torch.tensor([0,1,2,3], dtype=torch.int32)
        columns = columns.reshape((nrows//self.v, ncols//self.m*4))

        mask = NMVectorSparsifier.get_random_mask(tensor, self.m, self.v)

        sparse_mtx = sten.SparseTensorWrapper.wrapped_from_dense(
            SparseVNMTensor(self.n, self.m, self.v, tensor, mask, columns, tensor.device),
            tensor,
            grad_fmt,
        )

        return sparse_mtx

In [24]:
torch.set_printoptions(precision=2)

input = torch.randn(256, 128, requires_grad=True, dtype=torch.half, device="cuda:0")
weight = torch.nn.Linear(128, 512, bias=True, dtype=torch.half, device="cuda:0")

In [25]:
v=64
n=2
m=8

In [26]:
import spatha_sddmm

In [27]:
global_grad_output = None

In [28]:
class VenomLinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)

        output = torch.matmul(input, weight.t())
        if bias is not None:
            output += bias
            
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors

        global global_grad_output  
        global_grad_output = grad_output 

        print(input.device, weight.device, grad_output.device)

        grad_input = grad_weight = grad_bias = None
        
        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight.to("cuda:0")
        
        if ctx.needs_input_grad[1]:
            #grad_weight = grad_output.t() @ input.to("cuda:0")
            grad_weight = grad_output.t() @ input.to("cuda:0")

            dense_grad_weight = torch.einsum("...m,...n->mn", grad_output, input)
            print("dense_grad_weight", dense_grad_weight[:4, :4])
        
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        print(ctx.needs_input_grad[0], ctx.needs_input_grad[1], ctx.needs_input_grad[2])
        """ print(grad_output)
        print(input)
        print(grad_weight) """
        #print(grad_input[:4, :4])
        print("grad_weight", grad_weight[:4, :4])
        print(torch.allclose(dense_grad_weight, grad_weight, atol=0.05))
        
        return grad_input, grad_weight, grad_bias

In [29]:
class SrnmSpmm(torch.nn.Module):
    def __init__(self, original: torch.nn.Linear):
        super(SrnmSpmm, self).__init__()        

        self.w = NMVectorSparsifier(n, m, v)(original.weight).wrapped_tensor

        self.values = torch.nn.Parameter(self.w.values.to(device="cuda:0").half())
        self.columns = self.w.columns.to(device="cuda:0")
        self.metadata = self.w.metadata.to(device="cuda:0")

        self.bias = original.bias

        self.dense = torch.nn.Parameter(self.w.to_dense())
        self.mask = self.w.mask

        self.nrows_sp = self.w.nrows
        self.ncols_sp = self.w.ncols
        self.nnz      = self.w.nnz

    def forward(self, input):
        
        return VenomLinearFunction.apply(input, self.dense, self.bias)

In [30]:
sparse_weight = SrnmSpmm(weight)

print( type(sparse_weight) )

max_idx: 16383
<class '__main__.SrnmSpmm'>


In [31]:
sparse = sparse_weight(input)

In [32]:
dense = input @ sparse_weight.dense.T + sparse_weight.bias

In [33]:
weight.weight = torch.nn.Parameter(sparse_weight.dense)

In [34]:
dense2 = weight(input)

In [35]:
torch.allclose(dense, dense2, atol=0.001)

True

In [36]:
torch.allclose(sparse, dense2, atol=0.001)

True

In [37]:
dense

tensor([[-0.52,  0.13, -0.01,  ...,  0.20,  0.11, -0.07],
        [-0.31, -0.13, -0.22,  ...,  0.22,  0.18, -0.12],
        [-0.09, -0.03, -0.06,  ...,  0.01,  0.26,  0.31],
        ...,
        [-0.12,  0.23, -0.68,  ..., -0.37,  0.18, -0.05],
        [ 0.26,  0.21,  0.22,  ..., -0.57,  0.01,  0.00],
        [-0.08, -0.52,  0.37,  ...,  0.10,  0.27, -0.09]], device='cuda:0',
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [38]:
dense2

tensor([[-0.52,  0.13, -0.01,  ...,  0.20,  0.11, -0.07],
        [-0.31, -0.13, -0.22,  ...,  0.22,  0.18, -0.12],
        [-0.09, -0.03, -0.06,  ...,  0.01,  0.26,  0.31],
        ...,
        [-0.12,  0.23, -0.68,  ..., -0.37,  0.18, -0.05],
        [ 0.26,  0.21,  0.22,  ..., -0.57,  0.01,  0.00],
        [-0.08, -0.52,  0.37,  ...,  0.10,  0.27, -0.09]], device='cuda:0',
       dtype=torch.float16, grad_fn=<AddmmBackward0>)

In [39]:
#print( dense[:8, :8] )

In [40]:
#print( sparse.device, sparse.dtype, dense.device, dense.dtype)

In [41]:
print( torch.allclose(sparse.half().cuda(), dense) )

True


In [42]:
print(sparse.sum())

tensor(74.12, device='cuda:0', dtype=torch.float16, grad_fn=<SumBackward0>)


In [43]:
print(input.grad)

None


In [44]:
print(sparse_weight.dense.grad)

None


In [45]:
sparse.sum().backward()

cuda:0 cuda:0 cuda:0
dense_grad_weight tensor([[  4.60, -21.77,  34.81,  16.61],
        [  4.60, -21.77,  34.81,  16.61],
        [  4.60, -21.77,  34.81,  16.61],
        [  4.60, -21.77,  34.81,  16.61]], device='cuda:0',
       dtype=torch.float16)
True True True
grad_weight tensor([[  4.60, -21.77,  34.81,  16.61],
        [  4.60, -21.77,  34.81,  16.61],
        [  4.60, -21.77,  34.81,  16.61],
        [  4.60, -21.77,  34.81,  16.61]], device='cuda:0',
       dtype=torch.float16)
True


In [46]:
print(input.grad[:8, :8])

tensor([[-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00],
        [-1.26,  0.00,  1.15,  0.00,  0.00,  0.00,  0.00,  0.00]],
       device='cuda:0', dtype=torch.float16)


In [47]:
print(sparse_weight.dense.grad[:8, :8])

tensor([[  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18]],
       device='cuda:0', dtype=torch.float16)


In [48]:
print(global_grad_output)

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16)


In [49]:
print(global_grad_output.dtype, global_grad_output.device)
print(input.dtype, input.device)
print(sparse_weight.metadata.dtype, sparse_weight.metadata.device)
print(sparse_weight.columns.dtype, sparse_weight.columns.device)

torch.float16 cuda:0
torch.float16 cuda:0
torch.int32 cuda:0
torch.int32 cuda:0


In [50]:
global_grad_output = global_grad_output.T.contiguous()
input = input.T.contiguous()

In [51]:
print(  type(global_grad_output),    
        type(input),                 
        type(sparse_weight.metadata),
        type(sparse_weight.columns), 
        type(sparse_weight.nrows_sp),
        type(sparse_weight.ncols_sp),
        type(n),                     
        type(m),                     
        type(sparse_weight.nnz),     
        type(0),                     
        type(32),                    
        type(4)                      
        )

<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'int'> <class 'int'> <class 'int'> <class 'int'> <class 'int'> <class 'int'> <class 'int'> <class 'int'>


In [52]:
print(sparse_weight.nrows_sp, 
      sparse_weight.ncols_sp, 
      input.shape[1], 
      n, 
      m, 
      sparse_weight.nnz)
print(global_grad_output.dtype, input.dtype)
print(global_grad_output.device, input.device)
print(global_grad_output.shape, 
            input.shape, 
            sparse_weight.metadata.shape, 
            sparse_weight.columns.shape)

compressed_grad_weights = spatha_sddmm.sddmm(
                            global_grad_output,     # A_matrix
                            input,                  # B_matrix
                            sparse_weight.metadata, # C_metadata
                            sparse_weight.columns,  # C_indices
                            sparse_weight.nrows_sp, # C_num_rows
                            sparse_weight.ncols_sp, # C_num_cols    
                            input.shape[1], 
                            n,                      # N
                            m,                      # M
                            sparse_weight.nnz,      # nnz
                            0,                      # seed
                            32,                     # mbrow
                            4                       # brow
                            )
print(global_grad_output.shape)
print(input.shape)
print(compressed_grad_weights.shape)
print(sparse_weight.nrows_sp, 
      sparse_weight.ncols_sp,
      input.shape[1])

512 128 256 2 8 16384
torch.float16 torch.float16
cuda:0 cuda:0
torch.Size([512, 256]) torch.Size([128, 256]) torch.Size([1024]) torch.Size([8, 64])
torch.Size([512, 256])
torch.Size([128, 256])
torch.Size([512, 32])
512 128 256


In [53]:
print(compressed_grad_weights[:8, :8])

tensor([[ 4.60, 34.81,  4.60, 34.81,  4.60, 34.81,  4.60, 34.81],
        [ 4.60, 34.81,  4.60, 34.81,  4.60, 34.81,  4.60, 34.81],
        [ 4.60, 34.81,  4.60, 34.81,  4.60, 34.81,  4.60, 34.81],
        [ 4.60, 34.81,  4.60, 34.81,  4.60, 34.81,  4.60, 34.81],
        [20.72, -1.66, 20.72, -1.66, 20.72, -1.66, 20.72, -1.66],
        [20.72, -1.66, 20.72, -1.66, 20.72, -1.66, 20.72, -1.66],
        [20.72, -1.66, 20.72, -1.66, 20.72, -1.66, 20.72, -1.66],
        [20.72, -1.66, 20.72, -1.66, 20.72, -1.66, 20.72, -1.66]],
       device='cuda:0', dtype=torch.float16)


In [54]:
compressed_grad_weights = compressed_grad_weights.float().cpu()
columns = sparse_weight.columns.cpu()
metadata = sparse_weight.metadata.cpu()

impl_builder = (
            venom2dense
            )
func = impl_builder(
        (sparse_weight.nrows_sp, sparse_weight.ncols_sp),
        torch.float32, 
        n,
        m,
        v
    )

dense = torch.zeros((sparse_weight.nrows_sp, sparse_weight.ncols_sp), dtype=compressed_grad_weights.dtype, device="cpu", requires_grad=True)

func(dense.data_ptr(), compressed_grad_weights.data_ptr(), columns.data_ptr(), metadata.data_ptr())

print(dense[0:8, 0:8])

tensor([[ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 4.60,  0.00, 34.81,  0.00,  0.00,  0.00,  0.00,  0.00]],
       grad_fn=<SliceBackward0>)


In [55]:
print(sparse_weight.dense.grad[:8, :8])

tensor([[  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18],
        [  4.60, -21.77,  34.81,  16.61, -10.14,  -2.64, -24.31,   0.18]],
       device='cuda:0', dtype=torch.float16)


In [56]:
torch.equal(dense, sparse_weight.dense.grad.cpu()*sparse_weight.mask)

True

In [57]:
sparse_weight.dense.grad.shape

torch.Size([512, 128])

In [58]:
print(dense.shape)

torch.Size([512, 128])
