In [18]:
import torch
import sten

In [19]:
import spatha
import spatha_sddmm

In [20]:
from native_scripting import compile
import functools
import ctypes
import time
import math
from heapq import nlargest

In [21]:
try:
    cache = functools.cache
except AttributeError:
    cache = functools.lru_cache(maxsize=None)

#Weight candidate selection for removal

In [22]:
@cache
def group_n_m2(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    m_fixed = 4
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    indexes_cols = A_num_cols_sp_pad//n*m_fixed
    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;

        int int_ceil(int x, int y){{
            return (x - 1) / y + 1;
        }}

        extern "C" void func({dtype}* dense, {dtype}* sparse, int* masks, int *columns){{
            int bm_m   = {nrows}/{tileM};
            int mcol_k = {ncols}/{m};
            int mcol_k_p = int_ceil({ncols},{m});
            int m_fixed = 4;

            std::vector<{dtype}> v(m_fixed, 0);
            std::vector<int> vx(m_fixed, 0);
            {dtype} max=0, total=0;

            std::vector<size_t> indices(v.size());
            std::iota(indices.begin(), indices.end(), 0);

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                int t_bm_i   = bm_i*{tileM}*{ncols};
                for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                    int t_mcol_i = mcol_i*{m};
                    max = 0;

                    std::vector<int> cols_max;
                    cols_max.resize(m_fixed, 0);
                    std::vector<int> masks_max;
                    masks_max.resize({tileM}*{m}, 0);

                    for(int col_i=0; col_i<{m}; col_i++){{
                        vx[0]=col_i;
                        for(int col_j=col_i+1; col_j<{m}; col_j++){{
                            vx[1]=col_j;
                            for(int col_k=col_j+1; col_k<{m}; col_k++){{
                                vx[2]=col_k;
                                for(int col_w=col_k+1; col_w<{m}; col_w++){{
                                    vx[3]=col_w;
                                    total=0;

                                    std::vector<int> mask({tileM}*{m}, 0);

                                    for(int row_i=0; row_i<{tileM}; row_i++){{
                                        int t_row_i  = row_i*{ncols};
                                        v[0]=dense[t_bm_i + t_row_i + t_mcol_i + col_i];
                                        v[1]=dense[t_bm_i + t_row_i + t_mcol_i + col_j];
                                        v[2]=dense[t_bm_i + t_row_i + t_mcol_i + col_k];
                                        v[3]=dense[t_bm_i + t_row_i + t_mcol_i + col_w];

                                        std::partial_sort(indices.begin(), indices.begin() + {n}, indices.end(), [&](size_t A, size_t B) {{
                                                    return v[A] > v[B];}});

                                        for(int id=0; id<{n}; id++){{
                                            total += dense[t_bm_i + t_row_i + t_mcol_i + vx[indices[id]]];

                                            mask[row_i*{m} + vx[indices[id]]] = 1;
                                        }}
                                    }}

                                    if(total>max){{
                                        max = total;
                                        std::copy(mask.begin(), mask.end(), masks_max.begin());

                                        std::sort(vx.begin(), vx.begin() + m_fixed);
                                        std::copy(vx.begin(), vx.end(), cols_max.begin());
                                    }}
                                }}
                            }}
                        }}
                    }}

                    for(int i=0; i<{tileM}; i++){{
                        for(int j=0; j<{m}; j++){{
                            int drop = masks_max[i*{m} + j];
                            masks[t_bm_i  + i*{ncols}+ t_mcol_i + j]  = drop;
                            sparse[t_bm_i + i*{ncols}+ t_mcol_i + j] *= drop;
                        }}
                    }}
                    for(int i=0; i<m_fixed; i++){{
                        columns[bm_i*{indexes_cols} + mcol_i*m_fixed + i] =
                        cols_max[i];
                    }}
                }}
            }}

            int remainder = {ncols}%{m};

            if (remainder>0){{
                int d_rows={tileM}, d_cols=remainder;

                if(remainder<3){{
                    for(int i=0; i<{nrows}; i++){{
                        for(int j={ncols}-remainder; j<{ncols}; j++){{
                            masks[i*{ncols}+j] = 1;
                        }}
                    }}
                    for(int bm_i=0; bm_i<bm_m; bm_i++){{
                        for(int j=0; j<m_fixed; j++){{
                            columns[bm_i*{indexes_cols} + mcol_k*m_fixed + j] = j;
                        }}
                    }}
                }} else if(remainder==3){{
                    v[3] = -1;
                    for(int bm_i=0; bm_i<bm_m; bm_i++){{
                        int t_bm_i   = bm_i*{tileM}*{ncols};
                        for(int mcol_i=mcol_k; mcol_i<mcol_k_p; mcol_i++){{
                            max = 0;
                            int t_mcol_i = mcol_i*{m};

                            std::vector<int> cols_max(m_fixed, 0);
                            std::vector<int> masks_max({tileM}*remainder, 0);

                            for(int col_i=0; col_i<remainder; col_i++){{
                                vx[0]=col_i;
                                for(int col_j=col_i+1; col_j<remainder; col_j++){{
                                    vx[1]=col_j;
                                    for(int col_k=col_j+1; col_k<remainder; col_k++){{
                                        vx[2]=col_k;
                                        total=0;
                                        std::vector<int> mask({tileM}*remainder, 0);

                                        for(int row_i=0; row_i<{tileM}; row_i++){{
                                            int t_row_i  = row_i*{ncols};
                                            v[0]=dense[t_bm_i + t_row_i + t_mcol_i + col_i];
                                            v[1]=dense[t_bm_i + t_row_i + t_mcol_i + col_j];
                                            v[2]=dense[t_bm_i + t_row_i + t_mcol_i + col_k];

                                            std::partial_sort(indices.begin(), indices.begin() + {n}, indices.end(), [&](size_t A, size_t B) {{
                                                        return v[A] > v[B]; }});

                                            for(int id=0; id<{n}; id++){{
                                                total += dense[t_bm_i + t_row_i + t_mcol_i + vx[indices[id]]];

                                                mask[row_i*remainder + vx[indices[id]]] = 1;
                                            }}
                                        }}

                                        if(total>max){{
                                            max = total;
                                            std::copy(mask.begin(), mask.end(), masks_max.begin());

                                            std::sort(vx.begin(), vx.begin() + remainder);//m_fixed
                                            std::copy(vx.begin(), vx.end(), cols_max.begin());
                                        }}
                                    }}
                                }}
                            }}

                            for(int i=0; i<{tileM}; i++){{
                                for(int j=0; j<remainder; j++){{
                                    int drop = masks_max[i*remainder + j];

                                    masks[t_bm_i  + i*{ncols}+ t_mcol_i + j]  = drop;
                                    sparse[t_bm_i + i*{ncols}+ t_mcol_i + j] *= drop;
                                }}
                            }}
                            for(int i=0; i<remainder; i++){{
                                columns[bm_i*{indexes_cols} + mcol_i*m_fixed + i] =
                                cols_max[i];
                            }}
                        }}
                    }}
                }} else {{
                    for(int bm_i=0; bm_i<bm_m; bm_i++){{
                        int t_bm_i   = bm_i*{tileM}*{ncols};
                        for(int mcol_i=mcol_k; mcol_i<mcol_k_p; mcol_i++){{
                            max = 0;
                            int t_mcol_i = mcol_i*{m};

                            std::vector<int> cols_max(m_fixed, 0);
                            std::vector<int> masks_max({tileM}*remainder, 0);

                            for(int col_i=0; col_i<remainder; col_i++){{
                                vx[0]=col_i;
                                for(int col_j=col_i+1; col_j<remainder; col_j++){{
                                    vx[1]=col_j;
                                    for(int col_k=col_j+1; col_k<remainder; col_k++){{
                                        vx[2]=col_k;
                                        for(int col_w=col_k+1; col_w<remainder; col_w++){{
                                            vx[3]=col_w;
                                            total=0;
                                            std::vector<int> mask({tileM}*remainder, 0);

                                            for(int row_i=0; row_i<{tileM}; row_i++){{
                                                int t_row_i  = row_i*{ncols};
                                                v[0]=dense[t_bm_i + t_row_i + t_mcol_i + col_i];
                                                v[1]=dense[t_bm_i + t_row_i + t_mcol_i + col_j];
                                                v[2]=dense[t_bm_i + t_row_i + t_mcol_i + col_k];
                                                v[3]=dense[t_bm_i + t_row_i + t_mcol_i + col_w];

                                                std::partial_sort(indices.begin(), indices.begin() + {n}, indices.end(), [&](size_t A, size_t B) {{
                                                            return v[A] > v[B]; }});

                                                for(int id=0; id<{n}; id++){{
                                                    total += dense[t_bm_i + t_row_i + t_mcol_i + vx[indices[id]]];

                                                    mask[row_i*remainder + vx[indices[id]]] = 1;
                                                }}
                                            }}

                                            if(total>max){{
                                                max = total;
                                                std::copy(mask.begin(), mask.end(), masks_max.begin());

                                                std::sort(vx.begin(), vx.begin() + m_fixed);
                                                std::copy(vx.begin(), vx.end(), cols_max.begin());
                                            }}
                                        }}
                                    }}
                                }}
                            }}

                            for(int i=0; i<{tileM}; i++){{
                                for(int j=0; j<remainder; j++){{
                                    int drop = masks_max[i*remainder + j];

                                    masks[t_bm_i  + i*{ncols}+ t_mcol_i + j]  = drop;
                                    sparse[t_bm_i + i*{ncols}+ t_mcol_i + j] *= drop;
                                }}
                            }}
                            for(int i=0; i<m_fixed; i++){{
                                columns[bm_i*{indexes_cols} + mcol_i*m_fixed + i] =
                                cols_max[i];
                            }}
                        }}
                    }}
                }}
            }}
        }}
        """,
    )
    lib.func.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func

#VENOM to dense

In [23]:
@cache
def to_dense(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    A_size = nrows*ncols
    density = n/m

    brow = 4 #this->brow = brow_;
    mbrow = 32 #this->mbrow = mbrow_;

    bm   = tileM
    # !IMPORTANT! constants because of architecture constraints
    m_fixed = 4
    bits_elem_meta=2
    mrow_m = 2
    bits_elem_cols=8
    brow_fixed = 16
    nelems=32//bits_elem_meta #(sizeof(uint)*8)=32
    nelems_col = nelems//mrow_m

    A_num_cols_sp = (ncols/m)*n
    A_num_cols_sp_pad_nm = (round_up(ncols, m)/m)*n
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    A_nnz = nrows*A_num_cols_sp_pad

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;


        extern "C" void func3({dtype}* hA_dense, {dtype}* hA_values, int *hA_columns, int *hA_metadata){{
            //this->hA_dense.resize(this->A_size, 0);

            // general variables N:M format
            int bm_m = {nrows}/{bm};
            int mbrow_m = {bm}/{mbrow};
            int mbrow_m2 = {mbrow}/{brow_fixed};
            int brow_m = {brow_fixed}/{brow};
            // metadata
            int mcol_kk = {nelems}/{mrow_m}/{n};
            int mcol_k = {A_num_cols_sp_pad}/{n}/mcol_kk;
            // indices
            int col_kk = mcol_kk;
            int col_k = {A_num_cols_sp_pad}/{n}/col_kk;

            uint indexes[{nelems}];
            uint columns[col_kk*{m_fixed}];

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                for(int mbrow_i=0; mbrow_i<mbrow_m; mbrow_i++){{
                    for(int mbrow_i2=0; mbrow_i2<mbrow_m2; mbrow_i2++){{
                        for(int brow_i=0; brow_i<brow_m; brow_i++){{
                            for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                                //read columns indexes
                                for(int col_i=0; col_i<col_kk; col_i++){{
                                    for(int col_ii=0; col_ii<{m_fixed}; col_ii++){{
                                        columns[col_i*{m_fixed} + col_ii] =
                                        hA_columns[bm_i*col_k*col_kk*{m_fixed} + mcol_i*col_kk*{m_fixed} + col_i*{m_fixed} + col_ii];
                                    }}
                                }}
                                // read metadata
                                for(int mbrow_ii=0; mbrow_ii<({brow}/{mrow_m}); mbrow_ii++){{
                                    for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                        for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                            for (int n_i=0; n_i<{n}; n_i++) {{
                                                indexes[
                                                    mbrow_iii*{n} +
                                                    mcol_ii*{mrow_m}*{n} +
                                                    n_i] =
                                                (((hA_metadata[
                                                    bm_i*mcol_k*{bm}/{mrow_m} +
                                                    mbrow_i*mcol_k*{mbrow}/{mrow_m} +
                                                    mbrow_i2*{brow_fixed}/{mrow_m} +
                                                    brow_i*{brow}/{mrow_m}  +
                                                    mcol_i*{mbrow}/{mrow_m} +
                                                    mbrow_ii]) >> (mbrow_iii*({nelems}/{mrow_m})*{bits_elem_meta}+mcol_ii*{n}*{bits_elem_meta}+n_i*{bits_elem_meta})) & 0x3);
                                            }}
                                        }}
                                    }}

                                    for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                        for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                            for(int n_i=0; n_i<{n}; n_i++){{
                                                unsigned int index = columns[mcol_ii*{m_fixed} + indexes[mcol_ii*{mrow_m}*{n}+mbrow_iii*{n}+n_i]];

                                                if((mcol_i*{m}*mcol_kk + mcol_ii*{m} + index) < {ncols}){{
                                                    hA_dense[
                                                        bm_i*{bm}*{ncols} +
                                                        mbrow_i*{mbrow}*{ncols} +
                                                        mbrow_i2*{brow_fixed}*{ncols} +
                                                        brow_i*{brow}*{ncols} +
                                                        mcol_i*{m}*mcol_kk +
                                                        mbrow_ii*{mrow_m}*{ncols} +
                                                        mcol_ii*{m} +
                                                        mbrow_iii*{ncols} +
                                                        index] =
                                                    hA_values[
                                                        bm_i*{bm}*{A_num_cols_sp_pad} +
                                                        mbrow_i*{mbrow}*{A_num_cols_sp_pad}+
                                                        mbrow_i2*{brow_fixed}*{A_num_cols_sp_pad}+
                                                        brow_i*{brow}*{nelems}/{mrow_m}+
                                                        mcol_i*{brow_fixed}*{nelems}/{mrow_m} +
                                                        mbrow_ii*{mrow_m}*{n} +
                                                        mcol_ii*{n}*{brow} +
                                                        mbrow_iii*{n} +
                                                        n_i];
                                                }}
                                            }}
                                        }}
                                    }}
                                }}
                            }}
                        }}
                    }}
                }}
            }}
        }}
        """,
    )
    lib.func3.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func3

#Dense to VENOM

In [24]:
@cache
def to_sparse_sr_nm(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    brow = 4 #this->brow = brow_;
    mbrow = 32 #this->mbrow = mbrow_;

    bm   = tileM
    # !IMPORTANT! constants because of architecture constraints
    m_fixed = 4
    bits_elem_meta=2
    mrow_m = 2
    bits_elem_cols=8
    brow_fixed = 16
    nelems=32//bits_elem_meta #(sizeof(uint)*8)=32
    nelems_col = nelems//mrow_m

    A_num_cols_sp = (ncols//m)*n
    A_num_cols_sp_pad_nm = (round_up(ncols, m)/m)*n
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    A_nnz = nrows*A_num_cols_sp_pad

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;


        extern "C" void func2({dtype}* sparse, int* masks, {dtype}* hA_values, int *hA_columns, int *hA_metadata){{

            int bm_m = {nrows}/{bm};
            int mbrow_m = {bm}/{mbrow};
            int mbrow_m2 = {mbrow}/{brow_fixed};
            int brow_m = {brow_fixed}/{brow};
            // metadata
            int mcol_kk = {nelems}/{mrow_m}/{n};
            int mcol_k = {A_num_cols_sp_pad}/{n}/mcol_kk;
            // indices
            int col_kk = mcol_kk;
            int col_k = {A_num_cols_sp_pad}/{n}/col_kk;

            {dtype} values[{nelems}];
            uint indexes[{nelems}];
            uint columns[col_kk*{m_fixed}];

            int max_idx = 0;

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                for(int mbrow_i=0; mbrow_i<mbrow_m; mbrow_i++){{
                    for(int mbrow_i2=0; mbrow_i2<mbrow_m2; mbrow_i2++){{
                        for(int brow_i=0; brow_i<brow_m; brow_i++){{
                            for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                                for(int col_i=0; col_i<col_kk; col_i++){{
                                    for(int col_ii=0; col_ii<{m_fixed}; col_ii++){{
                                        columns[col_i*{m_fixed} + col_ii] =
                                        hA_columns[bm_i*col_k*col_kk*{m_fixed} + mcol_i*col_kk*{m_fixed} + col_i*{m_fixed} + col_ii];
                                    }}
                                }}
                                for(int mbrow_ii=0; mbrow_ii<({brow}/{mrow_m}); mbrow_ii++){{
                                    for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                        for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                            int pos=0;
                                            for(int n_i=0; n_i<{m_fixed}; n_i++){{
                                                unsigned int index = columns[mcol_ii*{m_fixed} + n_i];

                                                if((mcol_i*{m}*mcol_kk + mcol_ii*{m} + index) < {ncols}){{
                                                    int nnz = masks[
                                                            bm_i*{bm}*{ncols} +
                                                            mbrow_i*{mbrow}*{ncols} +
                                                            mbrow_i2*{brow_fixed}*{ncols} +
                                                            brow_i*{brow}*{ncols} +
                                                            mcol_i*{m}*mcol_kk +
                                                            mbrow_ii*{mrow_m}*{ncols} +
                                                            mcol_ii*{m} +
                                                            mbrow_iii*{ncols} +
                                                            index];

                                                    if(nnz != 0){{
                                                        indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            pos] = n_i;

                                                        values[
                                                            mcol_ii*{mrow_m}*{n} +
                                                            mbrow_iii*{n} +
                                                            pos] =
                                                        sparse[
                                                            bm_i*{bm}*{ncols} +
                                                            mbrow_i*{mbrow}*{ncols} +
                                                            mbrow_i2*{brow_fixed}*{ncols} +
                                                            brow_i*{brow}*{ncols} +
                                                            mcol_i*{m}*mcol_kk +
                                                            mbrow_ii*{mrow_m}*{ncols} +
                                                            mcol_ii*{m} +
                                                            mbrow_iii*{ncols} +
                                                            index];

                                                        pos+=1;
                                                    }}
                                                }} else {{
                                                    if(n_i<2){{
                                                        indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            pos] = 0;

                                                        values[
                                                            mcol_ii*{mrow_m}*{n} +
                                                            mbrow_iii*{n} +
                                                            pos] = 0;

                                                        pos+=1;
                                                    }}
                                                }}
                                            }}
                                        }}
                                    }}
                                    // write metadata
                                    unsigned int meta=0;
                                    for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                        for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                            for (int n_i=0; n_i<{n}; n_i++) {{

                                                int idx = bm_i*{bm}*{A_num_cols_sp_pad} +
                                                        mbrow_i*{mbrow}*{A_num_cols_sp_pad}+
                                                        mbrow_i2*{brow_fixed}*{A_num_cols_sp_pad}+
                                                        brow_i*{brow}*{nelems}/{mrow_m}+
                                                        mcol_i*{brow_fixed}*{nelems}/{mrow_m} +
                                                        mbrow_ii*{mrow_m}*{n} +
                                                        mcol_ii*{n}*{brow} +
                                                        mbrow_iii*{n} +
                                                        n_i;

                                                max_idx = (idx>max_idx)?(idx):(max_idx);

                                                hA_values[
                                                        idx] =
                                                values[
                                                    mcol_ii*{mrow_m}*{n} +
                                                    mbrow_iii*{n} +
                                                    n_i];

                                                unsigned int tmp = indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            n_i];
                                                meta |= (tmp << (mbrow_iii*({nelems}/{mrow_m})*{bits_elem_meta}+mcol_ii*{n}*{bits_elem_meta}+n_i*{bits_elem_meta}));
                                            }}
                                        }}
                                    }}
                                    hA_metadata[bm_i*mcol_k*{bm}/{mrow_m} +
                                                mbrow_i*mcol_k*{mbrow}/{mrow_m} +
                                                mbrow_i2*{brow_fixed}/{mrow_m} +
                                                brow_i*{brow}/{mrow_m}  +
                                                mcol_i*{mbrow}/{mrow_m} +
                                                mbrow_ii] = meta;
                                }}
                            }}
                        }}
                    }}
                }}
            }}
            cout << "max_idx: " << max_idx << endl;
        }}
        """,
    )
    lib.func2.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func2

/*************************************************/

In [25]:
def round_up(x,y):
    return math.ceil(x/y)*y

#Input definition

In [26]:
input = torch.randn(64, 64, requires_grad=True)
weights = torch.randn(64, 64, requires_grad=True)
#b = torch.ones(6print_size, 64, requires_grad=True)
mask = torch.ones(64, 64, requires_grad=True)
grad_d = torch.randn(64, 64)
# Known tensors to check operations
reference_tensor = torch.ones(64, 64)
#for index in range(0, 256*256):
#    reference_tensor[index//256][index%256] = index

#print(reference_tensor)
result_forward_dense = torch.mm(torch.add(input, weights), mask)


#STen dense2sparse + sparse2dense

In [27]:
class SparseVNMTensor:
    def __init__(self, v_, n_, m_,  dense_ = None, mask_ = None, columns_ = None, values_ = None, metadata_ = None):
        self.v = v_
        self.n = n_
        self.m = m_
        self.nnz = 0
        if (dense_ is not None):
            self.nrows, self.ncols = dense_.shape
        elif (mask_ is not None):
            self.nrows, self.ncols = mask_.shape
        else:
            self.nrows = None
            self.ncols = None
        #self.dense = dense_
        self.mask = mask_
        self.columns = columns_        
        #self.data = None
        if (values_ is None):
            self.values = None
        elif( values_.dtype == torch.float16):
            self.values = values_.float()
        else:
            self.values = values_
        #self.values = None
        self.metadata = metadata_
        if (dense_ is not None and mask_ is not None):
            # Compress only if dense tensor was provided.
            self.to_sparse_sr_nm(dense_, mask_) 
            # Otherwise already compressed data was provided.
            
        #self.dense = None
        
    
    def to_sparse_sr_nm(self, dense_, mask_):
        impl_builder = (
            to_sparse_sr_nm
            )
        func = impl_builder(
                dense_.shape,
                dense_.dtype,
                self.n,
                self.m,
                self.v
            )
        
        self.nrows, self.ncols = dense_.shape
        A_num_cols_sp_pad = round_up((round_up(self.ncols, self.m)/self.m)*self.n, 16)
        self.nnz = self.nrows*A_num_cols_sp_pad
        m_fixed = 4
        mrow_m = 2
        bits_elem_meta=2

        nelems = 32//bits_elem_meta #32=(sizeof(uint)*8)
        nelems_col = nelems//mrow_m

        self.values = torch.zeros(self.nrows * A_num_cols_sp_pad, dtype=dense_.dtype)
        self.metadata = torch.zeros(self.nrows//mrow_m * A_num_cols_sp_pad//nelems_col, dtype=torch.int32)
        self.mask = mask_

        func(dense_.data_ptr(), mask_.data_ptr(), self.values.data_ptr(), self.columns.data_ptr(), self.metadata.data_ptr())     

    def to_dense(self):
        impl_builder = (
            to_dense
            )
        func = impl_builder(
                (self.nrows, self.ncols),
                torch.float32, #self.values.dtype,
                self.n,
                self.m,
                self.v
            )
        # initialize with ones
        #dense = torch.ones((self.nrows, self.ncols), dtype=torch.float32, device='cpu') #self.values.dtype
        self.values=self.values.cpu()
        # uncomment to keep initial values
        #func(dense.data_ptr(), self.values.cpu().to(dtype=torch.float32).data_ptr(), self.columns.cpu().data_ptr(), self.metadata.cpu().data_ptr())
        dense = torch.zeros((self.nrows, self.ncols), dtype=self.values.dtype, device="cpu")
        #print("Dense device", dense.device, "values device", self.values.device, "columns device", self.columns.device, "metadata device", self.metadata.device)
        #print("Dense shape", dense.shape, "values shape", self.values.shape, "columns shape", self.columns.shape, "metadata shape", self.metadata.shape)
        """ print("to_dense", dense.dtype, self.values.dtype, self.columns.dtype, self.metadata.dtype)
        print("to_dense", dense.device, self.values.device, self.columns.device, self.metadata.device) """
        func(dense.data_ptr(), self.values.data_ptr(), self.columns.data_ptr(), self.metadata.data_ptr())

        return dense.half()

#VENOM Sparsifier class

In [28]:
class NMVectorSparsifier:
    def __init__(self, n, m, v):
        self.n = n
        self.m = m
        self.v = v

    @staticmethod
    def get_random_mask(tensor, m, v):
        mask = torch.zeros(tensor.shape, dtype=tensor.dtype)
        m_tmp = torch.cat( (torch.tensor([1,0,1,0]), torch.zeros(m-4)), 0 )
        mask = mask.reshape(-1, v, m) + m_tmp
        mask = mask.reshape(tensor.shape)

        return mask

    def __call__(self, tensor, grad_fmt=None):
        # random pruning (cuSparseLt-like approach) -> mask, columns
        nrows, ncols = tensor.shape
        columns = torch.zeros(nrows//self.v, ncols//self.m*4, dtype=torch.int32)
        columns = columns.reshape((-1,4)) + torch.tensor([0,1,2,3], dtype=torch.int32)
        columns = columns.reshape((nrows//self.v, ncols//self.m*4))

        mask = NMVectorSparsifier.get_random_mask(tensor, self.m, self.v)

        sparse_mtx = sten.SparseTensorWrapper.wrapped_from_dense(
            SparseVNMTensor(self.n, self.m, self.v, tensor, mask, columns, tensor.device),
            tensor,
            grad_fmt,
        )

        return sparse_mtx
        

#Pruning function 

In [29]:
def nm_vector_mask_sparsify(tensor, v, n, m):
    #print("nm_vector_mask_sparsify", v, n, m)
    

    impl_builder = (
                group_n_m2
                )
    nrows, ncols = tensor.shape
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)            
    bm_m   = nrows//v
    mcol_k_p = math.ceil(ncols/m)
    m_fixed = 4
    
    # Structures represent sparse data
    masks = torch.zeros(tensor.shape, dtype=torch.int32)    
    columns = torch.zeros(nrows//v * A_num_cols_sp_pad//n*m_fixed,dtype=torch.int32)
    
    if len(tensor.shape) == 2:
        tensor_temp = tensor.cpu().detach().abs()
        sparse = tensor_temp.clone()        

        func = impl_builder(
                    tensor_temp.shape,
                    tensor_temp.dtype,
                    n,
                    m,
                    v
                )
        func(tensor_temp.data_ptr(), sparse.data_ptr(), masks.data_ptr(), columns.data_ptr())

    else:
        raise NotImplementedError("Only support layers of dimension 2 or 4")

    return masks, columns

#Sten sparsifier definition

In [30]:
@sten.register_sparsifier_implementation(
    sparsifier=NMVectorSparsifier, inp=torch.Tensor, out=SparseVNMTensor
)
def torch_tensor_to_srnm_random_fraction(sparsifier, tensor, grad_fmt=None):
    #print("inside NMVectorSparsifier sparsifier")
    #print(tensor.dtype)
    masks, columns = nm_vector_mask_sparsify(tensor, sparsifier.v, sparsifier.n, sparsifier.m)
    return sten.SparseTensorWrapper.wrapped_from_dense(
        SparseVNMTensor(sparsifier.v, sparsifier.n, sparsifier.m, dense_=tensor, mask_=masks, columns_=columns),
        tensor,
        grad_fmt,
    )

In [31]:
class CustomAutograd(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [32]:
#class MMWithCustomGrad(torch.nn.Module):
#    def __init__(self, original: torch.nn.Linear, v:int, n:int, m:int):
#        super().__init__()
#        #self.weights = torch.nn.Parameter(torch.rand(10))
#        self.bias = original.bias
#        self._v = v
#        self._n = n
#        self._m = m
#        self.weigth = original.weight
#        #self.bias = torch.zeros(original.bias.shape, dtype=original.bias.dtype, device=original.bias.device)
#
#        # Convert weights from original module to SrNM
#        w = VenomSparsifier(n, m, v)(original.weight).wrapped_tensor
#
#        self.weight_values = torch.nn.Parameter(w.values)
#        self.weight_columns = w.columns
#        self.weight_metadata = w.metadata
#
#        self.nrows_sp = w.nrows
#        self.ncols_sp = w.ncols
#        self.nnz      = w.nnz
#        
#        
#        self.fn = VenomSpmmGrad.apply
#
#    def forward(self, x):
#        pass

/*********************/

In [33]:
tiny_test = False

v=32; n=2; m=8
BM=32
BN=32
BK=32
WM=32
WN=32
WK=32
MM=16
MN=8
MK=32
NSTAGE=2


if tiny_test:
    tiny_test_A = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
    tiny_test_B = torch.tensor([[5., 6.], [7., 8.]], requires_grad=True)

    # A con 0s, dispersa, pero en expresión densa (0s polo medio)
    print("Tiny A: Shape:", tiny_test_A.shape, "\n", tiny_test_A, "\nGrad:", tiny_test_A.grad)
    print("Tiny B: Shape:", tiny_test_B.shape, "\n", tiny_test_B, "\nGrad:", tiny_test_B.grad)
    tiny_mm_output = torch.mm(tiny_test_A, tiny_test_B)
    print("Tiny A · Tiny B: shape", tiny_mm_output.shape, "\n", tiny_mm_output)
    tiny_mm_output.retain_grad()
    tiny_loss = torch.sum(tiny_mm_output); 
    #print("Tiny loss:", tiny_loss)
    print("tiny_mm_output.grad:", tiny_mm_output.grad)
    tiny_loss.backward()
    #tiny_mm_output.backward()
    # Printear matrices A, B, e loss, a ver cal cambia
    print("Tiny A: Shape:", tiny_test_A.shape, "\n", tiny_test_A, "\nGrad:", tiny_test_A.grad)
    print("Tiny B: Shape:", tiny_test_B.shape, "\n", tiny_test_B, "\nGrad:", tiny_test_B.grad)
    print("tiny_mm_output:", tiny_mm_output)
    print("tiny_mm_output.grad:", tiny_mm_output.grad)


In [None]:

small_sten_test = True
print_size = 16
if small_sten_test:
    #import numpy as np
    # Definir matrices que funcionen con SrNMTensors. Mínimo de 32 por dimension
    small_test_A = torch.arange(32 * 32).reshape(32, 32).float().requires_grad_()
    A_masks, A_columns = nm_vector_mask_sparsify(small_test_A, v, n, m)
    sparse_small_test_A = SparseVNMTensor(v, n, m, dense_=small_test_A, mask_=A_masks, columns_=A_columns)
    masked_small_test_A = sparse_small_test_A.to_dense().float().requires_grad_()
    #print("sparse_small_test_A: ", sparse_small_test_A)
    small_test_B = torch.arange(32 * 32).reshape(32, 32).float().requires_grad_()
    # Run normal torch.mm to get correct results.
    #small_mm_output = torch.mm(small_test_A, small_test_B)
    small_mm_output = torch.mm(masked_small_test_A, small_test_B)
    #print("MM result: Shape:", small_mm_output.shape, "\n", small_mm_output[:print_size, :print_size])
    # Run backwards, clear gradients just in case it has something stored
    masked_small_test_A.grad = None
    small_test_B.grad = None
    small_loss = torch.sum(small_mm_output); 
    small_loss.backward()
    #print("Small A: Shape:", masked_small_test_A.shape, "\n", masked_small_test_A[:print_size, :print_size], "\nGrad:", None if masked_small_test_A.grad is None else masked_small_test_A.grad[:print_size, :print_size])
    #print("Small B: Shape:", small_test_B.shape, "\n", small_test_B[:print_size, :print_size], "\nGrad:", None if small_test_B.grad is None else small_test_B.grad[:print_size, :print_size])
    
    custom_mm = sten.sparsified_op(
        orig_op=torch.mm,
        out_fmt=(
            (sten.KeepAll(), torch.Tensor,
            sten.KeepAll(), torch.Tensor),
        ),
        grad_out_fmt=(
            (sten.KeepAll(), torch.Tensor,
            sten.KeepAll(), torch.Tensor),
        ),
    )
    
    #Aplicar sten, definir forward/backwards que faga o que corresponda 
    @sten.register_fwd_op_impl(
        operator=custom_mm,
        inp=(torch.Tensor, torch.Tensor),
        out=[(sten.KeepAll, torch.Tensor)],
    )
    def torch_mm_fwd_impl(ctx, inputs, output_sparsifiers):
        A_operand_sparse, B_operand = inputs
        print("Forward on custom sten torch.mm")
        ctx.save_for_backward(A_operand_sparse, B_operand)
        #ctx.save_for_backward(input_matrix, weights, bias)
        
        return A_operand_sparse.to_dense() @ B_operand
    
    @sten.register_bwd_op_impl(
        operator=custom_mm,
        grad_out=[torch.Tensor],
        grad_inp=(
            (sten.KeepAll, torch.Tensor),
            (sten.KeepAll, torch.Tensor),
        ),
        inp=(torch.Tensor, torch.Tensor ),
    )
    def torch_mm_bwd_impl(ctx, grad_outputs, input_sparsifiers):
        print("Backward on custom sten torch.mm")
        [grad_output] = grad_outputs
        A_operand, B_operand = ctx.saved_tensors
        
        grad_A = grad_output @ B_operand.T
        grad_B = A_operand.T @ grad_output
        
        #print("grad_A: Shape:", grad_A.shape, "\n", grad_A[:print_size, :print_size])
        #print("grad_B: Shape:", grad_B.shape, "\n", grad_B[:print_size, :print_size])
        return grad_A, grad_B
    # con algebra densa (operador @)
    #....
    #torch_mm_output = torch.mm(A, B)# Con operación forward disperso
    small_sten_test_A = torch.arange(32 * 32).reshape(32, 32).float().requires_grad_()
    sparse_sten_small_test_A = SparseVNMTensor(v, n, m, dense_=small_sten_test_A, mask_=A_masks, columns_=A_columns)
    small_sten_test_B = torch.arange(32 * 32).reshape(32, 32).float().requires_grad_()
    masked_small_sten_test_A = sparse_sten_small_test_A.to_dense().float().requires_grad_()
    small_sten_mm_output = custom_mm(masked_small_sten_test_A, small_sten_test_B)
    print("Custom forward completed.\nOutput:", small_sten_mm_output)
    # Run backwards, clear gradients just in case it has something stored
    masked_small_sten_test_A.grad = None
    small_sten_test_B.grad = None 
    small_sten_loss = torch.sum(small_sten_mm_output); 
    #loss.backwards() # Co backwards disperso
    small_sten_loss.backward()
    # Comparar resultados.
    print("\nA operands to the torch.mm are equal?", torch.allclose(masked_small_test_A, masked_small_sten_test_A))
    #print("masked_small_test_A corner:\n", masked_small_test_A[:print_size, :print_size])
    #print("masked_small_sten_test_A corner:\n", masked_small_sten_test_A[:print_size, :print_size])
    
    print("B operands to the torch.mm are equal?",torch.allclose(small_test_B, small_sten_test_B) )
    #print("small_test_B corner:\n", small_test_B[:print_size, :print_size])
    #print("small_sten_test_B corner:\n", small_sten_test_B[:print_size, :print_size])
    
    print("Torch.mm results are equal?", torch.allclose(small_mm_output, small_sten_mm_output))
    #print("Result of unmodified torch.mm: Shape:", small_mm_output.shape, "\n", small_mm_output[:print_size, :print_size])
    #print("Result of modified torch.mm: Shape:", small_sten_mm_output.shape, "\n", small_sten_mm_output[:print_size, :print_size])
    
    print("Gradients of the A operand are equal?", torch.allclose(masked_small_test_A.grad, masked_small_sten_test_A.grad))
    #print("\nGradients of the A operand to the unmodified torch.mm (corner only):\n", None if masked_small_test_A.grad is None else masked_small_test_A.grad[:print_size, :print_size])
    #print("\nGradients of the A operand to the modified torch.mm  (corner only):\n", None if masked_small_sten_test_A.grad is None else masked_small_sten_test_A.grad[:print_size, :print_size])
    
    print("Gradients of the B operand are equal?", torch.allclose(small_test_B.grad, small_sten_test_B.grad))
    #print("\nGradients of the B operand to the unmodified torch.mm (corner only):\n", None if small_test_B.grad is None else small_test_B.grad[:print_size, :print_size])
    #print("\nGradients of the B operand to the modified torch.mm (corner only):\n", None if small_sten_test_B.grad is None else small_sten_test_B.grad[:print_size, :print_size])
    
    
    # Probar hasta que coincidan



max_idx: 511
max_idx: 511
Custom forward completed. Output: tensor([[ 108160.,  108308.,  108456.,  ...,  112452.,  112600.,  112748.],
        [ 259712.,  260116.,  260520.,  ...,  271428.,  271832.,  272236.],
        [ 411264.,  411924.,  412584.,  ...,  430404.,  431064.,  431724.],
        ...,
        [4503168., 4510740., 4518312.,  ..., 4722756., 4730328., 4737900.],
        [4654720., 4662548., 4670376.,  ..., 4881732., 4889560., 4897388.],
        [4806272., 4814356., 4822440.,  ..., 5040708., 5048792., 5056876.]],
       grad_fn=<SparseOperatorDispatcherBackward>)
Backward on custom sten torch.mm


ValueError: too many values to unpack (expected 2)

In [None]:



sparse_add = sten.sparsified_op(
    orig_op=torch.add,
    out_fmt=(
        (sten.KeepAll(), torch.Tensor,
         NMVectorSparsifier(v,n,m), SparseVNMTensor),
    ),
    grad_out_fmt=(
        (sten.KeepAll(), torch.Tensor,
         NMVectorSparsifier(v,n,m), SparseVNMTensor),
    ),
)

In [None]:
sparse_tensor = sparse_add(input, weights)
#print(sparse_tensor)

max_idx: 0


# Sparsity checking helpers

In [None]:

def check_VNM(v:int, n:int, m:int, m_fixed:int, tensor: torch.Tensor, verbose:bool=False):
    shape = tensor.shape

    blocks = (shape[0]/v)*(shape[1]/m)
    ok_blocks = 0
    invalid_blocks = 0

    for row_block in range(0, shape[0], v):
        for column_block in range(0, shape[1], m):
            #print(f'Block {row_block}:{row_block+v}, {column_block}:{column_block+m}: ', end="")
            any_row_invalid = False
            non_zero_columns = []
            invalid_rows = 0
            for row in range(row_block, row_block+v):
                non_zero_position = 0
                empty_positions = 0
                for column in range(column_block, column_block+m):
                    if row < shape[0] and column < shape[1]:
                        if tensor[row][column] == 0:
                            empty_positions+=1
                        else:
                            non_zero_position += 1
                            if column not in non_zero_columns:
                                non_zero_columns.append(column)
                if non_zero_position > n:
                    any_row_invalid = True
                    invalid_rows += 1
            if invalid_rows > 0:
                print(f'\t{invalid_rows} rows with  more than {n} non zero positions,')

            #print(f'{len(non_zero_columns)} columns with some non zero element')
            if len(non_zero_columns) > m_fixed:
                any_row_invalid = True
                print(f'\tBlock has {len(non_zero_columns)} columns with some non zero element')
                print(f'\tError: more than the {m_fixed} columns with some non zero elements allowed.')

            if any_row_invalid:
                invalid_blocks+=1
                print(f'\tBlock {row_block}:{row_block+v}, {column_block}:{column_block+m} does not meet requirements')
            else:
                #print(f'\tBlock {row_block}:{row_block+64}, {column_block}:{column_block+10} meets requirements')
                ok_blocks+=1
            if verbose: print(f'{ok_blocks+invalid_blocks}/{blocks} processed blocks. {ok_blocks} valid ones, {invalid_blocks} invalid. Processing block {row_block}:{row_block+v}, {column_block}:{column_block+m}                                             ', end="\r")
    if not verbose: print(f'{ok_blocks+invalid_blocks}/{blocks} processed blocks. {ok_blocks} valid ones, {invalid_blocks} invalid.')
    print("")

def check_sparsity(tensor: torch.Tensor):

    zero_elements = 0
    non_zero_elements = 0
    total_elements = 0

    for row in range(tensor.shape[0]):
        for column in range(tensor.shape[1]):
            total_elements+=1
            if tensor[row][column] == 0:
                zero_elements+=1
            else:
                non_zero_elements+=1

    print(f'Tensor calculated sparsity: {zero_elements/total_elements}')


#Forward function

In [None]:
@sten.register_fwd_op_impl(
    operator=torch.mm,
    inp=(SparseVNMTensor, torch.Tensor),
    out=[(sten.KeepAll, torch.Tensor)],
)
def sparse_torch_add_fwd_impl(ctx, inputs, output_sparsifiers):
    weights, input_matrix = inputs
    #ctx.save_for_backward(weights, input_matrix)

    bias = torch.ones(weights.wrapped_tensor.nrows)*2
    
    ctx.save_for_backward(input_matrix, weights, bias)
    #ctx.save_for_backward(input_matrix, 
    #                      weights.wrapped_tensor.values.to(dtype=torch.half).cuda(), 
    #                      weights.wrapped_tensor.columns.cuda(), 
    #                      weights.wrapped_tensor.metadata.cuda(), bias)

    output = spatha.spmm_128x64x32_32x64x32_16x8x32_2(
                          weights.wrapped_tensor.metadata.cuda(),                    # m-indices
                          weights.wrapped_tensor.columns.cuda(),                     # col-loc
                          weights.wrapped_tensor.values.to(dtype=torch.half).cuda(), # values
                          input_matrix.to(dtype=torch.half).cuda(),                       # rhs_matrix
                          bias.to(dtype=torch.half).cuda(),                         # bias
                          weights.wrapped_tensor.nrows,                              # A_num_rows
                          weights.wrapped_tensor.ncols,                              # A_num_cols
                          input_matrix.shape[1],                                          # B_num_cols
                          weights.wrapped_tensor.v,                              # vec_length
                          weights.wrapped_tensor.n,                                  # n
                          weights.wrapped_tensor.m,                                  # m
                          weights.wrapped_tensor.nnz,                                # nnz
                          0,                                                        # seed
                          32,                                                       # mbrow
                          4                                                         # brow
                          )
    return output

torch_mm_output = torch.mm(sparse_add(input, weights), mask)
# torch_mm_output = (A+B)@C
# A+B sumado con sparse_add para que la salida de la suma (D) sea dispersa.
# De esta forma, D@C tiene un operando disperso, que es lo que hace spmm.

max_idx: 0


#Compute masked results (dense computation)

In [None]:
dense = torch.from_numpy(sparse_tensor.wrapped_tensor.to_dense().cpu().to(dtype=torch.half).detach().numpy() @ mask.detach().numpy()).to(device="cuda:0").to(dtype=torch.half)
#print(dense)

bias = torch.ones((dense.shape))*2
dense+=bias.to(dtype=torch.half).cuda()

#Check correctness

In [None]:
print("sparse.T\n", torch_mm_output.T)
print("dense\n", dense)

print( torch.allclose(torch_mm_output.T,dense) )

sparse.T
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16,
       grad_fn=<PermuteBackward0>)
dense
 tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        ...,
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:0', dtype=torch.float16)
False


# SPMM and SDDMM kernels using sten

Replicate small test that defines a custom mm operator using sten that works on sparse tensors, with forward and backwards pass.



In [None]:
kernels_test = True
print_size = 16

if kernels_test:
    A_operand_shape = [64, 64]
    B_operand_shape = [64, 64]
    
    # Define custom operator that works on sparse tensors with sten
    # Definir matrices que funcionen con SrNMTensors. Mínimo de 32 por dimension
    kernel_test_A = torch.arange(A_operand_shape[0] * A_operand_shape[1]).reshape(A_operand_shape[0], A_operand_shape[1]).float().requires_grad_()
    # Sparsify to SparseVNMTensor
    A_masks, A_columns = nm_vector_mask_sparsify(kernel_test_A, v, n, m)
    sparse_kernel_test_A = SparseVNMTensor(v, n, m, dense_=kernel_test_A, mask_=A_masks, columns_=A_columns)
    #Redensify with 0s to compute reference results with standard torch.mm
    masked_kernel_test_A = sparse_kernel_test_A.to_dense().float().requires_grad_() 
    #print("sparse_small_test_A: ", sparse_small_test_A)
    kernel_test_B = torch.arange(B_operand_shape[0] * B_operand_shape[1]).reshape(B_operand_shape[0], B_operand_shape[1]).float().requires_grad_()
    # Run normal torch.mm to get correct results.
    kernel_mm_output = torch.mm(masked_kernel_test_A, kernel_test_B)
    #print("MM result: Shape:", kernel_mm_output.shape, "\n", kernel_mm_output[:print_size, :print_size])
    # Run backwards, clear gradients just in case it has something stored
    masked_kernel_test_A.grad = None
    kernel_test_B.grad = None
    kernel_loss = torch.sum(kernel_mm_output); 
    kernel_loss.backward()
    #print("Small A: Shape:", masked_kernel_test_A.shape, "\n", masked_kernel_test_A[:print_size, :print_size], "\nGrad:", None if masked_kernel_test_A.grad is None else masked_kernel_test_A.grad[:print_size, :print_size])
    #print("Small B: Shape:", kernel_test_B.shape, "\n", kernel_test_B[:print_size, :print_size], "\nGrad:", None if kernel_test_B.grad is None else kernel_test_B.grad[:print_size, :print_size])
    
    custom_sparse_mm = sten.sparsified_op(
        orig_op=torch.mm,
        out_fmt=(
            (sten.KeepAll(), torch.Tensor,
            #NMVectorSparsifier(v,n,m), SparseVNMTensor),
            sten.KeepAll(), torch.Tensor),
        ),
        grad_out_fmt=(
            (sten.KeepAll(), torch.Tensor,
            #NMVectorSparsifier(v,n,m), SparseVNMTensor),
            sten.KeepAll(), torch.Tensor),
        ),
    )
    # NMVectorSparsifier(v,n,m), SparseVNMTensor, para out_fmt
    # SparseVNMTensor, torch.Tensor Para register_fwd_op_impl
    #Aplicar sten, definir forward/backwards que faga o que corresponda 
    @sten.register_fwd_op_impl(
        operator=torch.mm,
        inp=(SparseVNMTensor, torch.Tensor),
        out=[(sten.KeepAll, torch.Tensor)],
    )
    def torch_mm_fwd_impl(ctx, inputs, output_sparsifiers):
        A_operand_sparse, B_operand = inputs
        print("Forward on custom sparse torch.mm with sten")
        ctx.save_for_backward(A_operand_sparse, B_operand)
        
        forward_bias = torch.ones(A_operand_sparse.wrapped_tensor.nrows)*2
        
        # Use SPMM. SPMM(B, A) performs A@B with B being in VENOM's sparse format. Or is it SPMM(A, B) -> A@B? Which is compressed, A or B.
        # Assume SPMM(A, B) -> A@B with A being sparse.
        # May cause problems for the layer if the weights are the sparse ones, since the operation should be input @ weights, but weights is the sparse one.
        output = spatha.spmm_128x64x32_32x64x32_16x8x32_2(
                        A_operand_sparse.wrapped_tensor.metadata.cuda(),                    # A m-indices
                        A_operand_sparse.wrapped_tensor.columns.cuda(),                     # A col-loc
                        A_operand_sparse.wrapped_tensor.values.to(dtype=torch.half).cuda(), # A values
                        B_operand.to(dtype=torch.half).cuda(),                              # B, rhs_matrix
                        forward_bias.to(dtype=torch.half).cuda(),                           # bias
                        A_operand_sparse.wrapped_tensor.nrows,                              # A_num_rows
                        A_operand_sparse.wrapped_tensor.ncols,                              # A_num_cols
                        B_operand.shape[1],                                                 # B_num_cols
                        A_operand_sparse.wrapped_tensor.v,                                  # vec_length, V
                        A_operand_sparse.wrapped_tensor.n,                                  # n
                        A_operand_sparse.wrapped_tensor.m,                                  # m
                        A_operand_sparse.wrapped_tensor.nnz,                                # nnz
                        0,                                                                  # seed
                        32,                                                                 # mbrow
                        4                                                                   # brow
                        )
        print("SPMM output:", output)
        return output
        #return A_operand_sparse.to_dense() @ B_operand
    
    @sten.register_bwd_op_impl(
        operator=torch.mm,
        grad_out=[torch.Tensor],
        grad_inp=(
            (sten.KeepAll, torch.Tensor),
            (sten.KeepAll, torch.Tensor),
        ),
        inp=(torch.Tensor, torch.Tensor ),
    )
    def torch_mm_bwd_impl(ctx, grad_outputs, input_sparsifiers):
        print("Backward on custom sten torch.mm")
        [grad_output] = grad_outputs
        A_operand, B_operand = ctx.saved_tensors
        
        # Compute gradients for A operand, as grad_output @ B_operand.T. Use SPMM
        grad_A = grad_output @ B_operand.T
        
        # Compute gradients for B operand, as A_operand.T @ grad_output, with the same sparsity as B_operand.
        grad_B = A_operand.T @ grad_output
        
        return grad_A, grad_B
    # con algebra densa (operador @)
    #....
    #torch_mm_output = torch.mm(A, B)# Con operación forward disperso
    small_sten_test_A = torch.arange(32 * 32).reshape(32, 32).float().requires_grad_()
    sparse_sten_small_test_A = SparseVNMTensor(v, n, m, dense_=small_sten_test_A, mask_=A_masks, columns_=A_columns)
    small_sten_test_B = torch.arange(32 * 32).reshape(32, 32).float().requires_grad_()
    masked_small_sten_test_A = sparse_sten_small_test_A.to_dense().float().requires_grad_()
    small_sten_mm_output = custom_sparse_mm(masked_small_sten_test_A, small_sten_test_B)
    print("Custom forward completed")
    # Run backwards, clear gradients just in case it has something stored
    masked_small_sten_test_A.grad = None
    small_sten_test_B.grad = None 
    small_sten_loss = torch.sum(small_sten_mm_output); 
    #loss.backwards() # Co backwards disperso
    small_sten_loss.backward()
    # Comparar resultados.
    print("\nA operands to the torch.mm are equal?", torch.allclose(masked_kernel_test_A, masked_small_sten_test_A))
    #print("masked_kernel_test_A corner:\n", masked_kernel_test_A[:print_size, :print_size])
    #print("masked_small_sten_test_A corner:\n", masked_small_sten_test_A[:print_size, :print_size])
    
    print("B operands to the torch.mm are equal?",torch.allclose(kernel_test_B, small_sten_test_B) )
    #print("kernel_test_B corner:\n", kernel_test_B[:print_size, :print_size])
    #print("small_sten_test_B corner:\n", small_sten_test_B[:print_size, :print_size])
    
    print("Torch.mm results are equal?", torch.allclose(small_mm_output, small_sten_mm_output))
    #print("Result of unmodified torch.mm: Shape:", small_mm_output.shape, "\n", small_mm_output[:print_size, :print_size])
    #print("Result of modified torch.mm: Shape:", small_sten_mm_output.shape, "\n", small_sten_mm_output[:print_size, :print_size])
    
    print("Gradients of the A operand are equal?", torch.allclose(masked_kernel_test_A.grad, masked_small_sten_test_A.grad))
    #print("\nGradients of the A operand to the unmodified torch.mm (corner only):\n", None if masked_kernel_test_A.grad is None else masked_kernel_test_A.grad[:print_size, :print_size])
    #print("\nGradients of the A operand to the modified torch.mm  (corner only):\n", None if masked_small_sten_test_A.grad is None else masked_small_sten_test_A.grad[:print_size, :print_size])
    
    print("Gradients of the B operand are equal?", torch.allclose(kernel_test_B.grad, small_sten_test_B.grad))
    #print("\nGradients of the B operand to the unmodified torch.mm (corner only):\n", None if kernel_test_B.grad is None else kernel_test_B.grad[:print_size, :print_size])
    #print("\nGradients of the B operand to the modified torch.mm (corner only):\n", None if small_sten_test_B.grad is None else small_sten_test_B.grad[:print_size, :print_size])
    
    


max_idx: 1023
max_idx: 511
Custom forward completed
Backward on custom sten torch.mm


ValueError: too many values to unpack (expected 2)


Prepare inputs for all tests

In [None]:
# Rectangular matrices
#A_matrix_shape = [256, 128]
#B_matrix_shape = [128, 256]

# Square matrices
A_matrix_shape = [64, 64]
B_matrix_shape = [64, 64]

torch.set_printoptions(linewidth=1000, threshold=10000)

In [None]:

#A_matrix_new = torch.randn(A_matrix_shape[0], A_matrix_shape[1], requires_grad=True)
A_matrix_new = torch.ones(A_matrix_shape[0], A_matrix_shape[1], requires_grad=True)
B_matrix_new = torch.randn(B_matrix_shape[0], B_matrix_shape[1], requires_grad=True)

add_matrix = torch.randn(A_matrix_new.shape[0], A_matrix_new.shape[1], requires_grad=True)
sparse_tensor_new = sparse_add(A_matrix_new, add_matrix)

print("A_matrix_new shape:", A_matrix_new.shape, "B_matrix_new shape:", B_matrix_new.shape)



max_idx: 0
A_matrix_new shape: torch.Size([64, 64]) B_matrix_new shape: torch.Size([64, 64])


In [None]:
# Load known B matrix from file
import numpy as np
file_matrix = []
file = "Reference_b_matrix_64x64.txt" if B_matrix_shape[0]==64 else "Reference_b_matrix_32x64.txt"
with open(file, "r") as file:
    for line in file:
        line = line.strip()
        if not line:
            continue
        row = line.strip().replace('[', '').replace(']', '').split()
        file_matrix.append([float(num) for num in row])
file_matrix = np.array(file_matrix)
A_matrix_new = torch.from_numpy(file_matrix)# Load into A for B·A suspicion
B_matrix_new = torch.ones(A_matrix_shape[0], A_matrix_shape[1], requires_grad=True)

print("Loaded_matrix: shape:", A_matrix_new.shape, "\n", A_matrix_new)
#B_matrix_new = B_matrix_new.T
#print("B_matrix_new: shape:", B_matrix_new.shape, "\n", B_matrix_new)
# GEMM in C version seems to be doing AxB.T

print("Operands shapes: A:", A_matrix_new.shape, "B:", B_matrix_new.shape)

# Create sparse vector to get a mask of the same size as the resulting product
placeholder_output_tensor = torch.randn(A_matrix_new.shape[1], B_matrix_new.shape[0])
placeholder_tensor_mask, placeholder_tensor_columns = nm_vector_mask_sparsify(placeholder_output_tensor, v, n, m)
#sparse_tensor_new = SparseVNMTensor(v, n, m, dense_=placeholder_output_tensor, 
#                                    mask_=placeholder_tensor_mask, columns_=placeholder_tensor_columns)




Loaded_matrix: shape: torch.Size([64, 64]) 
 tensor([[ 0.,  2.,  1.,  0.,  1.,  4., -1., -4., -2., -3.,  4.,  2., -2., -3., -2.,  3.,  4.,  3.,  0., -4., -1.,  3.,  1., -4., -1.,  0., -3.,  3., -4.,  0.,  2.,  0.,  4., -2., -4., -2.,  4., -3., -2., -4., -4.,  4.,  2., -2.,  3., -4.,  4.,  3.,  3., -1.,  1.,  4., -4., -4.,  2., -1.,  0.,  4.,  4.,  1., -3.,  1.,  3., -4.],
        [ 1.,  1., -1.,  1.,  2., -1., -1.,  3., -3.,  3.,  3.,  4.,  3.,  0.,  2., -1.,  4., -1., -3.,  4., -3.,  1., -4.,  1., -2.,  4.,  4., -1.,  1.,  3.,  0., -3., -3.,  1.,  0.,  4.,  4.,  3.,  0., -4.,  2., -2.,  3., -2.,  3.,  0.,  0.,  0.,  1., -1., -3.,  2.,  4., -3., -3., -2., -4., -4., -1., -1.,  1.,  3., -1.,  3.],
        [-3.,  3.,  0.,  4.,  0.,  4.,  2.,  4.,  4.,  0.,  4.,  0., -4., -1.,  2., -1.,  0.,  3.,  4., -2.,  2.,  3., -2.,  1.,  1.,  1.,  4., -2., -2., -2.,  3., -1.,  4., -2., -2., -1.,  4., -4., -2.,  1.,  0., -3., -3., -2., -1.,  3.,  0.,  1.,  0., -1.,  3., -3.,  4., -4.,  0.,  0.,  2.,  

In [None]:
#weight = sparse_add(a, b)
#input =  torch.randn(256, 256, requires_grad=True)
#grad_output = grad_d

Define sparse backwards operation. Input is two tensors and the sparse matrix

In [None]:

"""operator=torch.mm,
    operator=torch.mm,
    grad_out=[torch.Tensor],
    grad_inp=(
        (sten.KeepAll, torch.Tensor),
        (sten.KeepAll, torch.Tensor),
    ),
    inp=(SparseVNMTensor, torch.Tensor ),
"""
"""
operator=torch.nn.functional.linear,
    grad_out=None,
    grad_inp=None,
    inp=[torch.Tensor, SparseVNMTensor]
    """
@sten.register_bwd_op_impl(
    operator=torch.mm,
    grad_out=[torch.Tensor],
    grad_inp=(
        (sten.KeepAll, torch.Tensor),
        (sten.KeepAll, torch.Tensor),
    ),
    inp=(SparseVNMTensor, torch.Tensor ),
)
def torch_mm_bwd_impl(ctx, grad_outputs, input_sparsifiers):
    
    input, weights, bias = ctx.saved_tensors
    print(f"weights type: {type(weights)} ")
    
    [grad_output] = grad_outputs
    
    #   grad_input = grad_output @ weigths.T -> spmm(grad_output, weigths.T)
    # Since the sparse matrix is the weights one, and needs to be transposed, 
    # this operations is changed to the following
    #    grad_input = (weigths @ grad_output.T).T -> spmm(weigths * grad_output.T).T
    # Now the weights can be given directly, in the compressed sparse format.
    transposed_grad_output = grad_output.T
    grad_input = spatha.spmm_128x64x32_32x64x32_16x8x32_2(
                          weights.wrapped_tensor.metadata.cuda(),  # metadata
                          weights.wrapped_tensor.columns.cuda(),   # indices
                          weights.wrapped_tensor.values.to(dtype=torch.half).cuda(),    # values
                          transposed_grad_output.to(dtype=torch.half).cuda(),           # rhs_matrix
                          bias.to(dtype=torch.half).cuda(),             # bias
                          weights.wrapped_tensor.nrows,         # A_num_rows
                          weights.wrapped_tensor.ncols,         # A_num_cols
                          transposed_grad_output.shape[1],          # B_num_cols
                          weights.wrapped_tensor.v,                # V
                          weights.wrapped_tensor.n,                # N
                          weights.wrapped_tensor.m,                # M
                          weights.wrapped_tensor.nnz,              # nnz
                          0,                # seed
                          32,               # mbrow
                          4                 # brow
                          ).T
    
    # grad_weights = torch.from_numpy(input.T @ grad_outputs)
    # grad_input2 = torch.from_numpy(input1.wrapped_tensor.data.transpose() @ grad_output)
    
    #transposed_flattened_input = torch.flatten(input, start_dim=0, end_dim=-2).T
    transposed_flattened_input = torch.flatten(input, start_dim=0, end_dim=-2).T
    flattened_grad_output = torch.flatten(grad_output, start_dim=0, end_dim=-2)
    #flattened_grad_output = flattened_grad_output.T # Transpose required due to internal working of sddmm kernel.
    
    # External inputs for testing
    transposed_flattened_input = A_matrix_new
    flattened_grad_output = B_matrix_new
    
    #print("A_operand: Shape: ", transposed_flattened_input.shape, "\n", transposed_flattened_input)
    #print("B_operand: Shape: ", flattened_grad_output.shape, "\n", flattened_grad_output)
    #print("B_operand transposed back: Shape: ", flattened_grad_output.T.shape, "\n", flattened_grad_output.T)
    
    compressed_grad_weights = spatha_sddmm.sddmm(
                          transposed_flattened_input.to(dtype=torch.half).cuda(),   # lhs operand
                          flattened_grad_output.to(dtype=torch.half).cuda(),   # rhs operand
                          weights.wrapped_tensor.metadata.cuda(),    # metada for output sparsity distribution
                          #sparse_tensor.wrapped_tensor.values.to(dtype=torch.half).cuda(),    # Values for output sparsity distribution
                          weights.wrapped_tensor.columns.cuda(),           # indices for output sparsity distribution
                          weights.wrapped_tensor.nrows, # Since LHS is transposed, LHS.shape[0] is LHS.shape[1]
                          weights.wrapped_tensor.ncols,         
                          flattened_grad_output.shape[1],         
#                          sddmm_C_matrix.wrapped_tensor.v,          
                          weights.wrapped_tensor.n,                # N
                          weights.wrapped_tensor.m,                # M
                          weights.wrapped_tensor.nnz,              # nnz
                          0,                # seed
                          32,               # mbrow
                          4                 # brow
                          )
    # Create a SparseVNMTensor from weights SparseVNMTensor data to densify sddmm output.
    grad_weights = SparseVNMTensor(weights.wrapped_tensor.v, weights.wrapped_tensor.n, weights.wrapped_tensor.m, 
                              mask_=weights.wrapped_tensor.mask, columns_=weights.wrapped_tensor.columns, 
                              values_=compressed_grad_weights, metadata_=weights.wrapped_tensor.metadata)
    
    #print(f"compressed_grad_weights type: {type(compressed_grad_weights)} and dtype: {compressed_grad_weights.dtype}")
    #print(f"grad_weights SparseVNMTensor type: {type(grad_weights)} and data: {grad_weights}")
    densified_grad_weights = grad_weights.to_dense()
    #dense_grad_weights = input.T.to(dtype=torch.half).cuda() @ grad_output.to(dtype=torch.half).cuda()
    #print("densified_grad_weights: Shape: ", densified_grad_weights.shape, "\n", densified_grad_weights)
    
    # Compute dense version and check they are the same
    dense_grad_weights = torch.from_numpy(transposed_flattened_input.detach().numpy() @ flattened_grad_output.detach().numpy()
                                      ).to(device="cuda:0").to(dtype=torch.half)
    #print("dense_grad_weights: transposed_flattened_input @ B_matrixflattened_grad_output_new: shape:", 
    #      dense_grad_weights.shape,"\n", dense_grad_weights)
    masked_dense_grad_weights = torch.multiply(dense_grad_weights, weights.wrapped_tensor.mask.to(device="cuda:0"))
    #print("masked_dense: shape:", masked_dense_grad_weights.shape, "\n", masked_dense_grad_weights)
    
    
    
    
    print( torch.allclose(densified_grad_weights.cpu(), masked_dense_grad_weights.cpu()) )
    #return grad_input, grad_weights 
    #return grad_input.float().cpu(), dense_grad_weights.cpu()
    
    return grad_input.float().cpu(), densified_grad_weights.cpu()
    #densified_grad_weights.float().cpu()
    #return torch.zeros(256, 256), torch.zeros(256, 256)
    #return sten.SparseTensorWrapper.wrapped_from_dense(
    #    SparseVNMTensor(sparsifier.v, sparsifier.n, sparsifier.m, dense_=tensor, mask_=masks, columns_=columns),
    #    tensor,
    #    grad_fmt,
    #)
    




Compute dense version for grad_input: operation grad_outputs @ weights.T

In [None]:
#[grad_output] = grad_d
# Operation should be grad_d @ weights.T, but in order to match C benchmark results, second operand has to be transposed.
dense_grad_input = torch.from_numpy(grad_d.detach().numpy() @ weights.detach().numpy()
                                    ).to(device="cuda:0").to(dtype=torch.half)

#print(dense_grad_input)

#bias = torch.ones((dense.shape))*2
#dense+=bias.to(dtype=torch.half).cuda()

Compute sparse version of the operation input.T @ grad_output as sddmm(input.T,  grad_output, sparse) through backward execution of the operation.

Uses sddmm.

This is done by calling the backwards method that computes two outputs, grad_input and grad_weights. Only grad_weigths uses sddmm operand, but the full backwards operation is needed to use the sparse versions through sten.


In [None]:
#torch_mm_output.retain_grad()
loss = torch.sum(torch_mm_output); 
print("weights:", loss)
#print("loss:", loss)
#print("torch_mm_output grad before backwards:", torch_mm_output.grad)
loss.backward()
#torch_mm_output.backward(grad_d.to(device="cuda:0").to(dtype=torch.half))
print("Sparse after backwards:", loss)
#print("sparse grad after backwards:", sparse.grad.shape, "\n", sparse.grad)
#print("sparse_grad_input:", sparse_grad_input, "sparse_grad_weights:", sparse_grad_weights)
#print("sparse_grad_input dtype:", sparse_grad_input.dtype, "sparse_grad_weights dtype:", sparse_grad_weights.dtype)

weights: tensor(0., device='cuda:0', dtype=torch.float16, grad_fn=<SumBackward0>)
weights type: <class 'sten.sten.SparseTensorWrapper'> 
True
max_idx: 0
Sparse after backwards: tensor(0., device='cuda:0', dtype=torch.float16, grad_fn=<SumBackward0>)


# Isolated kernel verification.
Use dense computation and direct call to sddmm kernel to test correct execution

Set matrix sizes

Compute dense version for grad_weights: input.T @ grad_output

In [None]:



#dense_grad_weights = torch.from_numpy(input.T.detach().numpy() @ grad_d.detach().numpy()
#                                      ).to(device="cuda:0").to(dtype=torch.half)

print("A_matrix_new: shape:", A_matrix_new.shape,"\n", A_matrix_new)
print("B_matrix_new: shape:", B_matrix_new.shape,"\n", B_matrix_new)

dense_grad_weights = torch.from_numpy(A_matrix_new.detach().numpy() @ B_matrix_new.detach().numpy()
                                      ).to(device="cuda:0").to(dtype=torch.half)
#positions_dense_grad_weights =  torch.from_numpy(reference_tensor.T.detach().numpy() @ grad_d.detach().numpy()
#                                      ).to(device="cuda:0").to(dtype=torch.half)
print("dense_grad_weights: A_matrix_new @ B_matrix_new: shape:", dense_grad_weights.shape,"\n", dense_grad_weights)


# Apply mask
print("sparse_tensor.wrapped_tensor.mask: ", sparse_tensor_new.wrapped_tensor.mask.shape)
#print("dense_grad_weights: ", dense_grad_weights.shape)
#padded_mask =  torch.nn.functional.pad(sparse_tensor_new.wrapped_tensor.mask, (0, 128), "constant", 1)
#masked_dense_grad_weights = torch.multiply(dense_grad_weights, padded_mask.to(device="cuda:0"))
print("Checking mask tensor for correct sparsity.")
check_VNM(v, n, m, n+2, sparse_tensor_new.wrapped_tensor.mask)
masked_dense_grad_weights = torch.multiply(dense_grad_weights, sparse_tensor_new.wrapped_tensor.mask.to(device="cuda:0"))
#positions_masked_dense_grad_weights = torch.multiply(positions_dense_grad_weights, sparse_tensor.wrapped_tensor.mask.to(device="cuda:0"))
print("Checking masked dense result tensor for correct sparsity.")
check_VNM(v, n, m, n+2, masked_dense_grad_weights)

print("masked_dense: shape:", masked_dense_grad_weights.shape, "\n", masked_dense_grad_weights)

#print("dense_grad_weights:\n", positions_dense_grad_weights)
#print("masked_dense:\n", positions_masked_dense_grad_weights)


A_matrix_new: shape: torch.Size([64, 64]) 
 tensor([[ 0.,  2.,  1.,  0.,  1.,  4., -1., -4., -2., -3.,  4.,  2., -2., -3., -2.,  3.,  4.,  3.,  0., -4., -1.,  3.,  1., -4., -1.,  0., -3.,  3., -4.,  0.,  2.,  0.,  4., -2., -4., -2.,  4., -3., -2., -4., -4.,  4.,  2., -2.,  3., -4.,  4.,  3.,  3., -1.,  1.,  4., -4., -4.,  2., -1.,  0.,  4.,  4.,  1., -3.,  1.,  3., -4.],
        [ 1.,  1., -1.,  1.,  2., -1., -1.,  3., -3.,  3.,  3.,  4.,  3.,  0.,  2., -1.,  4., -1., -3.,  4., -3.,  1., -4.,  1., -2.,  4.,  4., -1.,  1.,  3.,  0., -3., -3.,  1.,  0.,  4.,  4.,  3.,  0., -4.,  2., -2.,  3., -2.,  3.,  0.,  0.,  0.,  1., -1., -3.,  2.,  4., -3., -3., -2., -4., -4., -1., -1.,  1.,  3., -1.,  3.],
        [-3.,  3.,  0.,  4.,  0.,  4.,  2.,  4.,  4.,  0.,  4.,  0., -4., -1.,  2., -1.,  0.,  3.,  4., -2.,  2.,  3., -2.,  1.,  1.,  1.,  4., -2., -2., -2.,  3., -1.,  4., -2., -2., -1.,  4., -4., -2.,  1.,  0., -3., -3., -2., -1.,  3.,  0.,  1.,  0., -1.,  3., -3.,  4., -4.,  0.,  0.,  2.,  0

Compute sparse version of the operation input.T @ grad_output as sddmm(input.T,  grad_output, sparse). 
sddmm(left_operand, right_operand, sparsity_guide).

In [None]:
# Input data declaration
#transposed_flattened_input = torch.flatten(input, start_dim=0, end_dim=-2).T
#flattened_grad_output = torch.flatten(grad_d, start_dim=0, end_dim=-2)
#print("Input tensor:\n", input)
#print("Grad_output tensor:\n", grad_d)
#print("Sparse tensor metadata:\n", sparse_tensor.wrapped_tensor.metadata)
#print("Sparse tensor values:\n", sparse_tensor.wrapped_tensor.values)
#print("Sparse tensor columns:\n", sparse_tensor.wrapped_tensor.columns)




#A_matrix = input.to(dtype=torch.half).cuda()    # Correct would be input.T
#B_matrix = grad_d.to(dtype=torch.half).cuda() # Correct would be grad_d

A_matrix = A_matrix_new.to(dtype=torch.half).cuda()
B_matrix = B_matrix_new.to(dtype=torch.half).cuda()

#A_matrix = torch.ones(256, 256).to(dtype=torch.half).cuda() 
#B_matrix = torch.ones(256, 256).to(dtype=torch.half).cuda() 

print("A_matrix shape:", A_matrix.shape)
print("B_matrix shape:", B_matrix.shape)
print("A_matrix: shape:", A_matrix.shape, "\n", A_matrix)
print("B_matrix: shape:", B_matrix.shape, "\n", B_matrix)
#sddmm_C_matrix = sparse_tensor
sddmm_C_matrix = sparse_tensor_new

sddmm_output = spatha_sddmm.sddmm(
                          A_matrix,   # lhs operand
                          B_matrix,   # rhs operand
                          sddmm_C_matrix.wrapped_tensor.metadata.cuda(),    # metada for output sparsity distribution
                          #sparse_tensor.wrapped_tensor.values.to(dtype=torch.half).cuda(),    # Values for output sparsity distribution
                          sddmm_C_matrix.wrapped_tensor.columns.cuda(),           # indices for output sparsity distribution
                          sddmm_C_matrix.wrapped_tensor.nrows, # Since LHS is transposed, LHS.shape[0] is LHS.shape[1]
                          sddmm_C_matrix.wrapped_tensor.ncols,         
                          B_matrix.shape[1],         
#                          sddmm_C_matrix.wrapped_tensor.v,          
                          sddmm_C_matrix.wrapped_tensor.n,                # N
                          sddmm_C_matrix.wrapped_tensor.m,                # M
                          sddmm_C_matrix.wrapped_tensor.nnz,              # nnz
                          0,                # seed
                          32,               # mbrow
                          4                 # brow
                          )
print("SDDMM output tensor: shape:", sddmm_output.shape, "\n", sddmm_output)
# Densify sddmm_output to compare
#compressed_sddmm_output = SparseVNMTensor(sparse_tensor.wrapped_tensor.v, sparse_tensor.wrapped_tensor.n, sparse_tensor.wrapped_tensor.m, 
#                              mask_=sparse_tensor.wrapped_tensor.mask, columns_=sparse_tensor.wrapped_tensor.columns, 
#                              values_=sddmm_output, metadata_=sparse_tensor.wrapped_tensor.metadata)
compressed_sddmm_output = SparseVNMTensor(sddmm_C_matrix.wrapped_tensor.v, sddmm_C_matrix.wrapped_tensor.n, sddmm_C_matrix.wrapped_tensor.m, 
                              mask_=sddmm_C_matrix.wrapped_tensor.mask, columns_=sddmm_C_matrix.wrapped_tensor.columns, 
                              values_=sddmm_output, metadata_=sddmm_C_matrix.wrapped_tensor.metadata)
dense_sddmm_output = compressed_sddmm_output.to_dense().cuda()
check_VNM(v, n, m, n+2, dense_sddmm_output)
print("Densified output tensor: shape:", dense_sddmm_output.shape,"\n", dense_sddmm_output)

A_matrix shape: torch.Size([64, 64])
B_matrix shape: torch.Size([64, 64])
A_matrix: shape: torch.Size([64, 64]) 
 

tensor([[ 0.,  2.,  1.,  0.,  1.,  4., -1., -4., -2., -3.,  4.,  2., -2., -3., -2.,  3.,  4.,  3.,  0., -4., -1.,  3.,  1., -4., -1.,  0., -3.,  3., -4.,  0.,  2.,  0.,  4., -2., -4., -2.,  4., -3., -2., -4., -4.,  4.,  2., -2.,  3., -4.,  4.,  3.,  3., -1.,  1.,  4., -4., -4.,  2., -1.,  0.,  4.,  4.,  1., -3.,  1.,  3., -4.],
        [ 1.,  1., -1.,  1.,  2., -1., -1.,  3., -3.,  3.,  3.,  4.,  3.,  0.,  2., -1.,  4., -1., -3.,  4., -3.,  1., -4.,  1., -2.,  4.,  4., -1.,  1.,  3.,  0., -3., -3.,  1.,  0.,  4.,  4.,  3.,  0., -4.,  2., -2.,  3., -2.,  3.,  0.,  0.,  0.,  1., -1., -3.,  2.,  4., -3., -3., -2., -4., -4., -1., -1.,  1.,  3., -1.,  3.],
        [-3.,  3.,  0.,  4.,  0.,  4.,  2.,  4.,  4.,  0.,  4.,  0., -4., -1.,  2., -1.,  0.,  3.,  4., -2.,  2.,  3., -2.,  1.,  1.,  1.,  4., -2., -2., -2.,  3., -1.,  4., -2., -2., -1.,  4., -4., -2.,  1.,  0., -3., -3., -2., -1.,  3.,  0.,  1.,  0., -1.,  3., -3.,  4., -4.,  0.,  0.,  2.,  0.,  0.,  2.,  0., -2., -4., -1.],
        [-

#Check correctness

In [None]:
#print( torch.allclose(sparse_grad_input, dense_grad_input) )

print( torch.allclose(dense_sddmm_output, masked_dense_grad_weights) )

#print("dense_sddmm_output:\n", dense_sddmm_output)
#print("masked_dense_grad_weights:\n", masked_dense_grad_weights)


True
