In [1]:
from tinygrad.densetensor import DenseTensor
from tinygrad.sparsetensor import SparseTensor
import numpy as np

%load_ext autoreload
%autoreload 2

DEVICE:GPU


In [2]:
x_init = np.random.randn(2,6).astype(np.float32)
x2_init = np.random.randn(3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
V_init = np.random.randn(3,3).astype(np.float32)
W_init = np.random.randn(6,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)

x = DenseTensor(x_init)
W = DenseTensor(W_init)
m = DenseTensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()

out.cpu().data, x

x2 = DenseTensor(x2_init)#.gpu()
W = SparseTensor(W_init)
out = W.dot(x2).relu().sum()

out.backward()

out.cpu().data, x

In [3]:
import numpy as np
import pyopencl as cl

mf = cl.mem_flags

In [4]:
dim1 = 8
dim2 = 16
dim3 = 5
topkx = 16
topky = 5
topk  = topkx
bs = dim3

np.random.seed(9)

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.4

a = np.zeros((dim1,dim2))
b = np.zeros((dim2,dim3)).astype(np.float32)

a.shape, b.shape

((8, 16), (16, 5))

In [5]:
x_init = np.random.randn(dim1,dim3).astype(np.float32)
w_init = np.random.randn(dim2,dim3).astype(np.float32)

In [6]:
w_init

array([[-6.56452596e-01, -5.62572964e-02, -4.99902606e-01,
         4.36419368e-01, -3.75813037e-01],
       [-9.23061609e-01,  1.91725028e+00, -1.50302842e-01,
        -6.38729751e-01,  8.24770331e-01],
       [-1.21083879e+00, -5.03405392e-01, -7.01915681e-01,
        -1.97427106e+00, -2.65573215e+00],
       [-5.76822497e-02, -6.56186581e-01, -6.61706686e-01,
         7.69348443e-01, -8.99004877e-01],
       [ 1.69363797e+00, -1.69733524e+00, -2.79337025e+00,
        -2.26150647e-01,  3.97428840e-01],
       [ 1.65970361e+00, -4.93746817e-01, -3.76097679e-01,
        -1.69739768e-01,  2.41710639e+00],
       [-1.80884051e+00,  3.39751154e-01, -2.27297600e-02,
        -9.59997058e-01, -3.83114427e-01],
       [ 1.09529994e-01, -8.55162859e-01,  2.21606664e-04,
         6.63855076e-01,  7.49480963e-01],
       [-4.65818375e-01, -2.77439266e-01,  3.54995355e-02,
         8.48221183e-01,  1.62998557e-01],
       [ 1.20862365e+00,  5.02520800e-01, -1.58382213e+00,
         1.02303350e+00

In [7]:
def fill_sparse(mat, sparsity=0.5):
    indices = np.array(range(mat.shape[1]))
    nrows = int(mat.shape[1]*sparsity)
    for row in range(mat.shape[0]):
        lim = nrows #+ int(np.random.random()*3)
        mat[row][np.random.permutation(indices)[:lim]] = np.random.random(lim)
    return mat

a = fill_sparse(a, sparsity).astype(np.float32)
b = fill_sparse(b, sparsity).astype(np.float32)

In [8]:
a

array([[0.        , 0.68347037, 0.8035886 , 0.        , 0.        ,
        0.        , 0.        , 0.08023349, 0.8858316 , 0.        ,
        0.17861794, 0.        , 0.46201044, 0.        , 0.        ,
        0.        ],
       [0.9562663 , 0.5531271 , 0.        , 0.        , 0.5375557 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.20370705, 0.35530138, 0.        , 0.        ,
        0.9186601 ],
       [0.        , 0.20103766, 0.        , 0.        , 0.        ,
        0.74341315, 0.        , 0.        , 0.        , 0.5787496 ,
        0.        , 0.05387774, 0.        , 0.5917517 , 0.17797045,
        0.        ],
       [0.        , 0.        , 0.        , 0.14296253, 0.9103574 ,
        0.03242496, 0.        , 0.71459514, 0.        , 0.        ,
        0.        , 0.        , 0.7448136 , 0.        , 0.        ,
        0.66635793],
       [0.        , 0.        , 0.        , 0.59893256, 0.        ,
        0.        , 0.        , 

In [9]:
b

array([[0.49395892, 0.        , 0.46451807, 0.        , 0.        ],
       [0.65864426, 0.        , 0.9786625 , 0.        , 0.        ],
       [0.31044865, 0.        , 0.02502257, 0.        , 0.        ],
       [0.        , 0.46765924, 0.        , 0.        , 0.03160795],
       [0.71082693, 0.        , 0.58997136, 0.        , 0.        ],
       [0.        , 0.        , 0.84622455, 0.5916316 , 0.        ],
       [0.46450275, 0.        , 0.        , 0.        , 0.9429863 ],
       [0.        , 0.34248227, 0.7672639 , 0.        , 0.        ],
       [0.        , 0.        , 0.9855889 , 0.        , 0.2535647 ],
       [0.        , 0.7705159 , 0.31522992, 0.        , 0.        ],
       [0.        , 0.7167741 , 0.        , 0.8238369 , 0.        ],
       [0.        , 0.91970533, 0.        , 0.7873889 , 0.        ],
       [0.        , 0.        , 0.        , 0.404568  , 0.841323  ],
       [0.        , 0.        , 0.674503  , 0.48820347, 0.        ],
       [0.18725161, 0.11607183, 0.

In [10]:
x2_init.T

array([-1.3542958 , -0.06554703,  0.41033113], dtype=float32)

In [11]:
mult = a.dot(b)
mult.shape

(8, 5)

In [12]:
mult.shape

(8, 5)

In [13]:
def to_data(mat):
    all_rows = []
    all_idxs = []
    all_nnzs = []
    for row in range(mat.shape[0]):
        rowdata = []
        colidxs = []
        all_nnzs.append(0)
        for col in range(mat.shape[1]):
            val = mat[row][col]
            if val != 0:
                rowdata.append(val)
                colidxs.append(col)
                all_nnzs[-1] += 1
        all_rows.append(rowdata)
        all_idxs.append(colidxs)
    
    ellwidth = min(int(np.sqrt(np.max(all_nnzs))+1)**2, mat.shape[1])
    #all_rows = np.array(all_rows)#.astype(np.float32).flatten()           
    for row in range(mat.shape[0]):
        #print(row, all_rows)
        all_rows[row] = np.array(all_rows[row])
        all_rows[row].resize(ellwidth)
        all_idxs[row] = np.array(all_idxs[row])
        all_idxs[row].resize(ellwidth)
        #print(all_idxs[row])
    all_rows = np.array(all_rows)
    all_idxs = np.array(all_idxs)
    all_nnzs = np.array(all_nnzs)
    
#     while (not all_rows[:,-1].any()):
#         all_rows = all_rows[:,:-1]
#         all_idxs = all_idxs[:,:-1]
#         ellwidth -= 1
        
    
    all_rows = np.array(all_rows).astype(np.float32).flatten()
    all_idxs = np.array(all_idxs).astype(np.uint32).flatten()
    
    all_nnzs = np.array(all_nnzs).astype(np.uint32)
    
    
    return all_rows, all_idxs, all_nnzs, ellwidth

In [14]:
def to_dense(data, cols, nnzs, ellw, shape):
    out = np.zeros(shape)
    for row in range(shape[0]):
        for icol in range(nnzs[row]):
            out[row,cols[row*ellw+icol]] = data[row*ellw+icol]
    return out

In [15]:
wdata, wcols, wnnz, ellww = to_data(w_init)
wdata, wcols, wnnz, ellww

(array([-6.56452596e-01, -5.62572964e-02, -4.99902606e-01,  4.36419368e-01,
        -3.75813037e-01, -9.23061609e-01,  1.91725028e+00, -1.50302842e-01,
        -6.38729751e-01,  8.24770331e-01, -1.21083879e+00, -5.03405392e-01,
        -7.01915681e-01, -1.97427106e+00, -2.65573215e+00, -5.76822497e-02,
        -6.56186581e-01, -6.61706686e-01,  7.69348443e-01, -8.99004877e-01,
         1.69363797e+00, -1.69733524e+00, -2.79337025e+00, -2.26150647e-01,
         3.97428840e-01,  1.65970361e+00, -4.93746817e-01, -3.76097679e-01,
        -1.69739768e-01,  2.41710639e+00, -1.80884051e+00,  3.39751154e-01,
        -2.27297600e-02, -9.59997058e-01, -3.83114427e-01,  1.09529994e-01,
        -8.55162859e-01,  2.21606664e-04,  6.63855076e-01,  7.49480963e-01,
        -4.65818375e-01, -2.77439266e-01,  3.54995355e-02,  8.48221183e-01,
         1.62998557e-01,  1.20862365e+00,  5.02520800e-01, -1.58382213e+00,
         1.02303350e+00, -6.53017402e-01,  5.37045121e-01, -7.97706190e-03,
         9.2

In [16]:
wdatat, wcolst, wnnzt, ellwwt = to_data(w_init.T)
wdatat, wcolst, wnnzt, ellwwt

(array([-6.56452596e-01, -9.23061609e-01, -1.21083879e+00, -5.76822497e-02,
         1.69363797e+00,  1.65970361e+00, -1.80884051e+00,  1.09529994e-01,
        -4.65818375e-01,  1.20862365e+00,  5.37045121e-01, -1.65366137e+00,
         5.61277032e-01, -1.27321005e+00,  9.70861197e-01, -5.30199170e-01,
        -5.62572964e-02,  1.91725028e+00, -5.03405392e-01, -6.56186581e-01,
        -1.69733524e+00, -4.93746817e-01,  3.39751154e-01, -8.55162859e-01,
        -2.77439266e-01,  5.02520800e-01, -7.97706190e-03,  1.36799216e+00,
         2.39725500e-01,  9.20076132e-01,  1.15472412e+00,  8.14134836e-01,
        -4.99902606e-01, -1.50302842e-01, -7.01915681e-01, -6.61706686e-01,
        -2.79337025e+00, -3.76097679e-01, -2.27297600e-02,  2.21606664e-04,
         3.54995355e-02, -1.58382213e+00,  9.24784184e-01,  2.51062457e-02,
         4.79899108e-01,  1.05491853e+00,  4.22608823e-01, -1.01988423e+00,
         4.36419368e-01, -6.38729751e-01, -1.97427106e+00,  7.69348443e-01,
        -2.2

In [17]:
adata, acols, annz, ellwa = to_data(a)
adata, acols, annz, ellwa

(array([0.68347037, 0.8035886 , 0.08023349, 0.8858316 , 0.17861794,
        0.46201044, 0.        , 0.        , 0.        , 0.9562663 ,
        0.5531271 , 0.5375557 , 0.20370705, 0.35530138, 0.9186601 ,
        0.        , 0.        , 0.        , 0.20103766, 0.74341315,
        0.5787496 , 0.05387774, 0.5917517 , 0.17797045, 0.        ,
        0.        , 0.        , 0.14296253, 0.9103574 , 0.03242496,
        0.71459514, 0.7448136 , 0.66635793, 0.        , 0.        ,
        0.        , 0.59893256, 0.9217044 , 0.16333483, 0.17167741,
        0.7605819 , 0.48675168, 0.        , 0.        , 0.        ,
        0.11415514, 0.82574075, 0.27978534, 0.25018668, 0.83127654,
        0.09929471, 0.        , 0.        , 0.        , 0.28872353,
        0.04889954, 0.6493426 , 0.33329687, 0.84441024, 0.01120785,
        0.        , 0.        , 0.        , 0.67626965, 0.35563806,
        0.69630474, 0.6753613 , 0.12287189, 0.09184685, 0.        ,
        0.        , 0.        ], dtype=float32),

In [18]:
adatat, acolst, annzt, ellwat = to_data(a.T)
adatat, acolst, annzt, ellwat

(array([0.9562663 , 0.11415514, 0.28872353, 0.67626965, 0.        ,
        0.        , 0.        , 0.        , 0.68347037, 0.5531271 ,
        0.20103766, 0.82574075, 0.04889954, 0.        , 0.        ,
        0.        , 0.8035886 , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.14296253,
        0.59893256, 0.6493426 , 0.35563806, 0.        , 0.        ,
        0.        , 0.        , 0.5375557 , 0.9103574 , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.74341315, 0.03242496, 0.27978534, 0.33329687, 0.        ,
        0.        , 0.        , 0.        , 0.69630474, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.08023349, 0.71459514, 0.9217044 , 0.25018668,
        0.        , 0.        , 0.        , 0.        , 0.8858316 ,
        0.6753613 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.5787496 , 0.16

In [19]:
bdata, bcols, bnnz, ellwb = to_data(b)
bdata, bcols, bnnz, ellwb

(array([0.49395892, 0.46451807, 0.        , 0.        , 0.65864426,
        0.9786625 , 0.        , 0.        , 0.31044865, 0.02502257,
        0.        , 0.        , 0.46765924, 0.03160795, 0.        ,
        0.        , 0.71082693, 0.58997136, 0.        , 0.        ,
        0.84622455, 0.5916316 , 0.        , 0.        , 0.46450275,
        0.9429863 , 0.        , 0.        , 0.34248227, 0.7672639 ,
        0.        , 0.        , 0.9855889 , 0.2535647 , 0.        ,
        0.        , 0.7705159 , 0.31522992, 0.        , 0.        ,
        0.7167741 , 0.8238369 , 0.        , 0.        , 0.91970533,
        0.7873889 , 0.        , 0.        , 0.404568  , 0.841323  ,
        0.        , 0.        , 0.674503  , 0.48820347, 0.        ,
        0.        , 0.18725161, 0.11607183, 0.        , 0.        ,
        0.7250254 , 0.12346577, 0.        , 0.        ], dtype=float32),
 array([0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 1, 4, 0, 0, 0, 2, 0, 0, 2, 3,
        0, 0, 0, 4, 0, 0, 1, 2, 0, 0,

In [20]:
bdatat, bcolst, bnnzt, ellwbt = to_data(b.T)
adatat, bcolst, bnnzt, ellwbt

(array([0.9562663 , 0.11415514, 0.28872353, 0.67626965, 0.        ,
        0.        , 0.        , 0.        , 0.68347037, 0.5531271 ,
        0.20103766, 0.82574075, 0.04889954, 0.        , 0.        ,
        0.        , 0.8035886 , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.14296253,
        0.59893256, 0.6493426 , 0.35563806, 0.        , 0.        ,
        0.        , 0.        , 0.5375557 , 0.9103574 , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.74341315, 0.03242496, 0.27978534, 0.33329687, 0.        ,
        0.        , 0.        , 0.        , 0.69630474, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.08023349, 0.71459514, 0.9217044 , 0.25018668,
        0.        , 0.        , 0.        , 0.        , 0.8858316 ,
        0.6753613 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.5787496 , 0.16

In [21]:
adense = to_dense(adata, acols, annz, ellwa, a.shape)

In [22]:
adenset = to_dense(adatat, acolst, annzt, ellwat, a.T.shape)

In [23]:
bdense = to_dense(bdata, bcols, bnnz, ellwb, b.shape)

In [24]:
bdenset = to_dense(bdatat, bcolst, bnnzt, ellwbt, b.T.shape)

In [25]:
adense

array([[0.        , 0.68347037, 0.80358863, 0.        , 0.        ,
        0.        , 0.        , 0.08023349, 0.88583159, 0.        ,
        0.17861794, 0.        , 0.46201044, 0.        , 0.        ,
        0.        ],
       [0.95626628, 0.55312711, 0.        , 0.        , 0.53755569,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.20370705, 0.35530138, 0.        , 0.        ,
        0.9186601 ],
       [0.        , 0.20103766, 0.        , 0.        , 0.        ,
        0.74341315, 0.        , 0.        , 0.        , 0.5787496 ,
        0.        , 0.05387774, 0.        , 0.59175169, 0.17797045,
        0.        ],
       [0.        , 0.        , 0.        , 0.14296253, 0.91035742,
        0.03242496, 0.        , 0.71459514, 0.        , 0.        ,
        0.        , 0.        , 0.74481362, 0.        , 0.        ,
        0.66635793],
       [0.        , 0.        , 0.        , 0.59893256, 0.        ,
        0.        , 0.        , 

In [26]:
adenset.T == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  Tru

In [27]:
bdenset.T == bdense

array([[ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True]])

In [28]:
a

array([[0.        , 0.68347037, 0.8035886 , 0.        , 0.        ,
        0.        , 0.        , 0.08023349, 0.8858316 , 0.        ,
        0.17861794, 0.        , 0.46201044, 0.        , 0.        ,
        0.        ],
       [0.9562663 , 0.5531271 , 0.        , 0.        , 0.5375557 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.20370705, 0.35530138, 0.        , 0.        ,
        0.9186601 ],
       [0.        , 0.20103766, 0.        , 0.        , 0.        ,
        0.74341315, 0.        , 0.        , 0.        , 0.5787496 ,
        0.        , 0.05387774, 0.        , 0.5917517 , 0.17797045,
        0.        ],
       [0.        , 0.        , 0.        , 0.14296253, 0.9103574 ,
        0.03242496, 0.        , 0.71459514, 0.        , 0.        ,
        0.        , 0.        , 0.7448136 , 0.        , 0.        ,
        0.66635793],
       [0.        , 0.        , 0.        , 0.59893256, 0.        ,
        0.        , 0.        , 

In [29]:
a == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  Tru

In [30]:
a.shape

(8, 16)

In [31]:
adata.shape, acols.shape, annz.shape, ellwa

((72,), (72,), (8,), 9)

In [32]:
#acols = acols.astype(np.uint32)
#annz = annz.astype(np.uint32)

In [33]:
adata, acols, annz, b

(array([0.68347037, 0.8035886 , 0.08023349, 0.8858316 , 0.17861794,
        0.46201044, 0.        , 0.        , 0.        , 0.9562663 ,
        0.5531271 , 0.5375557 , 0.20370705, 0.35530138, 0.9186601 ,
        0.        , 0.        , 0.        , 0.20103766, 0.74341315,
        0.5787496 , 0.05387774, 0.5917517 , 0.17797045, 0.        ,
        0.        , 0.        , 0.14296253, 0.9103574 , 0.03242496,
        0.71459514, 0.7448136 , 0.66635793, 0.        , 0.        ,
        0.        , 0.59893256, 0.9217044 , 0.16333483, 0.17167741,
        0.7605819 , 0.48675168, 0.        , 0.        , 0.        ,
        0.11415514, 0.82574075, 0.27978534, 0.25018668, 0.83127654,
        0.09929471, 0.        , 0.        , 0.        , 0.28872353,
        0.04889954, 0.6493426 , 0.33329687, 0.84441024, 0.01120785,
        0.        , 0.        , 0.        , 0.67626965, 0.35563806,
        0.69630474, 0.6753613 , 0.12287189, 0.09184685, 0.        ,
        0.        , 0.        ], dtype=float32),

In [None]:
asdf

## Weight update kernel new2 (sparse ouput)

In [34]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [35]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
x_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, a.shape[0]*4)
y_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.shape[1]*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topkx*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topky*4)
sdata_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
sidxs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
snnzs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*4)
sdatat_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*topky*4)
sidxst_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*topky*4)
snnzst_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate4(__global  float* x,     // INPUT MATRIX DATA
                              __global  float* y,    // INPUT
                              __global  float* xsum,    // INPUT
                              __global  float* ysum,    // INPUT
                              uint isize,
                              uint msize,
                              uint osize,
                              uint topkx,
                              uint topky,
                              __global  uint*  xoutidx,
                              __global  uint*  youtidx,
                              __global  float* matData,     // OUTPUT MATRIX DATA
                              __global  uint*  colIdx,
                              __global  uint*  rowNnz,
                              __global  float* matDatat,    // OUTPUT MATRIX DATA
                              __global  uint*  colIdxt,
                              __global  uint*  rowNnzt
                              ) {
      uint gid = get_global_id(0);

      // get for a: sum axis0  b: sum axis1 then get topk
      ///////////////////////////////////////////////////
      if (gid < isize) {
        xsum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = x[i*isize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, i*msize+gid);
          //}
          xsum[gid] += val;
        }

        float valx = xsum[gid];
        uint posx = 0;
        for (uint i = 0; i < isize; i++) {
          float tempval = fabs(xsum[i]);
          bool larger = tempval > fabs(valx);
          posx += (larger)?1:0;
        }
        if (posx < topky) {
          youtidx[posx] = gid;
        }
      }

      if (gid < osize) {
        ysum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = y[i*osize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, gid*osize+i);
          //}
          ysum[gid] += val;
        }

        float valy = ysum[gid];
        uint posy = 0;
        for (uint i = 0; i < osize; i++) {
          float tempval = fabs(ysum[i]);
          bool larger = tempval > fabs(valy);
          posy += (larger)?1:0;
        }

        if (posy < topkx) {
          xoutidx[posy] = gid;
        }
      }

      if (gid < topkx) {
        float valx = xoutidx[gid];
        uint posx = 0;
        for (uint i = 0; i < topkx; i++) {
          float tempval = xoutidx[i];
          bool larger = tempval < valx;
          posx += (larger)?1:0;
        }
        xoutidx[posx] = valx;
      }

      if (gid < topky) {
        float valy = youtidx[gid];
        uint posy = 0;
        for (uint i = 0; i < topky; i++) {
          float tempval = youtidx[i];
          bool larger = tempval < valy;
          posy += (larger)?1:0;
        }
        youtidx[posy] = valy;
      }

      // only calc matrix multiplications for used grads
      ///////////////////////////////////////////////////
      if (gid < isize) {
        for (uint i=0; i<topkx; i++) {
          matData[gid*topkx+i] = 0;
          colIdx[gid*topkx+i] = 0;
        }
        rowNnz[gid] = 0;
      }
      if (gid < osize) {
        for (uint i=0; i<topky; i++) {
          matDatat[gid*topky+i] = 0;
          colIdxt[gid*topky+i] = 0;
        }
        rowNnzt[gid] = 0;
      }
      
      if (gid < topky) {
        uint idxx = youtidx[gid];
        for (uint j=0; j<topkx; j++) {
          uint idxy = xoutidx[j];
          printf("\\nIDXX:%i  IDXY:%i", idxx, idxy);
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxx;
            uint yidx2 = osize*k+idxy;
            uint colidx = idxy;
            matData[idxx*topkx+j] += x[xidx2] * y[yidx2];
            colIdx[idxx*topkx+j] = idxy;
            //if (gid == 0)
            //  printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
          rowNnz[idxx] += 1;
        }
      }


      if (gid < topkx) {
        uint idxx = xoutidx[gid];
        for (uint j=0; j<topky; j++) {
          uint idxy = youtidx[j];
          //printf("\\nIDXX:%i  IDXY:%i", idxx, idxy);
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxy;
            uint yidx2 = osize*k+idxx;
            uint colidx = idxy;
            matDatat[idxx*topky+j] += x[xidx2] * y[yidx2];
            colIdxt[idxx*topky+j] = idxy;
            //if (gid == 0)
            //  printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
          rowNnzt[idxx] += 1;
        }
      }
      
    }""").build()

In [36]:
a.shape, b.shape

((8, 16), (16, 5))

In [37]:
rows = a.shape[0]
msize = a.shape[1]

In [38]:
cols = b.shape[1]

In [39]:
mult = a.dot(b)

In [40]:
mult = mult.astype(np.float32)

In [41]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.genwupdate4  # Use this Kernel object for repeate/duald calls
evt = knl(queue, [max(rows,cols)], None, a_buf, b_buf, x_sum_buf, y_sum_buf, np.uint32(rows), np.uint32(msize),np.uint32(cols), 
          np.uint32(topkx),np.uint32(topky), x_idx_buf, y_idx_buf, sdata_buf, sidxs_buf, snnzs_buf, sdatat_buf, sidxst_buf, snnzst_buf)


IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:1
IDXX:1  IDXY:1
IDXX:3  IDXY:1
IDXX:4  IDXY:1
IDXX:7  IDXY:1
IDXX:0  IDXY:3
IDXX:1  IDXY:3
IDXX:3  IDXY:3
IDXX:4  IDXY:3
IDXX:7  IDXY:3
IDXX:0  IDXY:4
IDXX:1  IDXY:4
IDXX:3  IDXY:4
IDXX:4  IDXY:4
IDXX:7  IDXY:4
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:0
IDXX:1  IDXY:0
IDXX:3  IDXY:0
IDXX:4  IDXY:0
IDXX:7  IDXY:0
IDXX:0  IDXY:1
IDXX:1  IDXY:1
IDXX:3  IDXY:1
IDXX:4  IDXY:1
IDXX:7  IDXY:1
IDXX:0  IDXY:2
IDXX:1  I

In [42]:
resxsum = np.zeros(a.shape[0]).astype(np.float32)
resysum = np.zeros(b.shape[1]).astype(np.float32)
resxidx = np.zeros(topkx).astype(np.uint32)
resyidx = np.zeros(topky).astype(np.uint32)
resxdat = np.zeros(a.shape[0]*topkx).astype(np.float32)
resxcol = np.zeros(a.shape[0]*topkx).astype(np.uint32)
resxnnz = np.zeros(a.shape[0]).astype(np.uint32)
resxdatt = np.zeros(b.shape[1]*topky).astype(np.float32)
resxcolt = np.zeros(b.shape[1]*topky).astype(np.uint32)
resxnnzt = np.zeros(b.shape[1]).astype(np.uint32)

cl.enqueue_copy(queue, resxsum, x_sum_buf)
cl.enqueue_copy(queue, resysum, y_sum_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)
cl.enqueue_copy(queue, resxdat, sdata_buf)
cl.enqueue_copy(queue, resxcol, sidxs_buf)
cl.enqueue_copy(queue, resxnnz, snnzs_buf)
cl.enqueue_copy(queue, resxdatt, sdatat_buf)
cl.enqueue_copy(queue, resxcolt, sidxst_buf)
cl.enqueue_copy(queue, resxnnzt, snnzst_buf)

<pyopencl._cl.NannyEvent at 0x7f1848102d10>

## results

In [43]:
topkx, topky

(16, 5)

In [44]:
mult.shape

(8, 5)

In [45]:
mult

array([[0.6996368 , 0.15550727, 1.6236207 , 0.3340667 , 0.61331564],
       [1.8848312 , 0.18735047, 1.4160932 , 0.30414024, 0.29892322],
       [0.16573755, 0.51614475, 1.40742   , 0.7711447 , 0.        ],
       [1.130233  , 0.31159392, 1.1950791 , 0.32051137, 0.6311476 ],
       [0.49532753, 0.9677906 , 0.8187757 , 0.1351769 , 0.01893103],
       [0.60025734, 0.7728439 , 1.2898692 , 0.9285497 , 0.        ],
       [0.17482497, 0.30367106, 0.47157714, 0.544282  , 0.7309462 ],
       [0.7240762 , 0.27932334, 0.99110806, 0.09674796, 0.8390946 ]],
      dtype=float32)

In [46]:
resxdatt.reshape(b.shape[1],topky)

array([[0.6996368 , 1.8848312 , 1.130233  , 0.49532753, 0.7240762 ],
       [0.15550727, 0.18735047, 0.31159392, 0.9677906 , 0.27932334],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.3340667 , 0.30414024, 0.32051137, 0.1351769 , 0.09674796],
       [0.61331564, 0.29892322, 0.6311476 , 0.01893103, 0.8390946 ]],
      dtype=float32)

In [47]:
resxcol.reshape(a.shape[0],-1)

array([[0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4],
       [0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4],
       [0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4]], dtype=uint32)

In [48]:
resxnnz

array([16, 16,  0, 16, 16,  0,  0, 16], dtype=uint32)

In [49]:
resxnnzt.reshape(b.shape[1])

array([5, 5, 0, 5, 5], dtype=uint32)

In [50]:
resdense = to_dense(resxdat, resxcol, resxnnz, topkx, mult.shape)
resdense

array([[0.69963682, 0.15550727, 1.62362075, 0.33406669, 0.61331564],
       [1.88483119, 0.18735047, 1.41609323, 0.30414024, 0.29892322],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.13023305, 0.31159392, 1.19507909, 0.32051137, 0.63114762],
       [0.49532753, 0.9677906 , 0.81877571, 0.1351769 , 0.01893103],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.72407621, 0.27932334, 0.99110806, 0.09674796, 0.83909458]])

In [51]:
resdenset = to_dense(resxdatt, resxcolt, resxnnzt, topky, mult.T.shape)
resdenset

array([[0.69963682, 1.88483119, 0.        , 1.13023305, 0.49532753,
        0.        , 0.        , 0.72407621],
       [0.15550727, 0.18735047, 0.        , 0.31159392, 0.9677906 ,
        0.        , 0.        , 0.27932334],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.33406669, 0.30414024, 0.        , 0.32051137, 0.1351769 ,
        0.        , 0.        , 0.09674796],
       [0.61331564, 0.29892322, 0.        , 0.63114762, 0.01893103,
        0.        , 0.        , 0.83909458]])

In [52]:
resdenset = to_dense(resxdatt, resxcolt, resxnnzt, topky, mult.T.shape)
resdenset

array([[0.69963682, 1.88483119, 0.        , 1.13023305, 0.49532753,
        0.        , 0.        , 0.72407621],
       [0.15550727, 0.18735047, 0.        , 0.31159392, 0.9677906 ,
        0.        , 0.        , 0.27932334],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.33406669, 0.30414024, 0.        , 0.32051137, 0.1351769 ,
        0.        , 0.        , 0.09674796],
       [0.61331564, 0.29892322, 0.        , 0.63114762, 0.01893103,
        0.        , 0.        , 0.83909458]])

In [53]:
(mult - resdense).sum()

8.677268534898758

In [54]:
(mult - resdense)

array([[0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.16573755, 0.51614475, 1.40742004, 0.77114469, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.60025734, 0.7728439 , 1.28986919, 0.92854971, 0.        ],
       [0.17482497, 0.30367106, 0.47157714, 0.54428202, 0.73094618],
       [0.        , 0.        , 0.        , 0.        , 0.        ]])

In [55]:
(mult - resdenset.T).sum()

14.721945375204086

## comp

In [56]:
resxsum

array([3.0937524, 3.5246177, 2.3468003, 3.2115116, 3.102983 , 2.4004393,
       2.1758807, 2.6182923], dtype=float32)

In [57]:
a.sum(axis=1)

array([3.0937524, 3.5246177, 2.3468003, 3.2115116, 3.102983 , 2.4004393,
       2.1758807, 2.6182923], dtype=float32)

In [58]:
resdenset.T

array([[0.69963682, 0.15550727, 0.        , 0.33406669, 0.61331564],
       [1.88483119, 0.18735047, 0.        , 0.30414024, 0.29892322],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.13023305, 0.31159392, 0.        , 0.32051137, 0.63114762],
       [0.49532753, 0.9677906 , 0.        , 0.1351769 , 0.01893103],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.72407621, 0.27932334, 0.        , 0.09674796, 0.83909458]])

In [59]:
resysum

array([3.5506585, 3.3332088, 5.77045  , 3.095629 , 2.0694818],
      dtype=float32)

In [60]:
b.sum(axis=0)

array([3.5506585, 3.3332088, 5.77045  , 3.095629 , 2.0694818],
      dtype=float32)

In [61]:
mult

array([[0.6996368 , 0.15550727, 1.6236207 , 0.3340667 , 0.61331564],
       [1.8848312 , 0.18735047, 1.4160932 , 0.30414024, 0.29892322],
       [0.16573755, 0.51614475, 1.40742   , 0.7711447 , 0.        ],
       [1.130233  , 0.31159392, 1.1950791 , 0.32051137, 0.6311476 ],
       [0.49532753, 0.9677906 , 0.8187757 , 0.1351769 , 0.01893103],
       [0.60025734, 0.7728439 , 1.2898692 , 0.9285497 , 0.        ],
       [0.17482497, 0.30367106, 0.47157714, 0.544282  , 0.7309462 ],
       [0.7240762 , 0.27932334, 0.99110806, 0.09674796, 0.8390946 ]],
      dtype=float32)

In [62]:
mult.T - resdenset

array([[0.        , 0.        , 0.16573755, 0.        , 0.        ,
        0.60025734, 0.17482497, 0.        ],
       [0.        , 0.        , 0.51614475, 0.        , 0.        ,
        0.7728439 , 0.30367106, 0.        ],
       [1.62362075, 1.41609323, 1.40742004, 1.19507909, 0.81877571,
        1.28986919, 0.47157714, 0.99110806],
       [0.        , 0.        , 0.77114469, 0.        , 0.        ,
        0.92854971, 0.54428202, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.73094618, 0.        ]])

In [63]:
resxidx

array([0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=uint32)

In [64]:
resyidx

array([0, 1, 3, 4, 7], dtype=uint32)

In [None]:
asdf

## Prune Weights

In [65]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)

prg = cl.Program(ctx, """
    // prunes weights smaller than a constant C
    __kernel void prune(__global  float* matData,     // INPUT MATRIX DATA
                        __global  uint*  colIdx,
                        __global  uint*  rowNnz,
                        uint ellw,
                        float pruneval) { 
      uint gid = get_global_id(0);
      
      uint nnzs = rowNnz[gid];
      for (uint i=0; i<nnzs; i++) {
        uint idx = ellw * gid + i;
        float val = matData[idx];
        printf("\\nDATA:%.2f - %.2f", matData[idx], pruneval);
        if(fabs(val)<pruneval) {
          //printf("\\nPRUNE(%i): %.2f", gid, matData[idx]);
          for (uint j=i; j<=nnzs-1; j++) {
            uint idx2 = ellw * gid + j;
            matData[idx2] = matData[idx2+1];
            colIdx[idx2] = colIdx[idx2+1];
          }
          matData[ellw*gid+nnzs] = 0;
          colIdx[ellw*gid+nnzs] = 0;
          rowNnz[gid] -= 1;
          nnzs = rowNnz[gid];
        }
      }
    }""").build()

In [66]:
a.shape

(8, 16)

In [67]:
rows = a.shape[0]
cols = a.shape[1]

pruneval = .35

In [68]:
knl = prg.prune  # Use this Kernel object for repeated calls
evt = knl(queue, [rows,], None, adata_buf, acols_buf, annzs_buf, np.uint32(ellwa), np.float32(pruneval))


DATA:0.68 - 0.35
DATA:0.96 - 0.35
DATA:0.20 - 0.35
DATA:0.14 - 0.35
DATA:0.60 - 0.35
DATA:0.11 - 0.35
DATA:0.29 - 0.35
DATA:0.68 - 0.35
DATA:0.80 - 0.35
DATA:0.55 - 0.35
DATA:0.58 - 0.35
DATA:0.03 - 0.35
DATA:0.92 - 0.35
DATA:0.28 - 0.35
DATA:0.65 - 0.35
DATA:0.36 - 0.35
DATA:0.08 - 0.35
DATA:0.54 - 0.35
DATA:0.05 - 0.35
DATA:0.74 - 0.35
DATA:0.16 - 0.35
DATA:0.83 - 0.35
DATA:0.33 - 0.35
DATA:0.70 - 0.35
DATA:0.18 - 0.35
DATA:0.20 - 0.35
DATA:0.18 - 0.35
DATA:0.67 - 0.35
DATA:0.76 - 0.35
DATA:0.10 - 0.35
DATA:0.01 - 0.35
DATA:0.68 - 0.35
DATA:0.92 - 0.35
DATA:0.49 - 0.35
DATA:0.12 - 0.35

In [69]:
resxdat = np.zeros(adata.shape).astype(np.float32)
resxcol = np.zeros(acols.shape).astype(np.uint32)
resxnnz = np.zeros(annz.shape).astype(np.uint32)

cl.enqueue_copy(queue, resxdat, adata_buf)
cl.enqueue_copy(queue, resxcol, acols_buf)
cl.enqueue_copy(queue, resxnnz, annzs_buf)

<pyopencl._cl.NannyEvent at 0x7f18480f54a0>

In [70]:
a

array([[0.        , 0.68347037, 0.8035886 , 0.        , 0.        ,
        0.        , 0.        , 0.08023349, 0.8858316 , 0.        ,
        0.17861794, 0.        , 0.46201044, 0.        , 0.        ,
        0.        ],
       [0.9562663 , 0.5531271 , 0.        , 0.        , 0.5375557 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.20370705, 0.35530138, 0.        , 0.        ,
        0.9186601 ],
       [0.        , 0.20103766, 0.        , 0.        , 0.        ,
        0.74341315, 0.        , 0.        , 0.        , 0.5787496 ,
        0.        , 0.05387774, 0.        , 0.5917517 , 0.17797045,
        0.        ],
       [0.        , 0.        , 0.        , 0.14296253, 0.9103574 ,
        0.03242496, 0.        , 0.71459514, 0.        , 0.        ,
        0.        , 0.        , 0.7448136 , 0.        , 0.        ,
        0.66635793],
       [0.        , 0.        , 0.        , 0.59893256, 0.        ,
        0.        , 0.        , 

In [71]:
adata.reshape((4,-1))

array([[0.68347037, 0.8035886 , 0.08023349, 0.8858316 , 0.17861794,
        0.46201044, 0.        , 0.        , 0.        , 0.9562663 ,
        0.5531271 , 0.5375557 , 0.20370705, 0.35530138, 0.9186601 ,
        0.        , 0.        , 0.        ],
       [0.20103766, 0.74341315, 0.5787496 , 0.05387774, 0.5917517 ,
        0.17797045, 0.        , 0.        , 0.        , 0.14296253,
        0.9103574 , 0.03242496, 0.71459514, 0.7448136 , 0.66635793,
        0.        , 0.        , 0.        ],
       [0.59893256, 0.9217044 , 0.16333483, 0.17167741, 0.7605819 ,
        0.48675168, 0.        , 0.        , 0.        , 0.11415514,
        0.82574075, 0.27978534, 0.25018668, 0.83127654, 0.09929471,
        0.        , 0.        , 0.        ],
       [0.28872353, 0.04889954, 0.6493426 , 0.33329687, 0.84441024,
        0.01120785, 0.        , 0.        , 0.        , 0.67626965,
        0.35563806, 0.69630474, 0.6753613 , 0.12287189, 0.09184685,
        0.        , 0.        , 0.        ]], dty

In [72]:
acols.reshape((4,-1))

array([[ 1,  2,  7,  8, 10, 12,  0,  0,  0,  0,  1,  4, 11, 12, 15,  0,
         0,  0],
       [ 1,  5,  9, 11, 13, 14,  0,  0,  0,  3,  4,  5,  7, 12, 15,  0,
         0,  0],
       [ 3,  7,  9, 11, 14, 15,  0,  0,  0,  0,  1,  5,  7, 10, 11,  0,
         0,  0],
       [ 0,  1,  3,  5, 12, 13,  0,  0,  0,  0,  3,  6,  8, 11, 15,  0,
         0,  0]], dtype=uint32)

In [73]:
resxdat.reshape((4,-1))

array([[0.68347037, 0.8035886 , 0.8858316 , 0.46201044, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.9562663 ,
        0.5531271 , 0.5375557 , 0.35530138, 0.9186601 , 0.        ,
        0.        , 0.        , 0.        ],
       [0.74341315, 0.5787496 , 0.5917517 , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.9103574 ,
        0.71459514, 0.7448136 , 0.66635793, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.59893256, 0.9217044 , 0.17167741, 0.7605819 , 0.48675168,
        0.        , 0.        , 0.        , 0.        , 0.82574075,
        0.25018668, 0.83127654, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.04889954, 0.6493426 , 0.84441024, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.67626965,
        0.35563806, 0.69630474, 0.6753613 , 0.09184685, 0.        ,
        0.        , 0.        , 0.        ]], dty

In [74]:
resxcol.reshape((4,-1))

array([[ 1,  2,  8, 12,  0,  0,  0,  0,  0,  0,  1,  4, 12, 15,  0,  0,
         0,  0],
       [ 5,  9, 13,  0,  0,  0,  0,  0,  0,  4,  7, 12, 15,  0,  0,  0,
         0,  0],
       [ 3,  7, 11, 14, 15,  0,  0,  0,  0,  1,  7, 10,  0,  0,  0,  0,
         0,  0],
       [ 1,  3, 12,  0,  0,  0,  0,  0,  0,  0,  3,  6,  8, 15,  0,  0,
         0,  0]], dtype=uint32)

In [75]:
resxnnz

array([4, 5, 3, 4, 5, 3, 3, 5], dtype=uint32)

## results

In [76]:
mult.T

array([[0.6996368 , 1.8848312 , 0.16573755, 1.130233  , 0.49532753,
        0.60025734, 0.17482497, 0.7240762 ],
       [0.15550727, 0.18735047, 0.51614475, 0.31159392, 0.9677906 ,
        0.7728439 , 0.30367106, 0.27932334],
       [1.6236207 , 1.4160932 , 1.40742   , 1.1950791 , 0.8187757 ,
        1.2898692 , 0.47157714, 0.99110806],
       [0.3340667 , 0.30414024, 0.7711447 , 0.32051137, 0.1351769 ,
        0.9285497 , 0.544282  , 0.09674796],
       [0.61331564, 0.29892322, 0.        , 0.6311476 , 0.01893103,
        0.        , 0.7309462 , 0.8390946 ]], dtype=float32)

In [77]:
resxdat.reshape(a.shape[0],topk)

ValueError: cannot reshape array of size 72 into shape (8,16)

In [None]:
resxcol.reshape(a.shape[0],topk)

In [None]:
resxnnz.reshape(a.shape[0])

In [None]:
resxdatt.reshape(b.shape[1],topk)

In [None]:
resxcolt.reshape(b.shape[1],topk)

In [None]:
resxnnzt.reshape(b.shape[1])

### Update Vals (add sparse)

In [None]:
multdata, multcols, multnnz, multellw = to_data(mult)
multdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multdata)
multcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multcols)
multnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multnnz)
sdata_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
sidxs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
snnzs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*4)

prg = cl.Program(ctx, """
    // Every global_id_0 works on a row
    __kernel void adddense(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            float  lr,
                            uint   ellwidth,
                            __global  float* matDataAdd,     // INPUT MATRIX DATA
                            __global  uint*  colIdxAdd,
                            __global  uint*  rowNnzAdd,
                            uint ellwidthAdd
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);

      uint nnz    = rowNnz[gid];
      
      uint baseidxs = gid*ellwidth;
      uint baseidxd = gid*ellwidthAdd;
      
      uint nnzadd = rowNnzAdd[gid];
      printf("\\nNNZs: %i   GID:%i", nnzadd, gid);
      
      for (uint i=0; i<nnzadd; i++) {
        float addval = matDataAdd[baseidxd+i];
        uint addcol = colIdxAdd[baseidxd+i];
        
        uint refcol = colIdx[baseidxs+i];
        uint m = 0;
        while (addcol > refcol) {
          m += 1;
          refcol = colIdx[baseidxs+i+m];
        }
        
        //printf("\\nADD VAL:%.2f  ADDCOL:%i  idxs/d:(%i/%i)  gid/i:(%i/%i)", addval, addcol, baseidxs, baseidxd, gid,i);
        if (addval == 0.0) {
          //printf("\\nZERO VAL, CONT: %.2f - %i", addval, gid);
          continue;
        }
        if (addcol == refcol) {
          matData[baseidxs+i+m] += addval;
          printf("\\nINCREMENT: %.2f",addval);
        } else {
          if (rowNnz[gid] >= ellwidth) {
            break;
          }
          if (addcol > refcol) {
            rowNnz[gid] += 1;
            printf("\\nSET VAL0:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
            matData[baseidxs+i+m] = addval;
            colIdx[baseidxs+i+m] = addcol;
            continue;
          }
          for (uint j=nnz; j>i+m; j--) {
            //printf("\\nMOVE:%.2f", matData[baseidx+j-1]);
            colIdx[baseidxs+j] = colIdx[baseidxs+j-1];
            matData[baseidxs+j] = matData[baseidxs+j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
          
          printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i+m] = addval;
          colIdx[baseidxs+i+m] = addcol;
          if (nnz >= ellwidth)
            break;
        }
      }
    }""").build()

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
rows = a.shape[0]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddense  # Use this Kernel object for repeated calls
knl(queue, [rows], None, multdata_buf, multcols_buf, multnnzs_buf, np.float32(1), np.uint32(multellw), 
    sdata_buf, sidxs_buf, snnzs_buf, np.uint32(topk))

In [None]:
mult

In [None]:
data_res = np.empty_like(multdata)
cols_res = np.empty_like(multcols)
nnzs_res = np.empty_like(multnnz)
cl.enqueue_copy(queue, data_res, multdata_buf, is_blocking=True)
cl.enqueue_copy(queue, cols_res, multcols_buf, is_blocking=True)
cl.enqueue_copy(queue, nnzs_res, multnnzs_buf, is_blocking=True)

In [None]:
adenseadd = to_dense(data_res, cols_res, nnzs_res, multellw, mult.shape)
adenseadd.T

In [None]:
mult-adenseadd

### Update Vals (add sparset)

In [None]:
multt=np.zeros(mult.T.shape)

for row in range(multt.shape[0]):
    for col in range(multt.shape[1]):
        multt[row][col] = mult[col][row]

In [None]:
multdata, multcols, multnnz, multellw = to_data(multt)
multdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multdata)
multcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multcols)
multnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multnnz)

In [None]:
a.shape, b.shape

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
rows = mult.T.shape[0]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddense  # Use this Kernel object for repeated calls
knl(queue, [rows], None, multdata_buf, multcols_buf, multnnzs_buf, np.float32(1), np.uint32(multellw), 
    sdatat_buf, sidxst_buf, snnzst_buf, np.uint32(topk))

In [None]:
mult.T

In [None]:
data_res = np.empty_like(multdata)
cols_res = np.empty_like(multcols)
nnzs_res = np.empty_like(multnnz)
cl.enqueue_copy(queue, data_res, multdata_buf, is_blocking=True)
cl.enqueue_copy(queue, cols_res, multcols_buf, is_blocking=True)
cl.enqueue_copy(queue, nnzs_res, multnnzs_buf, is_blocking=True)

In [None]:
multt-data_res.reshape(multt.shape)

In [None]:
nnzs_res

In [None]:
adenseaddt = to_dense(data_res, cols_res, nnzs_res, multellw, multt.shape)
adenseaddt

In [None]:
multt-adenseaddt

In [None]:
adenseaddt

In [None]:
adenseadd.T == adenseaddt

### Update Vals (add topk to sparse)

In [None]:
matadd = np.random.randn(*a.shape).astype(np.float32)
matadd

In [None]:
a

In [None]:
a_added = a + matadd

In [None]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
add_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=matadd)

prg = cl.Program(ctx, """
    // Every global_id_0 works on a row
    __kernel void adddense(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            float  lr,
                            uint   ellwidth,
                            uint   awidth,
                            __global  float* vector_x    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);

      uint nnz    = rowNnz[gid];
      uint baseidxs = gid*ellwidth;
      uint baseidxd = gid*awidth;
      
      for (uint i=0; i<awidth; i++) {
        float addval = vector_x[baseidxd+i];
        //if (gid==1)
        //  printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[baseidxs+i]);
        if (addval == 0) {
          continue;
        }
        if (i == colIdx[baseidxs+i]) {
          matData[baseidxs+i] += addval;
        } else {
          if (rowNnz[gid] >= ellwidth) {
            break;
          }
          if (i > colIdx[baseidxs+i]) {
            rowNnz[gid] += 1;
            //if (gid==1)
            //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
            matData[baseidxs+i] = addval;
            colIdx[baseidxs+i] = i;
            continue;
          }
          for (uint j=nnz; j>i; j--) {
            //printf("\\nMOVE:%.2f", matData[baseidx+j-1]);
            colIdx[baseidxs+j] = colIdx[baseidxs+j-1];
            matData[baseidxs+j] = matData[baseidxs+j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
          //if (gid==1)
          //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i] = addval;
          colIdx[baseidxs+i] = i;
          if (nnz >= ellwidth)
            break;
        }
      }
    }""").build()

In [None]:
a.shape, b.shape

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
rows = a.shape[0]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddense  # Use this Kernel object for repeated calls
knl(queue, [rows], None, adata_buf, acols_buf, annzs_buf, np.float32(1), np.uint32(ellwa),np.uint32(a.shape[1]), add_buf)

In [None]:
matadd[0][0]

In [None]:
data_res = np.empty_like(adata)
cols_res = np.empty_like(acols)
nnzs_res = np.empty_like(annz)
cl.enqueue_copy(queue, data_res, adata_buf)
cl.enqueue_copy(queue, cols_res, acols_buf)
cl.enqueue_copy(queue, nnzs_res, annzs_buf)

In [None]:
adenseadd = to_dense(data_res, cols_res, nnzs_res, ellwa, a.shape)
adenseadd

In [None]:
a

In [None]:
matadd

In [None]:
a_added

In [None]:
adenseadd

In [None]:
adenseadd == a_added

### update vals

In [None]:
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
add_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=matadd)

prg = cl.Program(ctx, """
    // Every global_id_0 works on a row
    __kernel void adddenset(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            float  lr,
                            uint   ellwidth,
                            uint   aheight,
                            __global  float* vector_x    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint ncols = get_global_size(0);

      uint nnz    = rowNnz[gid];
      uint baseidxs = gid*ellwidth;
      
      for (uint i=0; i<aheight; i++) {
        if (nnz > ellwidth)
            break;
        uint baseidxd = i*ncols+gid;
        float addval = vector_x[baseidxd];
        //if (gid==1)
        //  printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[baseidxs+i]);
        if (addval == 0) {
          continue;
        }
        if (i == colIdx[baseidxs+i]) {
          printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i] += addval;
        } else {
          if (rowNnz[gid] >= ellwidth) {
            break;
          }
          if (i > colIdx[baseidxs+i]) {
            rowNnz[gid] += 1;
            //if (gid==1)
            //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
            matData[baseidxs+i] = addval;
            colIdx[baseidxs+i] = i;
            continue;
          }
          for (uint j=nnz; j>i; j--) {
            //printf("\\nMOVE:%.2f", matData[baseidx+j-1]);
            colIdx[baseidxs+j] = colIdx[baseidxs+j-1];
            matData[baseidxs+j] = matData[baseidxs+j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
          //if (gid==1)
          //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i] = addval;
          colIdx[baseidxs+i] = i;
        }
      }
    }""").build()

In [None]:
a.shape, b.shape

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
cols = a.shape[1]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddenset  # Use this Kernel object for repeated calls
knl(queue, [cols], None, adatat_buf, acolst_buf, annzst_buf, np.float32(1), np.uint32(ellwat),np.uint32(a.T.shape[1]), add_buf)

In [None]:
matadd[0][0]

In [None]:
datat_res = np.empty_like(adatat)
colst_res = np.empty_like(acolst)
nnzst_res = np.empty_like(annzt)
cl.enqueue_copy(queue, datat_res, adatat_buf)
cl.enqueue_copy(queue, colst_res, acolst_buf)
cl.enqueue_copy(queue, nnzst_res, annzst_buf)

In [None]:
adenseaddt = to_dense(datat_res, colst_res, nnzst_res, ellwat, a.T.shape).T
adenseaddt

In [None]:
a

In [None]:
matadd

In [None]:
a_added

In [None]:
adenseaddt == a_added

### Make Random

In [None]:
rand = SparseTensor.uniform(2,4)
rand

In [None]:
rand.to_numpy()

In [None]:
rand.data

### update vals

In [None]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)

In [None]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint col = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint row = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [None]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint topk = get_global_size(0);
      uint col = updateyidx[gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint row = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];m
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [None]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,1], None, adata_buf, acols_buf, annzs_buf, np.float32(1), np.uint32(ellwa), x_cp_buf, x_idx_buf, y_idx_buf)

resa = np.empty_like(adata)
resaidx = np.zeros(acols.shape).astype(np.uint32)
resannz = np.zeros(annz.shape).astype(np.uint32)

cl.enqueue_copy(queue, resa, adata_buf)
cl.enqueue_copy(queue, resaidx, acols_buf)
cl.enqueue_copy(queue, resannz, annzs_buf)

In [None]:
resa.shape, resaidx.shape, resannz.shape, ellwa, a.T.shape

In [None]:
adenseadd = to_dense(resa, resaidx, resannz, ellwa, a.shape)
adenseadd

In [None]:
adenseadd - adense

In [None]:
adenseadd == adense

In [None]:
ellwa

In [None]:
adata2 = adata.reshape(-1, ellwa)
adata2

In [None]:
resa = resa.reshape(-1, ellwa)
resa

In [None]:
resa - adata2

In [None]:
acols

In [None]:
resaidx

In [None]:
resannz

In [None]:
annz

### update vals2

In [None]:
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)

In [None]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint row = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint col = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [None]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,bs], None, adatat_buf, acolst_buf, annzst_buf, np.float32(1), np.uint32(ellwat), x_cp_buf, x_idx_buf, y_idx_buf)

resat = np.empty_like(adatat)
resaidxt = np.zeros(acolst.shape).astype(np.uint32)
resannzt = np.zeros(annzt.shape).astype(np.uint32)

cl.enqueue_copy(queue, resat, adatat_buf)
cl.enqueue_copy(queue, resaidxt, acolst_buf)
cl.enqueue_copy(queue, resannzt, annzst_buf)

In [None]:
ellwa

In [None]:
resat.shape, resaidxt.shape, resannzt.shape

In [None]:
adenseaddt = to_dense(resat, resaidxt, resannzt, ellwat, a.T.shape)
adenseaddt

In [None]:
adenseadd == adenseaddt.T

In [None]:
adata2t = adatat.reshape(-1, ellwat)
adata2t

In [None]:
resat = resat.reshape(-1, ellwat)
resat

In [None]:
resat - adata2t

In [None]:
acols

In [None]:
resaidx

In [None]:
resannz

In [None]:
annz