In [1]:
from tinygrad.densetensor import DenseTensor
from tinygrad.sparsetensor import SparseTensor
import numpy as np

%load_ext autoreload
%autoreload 2

DEVICE:GPU


In [2]:
x_init = np.random.randn(2,6).astype(np.float32)
x2_init = np.random.randn(3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
V_init = np.random.randn(3,3).astype(np.float32)
W_init = np.random.randn(6,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)

x = DenseTensor(x_init)
W = DenseTensor(W_init)
m = DenseTensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()

out.cpu().data, x

x2 = DenseTensor(x2_init)#.gpu()
W = SparseTensor(W_init)
out = W.dot(x2).relu().sum()

out.backward()

out.cpu().data, x

In [3]:
import numpy as np
import pyopencl as cl

mf = cl.mem_flags

In [682]:
dim1 = 784
dim2 = 8
dim3 = 10
topkx = 5
topky = 8
topk  = topkx
bs = dim3

np.random.seed(9)

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.4

a = np.zeros((dim1,dim2))
b = np.zeros((dim2,dim3)).astype(np.float32)

a.shape, b.shape

((784, 8), (8, 10))

In [683]:
x_init = np.random.randn(dim1,dim3).astype(np.float32)
w_init = np.random.randn(dim2,dim3).astype(np.float32)

In [684]:
w_init

array([[-0.60463536, -1.1710321 , -0.07425962, -0.89242464, -0.43796238,
         0.5658415 , -0.48279902, -0.2975719 , -1.143396  , -0.6425492 ],
       [ 1.5146434 , -0.5019981 ,  0.6814257 ,  1.1993246 , -0.70170444,
        -1.1829906 , -0.9982734 , -0.03202078, -0.77119285, -0.6904869 ],
       [ 0.12832837, -0.36010188,  0.2039541 , -0.57304317, -0.50078034,
        -0.38172   ,  0.57351166, -1.0273042 ,  1.3401899 ,  0.16892587],
       [ 1.831064  ,  1.2889572 ,  0.0827049 , -1.3255974 ,  0.8816386 ,
        -1.2041618 , -0.69006664, -1.6419983 ,  0.13103886,  1.7757125 ],
       [-2.2099125 , -1.1403371 ,  0.40859553, -0.33127317,  0.194872  ,
         0.47998413, -1.0421513 , -1.4831108 , -2.1041229 , -0.75974137],
       [ 0.39532152, -0.29778364, -0.1652074 ,  0.2927815 ,  0.5866176 ,
         0.38426524,  0.37534708,  2.0616708 ,  0.13078961, -0.4044718 ],
       [ 0.77458316,  1.0405055 ,  0.311686  , -0.916232  , -0.8365159 ,
         0.5338632 , -0.68742096,  0.12771639

In [685]:
def fill_sparse(mat, sparsity=0.5):
    indices = np.array(range(mat.shape[1]))
    nrows = int(mat.shape[1]*sparsity)
    for row in range(mat.shape[0]):
        lim = nrows #+ int(np.random.random()*3)
        mat[row][np.random.permutation(indices)[:lim]] = np.random.random(lim)
    return mat

a = fill_sparse(a, sparsity).astype(np.float32)
b = fill_sparse(b, sparsity).astype(np.float32)

In [686]:
a

array([[0.        , 0.20136862, 0.        , ..., 0.        , 0.24664757,
        0.        ],
       [0.        , 0.08636539, 0.        , ..., 0.2248708 , 0.        ,
        0.61742616],
       [0.8060224 , 0.        , 0.02815218, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.2683797 , 0.        , ..., 0.        , 0.        ,
        0.14286678],
       [0.        , 0.        , 0.85706204, ..., 0.2967158 , 0.5745791 ,
        0.        ],
       [0.        , 0.        , 0.22261621, ..., 0.41236356, 0.5468153 ,
        0.        ]], dtype=float32)

In [687]:
b

array([[0.51228654, 0.        , 0.89396334, 0.        , 0.583282  ,
        0.37977967, 0.        , 0.        , 0.        , 0.        ],
       [0.86082464, 0.        , 0.        , 0.29537702, 0.        ,
        0.        , 0.        , 0.51280856, 0.6324379 , 0.        ],
       [0.16401087, 0.83399713, 0.796908  , 0.98693824, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.12299865, 0.        , 0.        , 0.        , 0.99917537,
        0.1828657 , 0.        , 0.45594448, 0.        , 0.        ],
       [0.        , 0.45390037, 0.        , 0.8534569 , 0.        ,
        0.78406423, 0.        , 0.19776365, 0.        , 0.        ],
       [0.1423029 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.6472351 , 0.        , 0.31429735, 0.19298612],
       [0.7907769 , 0.        , 0.        , 0.        , 0.712911  ,
        0.        , 0.40409806, 0.96907926, 0.        , 0.        ],
       [0.        , 0.5134478 , 0.       

In [688]:
x2_init.T

array([ 1.0165602 ,  0.7194302 , -0.04339715], dtype=float32)

In [689]:
mult = a.dot(b)
mult.shape

(784, 10)

In [690]:
mult.shape

(784, 10)

In [691]:
def to_data(mat):
    all_rows = []
    all_idxs = []
    all_nnzs = []
    for row in range(mat.shape[0]):
        rowdata = []
        colidxs = []
        all_nnzs.append(0)
        for col in range(mat.shape[1]):
            val = mat[row][col]
            if val != 0:
                rowdata.append(val)
                colidxs.append(col)
                all_nnzs[-1] += 1
        all_rows.append(rowdata)
        all_idxs.append(colidxs)
    
    ellwidth = min(int(np.sqrt(np.max(all_nnzs))+1)**2, mat.shape[1])
    ellwidth = mat.shape[1]
    #all_rows = np.array(all_rows)#.astype(np.float32).flatten()           
    for row in range(mat.shape[0]):
        #print(row, all_rows)
        all_rows[row] = np.array(all_rows[row])
        all_rows[row].resize(ellwidth)
        all_idxs[row] = np.array(all_idxs[row])
        all_idxs[row].resize(ellwidth)
        #print(all_idxs[row])
    all_rows = np.array(all_rows)
    all_idxs = np.array(all_idxs)
    all_nnzs = np.array(all_nnzs)
    
#     while (not all_rows[:,-1].any()):
#         all_rows = all_rows[:,:-1]
#         all_idxs = all_idxs[:,:-1]
#         ellwidth -= 1
        
    
    all_rows = np.array(all_rows).astype(np.float32).flatten()
    all_idxs = np.array(all_idxs).astype(np.uint32).flatten()
    
    all_nnzs = np.array(all_nnzs).astype(np.uint32)
    
    
    return all_rows, all_idxs, all_nnzs, ellwidth

In [692]:
def to_dense(data, cols, nnzs, ellw, shape):
    out = np.zeros(shape)
    for row in range(shape[0]):
        for icol in range(nnzs[row]):
            out[row,cols[row*ellw+icol]] = data[row*ellw+icol]
    return out

In [693]:
wdata, wcols, wnnz, ellww = to_data(w_init)
wdata, wcols, wnnz, ellww

(array([-0.60463536, -1.1710321 , -0.07425962, -0.89242464, -0.43796238,
         0.5658415 , -0.48279902, -0.2975719 , -1.143396  , -0.6425492 ,
         1.5146434 , -0.5019981 ,  0.6814257 ,  1.1993246 , -0.70170444,
        -1.1829906 , -0.9982734 , -0.03202078, -0.77119285, -0.6904869 ,
         0.12832837, -0.36010188,  0.2039541 , -0.57304317, -0.50078034,
        -0.38172   ,  0.57351166, -1.0273042 ,  1.3401899 ,  0.16892587,
         1.831064  ,  1.2889572 ,  0.0827049 , -1.3255974 ,  0.8816386 ,
        -1.2041618 , -0.69006664, -1.6419983 ,  0.13103886,  1.7757125 ,
        -2.2099125 , -1.1403371 ,  0.40859553, -0.33127317,  0.194872  ,
         0.47998413, -1.0421513 , -1.4831108 , -2.1041229 , -0.75974137,
         0.39532152, -0.29778364, -0.1652074 ,  0.2927815 ,  0.5866176 ,
         0.38426524,  0.37534708,  2.0616708 ,  0.13078961, -0.4044718 ,
         0.77458316,  1.0405055 ,  0.311686  , -0.916232  , -0.8365159 ,
         0.5338632 , -0.68742096,  0.12771639, -0.8

In [694]:
wdatat, wcolst, wnnzt, ellwwt = to_data(w_init.T)
wdatat, wcolst, wnnzt, ellwwt

(array([-0.60463536,  1.5146434 ,  0.12832837,  1.831064  , -2.2099125 ,
         0.39532152,  0.77458316,  1.3192611 , -1.1710321 , -0.5019981 ,
        -0.36010188,  1.2889572 , -1.1403371 , -0.29778364,  1.0405055 ,
         1.1386915 , -0.07425962,  0.6814257 ,  0.2039541 ,  0.0827049 ,
         0.40859553, -0.1652074 ,  0.311686  , -0.34535363, -0.89242464,
         1.1993246 , -0.57304317, -1.3255974 , -0.33127317,  0.2927815 ,
        -0.916232  , -0.00841912, -0.43796238, -0.70170444, -0.50078034,
         0.8816386 ,  0.194872  ,  0.5866176 , -0.8365159 ,  0.7912638 ,
         0.5658415 , -1.1829906 , -0.38172   , -1.2041618 ,  0.47998413,
         0.38426524,  0.5338632 , -0.30207726, -0.48279902, -0.9982734 ,
         0.57351166, -0.69006664, -1.0421513 ,  0.37534708, -0.68742096,
         0.47890145, -0.2975719 , -0.03202078, -1.0273042 , -1.6419983 ,
        -1.4831108 ,  2.0616708 ,  0.12771639, -1.021893  , -1.143396  ,
        -0.77119285,  1.3401899 ,  0.13103886, -2.1

In [695]:
adata, acols, annz, ellwa = to_data(a)
adata, acols, annz, ellwa

(array([0.20136862, 0.45337066, 0.24664757, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 array([1, 4, 6, ..., 0, 0, 0], dtype=uint32),
 array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 

In [696]:
adatat, acolst, annzt, ellwat = to_data(a.T)
adatat, acolst, annzt, ellwat

(array([0.8060224 , 0.04902843, 0.06091982, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 array([2, 3, 6, ..., 0, 0, 0], dtype=uint32),
 array([301, 305, 283, 273, 297, 311, 294, 288], dtype=uint32),
 784)

In [697]:
bdata, bcols, bnnz, ellwb = to_data(b)
bdata, bcols, bnnz, ellwb

(array([0.51228654, 0.89396334, 0.583282  , 0.37977967, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.86082464, 0.29537702, 0.51280856, 0.6324379 , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.16401087, 0.83399713, 0.796908  , 0.98693824, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.12299865, 0.99917537, 0.1828657 , 0.45594448, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.45390037, 0.8534569 , 0.78406423, 0.19776365, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.1423029 , 0.6472351 , 0.31429735, 0.19298612, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.7907769 , 0.712911  , 0.40409806, 0.96907926, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.5134478 , 0.2370818 , 0.86837345, 0.34

In [698]:
bdatat, bcolst, bnnzt, ellwbt = to_data(b.T)
adatat, bcolst, bnnzt, ellwbt

(array([0.8060224 , 0.04902843, 0.06091982, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 array([0, 1, 2, 3, 5, 6, 0, 0, 2, 4, 7, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0,
        0, 0, 1, 2, 4, 7, 0, 0, 0, 0, 0, 3, 6, 7, 0, 0, 0, 0, 0, 3, 4, 7,
        0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 1, 3, 4, 6, 0, 0, 0, 0, 1, 5,
        0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0], dtype=uint32),
 array([6, 3, 2, 4, 4, 4, 2, 4, 2, 1], dtype=uint32),
 8)

In [699]:
adense = to_dense(adata, acols, annz, ellwa, a.shape)

In [700]:
adenset = to_dense(adatat, acolst, annzt, ellwat, a.T.shape)

In [701]:
bdense = to_dense(bdata, bcols, bnnz, ellwb, b.shape)

In [702]:
bdenset = to_dense(bdatat, bcolst, bnnzt, ellwbt, b.T.shape)

In [703]:
adense

array([[0.        , 0.20136862, 0.        , ..., 0.        , 0.24664757,
        0.        ],
       [0.        , 0.08636539, 0.        , ..., 0.2248708 , 0.        ,
        0.61742616],
       [0.80602241, 0.        , 0.02815218, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.26837969, 0.        , ..., 0.        , 0.        ,
        0.14286678],
       [0.        , 0.        , 0.85706204, ..., 0.2967158 , 0.57457912,
        0.        ],
       [0.        , 0.        , 0.22261621, ..., 0.41236356, 0.54681528,
        0.        ]])

In [704]:
adenset.T == adense

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])

In [705]:
bdenset.T == bdense

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])

In [706]:
a

array([[0.        , 0.20136862, 0.        , ..., 0.        , 0.24664757,
        0.        ],
       [0.        , 0.08636539, 0.        , ..., 0.2248708 , 0.        ,
        0.61742616],
       [0.8060224 , 0.        , 0.02815218, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.2683797 , 0.        , ..., 0.        , 0.        ,
        0.14286678],
       [0.        , 0.        , 0.85706204, ..., 0.2967158 , 0.5745791 ,
        0.        ],
       [0.        , 0.        , 0.22261621, ..., 0.41236356, 0.5468153 ,
        0.        ]], dtype=float32)

In [707]:
a == adense

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])

In [708]:
a.shape

(784, 8)

In [709]:
adata.shape, acols.shape, annz.shape, ellwa

((6272,), (6272,), (784,), 8)

In [710]:
#acols = acols.astype(np.uint32)
#annz = annz.astype(np.uint32)

In [711]:
adata, acols, annz, b

(array([0.20136862, 0.45337066, 0.24664757, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 array([1, 4, 6, ..., 0, 0, 0], dtype=uint32),
 array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 

## MatMul (Sparse-Dense)

adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // SPARSE x DENSE
    __kernel void matmul2(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint ncols,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      

      uint nnz    = rowNnz[gid];
      float sum = 0;
      for (uint gid2 = 0; gid2 < ncols; gid2++) {
        for (uint i = 0; i < nnz; i++) {
          uint index   = (gid * ellwidth) + i;
          uint col     = colIdx[index];
          float aval  = matData[index];
          float xval  = vector_x[col*ncols+gid2];
          //if (gid==0 && gid2==2)
          //  printf("aval, xval: %.2f,%.2f: (%i,%i) \\n", aval, xval, col, index);
          sum  += aval * xval;
        }
        //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
        vector_y[gid*ncols+gid2] = sum;
      }
    }""").build()

In [712]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // SPARSE x DENSE
    __kernel void matmul2(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint ncols,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      

      uint nnz    = rowNnz[gid];
      
      for (uint gid2 = 0; gid2 < ncols; gid2++) {
        float sum = 0;
        for (uint i = 0; i < nnz; i++) {
          uint index   = (gid * ellwidth) + i;
          uint col     = colIdx[index];
          float aval  = matData[index];
          uint xidx = col*ncols+gid2;
          float xval  = vector_x[xidx];
          if (gid==0 && gid2==1)
            printf("aval, xval: %.2f,%.2f: (%i,%i) - %i \\n", aval, xval, col, index, xidx);
          sum  += aval * xval;
        }
        //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
        vector_y[gid*ncols+gid2] = sum;
      }
    }""").build()

In [35]:
a.shape, b.shape

((32, 64), (64, 10))

In [36]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [37]:
rows = a.shape[0]

In [38]:
mult = mult.astype(np.float32)

In [39]:
outshape = (a.shape[0], b.shape[1])
outshape

(32, 10)

In [40]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod(outshape)*4)
knl = prg.matmul2  # Use this Kernel object for repeated calls
knl(queue, [outshape[0]], None, adata_buf, acols_buf, annzs_buf, np.uint32(ellwa), np.uint32(outshape[1]), b_buf, res_buf)

res_np = np.zeros(outshape).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)

aval, xval: 0.82,0.00: (0,0) - 1 
aval, xval: 0.48,0.00: (3,1) - 31 
aval, xval: 0.91,0.02: (4,2) - 41 
aval, xval: 0.88,0.41: (7,3) - 71 
aval, xval: 0.87,0.00: (11,4) - 111 
aval, xval: 0.26,0.00: (14,5) - 141 
aval, xval: 0.36,0.30: (15,6) - 151 
aval, xval: 0.41,0.00: (17,7) - 171 
aval, xval: 0.62,0.00: (20,8) - 201 
aval, xval: 0.47,0.00: (21,9) - 211 
aval, xval: 0.22,0.00: (28,10) - 281 
aval, xval: 0.59,0.65: (29,11) - 291 
aval, xval: 0.44,0.88: (31,12) - 311 
aval, xval: 0.47,0.55: (32,13) - 321 
aval, xval: 0.69,0.00: (33,14) - 331 
aval, xval: 0.38,0.00: (36,15) - 361 
aval, xval: 0.81,0.03: (37,16) - 371 
aval, xval: 0.98,0.00: (38,17) - 381 


<pyopencl._cl.NannyEvent at 0x7ff154347220>

aval, xval: 0.45,0.00: (39,18) - 391 
aval, xval: 0.29,0.45: (42,19) - 421 
aval, xval: 0.04,0.00: (44,20) - 441 
aval, xval: 0.28,0.15: (46,21) - 461 
aval, xval: 0.42,0.00: (50,22) - 501 
aval, xval: 0.39,0.00: (54,23) - 541 
aval, xval: 0.23,0.65: (61,24) - 611 


In [41]:
(res_np-mult).sum()

5.9604645e-07

In [42]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [43]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [44]:
res_buf

<pyopencl._cl.Buffer at 0x7ff1541fd770>

In [45]:
res_np

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [46]:
mult

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [47]:
res_np==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [48]:
res_np.shape

(32, 10)

In [49]:
mult.shape

(32, 10)

## MatMul (dense * sparse)

In [50]:
bdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bdata)
bcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcols)
bnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnz)
bdatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bdatat)
bcolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcolst)
bnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnzt)
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

prg = cl.Program(ctx, """
    // DENSE x SPARSE
    __kernel void matmul(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint   mwidth,
                            uint   ncols,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);

      for (uint gid2 = 0; gid2 < ncols; gid2++) {
        uint nnz = rowNnz[gid2];
        float sum = 0;
        for (uint i = 0; i < nnz; i++) {
          uint index   = (gid2 * ellwidth) + i;
          uint col     = colIdx[index];
          float aval  = matData[index];
          float xval  = vector_x[gid*mwidth+col];
          sum  += aval * xval;
          if (gid==0 && gid2==0)
            printf("aval, xval: %.2f,%.2f - %.2f: (%i,%i) \\n", aval, xval, sum, col, index);
        }
        //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
        vector_y[gid*ncols+gid2] = sum;
      }
    }""").build()

In [51]:
a.shape, b.shape

((32, 64), (64, 10))

In [52]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [53]:
rows = a.shape[0]

In [54]:
a.shape, b.shape

((32, 64), (64, 10))

In [55]:
mult = a.dot(b)
mult = mult.astype(np.float32)

In [56]:
outshape = np.array([a.shape[0], b.shape[1]])
outshape

array([32, 10])

In [57]:
b.T

array([[0.9414536 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.31479523, 0.11005293, 0.06780045,
        0.        , 0.16214237, 0.        , 0.6445239 , 0.50232446,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.17604527, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.42811713, 0.11871076,
        0.        , 0.        , 0.        , 0.6127449 , 0.41137826,
        0.82947165, 0.04573116, 0.        , 0.        , 0.04850706,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.6346003 , 0.        , 0.        , 0.69351566, 0.9864582 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.2653583 , 0.        , 0.2264393 ,
        0.10093832, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.03283854, 0.        , 0.01996238,
        0.        , 0.2878976 , 0.40940058, 0.5290864 , 0.0

In [58]:
a.T

array([[0.82272685, 0.57626116, 0.        , ..., 0.8716895 , 0.        ,
        0.        ],
       [0.        , 0.2289813 , 0.48130327, ..., 0.10489579, 0.        ,
        0.6832084 ],
       [0.        , 0.16428949, 0.        , ..., 0.        , 0.        ,
        0.9385753 ],
       ...,
       [0.22595333, 0.27008578, 0.        , ..., 0.        , 0.5132175 ,
        0.02876937],
       [0.        , 0.        , 0.12290299, ..., 0.39542586, 0.9358725 ,
        0.        ],
       [0.        , 0.9652075 , 0.        , ..., 0.        , 0.37984058,
        0.        ]], dtype=float32)

In [59]:
outshape.T

array([32, 10])

In [60]:
b.shape, outshape

((64, 10), array([32, 10]))

In [61]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod(outshape)*4)
knl = prg.matmul  # Use this Kernel object for repeated calls
knl(queue, [outshape.T[0]], None, bdatat_buf, bcolst_buf, bnnzst_buf, np.uint32(ellwbt), np.uint32(b.shape[0]), np.uint32(outshape.T[1]), a_buf, res_buf)

res_np = np.zeros(outshape).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)

<pyopencl._cl.NannyEvent at 0x7ff1543393b0>

aval, xval: 0.94,0.82 - 0.77: (0,0) 
aval, xval: 0.31,0.88 - 1.05: (7,1) 
aval, xval: 0.11,0.00 - 1.05: (8,2) 
aval, xval: 0.07,0.00 - 1.05: (9,3) 
aval, xval: 0.16,0.87 - 1.19: (11,4) 
aval, xval: 0.64,0.00 - 1.19: (13,5) 
aval, xval: 0.50,0.26 - 1.32: (14,6) 
aval, xval: 0.18,0.62 - 1.43: (20,7) 
aval, xval: 0.43,0.22 - 1.53: (28,8) 
aval, xval: 0.12,0.59 - 1.60: (29,9) 
aval, xval: 0.61,0.69 - 2.02: (33,10) 
aval, xval: 0.41,0.00 - 2.02: (34,11) 
aval, xval: 0.83,0.00 - 2.02: (35,12) 
aval, xval: 0.05,0.38 - 2.04: (36,13) 
aval, xval: 0.05,0.45 - 2.06: (39,14) 
aval, xval: 0.63,0.00 - 2.06: (45,15) 
aval, xval: 0.69,0.00 - 2.06: (48,16) 
aval, xval: 0.99,0.00 - 2.06: (49,17) 
aval, xval: 0.27,0.00 - 2.06: (57,18) 
aval, xval: 0.23,0.00 - 2.06: (59,19) 
aval, xval: 0.10,0.00 - 2.06: (60,20) 


In [62]:
(res_np-mult).sum()

5.9604645e-07

In [63]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [64]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [65]:
res_buf

<pyopencl._cl.Buffer at 0x7ff15433f0e0>

In [66]:
res_np

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [67]:
mult

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [68]:
res_np==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [69]:
res_np-mult

array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
   

In [70]:
res_np.shape

(32, 10)

In [71]:
mult.shape

(32, 10)

## MatMul2 (dense * sparse)

wdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wdata)
wcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wcols)
wnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wnnz)
wdatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wdatat)
wcolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wcolst)
wnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wnnzt)
x_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=x_init)

prg = cl.Program(ctx, """
    // DENSE x SPARSE-T
    __kernel void matmul2(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint   mwidth,
                            uint   ncols,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      uint nnz = rowNnz[gid];

      for (uint gid2 = 0; gid2 < ncols; gid2++) {
        float sum = 0;
        for (uint i = 0; i < nnz; i++) {
          uint index   = (gid * ellwidth) + i;
          uint col     = colIdx[index];
          float aval  = matData[index];
          float xval  = vector_x[gid2*ncols+col];
          sum  += aval * xval;
          if (gid==0 && gid2==1)
            printf("aval, xval: %.2f,%.2f - %.2f: (%i,%i) \\n", aval, xval, sum, col, index);
        }
        //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
        vector_y[gid2*ncols+gid] = sum;
      }
    }""").build()

In [72]:
wdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wdata)
wcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wcols)
wnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wnnz)
wdatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wdatat)
wcolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wcolst)
wnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=wnnzt)
x_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=x_init)

prg = cl.Program(ctx, """
    // DENSE x SPARSE-T
    __kernel void matmul2(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint   mwidth,
                            uint   ncols0,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      uint nnz = rowNnz[gid];
      uint gid2 = get_global_id(1);
      uint ncols = get_global_size(1);

      float sum = 0;
      for (uint i = 0; i < nnz; i++) {
        uint index   = (gid2 * ellwidth) + i;
        uint col     = colIdx[index];
        float aval  = matData[index];
        float xval  = vector_x[gid*mwidth+col];
        sum  += aval * xval;
        if (gid==1 && gid2==0) {
          printf("aval, xval: %.2f,%.2f - %.2f: (%i,%i) \\n", aval, xval, sum, col, index);
        }
      }
      //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
      vector_y[gid*ncols+gid2] = sum;
    }""").build()

In [73]:
outshape

array([32, 10])

In [74]:
w_init.shape, x_init.shape
w = w_init
x = x_init

In [75]:
res = np.zeros(w.shape[0]).astype(np.float32)
#res

In [76]:
rows = w.shape[0]

In [77]:
mult = mult.astype(np.float32)

In [78]:
outshape = np.array([x.shape[0], w.shape[0]])
outshape

array([32, 64])

In [79]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod(outshape)*4)
knl = prg.matmul2  # Use this Kernel object for repeated calls
knl(queue, outshape, None, wdata_buf, wcols_buf, wnnzs_buf, np.uint32(ellww), np.uint32(w.shape[1]), np.uint32(x.shape[1]), x_buf, res_buf)

res_np = np.zeros(outshape).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)

aval, xval: 0.37,0.64 - 0.23: (0,0) 
aval, xval: -0.03,1.74 - 0.17: (1,1) 
aval, xval: 0.08,0.30 - 0.20: (2,2) 
aval, xval: -0.21,0.71 - 0.05: (3,3) 
aval, xval: -0.50,1.82 - -0.86: (4,4) 
aval, xval: 0.30,0.43 - -0.73: (5,5) 
aval, xval: -0.18,1.54 - -1.01: (6,6) 
aval, xval: -1.70,-0.90 - 0.53: (7,7) 
aval, xval: -0.55,-0.14 - 0.60: (8,8) 
aval, xval: 0.30,1.30 - 0.99: (9,9) 


<pyopencl._cl.NannyEvent at 0x7ff1542db1d0>

In [80]:
mult = x.dot(w_init.T)
mult.shape

(32, 64)

In [81]:
mult

array([[ 1.0198762 , -2.8557298 , -0.7232425 , ..., -3.2115216 ,
         2.913457  , -2.3689373 ],
       [ 0.9885826 ,  5.6755376 , -3.0078568 , ...,  4.609541  ,
        -4.1741986 ,  5.5915403 ],
       [-1.0711612 ,  3.290888  ,  1.636742  , ..., -0.17965953,
         2.1418145 , -2.2195456 ],
       ...,
       [-1.1189016 , -5.7348213 ,  0.99644226, ..., -1.9031762 ,
        -0.62815833, -0.45603287],
       [-1.8187716 ,  1.9929719 ,  0.54578674, ...,  6.7044945 ,
        -5.146358  ,  0.7609284 ],
       [ 1.6192342 , -0.93305635, -3.1302233 , ...,  3.6590157 ,
         3.3437638 , -4.381602  ]], dtype=float32)

In [82]:
res_np

array([[ 1.0198761 , -2.8557298 , -0.7232425 , ..., -3.2115214 ,
         2.913457  , -2.3689373 ],
       [ 0.9885827 ,  5.6755376 , -3.0078568 , ...,  4.609541  ,
        -4.174198  ,  5.5915403 ],
       [-1.0711612 ,  3.290888  ,  1.636742  , ..., -0.17965946,
         2.1418145 , -2.2195456 ],
       ...,
       [-1.1189016 , -5.734821  ,  0.99644226, ..., -1.9031763 ,
        -0.62815833, -0.456033  ],
       [-1.8187717 ,  1.992972  ,  0.54578674, ...,  6.704494  ,
        -5.146358  ,  0.76092845],
       [ 1.6192341 , -0.93305606, -3.130223  , ...,  3.6590157 ,
         3.3437638 , -4.3816023 ]], dtype=float32)

In [83]:
x

array([[ 1.10855466e-03, -2.89544076e-01, -1.11606634e+00,
        -1.28827570e-02, -3.78361464e-01, -4.81135368e-01,
        -1.51733112e+00, -4.90871996e-01, -2.40680575e-01,
        -6.47947431e-01],
       [ 6.35891080e-01,  1.74011731e+00,  2.96682209e-01,
         7.07503676e-01,  1.82281578e+00,  4.30769026e-01,
         1.54272962e+00, -9.00721192e-01, -1.37125015e-01,
         1.29757905e+00],
       [ 6.75271153e-01,  3.19581181e-02,  9.18145895e-01,
         3.80509466e-01,  5.16367495e-01, -3.55239451e-01,
         2.08776996e-01,  3.28411072e-01, -4.98224765e-01,
        -2.09177685e+00],
       [-8.25877413e-02,  2.45518255e+00, -2.67211008e+00,
        -9.13279295e-01, -2.27314353e-01,  2.69315392e-01,
         1.13046122e+00,  1.04239750e+00,  1.30381048e+00,
         1.38940072e+00],
       [-6.56452596e-01, -5.62572964e-02, -4.99902606e-01,
         4.36419368e-01, -3.75813037e-01, -9.23061609e-01,
         1.91725028e+00, -1.50302842e-01, -6.38729751e-01,
         8.

In [84]:
w

array([[ 0.36527893, -0.03430624,  0.07568537, -0.20548421, -0.49928266,
         0.29659814, -0.17751434, -1.7038512 , -0.5477791 ,  0.29694587],
       [-0.3838112 ,  0.6774855 ,  1.7037388 , -1.3552694 ,  3.2101295 ,
        -0.47412255, -0.1520954 ,  0.5938055 , -1.3909731 ,  0.09679974],
       [ 0.02328284, -0.52373195,  0.2959063 , -0.6676638 ,  0.04768714,
        -1.580838  ,  0.5743469 ,  0.93183094,  1.7980465 , -0.7189207 ],
       [ 0.61396384, -0.26823863,  0.93568426, -1.5383533 , -0.6345008 ,
        -0.0785874 ,  1.5962073 , -1.6852865 ,  0.72588986, -0.04414608],
       [ 0.3679531 , -0.54433686, -1.5003314 ,  1.275841  , -1.0670458 ,
         0.15271302, -1.042708  ,  0.26798233,  1.182636  ,  0.3193706 ],
       [-0.62229687, -0.12888767,  0.35747743, -1.4602345 ,  0.5396376 ,
         0.1378612 , -0.30455658,  0.5800385 ,  0.5298659 ,  0.02696081],
       [-0.19980964, -0.5811444 ,  0.41489   ,  1.6072612 ,  0.5776702 ,
         0.74300027, -1.2662393 , -1.53522   

In [85]:
(res_np-mult).sum()

-4.2980537e-06

In [86]:
mult

array([[ 1.0198762 , -2.8557298 , -0.7232425 , ..., -3.2115216 ,
         2.913457  , -2.3689373 ],
       [ 0.9885826 ,  5.6755376 , -3.0078568 , ...,  4.609541  ,
        -4.1741986 ,  5.5915403 ],
       [-1.0711612 ,  3.290888  ,  1.636742  , ..., -0.17965953,
         2.1418145 , -2.2195456 ],
       ...,
       [-1.1189016 , -5.7348213 ,  0.99644226, ..., -1.9031762 ,
        -0.62815833, -0.45603287],
       [-1.8187716 ,  1.9929719 ,  0.54578674, ...,  6.7044945 ,
        -5.146358  ,  0.7609284 ],
       [ 1.6192342 , -0.93305635, -3.1302233 , ...,  3.6590157 ,
         3.3437638 , -4.381602  ]], dtype=float32)

In [87]:
res_np==mult

array([[False,  True,  True, ..., False,  True,  True],
       [False,  True,  True, ...,  True, False,  True],
       [ True,  True,  True, ..., False,  True,  True],
       ...,
       [ True, False,  True, ..., False,  True, False],
       [False, False,  True, ..., False,  True, False],
       [False, False, False, ...,  True,  True, False]])

In [88]:
res_np-mult

array([[-1.1920929e-07,  0.0000000e+00,  0.0000000e+00, ...,
         2.3841858e-07,  0.0000000e+00,  0.0000000e+00],
       [ 5.9604645e-08,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  4.7683716e-07,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         7.4505806e-08,  0.0000000e+00,  0.0000000e+00],
       ...,
       [ 0.0000000e+00,  4.7683716e-07,  0.0000000e+00, ...,
        -1.1920929e-07,  0.0000000e+00, -1.1920929e-07],
       [-1.1920929e-07,  1.1920929e-07,  0.0000000e+00, ...,
        -4.7683716e-07,  0.0000000e+00,  5.9604645e-08],
       [-1.1920929e-07,  2.9802322e-07,  2.3841858e-07, ...,
         0.0000000e+00,  0.0000000e+00, -4.7683716e-07]], dtype=float32)

In [89]:
res_np.shape

(32, 64)

In [90]:
mult.shape

(32, 64)

In [91]:
asdf

NameError: name 'asdf' is not defined

## MatMul (dense * sparse) NEW

bdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bdata)
bcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcols)
bnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnz)
bdatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bdatat)
bcolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcolst)
bnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnzt)
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

prg = cl.Program(ctx, """
    // DENSE x SPARSE
    __kernel void matmulnew(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint   mwidth,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      uint gid2 = get_global_id(1);
      uint ncols = get_global_size(1);
      uint nnz = rowNnz[gid2];
      float sum = 0;
      for (uint i = 0; i < nnz; i++) {
        uint index   = (gid2 * ellwidth) + i;
        uint col     = colIdx[index];
        float aval  = matData[index];
        float xval  = vector_x[gid*mwidth+col];
        vector_y[gid2*nrows+gid] += aval * xval;
        if (gid==0 && gid2==0)
          printf("aval, xval: %.2f,%.2f - %.2f: (%i,%i) \\n", aval, xval, sum, col, index);
        //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
        
      }
      
    }""").build()

In [None]:
mult = a.dot(b)
mult

In [None]:
bdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=data)
bcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcols)
bnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnz)
bdatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bdatat)
bcolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcolst)
bnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnzt)
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

prg = cl.Program(ctx, """
    // DENSE x SPARSE
    __kernel void matmulnew(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint   mwidth,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      uint gid2 = get_global_id(1);
      uint ncols = get_global_size(1);
      uint nnz = rowNnz[gid2];
      float sum = 0;
      for (uint i = 0; i < nnz; i++) {
        uint index   = (gid2 * ellwidth) + i;
        uint col     = colIdx[index];
        float aval  = matData[index];
        float xval  = vector_x[gid*mwidth+col];
        sum  += aval * xval;
        if (gid==1 && gid2==0)
          printf("aval, xval: %.2f,%.2f - %.2f: (%i,%i) \\n", aval, xval, sum, col, index);
        //printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
      }
      vector_y[gid2*ncols+gid] = sum;
    }""").build()

In [None]:
a.shape, b.shape

In [92]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [93]:
rows = a.shape[0]

In [94]:
mult = mult.astype(np.float32)

In [95]:
outshape = np.array([a.shape[0], b.shape[1]])
outshape

array([32, 10])

In [96]:
b.T

array([[0.9414536 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.31479523, 0.11005293, 0.06780045,
        0.        , 0.16214237, 0.        , 0.6445239 , 0.50232446,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.17604527, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.42811713, 0.11871076,
        0.        , 0.        , 0.        , 0.6127449 , 0.41137826,
        0.82947165, 0.04573116, 0.        , 0.        , 0.04850706,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.6346003 , 0.        , 0.        , 0.69351566, 0.9864582 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.2653583 , 0.        , 0.2264393 ,
        0.10093832, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.03283854, 0.        , 0.01996238,
        0.        , 0.2878976 , 0.40940058, 0.5290864 , 0.0

In [97]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [98]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod(outshape)*4)
knl = prg.matmulnew  # Use this Kernel object for repeated calls
knl(queue, outshape, None, bdatat_buf, bcolst_buf, bnnzst_buf, np.uint32(ellwbt), np.uint32(b.shape[0]), a_buf, res_buf)

res_np = np.zeros((outshape[0],)).astype(np.float32)
print(res_np.shape)
cl.enqueue_copy(queue, res_np, res_buf)

AttributeError: 'matmulnew' was not found as a program info attribute or as a kernel name

In [99]:
(res_np-mult.T).sum()

ValueError: operands could not be broadcast together with shapes (32,64) (64,32) 

In [100]:
res_buf

<pyopencl._cl.Buffer at 0x7ff154347630>

In [101]:
res_np.T

array([[ 1.0198761 ,  0.9885827 , -1.0711612 , ..., -1.1189016 ,
        -1.8187717 ,  1.6192341 ],
       [-2.8557298 ,  5.6755376 ,  3.290888  , ..., -5.734821  ,
         1.992972  , -0.93305606],
       [-0.7232425 , -3.0078568 ,  1.636742  , ...,  0.99644226,
         0.54578674, -3.130223  ],
       ...,
       [-3.2115214 ,  4.609541  , -0.17965946, ..., -1.9031763 ,
         6.704494  ,  3.6590157 ],
       [ 2.913457  , -4.174198  ,  2.1418145 , ..., -0.62815833,
        -5.146358  ,  3.3437638 ],
       [-2.3689373 ,  5.5915403 , -2.2195456 , ..., -0.456033  ,
         0.76092845, -4.3816023 ]], dtype=float32)

In [102]:
mult

array([[ 1.0198762 , -2.8557298 , -0.7232425 , ..., -3.2115216 ,
         2.913457  , -2.3689373 ],
       [ 0.9885826 ,  5.6755376 , -3.0078568 , ...,  4.609541  ,
        -4.1741986 ,  5.5915403 ],
       [-1.0711612 ,  3.290888  ,  1.636742  , ..., -0.17965953,
         2.1418145 , -2.2195456 ],
       ...,
       [-1.1189016 , -5.7348213 ,  0.99644226, ..., -1.9031762 ,
        -0.62815833, -0.45603287],
       [-1.8187716 ,  1.9929719 ,  0.54578674, ...,  6.7044945 ,
        -5.146358  ,  0.7609284 ],
       [ 1.6192342 , -0.93305635, -3.1302233 , ...,  3.6590157 ,
         3.3437638 , -4.381602  ]], dtype=float32)

In [103]:
res_np-mult.T

ValueError: operands could not be broadcast together with shapes (32,64) (64,32) 

In [104]:
res_np.shape

(32, 64)

In [105]:
mult.shape

(32, 64)

In [106]:
asdf

NameError: name 'asdf' is not defined

# Matmult Dense Dense

In [107]:
b_buf2 = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

prg = cl.Program(ctx, """
    // multilplies x by y WITH Y TRANSPOSED INDEXING
    __kernel void matmul0(__global  float* x,     // INPUT MATRIX DATA
                          __global  float* y,    // INPUT
                          __global  float* res,    // INPUT
                          uint msize
                          ) { // LOCAL SHARED BUFFER
      uint isize = get_global_size(0);
      uint osize = get_global_size(1);
      int gidx = get_global_id(0); // row
      int gidy = get_global_id(1); // col

      float ret = 0.0;
      for (int i = 0; i < msize; i++) {
        uint xidx = gidx*msize+i; 
        float xval = x[xidx];
        uint yidx = osize*i+gidy;
        float yval = y[yidx];
        ret += xval*yval;
        if (gidx==0 && gidy==0)
          printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, res, xidx, yidx);
      }

      //if (gidx==0&&gidy==0)
      //  printf("\\nsum:%.2f", ret);
      res[gidx * osize + gidy] = ret;
    }""").build()

In [108]:
a.shape, b.shape

((32, 64), (64, 10))

In [109]:
rows = a.shape[0]

In [110]:
mult = mult.astype(np.float32)

In [111]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.matmul0  # Use this Kernel object for repeated calls
knl(queue, [rows,b.shape[1]], None, a_buf, b_buf2, res_buf, np.uint32(a.shape[1]))

res_np = np.zeros([rows,b.shape[1]]).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)


mult: 0.82 x 0.94 - 0.00  -- 0/0
mult: 0.00 x 0.00 - 0.00  -- 1/10
mult: 0.00 x 0.00 - 0.00  -- 2/20
mult: 0.48 x 0.00 - 0.00  -- 3/30
mult: 0.91 x 0.00 - 0.00  -- 4/40
mult: 0.00 x 0.00 - 0.00  -- 5/50
mult: 0.00 x 0.00 - 0.00  -- 6/60
mult: 0.88 x 0.31 - 0.00  -- 7/70
mult: 0.00 x 0.11 - 0.00  -- 8/80
mult: 0.00 x 0.07 - 0.00  -- 9/90
mult: 0.00 x 0.00 - 0.00  -- 10/100
mult: 0.87 x 0.16 - 0.00  -- 11/110
mult: 0.00 x 0.00 - 0.00  -- 12/120
mult: 0.00 x 0.64 - 0.00  -- 13/130
mult: 0.26 x 0.50 - 0.00  -- 14/140
mult: 0.36 x 0.00 - 0.00  -- 15/150
mult: 0.00 x 0.00 - 0.00  -- 16/160
mult: 0.41 x 0.00 - 0.00  -- 17/170
mult: 0.00 x 0.00 - 0.00  -- 18/180
mult: 0.00 x 0.00 - 0.00  -- 19/190
mult: 0.62 x 0.18 - 0.00  -- 20/200
mult: 0.47 x 0.00 - 0.00  -- 21/210
mult: 0.00 x 0.00 - 0.00  -- 22/220
mult: 0.00 x 0.00 - 0.00  -- 23/230
mult: 0.00 x 0.00 - 0.00  -- 24/240
mult: 0.00 x 0.00 - 0.00  -- 25/250
mult: 0.00 x 0.00 - 0.00  -- 26/260
mult: 0.00 x 0.00 - 0.00  -- 27/270
mult: 0.22 x

<pyopencl._cl.NannyEvent at 0x7ff15428a9f0>

 0.43 - 0.00  -- 28/280
mult: 0.59 x 0.12 - 0.00  -- 29/290
mult: 0.00 x 0.00 - 0.00  -- 30/300
mult: 0.44 x 0.00 - 0.00  -- 31/310
mult: 0.47 x 0.00 - 0.00  -- 32/320
mult: 0.69 x 0.61 - 0.00  -- 33/330
mult: 0.00 x 0.41 - 0.00  -- 34/340
mult: 0.00 x 0.83 - 0.00  -- 35/350
mult: 0.38 x 0.05 - 0.00  -- 36/360
mult: 0.81 x 0.00 - 0.00  -- 37/370
mult: 0.98 x 0.00 - 0.00  -- 38/380
mult: 0.45 x 0.05 - 0.00  -- 39/390
mult: 0.00 x 0.00 - 0.00  -- 40/400
mult: 0.00 x 0.00 - 0.00  -- 41/410
mult: 0.29 x 0.00 - 0.00  -- 42/420
mult: 0.00 x 0.00 - 0.00  -- 43/430
mult: 0.04 x 0.00 - 0.00  -- 44/440
mult: 0.00 x 0.63 - 0.00  -- 45/450
mult: 0.28 x 0.00 - 0.00  -- 46/460
mult: 0.00 x 0.00 - 0.00  -- 47/470
mult: 0.00 x 0.69 - 0.00  -- 48/480
mult: 0.00 x 0.99 - 0.00  -- 49/490
mult: 0.42 x 0.00 - 0.00  -- 50/500
mult: 0.00 x 0.00 - 0.00  -- 51/510
mult: 0.00 x 0.00 - 0.00  -- 52/520
mult: 0.00 x 0.00 - 0.00  -- 53/530
mult: 0.39 x 0.00 - 0.00  -- 54/540
mult: 0.00 x 0.00 - 0.00  -- 55/550
mult

In [112]:
(res_np-mult.T).sum()

ValueError: operands could not be broadcast together with shapes (32,10) (64,32) 

In [113]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [114]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [115]:
res_np

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [116]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [117]:
res_np==mult

  res_np==mult


False

In [118]:
res_np.shape

(32, 10)

In [119]:
mult.shape

(32, 64)

# Matmult Dense Transposed

In [120]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [121]:
c=np.zeros(b.T.shape)
bt = b.T
for row in range(bt.shape[0]):
    for col in range(bt.shape[1]):
        c[row][col] = bt[row][col]

In [122]:
b_buf2 = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

prg = cl.Program(ctx, """
    // multilplies x by y WITH Y TRANSPOSED INDEXING
    __kernel void matmul0(__global  float* x,     // INPUT MATRIX DATA
                          __global  float* y,    // INPUT
                          __global  float* res,    // INPUT
                          uint msize
                          ) { // LOCAL SHARED BUFFER
      uint isize = get_global_size(0);
      uint osize = get_global_size(1);
      int gidx = get_global_id(0); // row
      int gidy = get_global_id(1); // col

      float ret = 0.0;
      for (int i = 0; i < msize; i++) {
        uint xidx = gidx*msize+i;
        float xval = x[xidx];
        uint yidx = msize*gidy+i;
        float yval = y[yidx];
        ret += xval*yval;
        if (gidx==0 && gidy==0)
          printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, res, xidx, yidx);
      }

      //if (gidx==0&&gidy==0)
      //  printf("\\nsum:%.2f", ret);
      res[gidx * osize + gidy] = ret;
    }""").build()

In [123]:
a.shape, b.T.shape

((32, 64), (10, 64))

In [124]:
rows = a.shape[0]

In [125]:
mult = mult.astype(np.float32)

In [126]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.matmul0  # Use this Kernel object for repeated calls
knl(queue, [rows,b.shape[1]], None, a_buf, b_buf2, res_buf, np.uint32(a.shape[1]))

res_np = np.zeros([rows,b.shape[1]]).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)

<pyopencl._cl.NannyEvent at 0x7ff15428a770>


mult: 0.82 x 0.94 - 0.00  -- 0/0
mult: 0.00 x 0.00 - 0.00  -- 1/1
mult: 0.00 x 0.00 - 0.00  -- 2/2
mult: 0.48 x 0.00 - 0.00  -- 3/3
mult: 0.91 x 0.00 - 0.00  -- 4/4
mult: 0.00 x 0.00 - 0.00  -- 5/5
mult: 0.00 x 0.00 - 0.00  -- 6/6
mult: 0.88 x 0.31 - 0.00  -- 7/7
mult: 0.00 x 0.11 - 0.00  -- 8/8
mult: 0.00 x 0.07 - 0.00  -- 9/9
mult: 0.00 x 0.00 - 0.00  -- 10/10
mult: 0.87 x 0.16 - 0.00  -- 11/11
mult: 0.00 x 0.00 - 0.00  -- 12/12
mult: 0.00 x 0.64 - 0.00  -- 13/13
mult: 0.26 x 0.50 - 0.00  -- 14/14
mult: 0.36 x 0.00 - 0.00  -- 15/15
mult: 0.00 x 0.00 - 0.00  -- 16/16
mult: 0.41 x 0.00 - 0.00  -- 17/17
mult: 0.00 x 0.00 - 0.00  -- 18/18
mult: 0.00 x 0.00 - 0.00  -- 19/19
mult: 0.62 x 0.18 - 0.00  -- 20/20
mult: 0.47 x 0.00 - 0.00  -- 21/21
mult: 0.00 x 0.00 - 0.00  -- 22/22
mult: 0.00 x 0.00 - 0.00  -- 23/23
mult: 0.00 x 0.00 - 0.00  -- 24/24
mult: 0.00 x 0.00 - 0.00  -- 25/25
mult: 0.00 x 0.00 - 0.00  -- 26/26
mult: 0.00 x 0.00 - 0.00  -- 27/27
mult: 0.22 x 0.43 - 0.00  -- 28/28
mult

In [127]:
(res_np-mult).sum()

ValueError: operands could not be broadcast together with shapes (32,10) (32,64) 

In [128]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [129]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [130]:
res_np

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [131]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [132]:
res_np==mult

  res_np==mult


False

In [133]:
res_np.shape

(32, 10)

In [134]:
mult.shape

(32, 64)

# Matmult Transposed Dense

In [135]:
mult = a.dot(b)

In [136]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [137]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void matmul0(__global  float* x,     // INPUT MATRIX DATA
                          __global  float* y,    // INPUT
                          __global  float* res,    // INPUT
                          uint msize,
                          uint isize
                          ) { // LOCAL SHARED BUFFER
      uint osize = get_global_size(0);
      int gidy = get_global_id(0); // row
      
      for (uint gidx = 0; gidx < isize; gidx++) {
        float ret = 0.0;
        for (uint i = 0; i < msize; i++) {
          uint xidx = i*isize+gidx;
          float xval = x[xidx];
          uint yidx = osize*i+gidy;
          float yval = y[yidx];
          ret += xval*yval;
          if (gidx==0 && gidy==0)
            printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, ret, xidx, yidx);
        }
        //if (gidx==0&&gidy==0)
        //  printf("\\nsum:%.2f", ret);
        res[gidx * osize + gidy] = ret;
      }
    }""").build()

In [138]:
a.shape, b.shape

((32, 64), (64, 10))

In [139]:
rows = a.shape[0]

In [140]:
mult = mult.astype(np.float32)
mult.shape

(32, 10)

In [141]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.matmul0  # Use this Kernel object for repeated calls
knl(queue, [b.shape[1]], None, a_buf, b_buf, res_buf, np.uint32(a.shape[1]), np.uint32(rows))

res_np = np.zeros([rows,b.shape[1]]).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)

<pyopencl._cl.NannyEvent at 0x7ff1542894f0>


mult: 0.82 x 0.94 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 32/10
mult: 0.00 x 0.00 - 0.77  -- 64/20
mult: 0.48 x 0.00 - 0.77  -- 96/30
mult: 0.91 x 0.00 - 0.77  -- 128/40
mult: 0.00 x 0.00 - 0.77  -- 160/50
mult: 0.00 x 0.00 - 0.77  -- 192/60
mult: 0.88 x 0.31 - 1.05  -- 224/70
mult: 0.00 x 0.11 - 1.05  -- 256/80
mult: 0.00 x 0.07 - 1.05  -- 288/90
mult: 0.00 x 0.00 - 1.05  -- 320/100
mult: 0.87 x 0.16 - 1.19  -- 352/110
mult: 0.00 x 0.00 - 1.19  -- 384/120
mult: 0.00 x 0.64 - 1.19  -- 416/130
mult: 0.26 x 0.50 - 1.32  -- 448/140
mult: 0.36 x 0.00 - 1.32  -- 480/150
mult: 0.00 x 0.00 - 1.32  -- 512/160
mult: 0.41 x 0.00 - 1.32  -- 544/170
mult: 0.00 x 0.00 - 1.32  -- 576/180
mult: 0.00 x 0.00 - 1.32  -- 608/190
mult: 0.62 x 0.18 - 1.43  -- 640/200
mult: 0.47 x 0.00 - 1.43  -- 672/210
mult: 0.00 x 0.00 - 1.43  -- 704/220
mult: 0.00 x 0.00 - 1.43  -- 736/230
mult: 0.00 x 0.00 - 1.43  -- 768/240
mult: 0.00 x 0.00 - 1.43  -- 800/250
mult: 0.00 x 0.00 - 1.43  -- 832/260
mult: 0.00 x 0.0

In [142]:
(res_np-mult).sum()

5.9604645e-07

In [143]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [144]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [145]:
res_np

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [146]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [147]:
res_np==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [148]:
res_np.shape

(32, 10)

In [149]:
mult.shape

(32, 10)

# Matmult Transposed Dense (SPR) - Get Topk - NEW

In [789]:
topkx = 10
topky = 784

In [790]:
mult.shape

(784, 10)

In [822]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
x_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, a.shape[0]*4)
y_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.shape[1]*4)
xs_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topkx*4)
ys_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topky*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topkx*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topky*4)

prg0 = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void gettopkx(__global  float* x,      // INPUT MATRIX DATA
                          __global  float* xsum,    // INPUT
                          __global  uint*  youtidx, // OUT
                          uint topky,
                          uint msize
                          ) { // LOCAL SHARED BUFFER
      uint isize = get_global_size(0);
      int gidx = get_global_id(0); // row
      
      // get topk
      xsum[gidx] = 0;
      for (uint i=0; i<msize; i++) {
        float val = x[i*isize+gidx];
        //if (gid == 0) {
        //  printf("\\nADD VALx: %.2f - %i", val, i*msize+gid);
        //}
        xsum[gidx] += val;
      }
      
      float valx = xsum[gidx];
      uint posx = 0;
      for (uint i = 0; i < isize; i++) {
        float tempval = fabs(xsum[i]);
        bool larger = (tempval > fabs(valx)) || (fabs(tempval) == fabs(valx) && i < gidx);
        posx += (larger)?1:0;
      }
      if (posx < topky) {
        youtidx[posx] = gidx;
      }
    }""").build()

prg = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void gettopky(__global  float* y,      // INPUT
                          __global  float* ysum,    // INPUT
                          __global  uint*  xoutidx, // OUT
                          uint topkx,
                          uint msize
                          ) { // LOCAL SHARED BUFFER
      uint osize = get_global_size(0);
      int gidy = get_global_id(0); // row
      
      ysum[gidy] = 0;
      for (uint i=0; i<msize; i++) {
        float val = y[i*osize+gidy];
        ysum[gidy] += val;
      }
      //barrier(CLK_GLOBAL_MEM_FENCE);
      float valy = ysum[gidy];
      uint posy = 0;
      for (uint i = 0; i < osize; i++) {
        float tempval = fabs(ysum[i]);
        bool larger = (tempval > fabs(valy)) || (fabs(tempval) == fabs(valy) && i < gidy);
        posy += (larger)?1:0;
      }
      if (posy < topkx) {
        xoutidx[posy] = gidy;
      }
    }""").build()

In [823]:
prg2 = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void sortuints(__global  uint* x,      // INPUT MATRIX DATA
                            __global  uint* xs      // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint isize = get_global_size(0);
      int gidx = get_global_id(0); // row
      
      uint val = x[gidx];
      uint posx = 0;
      for (uint i = 0; i < isize; i++) {
        uint tempval = x[i];
        bool smaller = (tempval < val) || (tempval == val && i < gidx);
        posx += (smaller)?1:0;
      }
      xs[posx] = x[gidx];
    }""").build()

In [824]:
topkx, topky

(10, 784)

In [825]:
rows = a.shape[0]
cols = b.shape[1]
rows, cols

(784, 10)

In [826]:
mult = mult.astype(np.float32)
mult.shape

(784, 10)

In [827]:
knlx = prg0.gettopkx  # Use this Kernel object for repeated calls
knly = prg.gettopky  # Use this Kernel object for repeated calls
sort = prg2.sortuints  # Use this Kernel object for repeated calls
knlx(queue, [rows], None, a_buf, x_sum_buf, y_idx_buf, np.uint32(topky), np.uint32(a.shape[1]))
knly(queue, [cols], None, b_buf, y_sum_buf, x_idx_buf, np.uint32(topkx), np.uint32(a.shape[1]))
sort(queue, [topkx], None, x_idx_buf, xs_idx_buf)
sort(queue, [topky], None, y_idx_buf, ys_idx_buf)

xsum = np.zeros(a.shape[0]).astype(np.float32)
ysum = np.zeros(b.shape[1]).astype(np.float32)
xidxcols = np.zeros(topkx).astype(np.uint32)
yidxcols = np.zeros(topky).astype(np.uint32)

In [828]:
cl.enqueue_copy(queue, xsum, x_sum_buf)
cl.enqueue_copy(queue, ysum, y_sum_buf)
cl.enqueue_copy(queue, xidxcols, xs_idx_buf)
cl.enqueue_copy(queue, yidxcols, ys_idx_buf)

<pyopencl._cl.NannyEvent at 0x7ff154140860>

In [829]:
yidxcols

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

In [830]:
a.sum(axis=1)

array([0.90138686, 0.92866236, 1.8153857 , 0.49910665, 1.9698949 ,
       2.1719928 , 1.364083  , 1.6108192 , 1.4910023 , 1.3507786 ,
       1.3591615 , 1.7076448 , 1.7041619 , 1.4870937 , 1.5615715 ,
       1.369683  , 2.2404904 , 1.5142088 , 1.2764547 , 1.6087517 ,
       0.63693   , 2.4837143 , 1.5524993 , 1.1814958 , 2.0415869 ,
       1.7602775 , 1.5345978 , 0.927653  , 1.6702607 , 1.6856377 ,
       1.880811  , 1.8519274 , 1.7408102 , 0.6228417 , 2.0433335 ,
       1.3911488 , 2.015398  , 2.6073415 , 2.7258937 , 1.7534881 ,
       2.402585  , 1.5583308 , 1.8388636 , 1.2620342 , 2.0273886 ,
       2.1319015 , 1.9676442 , 0.8557663 , 1.541347  , 1.7683222 ,
       1.0179391 , 2.2274227 , 1.0678444 , 1.8316495 , 1.9448184 ,
       2.5434635 , 0.7626756 , 0.77479506, 1.7022212 , 1.1181333 ,
       1.0256431 , 2.6158853 , 2.0894628 , 1.9667482 , 2.1732457 ,
       1.7323526 , 1.9612563 , 1.561662  , 1.4636836 , 1.4231989 ,
       1.5639279 , 1.2725452 , 1.4059178 , 1.5115663 , 1.10279

In [831]:
xsum

array([1.7687612 , 3.8299246 , 1.3313504 , 0.        , 0.4014198 ,
       1.3125952 , 1.0801463 , 1.5092187 , 0.        , 1.3219044 ,
       0.40355822, 1.4326475 , 2.549487  , 0.38626537, 1.6578603 ,
       1.138812  , 0.8584388 , 0.13806182, 0.61311436, 0.81969976,
       0.        , 2.3148003 , 0.78477144, 0.31231064, 1.1314069 ,
       0.8360748 , 0.2960138 , 1.6086118 , 0.        , 2.3569152 ,
       0.65107363, 0.6576212 , 1.2227106 , 0.56125623, 0.90429944,
       0.21665967, 0.3985087 , 0.        , 2.6146166 , 1.6308575 ,
       0.        , 0.9734551 , 0.107081  , 0.0804159 , 0.20392291,
       0.69010115, 0.39016965, 1.4641045 , 0.87043756, 0.        ,
       1.2277867 , 1.6929588 , 1.7205766 , 2.6219375 , 0.        ,
       1.7949355 , 1.0819263 , 0.36141634, 0.6526225 , 1.3092796 ,
       0.9797343 , 2.3917086 , 1.6724448 , 0.6832084 , 0.38345355,
       1.9219778 , 0.9254752 , 0.6608299 , 1.6344087 , 1.1944346 ,
       1.6143335 , 0.76337427, 2.9167967 , 0.4967291 , 0.68125

In [832]:
ysum

array([2.5932004 , 1.8013453 , 1.6908714 , 2.372854  , 3.1637416 ,
       1.6946486 , 1.0513332 , 2.1355958 , 0.94673526, 0.19298612],
      dtype=float32)

In [833]:
xidxcols

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint32)

In [834]:
yidxcols

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

# Matmult Transposed Dense (SPRNEW)

In [562]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [563]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void matmul0(__global  float* x,      // INPUT MATRIX DATA
                          __global  float* y,      // INPUT
                          __global  uint* xidx,   // INPUT YIDX
                          __global  uint* yidx,   // INPUT YIDX
                          __global  float* resdata,// OUT
                          __global  uint*  rescols,
                          __global  uint*  resnnzs,
                          uint topkx,
                          uint ellw,
                          uint isize,
                          uint msize,
                          uint osize
                          ) { // LOCAL SHARED BUFFER  
                          
      uint topky = get_global_size(0);
      uint gidx = yidx[get_global_id(0)]; // row
      
      for (uint gidy0 = 0; gidy0 < topkx; gidy0++) {
        uint gidy = xidx[gidy0];
        float ret = 0.0;
        uint i;
        for (i = 0; i < msize; i++) {
          uint xidx = i*isize+gidx;
          float xval = x[xidx];
          uint yidx = osize*i+gidy;
          float yval = y[yidx];
          ret += xval*yval;
          //if (gidx==0 && gidy==0)
          //  printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, ret, xidx, yidx);
        }
        //if (gidx==0&&gidy==0)
        //  printf("\\nsum:%.2f", ret);
        
        // add for 
        uint nnz = resnnzs[gidx];
        for (i = 0; i < nnz; i++) {
          if (rescols[i] >= gidy) {
            break;
          }
          for (uint j = nnz; j >= i; j--) {
            //resdata[j+1] = resdata[j];
          }
        }
        resdata[gidx * ellw + gidy0] = ret;
        rescols[gidx * ellw + gidy0] = gidy;
        resnnzs[gidx] += 1;
      }
    }""").build()

In [564]:
a.shape, b.shape

((32, 64), (64, 10))

In [565]:
rows = a.shape[0]

In [566]:
mult = mult.astype(np.float32)
mult.shape

(32, 10)

In [567]:
topky, topkx

(32, 10)

In [568]:
resdata_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,topkx])*4)
rescols_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,topkx])*4)
resnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=np.zeros(rows))

knl = prg.matmul0  # Use this Kernel object for repeated calls
knl(queue, [topky], None, a_buf, b_buf, xs_idx_buf, ys_idx_buf, resdata_buf, rescols_buf, resnnzs_buf, np.uint32(topkx), np.uint32(topkx), np.uint32(a.shape[0]), np.uint32(a.shape[1]), np.uint32(b.shape[1]))

resdata = np.zeros(a.shape[0]*topkx).astype(np.float32)
rescols = np.zeros(a.shape[0]*topkx).astype(np.uint32)
resnnzs = np.zeros(a.shape[0]).astype(np.uint32)
cl.enqueue_copy(queue, resdata, resdata_buf)
cl.enqueue_copy(queue, rescols, rescols_buf)
cl.enqueue_copy(queue, resnnzs, resnnzs_buf)

<pyopencl._cl.NannyEvent at 0x7ff1540fe450>

In [569]:
resdata

array([2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
       1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ,
       2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
       3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ,
       1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
       2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ,
       1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
       1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ,
       1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
       1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ,
       1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
       1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ,
       1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
       1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ,
       1.6907904 , 2.3283677 , 1.7894993 , 2.1847265 , 2.85510

In [570]:
rescols

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
       8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
       8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0,

In [571]:
resnnzs

array([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
      dtype=uint32)

In [572]:
res_np = to_dense(resdata, rescols, resnnzs, topkx, mult.shape)

In [573]:
(res_np-mult).sum()

5.960464477539062e-07

In [574]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [575]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [576]:
res_np

array([[2.05889988, 1.86816645, 4.81181717, 2.25293875, 2.10965443,
        1.3379606 , 2.90000367, 2.04914141, 2.58978033, 1.50316036],
       [2.73214602, 2.55654955, 3.71485209, 1.96740305, 2.78127933,
        3.90701485, 3.27775383, 2.29690385, 1.92299247, 2.65563965],
       [1.92874503, 3.16009569, 2.65028286, 1.81787264, 3.2470448 ,
        2.28542805, 2.63659263, 2.51668644, 1.07090104, 2.06057048],
       [1.06318498, 4.17379856, 2.22772551, 1.25156748, 4.18276644,
        1.29158247, 1.3805095 , 3.41988373, 1.02567685, 2.77550673],
       [1.18339908, 1.85575962, 3.05342436, 1.07224607, 3.49056172,
        1.643646  , 3.45331311, 3.8086679 , 0.65044206, 2.05637264],
       [1.04616213, 3.55668378, 3.21012974, 1.57240605, 2.28090835,
        1.85009015, 1.33866227, 4.39697981, 1.56214321, 1.77257848],
       [1.51664436, 2.50481415, 1.34177089, 3.52348685, 2.7652483 ,
        1.13953555, 2.25747991, 4.70517206, 2.24308538, 1.81945622],
       [1.69079041, 2.32836771, 1.7894992

In [577]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [578]:
res_np==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [579]:
res_np.shape

(32, 10)

In [580]:
mult.shape

(32, 10)

# Matmult Transposed Dense (SPR-T OUT NEW)

In [581]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [582]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void matmul0t(__global  float* x,      // INPUT MATRIX DATA
                          __global  float* y,      // INPUT
                          __global  uint* xidx,   // INPUT YIDX
                          __global  uint* yidx,   // INPUT YIDX
                          __global  float* resdata,// OUT
                          __global  uint*  rescols,
                          __global  uint*  resnnzs,
                          uint topky,
                          uint ellw,
                          uint isize,
                          uint msize,
                          uint osize
                          ) { // LOCAL SHARED BUFFER
      uint topkx = get_global_size(0);
      uint gidy = xidx[get_global_id(0)]; // row
      
      for (uint gidx0 = 0; gidx0 < topky; gidx0++) {
        uint gidx = yidx[gidx0];
        float ret = 0.0;
        uint i;
        for (i = 0; i < msize; i++) {
          uint xidx = i*isize+gidx;
          float xval = x[xidx];
          uint yidx = osize*i+gidy;
          float yval = y[yidx];
          ret += xval*yval;
          if (gidx==0 && gidy==0)
            printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, ret, gidx, gidy,i);
        }
        //if (gidx==0&&gidy==0)
        //  printf("\\nsum:%.2f", ret);
        
        // add for 
        uint nnz = resnnzs[gidx];
        for (i = 0; i < nnz; i++) {
          if (rescols[i] >= gidy) {
            break;
          }
          for (uint j = nnz; j >= i; j--) {
            //resdata[j+1] = resdata[j];
          }
        }
        resdata[gidy * ellw + gidx0] = ret;
        rescols[gidy * ellw + gidx0] = gidx;
        resnnzs[gidy] += 1;
      }
    }""").build()

In [583]:
a.shape, b.shape

((32, 64), (64, 10))

In [584]:
rows = a.shape[0]
cols = b.shape[1]

In [585]:
mult = mult.astype(np.float32)
mult.shape

(32, 10)

In [586]:
resdatat_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([cols,topky])*4)
rescolst_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([cols,topky])*4)
resnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=np.zeros(cols))

knl = prg.matmul0t  # Use this Kernel object for repeated calls
knl(queue, [topkx], None, a_buf, b_buf, xs_idx_buf, ys_idx_buf, resdatat_buf, rescolst_buf, resnnzst_buf, np.uint32(topky), np.uint32(topky), np.uint32(a.shape[0]), np.uint32(a.shape[1]), np.uint32(b.shape[1]))

resdatat = np.zeros(cols*topky).astype(np.float32)
rescolst = np.zeros(cols*topky).astype(np.uint32)
resnnzst = np.zeros(cols).astype(np.uint32)
cl.enqueue_copy(queue, resdatat, resdatat_buf)
cl.enqueue_copy(queue, rescolst, rescolst_buf)
cl.enqueue_copy(queue, resnnzst, resnnzst_buf)

<pyopencl._cl.NannyEvent at 0x7ff154110d10>


mult: 0.82 x 0.94 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 0/0
mult: 0.48 x 0.00 - 0.77  -- 0/0
mult: 0.91 x 0.00 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 0/0
mult: 0.88 x 0.31 - 1.05  -- 0/0
mult: 0.00 x 0.11 - 1.05  -- 0/0
mult: 0.00 x 0.07 - 1.05  -- 0/0
mult: 0.00 x 0.00 - 1.05  -- 0/0
mult: 0.87 x 0.16 - 1.19  -- 0/0
mult: 0.00 x 0.00 - 1.19  -- 0/0
mult: 0.00 x 0.64 - 1.19  -- 0/0
mult: 0.26 x 0.50 - 1.32  -- 0/0
mult: 0.36 x 0.00 - 1.32  -- 0/0
mult: 0.00 x 0.00 - 1.32  -- 0/0
mult: 0.41 x 0.00 - 1.32  -- 0/0
mult: 0.00 x 0.00 - 1.32  -- 0/0
mult: 0.00 x 0.00 - 1.32  -- 0/0
mult: 0.62 x 0.18 - 1.43  -- 0/0
mult: 0.47 x 0.00 - 1.43  -- 0/0
mult: 0.00 x 0.00 - 1.43  -- 0/0
mult: 0.00 x 0.00 - 1.43  -- 0/0
mult: 0.00 x 0.00 - 1.43  -- 0/0
mult: 0.00 x 0.00 - 1.43  -- 0/0
mult: 0.00 x 0.00 - 1.43  -- 0/0
mult: 0.00 x 0.00 - 1.43  -- 0/0
mult: 0.22 x 0.43 - 1.53  -- 0/0
mult: 0.59 x 0.12 - 1.60  -- 0/0
mult: 0.0

In [587]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [588]:
b.T

array([[0.9414536 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.31479523, 0.11005293, 0.06780045,
        0.        , 0.16214237, 0.        , 0.6445239 , 0.50232446,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.17604527, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.42811713, 0.11871076,
        0.        , 0.        , 0.        , 0.6127449 , 0.41137826,
        0.82947165, 0.04573116, 0.        , 0.        , 0.04850706,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.6346003 , 0.        , 0.        , 0.69351566, 0.9864582 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.2653583 , 0.        , 0.2264393 ,
        0.10093832, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.03283854, 0.        , 0.01996238,
        0.        , 0.2878976 , 0.40940058, 0.5290864 , 0.0

In [589]:
resdatat

array([2.0588999 , 2.732146  , 1.928745  , 1.063185  , 1.1833991 ,
       1.0461621 , 1.5166444 , 1.6907904 , 1.2141374 , 2.0491679 ,
       1.4538853 , 2.1386282 , 2.2865992 , 1.9966788 , 0.48399195,
       2.0954764 , 1.5846765 , 0.49207467, 2.5375338 , 2.0678236 ,
       1.8444016 , 2.6537714 , 0.6337883 , 1.3287122 , 0.8400132 ,
       0.95923513, 0.8192116 , 3.6204023 , 1.0327702 , 1.928727  ,
       1.5923332 , 1.0044551 , 1.8681664 , 2.5565495 , 3.1600957 ,
       4.1737986 , 1.8557596 , 3.5566838 , 2.5048141 , 2.3283677 ,
       2.9877534 , 3.8442254 , 3.52397   , 2.708146  , 3.9609175 ,
       2.9262302 , 2.1673245 , 2.4307942 , 2.661023  , 2.325397  ,
       2.4817715 , 1.8836694 , 2.2461612 , 1.1732535 , 3.3086834 ,
       2.5962925 , 2.0994346 , 2.4799027 , 1.8272312 , 1.9968115 ,
       2.8972733 , 1.5045325 , 3.816273  , 3.294622  , 4.811817  ,
       3.714852  , 2.6502829 , 2.2277255 , 3.0534244 , 3.2101297 ,
       1.3417709 , 1.7894993 , 3.1018186 , 5.2590375 , 3.42697

In [590]:
rescolst

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,
        2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,
        4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
       21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,
        6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
       23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,  6,  7,
        8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
       25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,
       10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
       27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
       29, 30, 31,  0,  1

In [591]:
resnnzst

array([32, 32, 32, 32, 32, 32, 32, 32, 32, 32], dtype=uint32)

In [592]:
res_np = to_dense(resdatat, rescolst, resnnzst, topky, mult.T.shape)
res_np.T

array([[2.05889988, 1.86816645, 4.81181717, 2.25293875, 2.10965443,
        1.3379606 , 2.90000367, 2.04914141, 2.58978033, 1.50316036],
       [2.73214602, 2.55654955, 3.71485209, 1.96740305, 2.78127933,
        3.90701485, 3.27775383, 2.29690385, 1.92299247, 2.65563965],
       [1.92874503, 3.16009569, 2.65028286, 1.81787264, 3.2470448 ,
        2.28542805, 2.63659263, 2.51668644, 1.07090104, 2.06057048],
       [1.06318498, 4.17379856, 2.22772551, 1.25156748, 4.18276644,
        1.29158247, 1.3805095 , 3.41988373, 1.02567685, 2.77550673],
       [1.18339908, 1.85575962, 3.05342436, 1.07224607, 3.49056172,
        1.643646  , 3.45331311, 3.8086679 , 0.65044206, 2.05637264],
       [1.04616213, 3.55668378, 3.21012974, 1.57240605, 2.28090835,
        1.85009015, 1.33866227, 4.39697981, 1.56214321, 1.77257848],
       [1.51664436, 2.50481415, 1.34177089, 3.52348685, 2.7652483 ,
        1.13953555, 2.25747991, 4.70517206, 2.24308538, 1.81945622],
       [1.69079041, 2.32836771, 1.7894992

In [593]:
(res_np.T-mult).sum()

5.960464477539062e-07

In [366]:
res_np.shape

(10, 32)

In [367]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [368]:
mult - res_np.T

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.

In [369]:
res_np.T==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [199]:
res_np.T

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 2.55654955, 3.71485209, 0.        , 2.78127933,
        0.        , 3.27775383, 2.29690385, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.       

# Matmult Transposed Dense (SPR)

In [200]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [201]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void matmul0(__global  float* x,      // INPUT MATRIX DATA
                          __global  float* y,      // INPUT
                          __global  float* resdata,// OUT
                          __global  uint*  rescols,
                          __global  uint*  resnnzs,
                          uint ellw,
                          uint msize,
                          uint osize
                          ) { // LOCAL SHARED BUFFER
      uint isize = get_global_size(0);
      int gidx = get_global_id(0); // row
      
      resnnzs[gidx] = 0;
      
      for (uint gidy = 0; gidy < osize; gidy++) {
        float ret = 0.0;
        uint i;
        for (i = 0; i < msize; i++) {
          uint xidx = i*isize+gidx;
          float xval = x[xidx];
          uint yidx = osize*i+gidy;
          float yval = y[yidx];
          ret += xval*yval;
          if (gidx==0 && gidy==0)
            printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, ret, xidx, yidx);
        }
        //if (gidx==0&&gidy==0)
        //  printf("\\nsum:%.2f", ret);
        
        // add for 
        uint nnz = resnnzs[gidx];
        for (i = 0; i < nnz; i++) {
          if (rescols[i] >= gidy) {
            break;
          }
          for (uint j = nnz; j >= i; j--) {
            //resdata[j+1] = resdata[j];
          }
        }
        resdata[gidx * ellw + i] = ret;
        rescols[gidx * ellw + i] = gidy;
        resnnzs[gidx] += 1;
      }
    }""").build()

In [202]:
a.shape, b.shape

((32, 64), (64, 10))

In [203]:
rows = a.shape[0]

In [204]:
mult = mult.astype(np.float32)
mult.shape

(32, 10)

In [205]:
resdata_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
rescols_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
resnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows])*4)

knl = prg.matmul0  # Use this Kernel object for repeated calls
knl(queue, [rows], None, a_buf, b_buf, resdata_buf, rescols_buf, resnnzs_buf, np.uint32(b.shape[1]), np.uint32(a.shape[1]), np.uint32(b.shape[1]))

resdata = np.zeros(a.shape[0]*b.shape[1]).astype(np.float32)
rescols = np.zeros(a.shape[0]*b.shape[1]).astype(np.uint32)
resnnzs = np.zeros(a.shape[0]).astype(np.uint32)
cl.enqueue_copy(queue, resdata, resdata_buf)
cl.enqueue_copy(queue, rescols, rescols_buf)
cl.enqueue_copy(queue, resnnzs, resnnzs_buf)


mult: 0.82 x 0.94 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 32/10
mult: 0.00 x 0.00 - 0.77  -- 64/20
mult: 0.48 x 0.00 - 0.77  -- 96/30
mult: 0.91 x 0.00 - 0.77  -- 128/40
mult: 0.00 x 0.00 - 0.77  -- 160/50
mult: 0.00 x 0.00 - 0.77  -- 192/60
mult: 0.88 x 0.31 - 1.05  -- 224/70
mult: 0.00 x 0.11 - 1.05  -- 256/80
mult: 0.00 x 0.07 - 1.05  -- 288/90
mult: 0.00 x 0.00 - 1.05  -- 320/100
mult: 0.87 x 0.16 - 1.19  -- 352/110
mult: 0.00 x 0.00 - 1.19  -- 384/120
mult: 0.00 x 0.64 - 1.19  -- 416/130
mult: 0.26 x 0.50 - 1.32  -- 448/140
mult: 0.36 x 0.00 - 1.32  -- 480/150
mult: 0.00 x 0.00 - 1.32  -- 512/160
mult: 0.41 x 0.00 - 1.32  -- 544/170
mult: 0.00 x 0.00 - 1.32  -- 576/180
mult: 0.00 x 0.00 - 1.32  -- 608/190
mult: 0.62 x 0.18 - 1.43  -- 640/200
mult: 0.47 x 0.00 - 1.43  -- 672/210
mult: 0.00 x 0.00 - 1.43  -- 704/220
mult: 0.00 x 0.00 - 1.43  -- 736/230
mult: 0.00 x 0.00 - 1.43  -- 768/240
mult: 0.00 x 0.00 - 1.43  -- 800/250
mult: 0.00 x 0.00 - 1.43  -- 832/260
mult: 0.00 x 0.0

<pyopencl._cl.NannyEvent at 0x7ff154226f90>

0 - 1.43  -- 864/270
mult: 0.22 x 0.43 - 1.53  -- 896/280
mult: 0.59 x 0.12 - 1.60  -- 928/290
mult: 0.00 x 0.00 - 1.60  -- 960/300
mult: 0.44 x 0.00 - 1.60  -- 992/310
mult: 0.47 x 0.00 - 1.60  -- 1024/320
mult: 0.69 x 0.61 - 2.02  -- 1056/330
mult: 0.00 x 0.41 - 2.02  -- 1088/340
mult: 0.00 x 0.83 - 2.02  -- 1120/350
mult: 0.38 x 0.05 - 2.04  -- 1152/360
mult: 0.81 x 0.00 - 2.04  -- 1184/370
mult: 0.98 x 0.00 - 2.04  -- 1216/380
mult: 0.45 x 0.05 - 2.06  -- 1248/390
mult: 0.00 x 0.00 - 2.06  -- 1280/400
mult: 0.00 x 0.00 - 2.06  -- 1312/410
mult: 0.29 x 0.00 - 2.06  -- 1344/420
mult: 0.00 x 0.00 - 2.06  -- 1376/430
mult: 0.04 x 0.00 - 2.06  -- 1408/440
mult: 0.00 x 0.63 - 2.06  -- 1440/450
mult: 0.28 x 0.00 - 2.06  -- 1472/460
mult: 0.00 x 0.00 - 2.06  -- 1504/470
mult: 0.00 x 0.69 - 2.06  -- 1536/480
mult: 0.00 x 0.99 - 2.06  -- 1568/490
mult: 0.42 x 0.00 - 2.06  -- 1600/500
mult: 0.00 x 0.00 - 2.06  -- 1632/510
mult: 0.00 x 0.00 - 2.06  -- 1664/520
mult: 0.00 x 0.00 - 2.06  -- 1696

In [206]:
resdata

array([2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
       1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ,
       2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
       3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ,
       1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
       2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ,
       1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
       1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ,
       1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
       1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ,
       1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
       1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ,
       1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
       1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ,
       1.6907904 , 2.3283677 , 1.7894993 , 2.1847265 , 2.85510

In [207]:
rescols

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
       8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
       8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0,

In [208]:
resnnzs

array([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
      dtype=uint32)

In [209]:
res_np = to_dense(resdata, rescols, resnnzs, b.shape[1], mult.shape)

In [210]:
(res_np-mult).sum()

5.960464477539062e-07

In [211]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [212]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [213]:
res_np

array([[2.05889988, 1.86816645, 4.81181717, 2.25293875, 2.10965443,
        1.3379606 , 2.90000367, 2.04914141, 2.58978033, 1.50316036],
       [2.73214602, 2.55654955, 3.71485209, 1.96740305, 2.78127933,
        3.90701485, 3.27775383, 2.29690385, 1.92299247, 2.65563965],
       [1.92874503, 3.16009569, 2.65028286, 1.81787264, 3.2470448 ,
        2.28542805, 2.63659263, 2.51668644, 1.07090104, 2.06057048],
       [1.06318498, 4.17379856, 2.22772551, 1.25156748, 4.18276644,
        1.29158247, 1.3805095 , 3.41988373, 1.02567685, 2.77550673],
       [1.18339908, 1.85575962, 3.05342436, 1.07224607, 3.49056172,
        1.643646  , 3.45331311, 3.8086679 , 0.65044206, 2.05637264],
       [1.04616213, 3.55668378, 3.21012974, 1.57240605, 2.28090835,
        1.85009015, 1.33866227, 4.39697981, 1.56214321, 1.77257848],
       [1.51664436, 2.50481415, 1.34177089, 3.52348685, 2.7652483 ,
        1.13953555, 2.25747991, 4.70517206, 2.24308538, 1.81945622],
       [1.69079041, 2.32836771, 1.7894992

In [214]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [215]:
res_np==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [216]:
res_np.shape

(32, 10)

In [217]:
mult.shape

(32, 10)

# Matmult Transposed Dense (SPR-T OUT)

In [218]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [219]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)

prg = cl.Program(ctx, """
    // multilplies x TRANSPOSED by y (dense-dense)
    __kernel void matmul0t(__global  float* x,      // INPUT MATRIX DATA
                          __global  float* y,      // INPUT
                          __global  float* resdata,// OUT
                          __global  uint*  rescols,
                          __global  uint*  resnnzs,
                          uint ellw,
                          uint msize,
                          uint isize
                          ) { // LOCAL SHARED BUFFER
      uint osize = get_global_size(0);
      int gidy = get_global_id(0); // row
      
      resnnzs[gidy] = 0;
      for (uint gidx = 0; gidx < isize; gidx++) {
        float ret = 0.0;
        uint i;
        for (i = 0; i < msize; i++) {
          uint xidx = i*isize+gidx;
          float xval = x[xidx];
          uint yidx = osize*i+gidy;
          float yval = y[yidx];
          ret += xval*yval;
          if (gidx==0 && gidy==0)
            printf("\\nmult: %.2f x %.2f - %.2f  -- %i/%i", xval, yval, ret, xidx, yidx);
        }
        //if (gidx==0&&gidy==0)
        //  printf("\\nsum:%.2f", ret);
        
        // add for 
        uint nnz = resnnzs[gidy];
        for (i = 0; i < nnz; i++) {
          if (rescols[i] >= gidx) {
            break;
          }
          for (uint j = nnz; j >= i; j--) {
            //resdata[j+1] = resdata[j];
          }
        }
        resdata[gidy * ellw + i] = ret;
        rescols[gidy * ellw + i] = gidx;
        resnnzs[gidy] += 1;
      }
    }""").build()

In [220]:
a.shape, b.shape

((32, 64), (64, 10))

In [221]:
rows = a.shape[0]
cols = b.shape[1]

In [222]:
mult = mult.astype(np.float32)
mult.shape

(32, 10)

In [223]:
resdatat_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([cols,rows])*4)
rescolst_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([cols,rows])*4)
resnnzst_buf = cl.Buffer(ctx, mf.READ_WRITE, cols*4)

knl = prg.matmul0t  # Use this Kernel object for repeated calls
knl(queue, [cols], None, a_buf, b_buf, resdatat_buf, rescolst_buf, resnnzst_buf, np.uint32(rows), np.uint32(a.shape[1]), np.uint32(rows))

resdatat = np.zeros(cols*rows).astype(np.float32)
rescolst = np.zeros(cols*rows).astype(np.uint32)
resnnzst = np.zeros(cols).astype(np.uint32)
cl.enqueue_copy(queue, resdatat, resdatat_buf)
cl.enqueue_copy(queue, rescolst, rescolst_buf)
cl.enqueue_copy(queue, resnnzst, resnnzst_buf)


mult: 0.82 x 0.94 - 0.77  -- 0/0
mult: 0.00 x 0.00 - 0.77  -- 32/10
mult: 0.00 x 0.00 - 0.77  -- 64/20
mult: 0.48 x 0.00 - 0.77  -- 96/30
mult: 0.91 x 0.00 - 0.77  -- 128/40
mult: 0.00 x 0.00 - 0.77  -- 160/50
mult: 0.00 x 0.00 - 0.77  -- 192/60
mult: 0.88 x 0.31 - 1.05  -- 224/70
mult: 0.00 x 0.11 - 1.05  -- 256/80
mult: 0.00 x 0.07 - 1.05  -- 288/90
mult: 0.00 x 0.00 - 1.05  -- 320/100
mult: 0.87 x 0.16 - 1.19  -- 352/110
mult: 0.00 x 0.00 - 1.19  -- 384/120
mult: 0.00 x 0.64 - 1.19  -- 416/130
mult: 0.26 x 0.50 - 1.32  -- 448/140
mult: 0.36 x 0.00 - 1.32  -- 480/150
mult: 0.00 x 0.00 - 1.32  -- 512/160
mult: 0.41 x 0.00 - 1.32  -- 544/170
mult: 0.00 x 0.00 - 1.32  -- 576/180
mult: 0.00 x 0.00 - 1.32  -- 608/190
mult: 0.62 x 0.18 - 1.43  -- 640/200
mult: 0.47 x 0.00 - 1.43  -- 672/210
mult: 0.00 x 0.00 - 1.43  -- 704/220
mult: 0.00 x 0.00 - 1.43  -- 736/230
mult: 0.00 x 0.00 - 1.43  -- 768/240
mult: 0.00 x 0.00 - 1.43  -- 800/250
mult: 0.00 x 0.00 - 1.43  -- 832/260
mult: 0.00 x 0.0

<pyopencl._cl.NannyEvent at 0x7ff1542101d0>

0 - 1.43  -- 864/270
mult: 0.22 x 0.43 - 1.53  -- 896/280
mult: 0.59 x 0.12 - 1.60  -- 928/290
mult: 0.00 x 0.00 - 1.60  -- 960/300
mult: 0.44 x 0.00 - 1.60  -- 992/310
mult: 0.47 x 0.00 - 1.60  -- 1024/320
mult: 0.69 x 0.61 - 2.02  -- 1056/330
mult: 0.00 x 0.41 - 2.02  -- 1088/340
mult: 0.00 x 0.83 - 2.02  -- 1120/350
mult: 0.38 x 0.05 - 2.04  -- 1152/360
mult: 0.81 x 0.00 - 2.04  -- 1184/370
mult: 0.98 x 0.00 - 2.04  -- 1216/380
mult: 0.45 x 0.05 - 2.06  -- 1248/390
mult: 0.00 x 0.00 - 2.06  -- 1280/400
mult: 0.00 x 0.00 - 2.06  -- 1312/410
mult: 0.29 x 0.00 - 2.06  -- 1344/420
mult: 0.00 x 0.00 - 2.06  -- 1376/430
mult: 0.04 x 0.00 - 2.06  -- 1408/440
mult: 0.00 x 0.63 - 2.06  -- 1440/450
mult: 0.28 x 0.00 - 2.06  -- 1472/460
mult: 0.00 x 0.00 - 2.06  -- 1504/470
mult: 0.00 x 0.69 - 2.06  -- 1536/480
mult: 0.00 x 0.99 - 2.06  -- 1568/490
mult: 0.42 x 0.00 - 2.06  -- 1600/500
mult: 0.00 x 0.00 - 2.06  -- 1632/510
mult: 0.00 x 0.00 - 2.06  -- 1664/520
mult: 0.00 x 0.00 - 2.06  -- 1696

In [224]:
resdatat

array([2.0588999 , 2.732146  , 1.928745  , 1.063185  , 1.1833991 ,
       1.0461621 , 1.5166444 , 1.6907904 , 1.2141374 , 2.0491679 ,
       1.4538853 , 2.1386282 , 2.2865992 , 1.9966788 , 0.48399195,
       2.0954764 , 1.5846765 , 0.49207467, 2.5375338 , 2.0678236 ,
       1.8444016 , 2.6537714 , 0.6337883 , 1.3287122 , 0.8400132 ,
       0.95923513, 0.8192116 , 3.6204023 , 1.0327702 , 1.928727  ,
       1.5923332 , 1.0044551 , 1.8681664 , 2.5565495 , 3.1600957 ,
       4.1737986 , 1.8557596 , 3.5566838 , 2.5048141 , 2.3283677 ,
       2.9877534 , 3.8442254 , 3.52397   , 2.708146  , 3.9609175 ,
       2.9262302 , 2.1673245 , 2.4307942 , 2.661023  , 2.325397  ,
       2.4817715 , 1.8836694 , 2.2461612 , 1.1732535 , 3.3086834 ,
       2.5962925 , 2.0994346 , 2.4799027 , 1.8272312 , 1.9968115 ,
       2.8972733 , 1.5045325 , 3.816273  , 3.294622  , 4.811817  ,
       3.714852  , 2.6502829 , 2.2277255 , 3.0534244 , 3.2101297 ,
       1.3417709 , 1.7894993 , 3.1018186 , 5.2590375 , 3.42697

In [225]:
rescolst

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,
        2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,
        4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
       21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,
        6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
       23, 24, 25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,  6,  7,
        8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
       25, 26, 27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,
       10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
       27, 28, 29, 30, 31,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
       29, 30, 31,  0,  1

In [226]:
resnnzst

array([32, 32, 32, 32, 32, 32, 32, 32, 32, 32], dtype=uint32)

In [227]:
res_np = to_dense(resdatat, rescolst, resnnzst, a.shape[0], mult.T.shape)
res_np.T

array([[2.05889988, 1.86816645, 4.81181717, 2.25293875, 2.10965443,
        1.3379606 , 2.90000367, 2.04914141, 2.58978033, 1.50316036],
       [2.73214602, 2.55654955, 3.71485209, 1.96740305, 2.78127933,
        3.90701485, 3.27775383, 2.29690385, 1.92299247, 2.65563965],
       [1.92874503, 3.16009569, 2.65028286, 1.81787264, 3.2470448 ,
        2.28542805, 2.63659263, 2.51668644, 1.07090104, 2.06057048],
       [1.06318498, 4.17379856, 2.22772551, 1.25156748, 4.18276644,
        1.29158247, 1.3805095 , 3.41988373, 1.02567685, 2.77550673],
       [1.18339908, 1.85575962, 3.05342436, 1.07224607, 3.49056172,
        1.643646  , 3.45331311, 3.8086679 , 0.65044206, 2.05637264],
       [1.04616213, 3.55668378, 3.21012974, 1.57240605, 2.28090835,
        1.85009015, 1.33866227, 4.39697981, 1.56214321, 1.77257848],
       [1.51664436, 2.50481415, 1.34177089, 3.52348685, 2.7652483 ,
        1.13953555, 2.25747991, 4.70517206, 2.24308538, 1.81945622],
       [1.69079041, 2.32836771, 1.7894992

In [228]:
(res_np.T-mult).sum()

5.960464477539062e-07

In [229]:
res_np.shape

(10, 32)

In [230]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [231]:
mult - res_np.T

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.

In [232]:
res_np.T==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

# Matmult Dense Transposed2

In [233]:
b_buf2 = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a)

prg = cl.Program(ctx, """
    // multilplies x by y WITH Y TRANSPOSED INDEXING
    __kernel void matmul0(__global  float* x,     // INPUT MATRIX DATA
                          __global  float* y,    // INPUT
                          __global  float* res,    // INPUT
                          uint msize,
                          uint osize
                          ) { // LOCAL SHARED BUFFER
      uint isize = get_global_size(0);
      // osize = get_global_size(1);
      int gidx = get_global_id(0); // col
      // int gidy = get_global_id(1); // row

      for (uint gidy = 0; gidy < osize; gidy++) {
        float ret = 0.0;
        for (uint i = 0; i < msize; i++) {
          ret += x[gidx*msize+i]*y[i*osize+gidy];
          if (gidx==0 && gidy==0)
            printf("\\nmult: %.2f x %.2f - %.2f", x[gidx*msize+i],y[i*msize+gidy], ret);
        }

        //if (gidx==0&&gidy==0)
        //  printf("\\nsum:%.2f", ret);
        res[gidx * osize + gidy] = ret;
      }
    }""").build()

In [234]:
a.shape, b.shape

((32, 64), (64, 10))

In [235]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [236]:
rows = a.shape[0]

In [237]:
mult = mult.astype(np.float32)

In [238]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.matmul0  # Use this Kernel object for repeated calls
knl(queue, [rows], None, a_buf, b_buf2, res_buf, np.uint32(a.shape[1]), np.uint32(b.shape[1]))

res_np = np.zeros([rows,b.shape[1]]).astype(np.float32)
cl.enqueue_copy(queue, res_np, res_buf)


mult: 0.82 x 0.94 - 0.77
mult: 0.00 x 0.71 - 0.77
mult: 0.00 x 0.00 - 0.77
mult: 0.48 x 0.00 - 0.77
mult: 0.91 x 0.00 - 0.77
mult: 0.00 x 0.00 - 0.77
mult: 0.00 x 0.00 - 0.77
mult: 0.88 x 0.00 - 1.05
mult: 0.00 x 0.89 - 1.05
mult: 0.00 x 0.07 - 1.05
mult: 0.00 x 2.06 - 1.05
mult: 0.87 x 2.77 - 1.19
mult: 0.00 x 2.01 - 1.19
mult: 0.00 x 1.34 - 1.19
mult: 0.26 x 3.57 - 1.32
mult: 0.36 x -1.38 - 1.32
mult: 0.00 x 12.77 - 1.32
mult: 0.41 x 0.92 - 1.32
mult: 0.00 x 8.32 - 1.32
mult: 0.00 x 1.25 - 1.32
mult: 0.62 x 0.00 - 1.43
mult: 0.47 x -3.80 - 1.43
mult: 0.00 x 0.00 - 1.43
mult: 0.00 x 1.13 - 1.43
mult: 0.00 x 0.00 - 1.43
mult: 0.00 x -1.41 - 1.43
mult: 0.00 x 0.00 - 1.43
mult: 0.00 x -0.32 - 1.43
mult: 0.22 x 2.06 - 1.53
mult: 0.59 x 2.77 - 1.60
mult: 0.00 x 2.01 - 1.60
mult: 0.44 x 1.34 - 1.60
mult: 0.47 x 3.57 - 1.60
mult: 0.69 x 0.00 - 2.02
mult: 0.00 x 0.00 - 2.02
mult: 0.00 x 0.00 - 2.02
mult: 0.38 x 2.06 - 2.04
mult: 0.81 x 4.81 - 2.04
mult: 0.98 x 2.11 - 2.04
mult: 0.45 x 2.90 -

<pyopencl._cl.NannyEvent at 0x7ff1541a6180>

 2.06
mult: 0.00 x 2.59 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.29 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.04 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.28 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.42 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.39 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.23 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06
mult: 0.00 x 0.00 - 2.06

In [239]:
(res_np-mult).sum()

5.9604645e-07

In [240]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [241]:
b.T

array([[0.9414536 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.31479523, 0.11005293, 0.06780045,
        0.        , 0.16214237, 0.        , 0.6445239 , 0.50232446,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.17604527, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.42811713, 0.11871076,
        0.        , 0.        , 0.        , 0.6127449 , 0.41137826,
        0.82947165, 0.04573116, 0.        , 0.        , 0.04850706,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.6346003 , 0.        , 0.        , 0.69351566, 0.9864582 ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.2653583 , 0.        , 0.2264393 ,
        0.10093832, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.03283854, 0.        , 0.01996238,
        0.        , 0.2878976 , 0.40940058, 0.5290864 , 0.0

In [242]:
a[0]

array([0.82272685, 0.        , 0.        , 0.47684366, 0.9107656 ,
       0.        , 0.        , 0.87656933, 0.        , 0.        ,
       0.        , 0.87043756, 0.        , 0.        , 0.26380414,
       0.36056653, 0.        , 0.40566143, 0.        , 0.        ,
       0.62320876, 0.474961  , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.21526212, 0.5885546 ,
       0.        , 0.44464737, 0.47268388, 0.69147617, 0.        ,
       0.        , 0.38345355, 0.80597454, 0.97980535, 0.44897598,
       0.        , 0.        , 0.28649956, 0.        , 0.04318999,
       0.        , 0.2821837 , 0.        , 0.        , 0.        ,
       0.4212168 , 0.        , 0.        , 0.        , 0.39205793,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.22595333, 0.        , 0.        ], dtype=float32)

In [243]:
b.T[0]

array([0.9414536 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.31479523, 0.11005293, 0.06780045,
       0.        , 0.16214237, 0.        , 0.6445239 , 0.50232446,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.17604527, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.42811713, 0.11871076,
       0.        , 0.        , 0.        , 0.6127449 , 0.41137826,
       0.82947165, 0.04573116, 0.        , 0.        , 0.04850706,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.6346003 , 0.        , 0.        , 0.69351566, 0.9864582 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.2653583 , 0.        , 0.2264393 ,
       0.10093832, 0.        , 0.        , 0.        ], dtype=float32)

In [244]:
res_buf

<pyopencl._cl.Buffer at 0x7ff1541a3bd0>

In [245]:
res_np

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [246]:
a.dot(b)

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [247]:
res_np==mult

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

In [248]:
res_np.shape

(32, 10)

In [249]:
mult.shape

(32, 10)

## Weight update kernel

In [250]:
bs = 4

In [251]:
a

array([[0.82272685, 0.        , 0.        , ..., 0.22595333, 0.        ,
        0.        ],
       [0.57626116, 0.2289813 , 0.16428949, ..., 0.27008578, 0.        ,
        0.9652075 ],
       [0.        , 0.48130327, 0.        , ..., 0.        , 0.12290299,
        0.        ],
       ...,
       [0.8716895 , 0.10489579, 0.        , ..., 0.        , 0.39542586,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.5132175 , 0.9358725 ,
        0.37984058],
       [0.        , 0.6832084 , 0.9385753 , ..., 0.02876937, 0.        ,
        0.        ]], dtype=float32)

In [252]:
dim = 8

x = np.random.rand(bs,dim).astype(np.float32)
y = np.random.rand(bs,dim).astype(np.float32)
x.shape,y.shape, topk

((4, 8), (4, 8), 5)

x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             __global  float* xout,    // INPUT
                             uint topk,
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      uint bs = get_global_size(1);
      uint gid2 = get_global_id(1);

      uint idx = n*gid2+gid;

      float valx = x[idx];
      float valy = y[idx];
      uint posx = 0;
      uint posy = 0;
      for (uint i = 0; i < n; i++) {
        uint idx2 = n*gid2+i;
        float tempval = x[idx2];
        float tempval2 = y[idx2];
        bool larger = tempval > valx;
        bool larger2 = tempval2 > valy;

        barrier(CLK_GLOBAL_MEM_FENCE);
        posx += (larger)?1:0;
        posy += (larger2)?1:0;
        barrier(CLK_GLOBAL_MEM_FENCE);
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
      //printf("posx:%i", posx);
      if (posx < topk) {
        xoutidx[posx+topk*gid2] = gid;
      }
      if (posy < topk) {
        youtidx[posy+topk*gid2] = gid;
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
      if (gid < topk) {
        for (uint j=0; j<topk; j++) {
          float res = x[xoutidx[gid+topk*gid2]+gid2*n] * y[youtidx[j+topk*gid2]+gid2*n];
          //printf("\\nJ:%i  gid:%i", j, gid);
          //printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
          barrier(CLK_GLOBAL_MEM_FENCE);
          xout[gid2*topk*topk+j*topk+gid] = res;
          barrier(CLK_GLOBAL_MEM_FENCE);
          
        }
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
    }""").build()

In [253]:
x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
#x_cp_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
#x_idx_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
#y_idx_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             __global  float* xout,    // INPUT
                             uint topk,
                             uint bs,
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      //uint bs = get_global_size(1);
      //uint gid2 = get_global_id(1);

      for (uint gid2=0; gid2<bs; gid2++){
        uint idx = n*gid2+gid;

        float valx = x[idx];
        float valy = y[idx];
        uint posx = 0;
        uint posy = 0;
        for (uint i = 0; i < n; i++) {
          uint idx2 = n*gid2+i;
          float tempval = x[idx2];
          float tempval2 = y[idx2];
          bool larger = tempval > valx;
          bool larger2 = tempval2 > valy;

          posx += (larger)?1:0;
          posy += (larger2)?1:0;
        }
        //printf("posx:%i", posx);
        if (posx < topk) {
        xoutidx[posx+topk*gid2] = gid;
        }
        if (posy < topk) {
          youtidx[posy+topk*gid2] = gid;
        }
      }
      for (uint gid2=0; gid2<bs; gid2++){
        if (gid < topk) {
          for (uint j=0; j<topk; j++) {
            float res = x[xoutidx[gid+topk*gid2]+gid2*n] * y[youtidx[j+topk*gid2]+gid2*n];
            //printf("\\nJ:%i  gid:%i", j, gid);
            //printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
            //barrier(CLK_GLOBAL_MEM_FENCE);
            xout[gid2*topk*topk+j*topk+gid] = res;
          }
        }
      }
    }""").build()

In [254]:
knl = prg.genwupdate2  # Use this Kernel object for repeated calls
evt = knl(queue, [dim], None, x_buf, y_buf, x_cp_buf, np.uint32(topk), np.uint32(bs), x_idx_buf, y_idx_buf)

#evt.wait()
resx = np.zeros(bs*topk*topk).astype(np.float32)
resxidx = np.zeros(bs*topk).astype(np.uint32)
resyidx = np.zeros(bs*topk).astype(np.uint32)

cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

<pyopencl._cl.NannyEvent at 0x7ff1541bfa90>

knl(queue, [dim], None, y_buf, x_buf, x_cp_buft, np.uint32(topk), np.uint32(bs), x_idx_buft, y_idx_buft)

#evt.wait()
resx = np.zeros(bs*topk*topk).astype(np.float32)
resxidx = np.zeros(bs*topk).astype(np.uint32)
resyidx = np.zeros(bs*topk).astype(np.uint32)

cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

In [255]:
x

array([[0.9823076 , 0.5675079 , 0.4899438 , 0.42544496, 0.21516374,
        0.0106969 , 0.73530084, 0.91274136],
       [0.87523365, 0.69470036, 0.227531  , 0.3580261 , 0.07031833,
        0.60395014, 0.20008752, 0.15327545],
       [0.6115412 , 0.8403613 , 0.06753407, 0.8783888 , 0.57819855,
        0.02901758, 0.3717504 , 0.12619837],
       [0.11297344, 0.17878655, 0.75727826, 0.9547149 , 0.78555155,
        0.08058742, 0.7716587 , 0.20285805]], dtype=float32)

In [256]:
y

array([[0.600012  , 0.51608175, 0.40237623, 0.6606863 , 0.20992446,
        0.47771522, 0.53684914, 0.6093566 ],
       [0.9005503 , 0.17203258, 0.48102152, 0.5812099 , 0.6695809 ,
        0.19274941, 0.44359848, 0.79617304],
       [0.4759505 , 0.08721701, 0.12529917, 0.08470217, 0.64269125,
        0.29062477, 0.52881134, 0.22973938],
       [0.39131835, 0.75165856, 0.67415816, 0.98627734, 0.688432  ,
        0.6221831 , 0.22869836, 0.81559634]], dtype=float32)

In [257]:
x.shape, y.shape

((4, 8), (4, 8))

In [258]:
resx

array([0.6489972 , 0.60303575, 0.4858032 , 0.37494472, 0.32369918,
       0.5985756 , 0.55618495, 0.4480604 , 0.34581468, 0.2985505 ,
       0.58939636, 0.54765576, 0.44118932, 0.34051156, 0.29397216,
       0.527351  , 0.49000442, 0.39474562, 0.30466613, 0.2630259 ,
       0.50695103, 0.47104916, 0.37947536, 0.29288048, 0.25285107,
       0.7881919 , 0.6256126 , 0.5438875 , 0.3224205 , 0.20490311,
       0.6968374 , 0.5531017 , 0.48084882, 0.28505072, 0.18115404,
       0.5860397 , 0.46515808, 0.40439346, 0.23972742, 0.15235041,
       0.50869447, 0.40376672, 0.3510218 , 0.20808831, 0.13224328,
       0.42100623, 0.3341658 , 0.290513  , 0.17221825, 0.10944731,
       0.5645328 , 0.5400929 , 0.3930322 , 0.37160316, 0.23892073,
       0.46450198, 0.4443926 , 0.32338992, 0.30575794, 0.19658583,
       0.4180696 , 0.39997038, 0.29106334, 0.2751939 , 0.1769348 ,
       0.25528154, 0.24422981, 0.17772903, 0.16803882, 0.10803988,
       0.20180051, 0.19306408, 0.1404951 , 0.13283499, 0.08540

In [259]:
resx.reshape(bs,topk,topk)

array([[[0.6489972 , 0.60303575, 0.4858032 , 0.37494472, 0.32369918],
        [0.5985756 , 0.55618495, 0.4480604 , 0.34581468, 0.2985505 ],
        [0.58939636, 0.54765576, 0.44118932, 0.34051156, 0.29397216],
        [0.527351  , 0.49000442, 0.39474562, 0.30466613, 0.2630259 ],
        [0.50695103, 0.47104916, 0.37947536, 0.29288048, 0.25285107]],

       [[0.7881919 , 0.6256126 , 0.5438875 , 0.3224205 , 0.20490311],
        [0.6968374 , 0.5531017 , 0.48084882, 0.28505072, 0.18115404],
        [0.5860397 , 0.46515808, 0.40439346, 0.23972742, 0.15235041],
        [0.50869447, 0.40376672, 0.3510218 , 0.20808831, 0.13224328],
        [0.42100623, 0.3341658 , 0.290513  , 0.17221825, 0.10944731]],

       [[0.5645328 , 0.5400929 , 0.3930322 , 0.37160316, 0.23892073],
        [0.46450198, 0.4443926 , 0.32338992, 0.30575794, 0.19658583],
        [0.4180696 , 0.39997038, 0.29106334, 0.2751939 , 0.1769348 ],
        [0.25528154, 0.24422981, 0.17772903, 0.16803882, 0.10803988],
        [0.20180

In [260]:
resxidx

array([0, 7, 6, 1, 2, 0, 1, 5, 3, 2, 3, 1, 0, 4, 6, 3, 4, 6, 2, 7],
      dtype=uint32)

In [261]:
resyidx

array([3, 7, 0, 6, 1, 0, 7, 4, 3, 2, 4, 6, 0, 5, 7, 3, 7, 1, 4, 2],
      dtype=uint32)

In [262]:
idx = 1
xy0 = x[idx].reshape(dim,1)*y[idx]
xy0.shape

(8, 8)

In [263]:
xy0[3][7]

0.28505072

## Weight update kernel new

In [264]:
b

array([[0.9414536 , 0.        , 0.        , 0.90892494, 0.        ,
        0.        , 0.05297933, 0.        , 0.7021168 , 0.        ],
       [0.        , 0.        , 0.14630048, 0.24825948, 0.        ,
        0.        , 0.        , 0.7792777 , 0.31232256, 0.        ],
       [0.        , 0.03283854, 0.27791733, 0.8543185 , 0.        ,
        0.        , 0.8433972 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33419287, 0.19840643, 0.        ,
        0.3856511 , 0.23348303, 0.        , 0.        , 0.        ],
       [0.        , 0.01996238, 0.9816661 , 0.        , 0.        ,
        0.        , 0.        , 0.5056456 , 0.        , 0.01917347],
       [0.        , 0.        , 0.        , 0.        , 0.49159616,
        0.0810385 , 0.64290637, 0.16045028, 0.        , 0.        ],
       [0.        , 0.2878976 , 0.        , 0.        , 0.7058473 ,
        0.        , 0.41645932, 0.44598925, 0.        , 0.        ],
       [0.31479523, 0.40940058, 0.       

In [265]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [266]:
at

array([[0.82272685, 0.57626116, 0.        , ..., 0.8716895 , 0.        ,
        0.        ],
       [0.        , 0.2289813 , 0.48130327, ..., 0.10489579, 0.        ,
        0.6832084 ],
       [0.        , 0.16428949, 0.        , ..., 0.        , 0.        ,
        0.9385753 ],
       ...,
       [0.22595333, 0.27008578, 0.        , ..., 0.        , 0.5132175 ,
        0.02876937],
       [0.        , 0.        , 0.12290299, ..., 0.39542586, 0.9358725 ,
        0.        ],
       [0.        , 0.9652075 , 0.        , ..., 0.        , 0.37984058,
        0.        ]], dtype=float32)

In [267]:
c

array([[0.82272685, 0.57626116, 0.        , ..., 0.8716895 , 0.        ,
        0.        ],
       [0.        , 0.2289813 , 0.48130327, ..., 0.10489579, 0.        ,
        0.68320841],
       [0.        , 0.16428949, 0.        , ..., 0.        , 0.        ,
        0.93857533],
       ...,
       [0.22595333, 0.27008578, 0.        , ..., 0.        , 0.51321751,
        0.02876937],
       [0.        , 0.        , 0.12290299, ..., 0.39542586, 0.9358725 ,
        0.        ],
       [0.        , 0.96520752, 0.        , ..., 0.        , 0.37984058,
        0.        ]])

a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate3(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             uint topk,
                             uint msize,
                             __global  float* xout,    // INPUT
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      uint bs = get_global_size(1);
      uint gid2 = get_global_id(1);



      uint idx = n*gid2+gid;
      float valx = x[idx];
      uint posx = 0;
      for (uint i = 0; i < n; i++) {
        uint idx2 = n*gid2+i;
        float tempval = x[idx2];
        bool larger = tempval > valx;
        posx += (larger)?1:0;
      }
      
      uint idxy = n*gid2+gid;
      float valy = y[idx];
      uint posy = 0;
      for (uint i = 0; i < n; i++) {
        uint idx2 = n*gid2+i;
        float tempval2 = y[idx2];
        bool larger2 = tempval2 > valy;
        posy += (larger2)?1:0;
      }
      
      if (posx < topk) {
        xoutidx[posx+topk*gid2] = idx;
      }
      if (posy < topk) {
        youtidx[posy+topk*gid2] = idxy;
      }
      return;
      if (gid < topk) {
        for (uint j=0; j<topk; j++) {
          float res = x[xoutidx[gid+topk*gid2]+gid2*msize] * y[youtidx[j+topk*gid2]+gid2*msize];
          printf("\\nJ:%i  gid:(%i,%i)", j, gid, gid2);
          printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
          //barrier(CLK_GLOBAL_MEM_FENCE);
          xout[gid2*topk*topk+j*topk+gid] = res;
        }
      }
    }""").build()

In [268]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
x_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, a.shape[0]*4)
y_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.shape[1]*4)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate3(__global  float* x,     // INPUT MATRIX DATA
                              __global  float* y,    // INPUT
                              __global  float* xsum,    // INPUT
                              __global  float* ysum,    // INPUT
                              uint isize,
                              uint msize,
                              uint osize,
                              uint topk,
                              __global  float* xout,
                              __global  uint* xoutidx,   
                              __global  uint* youtidx    
                              ) { 
      uint gid = get_global_id(0);
      
      // get for a: sum axis0  b: sum axis1 then get topk
      ///////////////////////////////////////////////////
      if (gid < isize) {
        xsum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = x[i*isize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, i*msize+gid);
          //}
          xsum[gid] += val;
        }
        
        float valx = xsum[gid];
        uint posx = 0;
        for (uint i = 0; i < isize; i++) {
          float tempval = xsum[i];
          bool larger = tempval > valx;
          posx += (larger)?1:0;
        }
        if (posx < topk) {
          xoutidx[posx] = gid;
        }
      }
      
      if (gid < osize) {
        ysum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = y[i*osize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, gid*osize+i);
          //}
          ysum[gid] += val;
        }
        
        float valy = ysum[gid];
        uint posy = 0;
        for (uint i = 0; i < osize; i++) {
          float tempval = ysum[i];
          bool larger = tempval > valy;
          posy += (larger)?1:0;
        }
      
        if (posy < topk) {
          youtidx[posy] = gid;
        }
      }
      
      if (gid < topk) {
        float valx = xoutidx[gid];
        uint posx = 0;
        for (uint i = 0; i < topk; i++) {
          float tempval = xoutidx[i];
          bool larger = tempval < valx;
          posx += (larger)?1:0;
        }
        xoutidx[posx] = valx;
        
        float valy = youtidx[gid];
        uint posy = 0;
        for (uint i = 0; i < topk; i++) {
          float tempval = youtidx[i];
          bool larger = tempval < valy;
          posy += (larger)?1:0;
        }
        youtidx[posy] = valy;
      }
      
      // only calc matrix multiplications for used grads
      ///////////////////////////////////////////////////
      if (gid < topk) {
        uint idxx = xoutidx[gid];
        for (uint j=0; j<topk; j++) {
          uint idxy = youtidx[j];
          xout[j*topk+gid] = 0;
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxx;
            uint yidx2 = osize*k+idxy;
            xout[j*topk+gid] += x[xidx2] * y[yidx2];
            //if (gid == 0 && j == 1)
            //  printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
        }
      }
    }""").build()

In [269]:
a.shape, b.shape

((32, 64), (64, 10))

In [270]:
rows = a.shape[0]
msize = a.shape[1]

In [271]:
cols = b.shape[1]

In [272]:
mult = a.dot(b)

In [273]:
mult = mult.astype(np.float32)

In [274]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.genwupdate3  # Use this Kernel object for repeated calls
evt = knl(queue, [max(rows,cols)], None, a_buf, b_buf, x_sum_buf, y_sum_buf, np.uint32(rows), np.uint32(msize),np.uint32(cols), np.uint32(topk), x_cp_buf, x_idx_buf, y_idx_buf)

resxsum = np.zeros(a.shape[0]).astype(np.float32)
resysum = np.zeros(b.shape[1]).astype(np.float32)
resx = np.zeros(topk*topk).astype(np.float32)
resxidx = np.zeros(topk).astype(np.uint32)
resyidx = np.zeros(topk).astype(np.uint32)

cl.enqueue_copy(queue, resxsum, x_sum_buf)
cl.enqueue_copy(queue, resysum, y_sum_buf)
cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

<pyopencl._cl.NannyEvent at 0x7ff15428a5e0>

In [275]:
resx.reshape(topk,topk)

array([[2.5565495, 2.9877534, 3.8442254, 2.4307942, 2.5962925],
       [3.714852 , 3.1018186, 5.2590375, 3.4558418, 1.9194646],
       [2.7812793, 3.6487942, 2.5204127, 1.5231056, 5.310211 ],
       [3.2777538, 2.6615262, 2.089722 , 3.3150718, 3.0769377],
       [2.2969038, 3.3801854, 2.0366   , 3.5051725, 1.7588413]],
      dtype=float32)

In [276]:
resxsum

array([12.76748 , 14.331857, 12.688714, 12.07191 , 13.300738, 12.708889,
       12.793668, 11.246282, 14.091072, 15.491286, 11.640156, 11.718621,
       12.305711, 13.083982, 13.09949 , 15.235169, 13.033628, 11.034429,
       12.793446, 11.830815, 13.510073, 11.045941, 11.79129 , 14.116871,
       11.174588, 13.510384, 12.001365, 13.343121, 12.215005, 11.441105,
       13.634347, 12.86991 ], dtype=float32)

In [277]:
resysum

array([ 8.32111 , 12.586111, 12.924564, 12.377018, 14.048852, 11.051554,
       14.963384, 14.75957 ,  7.215166, 10.108348], dtype=float32)

In [278]:
a.sum(axis=1)

array([12.76748  , 14.331858 , 12.688716 , 12.07191  , 13.300737 ,
       12.708889 , 12.793668 , 11.246283 , 14.09107  , 15.491286 ,
       11.640156 , 11.718621 , 12.30571  , 13.0839815, 13.099489 ,
       15.23517  , 13.03363  , 11.0344305, 12.7934475, 11.830814 ,
       13.510073 , 11.045944 , 11.791291 , 14.11687  , 11.174588 ,
       13.510384 , 12.001363 , 13.343121 , 12.215002 , 11.441103 ,
       13.634346 , 12.86991  ], dtype=float32)

In [279]:
b.sum(axis=0)

array([ 8.32111 , 12.586111, 12.924564, 12.377018, 14.048852, 11.051554,
       14.963384, 14.75957 ,  7.215166, 10.108348], dtype=float32)

In [280]:
mult

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [281]:
resxidx

array([ 1,  8,  9, 15, 23], dtype=uint32)

In [282]:
resyidx

array([1, 2, 4, 6, 7], dtype=uint32)

In [283]:
idx = 1
xy0 = x[idx].reshape(dim,1)*y[idx]
xy0.shape

(8, 8)

In [284]:
xy0[0][0]

0.7881919

## Weight update kernel new2 (sparse ouput)

In [285]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [286]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
x_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, a.shape[0]*4)
y_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.shape[1]*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topkx*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topky*4)
sdata_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
sidxs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
snnzs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*4)
sdatat_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*topky*4)
sidxst_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*topky*4)
snnzst_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate4(__global  float* x,     // INPUT MATRIX DATA
                              __global  float* y,    // INPUT
                              __global  float* xsum,    // INPUT
                              __global  float* ysum,    // INPUT
                              uint isize,
                              uint msize,
                              uint osize,
                              uint topkx,
                              uint topky,
                              __global  uint*  xoutidx,
                              __global  uint*  youtidx,
                              __global  float* matData,     // OUTPUT MATRIX DATA
                              __global  uint*  colIdx,
                              __global  uint*  rowNnz,
                              __global  float* matDatat,    // OUTPUT MATRIX DATA
                              __global  uint*  colIdxt,
                              __global  uint*  rowNnzt
                              ) {
      uint gid = get_global_id(0);

      // get for a: sum axis0  b: sum axis1 then get topk
      ///////////////////////////////////////////////////
      if (gid < isize) {
        xsum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = x[i*isize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, i*msize+gid);
          //}
          xsum[gid] += val;
        }

        float valx = xsum[gid];
        uint posx = 0;
        for (uint i = 0; i < isize; i++) {
          float tempval = fabs(xsum[i]);
          bool larger = tempval > fabs(valx);
          posx += (larger)?1:0;
        }
        if (posx < topky) {
          youtidx[posx] = gid;
        }
      }

      if (gid < osize) {
        ysum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = y[i*osize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, gid*osize+i);
          //}
          ysum[gid] += val;
        }

        float valy = ysum[gid];
        uint posy = 0;
        for (uint i = 0; i < osize; i++) {
          float tempval = fabs(ysum[i]);
          bool larger = tempval > fabs(valy);
          posy += (larger)?1:0;
        }

        if (posy < topkx) {
          xoutidx[posy] = gid;
        }
      }

      if (gid < topkx) {
        float valx = xoutidx[gid];
        uint posx = 0;
        for (uint i = 0; i < topkx; i++) {
          float tempval = xoutidx[i];
          bool larger = tempval < valx;
          posx += (larger)?1:0;
        }
        xoutidx[gid] = gid;
      }

      if (gid < topky) {
        float valy = youtidx[gid];
        uint posy = 0;
        for (uint i = 0; i < topky; i++) {
          float tempval = youtidx[i];
          bool larger = tempval < valy;
          posy += (larger)?1:0;
        }
        youtidx[gid] = gid;
      }

      // only calc matrix multiplications for used grads
      ///////////////////////////////////////////////////
      if (gid < isize) {
        for (uint i=0; i<topkx; i++) {
          matData[gid*topkx+i] = 0;
          colIdx[gid*topkx+i] = 0;
        }
        rowNnz[gid] = 0;
      }
      if (gid < osize) {
        for (uint i=0; i<topky; i++) {
          matDatat[gid*topky+i] = 0;
          colIdxt[gid*topky+i] = 0;
        }
        rowNnzt[gid] = 0;
      }


      if (gid < topkx) {
        uint idxx = xoutidx[gid];
        for (uint j=0; j<topky; j++) {
          uint idxy = youtidx[j];
          //printf("\\nIDXX:%i  IDXY:%i", idxx, idxy);
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxy;
            uint yidx2 = osize*k+idxx;
            uint colidx = idxy;
            matDatat[idxx*topky+j] += x[xidx2] * y[yidx2];
            colIdxt[idxx*topky+j] = idxy;
            if (gid == 0)
              printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
          rowNnzt[idxx] += 1;
        }
      }
      if (gid < topky) {
        uint idxx = youtidx[gid];
        for (uint j=0; j<topkx; j++) {
          uint idxy = xoutidx[j];
          //printf("\\nIDXX:%i  IDXY:%i", idxx, idxy);
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxx;
            uint yidx2 = osize*k+idxy;
            uint colidx = idxy;
            matData[idxx*topkx+j] += x[xidx2] * y[yidx2];
            colIdx[idxx*topkx+j] = idxy;
            if (gid == 0)
              printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
          rowNnz[idxx] += 1;
        }
      }
    }""").build()

In [287]:
a.shape, b.shape

((32, 64), (64, 10))

In [288]:
rows = a.shape[0]
msize = a.shape[1]

In [289]:
cols = b.shape[1]

In [290]:
mult = a.dot(b)

In [291]:
mult = mult.astype(np.float32)

In [292]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.genwupdate4  # Use this Kernel object for repeate/duald calls
evt = knl(queue, [max(rows,cols)], None, a_buf, b_buf, x_sum_buf, y_sum_buf, np.uint32(rows), np.uint32(msize),np.uint32(cols), 
          np.uint32(topkx),np.uint32(topky), x_idx_buf, y_idx_buf, sdata_buf, sidxs_buf, snnzs_buf, sdatat_buf, sidxst_buf, snnzst_buf)


 ADD VAL:0.82,0.94 - (0,0) - (0,0,0)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,1)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,2)
 ADD VAL:0.48,0.00 - (0,0) - (0,0,3)
 ADD VAL:0.91,0.00 - (0,0) - (0,0,4)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,5)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,6)
 ADD VAL:0.88,0.31 - (0,0) - (0,0,7)
 ADD VAL:0.00,0.11 - (0,0) - (0,0,8)
 ADD VAL:0.00,0.07 - (0,0) - (0,0,9)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,10)
 ADD VAL:0.87,0.16 - (0,0) - (0,0,11)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,12)
 ADD VAL:0.00,0.64 - (0,0) - (0,0,13)
 ADD VAL:0.26,0.50 - (0,0) - (0,0,14)
 ADD VAL:0.36,0.00 - (0,0) - (0,0,15)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,16)
 ADD VAL:0.41,0.00 - (0,0) - (0,0,17)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,18)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,19)
 ADD VAL:0.62,0.18 - (0,0) - (0,0,20)
 ADD VAL:0.47,0.00 - (0,0) - (0,0,21)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,22)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,23)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,24)
 ADD VAL:0.00,0.00 - (0,0) - (0,0,25)
 ADD VAL:0.00,0.00 - 

In [293]:
resxsum = np.zeros(a.shape[0]).astype(np.float32)
resysum = np.zeros(b.shape[1]).astype(np.float32)
resxidx = np.zeros(topkx).astype(np.uint32)
resyidx = np.zeros(topky).astype(np.uint32)
resxdat = np.zeros(a.shape[0]*topkx).astype(np.float32)
resxcol = np.zeros(a.shape[0]*topkx).astype(np.uint32)
resxnnz = np.zeros(a.shape[0]).astype(np.uint32)
resxdatt = np.zeros(b.shape[1]*topky).astype(np.float32)
resxcolt = np.zeros(b.shape[1]*topky).astype(np.uint32)
resxnnzt = np.zeros(b.shape[1]).astype(np.uint32)

cl.enqueue_copy(queue, resxsum, x_sum_buf)
cl.enqueue_copy(queue, resysum, y_sum_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)
cl.enqueue_copy(queue, resxdat, sdata_buf)
cl.enqueue_copy(queue, resxcol, sidxs_buf)
cl.enqueue_copy(queue, resxnnz, snnzs_buf)
cl.enqueue_copy(queue, resxdatt, sdatat_buf)
cl.enqueue_copy(queue, resxcolt, sidxst_buf)
cl.enqueue_copy(queue, resxnnzt, snnzst_buf)


 ADD VAL:0.00,0.00 - (0,5) - (0,5,50)
 ADD VAL:0.32,0.00 - (0,5) - (0,5,51)
 ADD VAL:0.00,0.00 - (0,5) - (0,5,52)
 ADD VAL:0.73,0.00 - (0,5) - (0,5,53)
 ADD VAL:0.56,0.00 - (0,5) - (0,5,54)
 ADD VAL:0.00,0.00 - (0,5) - (0,5,55)
 ADD VAL:0.00,0.00 - (0,5) - (0,5,56)
 ADD VAL:0.06,0.27 - (0,5) - (0,5,57)
 ADD VAL:0.85,0.00 - (0,5) - (0,5,58)
 ADD VAL:0.01,0.23 - (0,5) - (0,5,59)
 ADD VAL:0.00,0.10 - (0,5) - (0,5,60)
 ADD VAL:0.60,0.00 - (0,5) - (0,5,61)
 ADD VAL:0.00,0.00 - (0,5) - (0,5,62)
 ADD VAL:0.00,0.00 - (0,5) - (0,5,63)
 ADD VAL:0.98,0.94 - (0,6) - (0,6,0)
 ADD VAL:0.58,0.00 - (0,6) - (0,6,1)
 ADD VAL:0.55,0.00 - (0,6) - (0,6,2)
 ADD VAL:0.00,0.00 - (0,6) - (0,6,3)
 ADD VAL:0.43,0.00 - (0,6) - (0,6,4)
 ADD VAL:0.00,0.00 - (0,6) - (0,6,5)
 ADD VAL:0.28,0.00 - (0,6) - (0,6,6)
 ADD VAL:0.00,0.31 - (0,6) - (0,6,7)
 ADD VAL:0.74,0.11 - (0,6) - (0,6,8)
 ADD VAL:0.00,0.07 - (0,6) - (0,6,9)
 ADD VAL:0.00,0.00 - (0,6) - (0,6,10)
 ADD VAL:0.00,0.16 - (0,6) - (0,6,11)
 ADD VAL:0.92,0.00 - 

<pyopencl._cl.NannyEvent at 0x7ff154171d10>

## results

In [294]:
topkx, topky

(5, 8)

In [295]:
mult.shape

(32, 10)

In [296]:
mult

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [297]:
resxdatt.reshape(b.shape[1],topky).T

array([[2.0588999, 1.8681664, 4.811817 , 2.2529387, 2.1096544, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [2.732146 , 2.5565495, 3.714852 , 1.967403 , 2.7812793, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [1.928745 , 3.1600957, 2.6502829, 1.8178726, 3.2470448, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [1.063185 , 4.1737986, 2.2277255, 1.2515675, 4.1827664, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [1.1833991, 1.8557596, 3.0534244, 1.0722461, 3.4905617, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [1.0461621, 3.5566838, 3.2101297, 1.572406 , 2.2809083, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [1.5166444, 2.5048141, 1.3417709, 3.5234869, 2.7652483, 0.       ,
        0.       , 0.       , 0.       , 0.       ],
       [1.6907904, 2.3283677, 1.7894993, 2.1847265, 2.8551002, 0.       ,
        0.       , 0.       , 0.     

In [298]:
resxcol.reshape(a.shape[0],topkx)

array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]], dtype=uint32)

In [299]:
resxnnz.reshape(a.shape[0])

array([5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=uint32)

In [300]:
resxdatt.reshape(b.shape[1],topky)

array([[2.0588999, 2.732146 , 1.928745 , 1.063185 , 1.1833991, 1.0461621,
        1.5166444, 1.6907904],
       [1.8681664, 2.5565495, 3.1600957, 4.1737986, 1.8557596, 3.5566838,
        2.5048141, 2.3283677],
       [4.811817 , 3.714852 , 2.6502829, 2.2277255, 3.0534244, 3.2101297,
        1.3417709, 1.7894993],
       [2.2529387, 1.967403 , 1.8178726, 1.2515675, 1.0722461, 1.572406 ,
        3.5234869, 2.1847265],
       [2.1096544, 2.7812793, 3.2470448, 4.1827664, 3.4905617, 2.2809083,
        2.7652483, 2.8551002],
       [0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       ],
       [0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       ],
       [0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       ],
       [0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       ],
       [0.       , 0.       , 0.       , 0.       , 0. 

In [301]:
resxcolt.reshape(b.shape[1],topky)

array([[0, 1, 2, 3, 4, 5, 6, 7],
       [0, 1, 2, 3, 4, 5, 6, 7],
       [0, 1, 2, 3, 4, 5, 6, 7],
       [0, 1, 2, 3, 4, 5, 6, 7],
       [0, 1, 2, 3, 4, 5, 6, 7],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint32)

In [302]:
resxnnzt.reshape(b.shape[1])

array([8, 8, 8, 8, 8, 0, 0, 0, 0, 0], dtype=uint32)

In [303]:
resdense = to_dense(resxdat, resxcol, resxnnz, topkx, mult.shape)
resdense

array([[2.05889988, 1.86816645, 4.81181717, 2.25293875, 2.10965443,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [2.73214602, 2.55654955, 3.71485209, 1.96740305, 2.78127933,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.92874503, 3.16009569, 2.65028286, 1.81787264, 3.2470448 ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.06318498, 4.17379856, 2.22772551, 1.25156748, 4.18276644,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.18339908, 1.85575962, 3.05342436, 1.07224607, 3.49056172,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.04616213, 3.55668378, 3.21012974, 1.57240605, 2.28090835,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.51664436, 2.50481415, 1.34177089, 3.52348685, 2.7652483 ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [1.69079041, 2.32836771, 1.7894992

In [304]:
resdenset = to_dense(resxdatt, resxcolt, resxnnzt, topky, mult.T.shape)
resdenset

array([[2.05889988, 2.73214602, 1.92874503, 1.06318498, 1.18339908,
        1.04616213, 1.51664436, 1.69079041, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ],
       [1.86816645, 2.55654955, 3.16009569, 4.17379856, 1.85575962,
        3.55668378, 2.50481415, 2.32836771, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ],
       [4.81181717, 3.71485209, 2.65028286, 2.22772551, 3.05342436,
        3.21012974, 1.34177089, 1.78949928, 0.    

In [305]:
resdense == resdenset.T

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True, 

## comp

In [306]:
resxsum

array([12.76748 , 14.331857, 12.688714, 12.07191 , 13.300738, 12.708889,
       12.793668, 11.246282, 14.091072, 15.491286, 11.640156, 11.718621,
       12.305711, 13.083982, 13.09949 , 15.235169, 13.033628, 11.034429,
       12.793446, 11.830815, 13.510073, 11.045941, 11.79129 , 14.116871,
       11.174588, 13.510384, 12.001365, 13.343121, 12.215005, 11.441105,
       13.634347, 12.86991 ], dtype=float32)

In [307]:
a.sum(axis=1)

array([12.76748  , 14.331858 , 12.688716 , 12.07191  , 13.300737 ,
       12.708889 , 12.793668 , 11.246283 , 14.09107  , 15.491286 ,
       11.640156 , 11.718621 , 12.30571  , 13.0839815, 13.099489 ,
       15.23517  , 13.03363  , 11.0344305, 12.7934475, 11.830814 ,
       13.510073 , 11.045944 , 11.791291 , 14.11687  , 11.174588 ,
       13.510384 , 12.001363 , 13.343121 , 12.215002 , 11.441103 ,
       13.634346 , 12.86991  ], dtype=float32)

In [308]:
resysum

array([ 8.32111 , 12.586111, 12.924564, 12.377018, 14.048852, 11.051554,
       14.963384, 14.75957 ,  7.215166, 10.108348], dtype=float32)

In [309]:
b.sum(axis=0)

array([ 8.32111 , 12.586111, 12.924564, 12.377018, 14.048852, 11.051554,
       14.963384, 14.75957 ,  7.215166, 10.108348], dtype=float32)

In [310]:
mult

array([[2.0588999 , 1.8681664 , 4.811817  , 2.2529387 , 2.1096544 ,
        1.3379606 , 2.9000037 , 2.0491414 , 2.5897803 , 1.5031604 ],
       [2.732146  , 2.5565495 , 3.714852  , 1.967403  , 2.7812793 ,
        3.9070148 , 3.2777538 , 2.2969038 , 1.9229925 , 2.6556396 ],
       [1.928745  , 3.1600957 , 2.6502829 , 1.8178726 , 3.2470448 ,
        2.285428  , 2.6365926 , 2.5166864 , 1.070901  , 2.0605705 ],
       [1.063185  , 4.1737986 , 2.2277255 , 1.2515675 , 4.1827664 ,
        1.2915825 , 1.3805095 , 3.4198837 , 1.0256768 , 2.7755067 ],
       [1.1833991 , 1.8557596 , 3.0534244 , 1.0722461 , 3.4905617 ,
        1.643646  , 3.453313  , 3.808668  , 0.65044206, 2.0563726 ],
       [1.0461621 , 3.5566838 , 3.2101297 , 1.572406  , 2.2809083 ,
        1.8500901 , 1.3386623 , 4.39698   , 1.5621432 , 1.7725785 ],
       [1.5166444 , 2.5048141 , 1.3417709 , 3.5234869 , 2.7652483 ,
        1.1395355 , 2.25748   , 4.705172  , 2.2430854 , 1.8194562 ],
       [1.6907904 , 2.3283677 , 1.7894993

In [311]:
resxidx

array([0, 1, 2, 3, 4], dtype=uint32)

In [312]:
resyidx

array([0, 1, 2, 3, 4, 5, 6, 7], dtype=uint32)

In [313]:
asdf

NameError: name 'asdf' is not defined

## Weight update kernel new3 (sparse ouput)

In [None]:
c=np.zeros(a.T.shape)
at = a.T
for row in range(at.shape[0]):
    for col in range(at.shape[1]):
        c[row][col] = at[row][col]

In [None]:
a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=c.astype(np.float32))
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b)
x_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, a.shape[0]*4)
y_sum_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.shape[1]*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topkx*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, topky*4)
sdata_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
sidxs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*topkx*4)
snnzs_buf = cl.Buffer(ctx, mf.READ_WRITE, a.shape[0]*4)
sdatat_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*topky*4)
sidxst_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*topky*4)
snnzst_buf = cl.Buffer(ctx, mf.READ_WRITE, b.shape[1]*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate4(__global  float* x,     // INPUT MATRIX DATA
                              __global  float* y,    // INPUT
                              __global  float* xsum,    // INPUT
                              __global  float* ysum,    // INPUT
                              uint isize,
                              uint msize,
                              uint osize,
                              uint topkx,
                              uint topky,
                              __global  uint*  xoutidx,
                              __global  uint*  youtidx,
                              __global  float* matData,     // OUTPUT MATRIX DATA
                              __global  uint*  colIdx,
                              __global  uint*  rowNnz,
                              __global  float* matDatat,    // OUTPUT MATRIX DATA
                              __global  uint*  colIdxt,
                              __global  uint*  rowNnzt
                              ) {
      uint gid = get_global_id(0);

      // get for a: sum axis0  b: sum axis1 then get topk
      ///////////////////////////////////////////////////
      if (gid < isize) {
        xsum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = x[i*isize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, i*msize+gid);
          //}
          xsum[gid] += val;
        }

        float valx = xsum[gid];
        uint posx = 0;
        for (uint i = 0; i < isize; i++) {
          float tempval = fabs(xsum[i]);
          bool larger = tempval > fabs(valx);
          posx += (larger)?1:0;
        }
        if (posx < topky) {
          youtidx[posx] = gid;
        }
      }

      if (gid < osize) {
        ysum[gid] = 0;
        for (uint i=0; i<msize; i++) {
          float val = y[i*osize+gid];
          //if (gid == 0) {
          //  printf("\\nADD VALx: %.2f - %i", val, gid*osize+i);
          //}
          ysum[gid] += val;
        }

        float valy = ysum[gid];
        uint posy = 0;
        for (uint i = 0; i < osize; i++) {
          float tempval = fabs(ysum[i]);
          bool larger = tempval > fabs(valy);
          posy += (larger)?1:0;
        }

        if (posy < topkx) {
          xoutidx[posy] = gid;
        }
      }

      if (gid < topkx) {
        float valx = xoutidx[gid];
        uint posx = 0;
        for (uint i = 0; i < topkx; i++) {
          float tempval = xoutidx[i];
          bool larger = tempval < valx;
          posx += (larger)?1:0;
        }
        xoutidx[gid] = gid;
      }

      if (gid < topky) {
        float valy = youtidx[gid];
        uint posy = 0;
        for (uint i = 0; i < topky; i++) {
          float tempval = youtidx[i];
          bool larger = tempval < valy;
          posy += (larger)?1:0;
        }
        youtidx[gid] = gid;
      }

      // only calc matrix multiplications for used grads
      ///////////////////////////////////////////////////
      if (gid < isize) {
        for (uint i=0; i<topkx; i++) {
          matData[gid*topkx+i] = 0;
          colIdx[gid*topkx+i] = 0;
        }
        rowNnz[gid] = 0;
      }
      if (gid < osize) {
        for (uint i=0; i<topky; i++) {
          matDatat[gid*topky+i] = 0;
          colIdxt[gid*topky+i] = 0;
        }
        rowNnzt[gid] = 0;
      }


      if (gid < topkx) {
        uint idxx = xoutidx[gid];
        for (uint j=0; j<topky; j++) {
          uint idxy = youtidx[j];
          //printf("\\nIDXX:%i  IDXY:%i", idxx, idxy);
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxy;
            uint yidx2 = osize*k+idxx;
            uint colidx = idxy;
            matDatat[idxx*topky+j] += x[xidx2] * y[yidx2];
            colIdxt[idxx*topky+j] = idxy;
            if (gid == 0)
              printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
          rowNnzt[idxx] += 1;
        }
      }
      if (gid < topky) {
        uint idxx = youtidx[gid];
        for (uint j=0; j<topkx; j++) {
          uint idxy = xoutidx[j];
          //printf("\\nIDXX:%i  IDXY:%i", idxx, idxy);
          for (uint k=0; k<msize; k++) {
            uint xidx2 = isize*k+idxx;
            uint yidx2 = osize*k+idxy;
            uint colidx = idxy;
            matData[idxx*topkx+j] += x[xidx2] * y[yidx2];
            colIdx[idxx*topkx+j] = idxy;
            if (gid == 0)
              printf("\\n ADD VAL:%.2f,%.2f - (%i,%i) - (%i,%i,%i)", x[xidx2], y[yidx2], idxx, idxy, gid, j, k);
          }
          rowNnz[idxx] += 1;
        }
      }
    }""").build()

In [None]:
a.shape, b.shape

In [None]:
rows = a.shape[0]
msize = a.shape[1]

In [None]:
cols = b.shape[1]

In [None]:
mult = a.dot(b)

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.READ_WRITE, np.prod([rows,b.shape[1]])*4)
knl = prg.genwupdate4  # Use this Kernel object for repeate/duald calls
evt = knl(queue, [max(rows,cols)], None, a_buf, b_buf, x_sum_buf, y_sum_buf, np.uint32(rows), np.uint32(msize),np.uint32(cols), 
          np.uint32(topkx),np.uint32(topky), x_idx_buf, y_idx_buf, sdata_buf, sidxs_buf, snnzs_buf, sdatat_buf, sidxst_buf, snnzst_buf)

In [None]:
resxsum = np.zeros(a.shape[0]).astype(np.float32)
resysum = np.zeros(b.shape[1]).astype(np.float32)
resxidx = np.zeros(topkx).astype(np.uint32)
resyidx = np.zeros(topky).astype(np.uint32)
resxdat = np.zeros(a.shape[0]*topkx).astype(np.float32)
resxcol = np.zeros(a.shape[0]*topkx).astype(np.uint32)
resxnnz = np.zeros(a.shape[0]).astype(np.uint32)
resxdatt = np.zeros(b.shape[1]*topky).astype(np.float32)
resxcolt = np.zeros(b.shape[1]*topky).astype(np.uint32)
resxnnzt = np.zeros(b.shape[1]).astype(np.uint32)

cl.enqueue_copy(queue, resxsum, x_sum_buf)
cl.enqueue_copy(queue, resysum, y_sum_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)
cl.enqueue_copy(queue, resxdat, sdata_buf)
cl.enqueue_copy(queue, resxcol, sidxs_buf)
cl.enqueue_copy(queue, resxnnz, snnzs_buf)
cl.enqueue_copy(queue, resxdatt, sdatat_buf)
cl.enqueue_copy(queue, resxcolt, sidxst_buf)
cl.enqueue_copy(queue, resxnnzt, snnzst_buf)

## results

In [None]:
topkx, topky

In [None]:
mult.shape

In [None]:
mult

In [None]:
resxdatt.reshape(b.shape[1],topky).T

In [None]:
resxcol.reshape(a.shape[0],topkx)

In [None]:
resxnnz.reshape(a.shape[0])

In [None]:
resxdatt.reshape(b.shape[1],topky)

In [None]:
resxcolt.reshape(b.shape[1],topky)

In [None]:
resxnnzt.reshape(b.shape[1])

In [None]:
resdense = to_dense(resxdat, resxcol, resxnnz, topkx, mult.shape)
resdense

In [None]:
resdenset = to_dense(resxdatt, resxcolt, resxnnzt, topky, mult.T.shape)
resdenset

In [None]:
resdense == resdenset.T

## comp

In [None]:
resxsum

In [None]:
a.sum(axis=1)

In [None]:
resysum

In [None]:
b.sum(axis=0)

In [None]:
mult

In [None]:
resxidx

In [None]:
resyidx

In [None]:
asdf

## Prune Weights

In [None]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)

prg = cl.Program(ctx, """
    // prunes weights smaller than a constant C
    __kernel void prune(__global  float* matData,     // INPUT MATRIX DATA
                        __global  uint*  colIdx,
                        __global  uint*  rowNnz,
                        uint ellw,
                        float pruneval) { 
      uint gid = get_global_id(0);
      
      uint nnzs = rowNnz[gid];
      for (uint i=0; i<nnzs; i++) {
        uint idx = ellw * gid + i;
        float val = matData[idx];
        printf("\\nDATA:%.2f - %.2f", matData[idx], pruneval);
        if(fabs(val)<pruneval) {
          printf("\\nPRUNE(%i): %.2f", gid, matData[idx]);
          for (uint j=i; j<=nnzs-1; j++) {
            uint idx2 = ellw * gid + j;
            matData[idx2] = matData[idx2+1];
            colIdx[idx2] = colIdx[idx2+1];
          }
          matData[ellw*gid+nnzs] = 0;
          colIdx[ellw*gid+nnzs] = 0;
          rowNnz[gid] -= 1;
          nnzs = rowNnz[gid];
        }
      }
    }""").build()

In [None]:
a.shape

In [None]:
rows = a.shape[0]
cols = a.shape[1]

pruneval = .35

In [None]:
knl = prg.prune  # Use this Kernel object for repeated calls
evt = knl(queue, [rows,], None, adata_buf, acols_buf, annzs_buf, np.uint32(ellwa), np.float32(pruneval))

In [None]:
resxdat = np.zeros(adata.shape).astype(np.float32)
resxcol = np.zeros(acols.shape).astype(np.uint32)
resxnnz = np.zeros(annz.shape).astype(np.uint32)

cl.enqueue_copy(queue, resxdat, adata_buf)
cl.enqueue_copy(queue, resxcol, acols_buf)
cl.enqueue_copy(queue, resxnnz, annzs_buf)

In [None]:
a

In [None]:
adata.reshape((4,-1))

In [None]:
acols.reshape((4,-1))

In [None]:
resxdat.reshape((4,-1))

In [None]:
resxcol.reshape((4,-1))

In [None]:
resxnnz

## results

In [None]:
mult.T

In [None]:
resxdat.reshape(a.shape[0],topk)

In [None]:
resxcol.reshape(a.shape[0],topk)

In [None]:
resxnnz.reshape(a.shape[0])

In [None]:
resxdatt.reshape(b.shape[1],topk)

In [None]:
resxcolt.reshape(b.shape[1],topk)

In [None]:
resxnnzt.reshape(b.shape[1])

### Update Vals (add sparse)

In [None]:
b.shape

In [None]:
randadd = np.random.rand(*b.shape)

In [None]:
randdata, randcols, randnnz, randellw = to_data(randadd)
bdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bdata)
bcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bcols)
bnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=bnnz)
randdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=randdata)
randcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=randcols)
randnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=randnnz)


prg = cl.Program(ctx, """
    // Every global_id_0 works on a row
    __kernel void adddense(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            float  lr,
                            uint   ellwidth,
                            __global  float* matDataAdd,     // INPUT MATRIX DATA
                            __global  uint*  colIdxAdd,
                            __global  uint*  rowNnzAdd,
                            uint ellwidthAdd
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);

      uint nnz    = rowNnz[gid];

      uint baseidxs = gid*ellwidth;
      uint baseidxd = gid*ellwidthAdd;

      uint nnzadd = rowNnzAdd[gid];
      
      uint m = 0;
      for (uint i=0; i<nnzadd; i++) {
        float addval = matDataAdd[baseidxd+i] * lr;
        uint addcol = colIdxAdd[baseidxd+i];
        
        if (addval == 0.0) {
          //printf("\\nZERO VAL, CONT: %.2f - %i", addval, gid);
          continue;
        }

        uint refcol = colIdx[baseidxs+i];
        m = 0;
        while (refcol < addcol && (i+m) < nnz) {
          m += 1;
          refcol = colIdx[baseidxs+i+m];
        }

        //if (gid == 0)
        //  printf("\\nADD VAL:%.2f  ADDCOL:%i  ref:(%i)  gid/i/m:(%i/%i%i)", addval, addcol, refcol, gid,i,m);
        
        if (addcol == refcol) {
          matData[baseidxs+i+m] += addval;
          //if (gid == 0)
          //  printf("\\nINCREMENT: %.2f",addval);
          continue;
        } else {
          //if (gid == 0)
          //  printf("\\nADD: %.2f %i-%i",addval, addcol, refcol);
          if (rowNnz[gid] >= ellwidth) {
            break;
          }

          for (uint j=nnz; j>i+m; j--) {
            //printf("\\nMOVE:%.2f", matData[baseidx+j-1]);
            colIdx[baseidxs+j] = colIdx[baseidxs+j-1];
            matData[baseidxs+j] = matData[baseidxs+j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
        
          //if (gid == 0)
          //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, addcol);
          matData[baseidxs+i+m] = addval;
          colIdx[baseidxs+i+m] = addcol;
        }
      }
    }""").build()

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
rows = b.shape[0]

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddense  # Use this Kernel object for repeated calls
knl(queue, [rows], None, bdata_buf, bcols_buf, bnnzs_buf, np.float32(1), np.uint32(ellwb), 
    randdata_buf, randcols_buf, randnnzs_buf, np.uint32(randellw))

In [None]:
data_res = np.empty_like(bdata)
cols_res = np.empty_like(bcols)
nnzs_res = np.empty_like(bnnz)
cl.enqueue_copy(queue, data_res, bdata_buf, is_blocking=True)
cl.enqueue_copy(queue, cols_res, bcols_buf, is_blocking=True)
cl.enqueue_copy(queue, nnzs_res, bnnzs_buf, is_blocking=True)

In [None]:
b

In [None]:
bcols

In [None]:
data_res

In [None]:
cols_res

In [None]:
nnzs_res

In [None]:
randadd

In [None]:
adenseadd = to_dense(data_res, cols_res, nnzs_res, ellwb, b.shape)
adenseadd

In [None]:
baseadd = (b+randadd)
baseadd

In [None]:
adenseadd - baseadd

In [None]:
(adenseadd - baseadd).sum()

### Update Vals (add sparset)

In [None]:
multt=np.zeros(mult.T.shape)

for row in range(multt.shape[0]):
    for col in range(multt.shape[1]):
        multt[row][col] = mult[col][row]

In [None]:
multdata, multcols, multnnz, multellw = to_data(multt)
multdata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multdata)
multcols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multcols)
multnnzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=multnnz)

In [None]:
a.shape, b.shape

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
rows = mult.T.shape[0]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddense  # Use this Kernel object for repeated calls
knl(queue, [rows], None, multdata_buf, multcols_buf, multnnzs_buf, np.float32(1), np.uint32(multellw), 
    sdatat_buf, sidxst_buf, snnzst_buf, np.uint32(topk))

In [None]:
mult.T

In [None]:
data_res = np.empty_like(multdata)
cols_res = np.empty_like(multcols)
nnzs_res = np.empty_like(multnnz)
cl.enqueue_copy(queue, data_res, multdata_buf, is_blocking=True)
cl.enqueue_copy(queue, cols_res, multcols_buf, is_blocking=True)
cl.enqueue_copy(queue, nnzs_res, multnnzs_buf, is_blocking=True)

In [None]:
multt-data_res.reshape(multt.shape)

In [None]:
nnzs_res

In [None]:
adenseaddt = to_dense(data_res, cols_res, nnzs_res, multellw, multt.shape)
adenseaddt

In [None]:
multt-adenseaddt

In [None]:
adenseaddt

In [None]:
adenseadd.T == adenseaddt

### Update Vals (add topk to sparse)

In [None]:
matadd = np.random.randn(*a.shape).astype(np.float32)
matadd

In [None]:
a

In [None]:
a_added = a + matadd

In [None]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
add_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=matadd)

prg = cl.Program(ctx, """
    // Every global_id_0 works on a row
    __kernel void adddense(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            float  lr,
                            uint   ellwidth,
                            uint   awidth,
                            __global  float* vector_x    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);

      uint nnz    = rowNnz[gid];
      uint baseidxs = gid*ellwidth;
      uint baseidxd = gid*awidth;
      
      for (uint i=0; i<awidth; i++) {
        float addval = vector_x[baseidxd+i];
        //if (gid==1)
        //  printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[baseidxs+i]);
        if (addval == 0) {
          continue;
        }
        if (i == colIdx[baseidxs+i]) {
          matData[baseidxs+i] += addval;
        } else {
          if (rowNnz[gid] >= ellwidth) {
            break;
          }
          if (i > colIdx[baseidxs+i]) {
            rowNnz[gid] += 1;
            //if (gid==1)
            //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
            matData[baseidxs+i] = addval;
            colIdx[baseidxs+i] = i;
            continue;
          }
          for (uint j=nnz; j>i; j--) {
            //printf("\\nMOVE:%.2f", matData[baseidx+j-1]);
            colIdx[baseidxs+j] = colIdx[baseidxs+j-1];
            matData[baseidxs+j] = matData[baseidxs+j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
          //if (gid==1)
          //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i] = addval;
          colIdx[baseidxs+i] = i;
          if (nnz >= ellwidth)
            break;
        }
      }
    }""").build()

In [None]:
a.shape, b.shape

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
rows = a.shape[0]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddense  # Use this Kernel object for repeated calls
knl(queue, [rows], None, adata_buf, acols_buf, annzs_buf, np.float32(1), np.uint32(ellwa),np.uint32(a.shape[1]), add_buf)

In [None]:
matadd[0][0]

In [None]:
data_res = np.empty_like(adata)
cols_res = np.empty_like(acols)
nnzs_res = np.empty_like(annz)
cl.enqueue_copy(queue, data_res, adata_buf)
cl.enqueue_copy(queue, cols_res, acols_buf)
cl.enqueue_copy(queue, nnzs_res, annzs_buf)

In [None]:
adenseadd = to_dense(data_res, cols_res, nnzs_res, ellwa, a.shape)
adenseadd

In [None]:
a

In [None]:
matadd

In [None]:
a_added

In [None]:
adenseadd

In [None]:
adenseadd == a_added

### update vals

In [None]:
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
add_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=matadd)

prg = cl.Program(ctx, """
    // Every global_id_0 works on a row
    __kernel void adddenset(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            float  lr,
                            uint   ellwidth,
                            uint   aheight,
                            __global  float* vector_x    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint ncols = get_global_size(0);

      uint nnz    = rowNnz[gid];
      uint baseidxs = gid*ellwidth;
      
      for (uint i=0; i<aheight; i++) {
        if (nnz > ellwidth)
            break;
        uint baseidxd = i*ncols+gid;
        float addval = vector_x[baseidxd];
        //if (gid==1)
        //  printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[baseidxs+i]);
        if (addval == 0) {
          continue;
        }
        if (i == colIdx[baseidxs+i]) {
          printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i] += addval;
        } else {
          if (rowNnz[gid] >= ellwidth) {
            break;
          }
          if (i > colIdx[baseidxs+i]) {
            rowNnz[gid] += 1;
            //if (gid==1)
            //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
            matData[baseidxs+i] = addval;
            colIdx[baseidxs+i] = i;
            continue;
          }
          for (uint j=nnz; j>i; j--) {
            //printf("\\nMOVE:%.2f", matData[baseidx+j-1]);
            colIdx[baseidxs+j] = colIdx[baseidxs+j-1];
            matData[baseidxs+j] = matData[baseidxs+j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
          //if (gid==1)
          //  printf("\\nSET VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
          matData[baseidxs+i] = addval;
          colIdx[baseidxs+i] = i;
        }
      }
    }""").build()

In [None]:
a.shape, b.shape

In [None]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [None]:
cols = a.shape[1]

In [None]:
mult = mult.astype(np.float32)

In [None]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.adddenset  # Use this Kernel object for repeated calls
knl(queue, [cols], None, adatat_buf, acolst_buf, annzst_buf, np.float32(1), np.uint32(ellwat),np.uint32(a.T.shape[1]), add_buf)

In [None]:
matadd[0][0]

In [None]:
datat_res = np.empty_like(adatat)
colst_res = np.empty_like(acolst)
nnzst_res = np.empty_like(annzt)
cl.enqueue_copy(queue, datat_res, adatat_buf)
cl.enqueue_copy(queue, colst_res, acolst_buf)
cl.enqueue_copy(queue, nnzst_res, annzst_buf)

In [None]:
adenseaddt = to_dense(datat_res, colst_res, nnzst_res, ellwat, a.T.shape).T
adenseaddt

In [None]:
a

In [None]:
matadd

In [None]:
a_added

In [None]:
adenseaddt == a_added

### Make Random

In [None]:
rand = SparseTensor.uniform(2,4)
rand

In [None]:
rand.to_numpy()

In [None]:
rand.data

### update vals

In [None]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)

In [None]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint col = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint row = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [None]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint topk = get_global_size(0);
      uint col = updateyidx[gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint row = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];m
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [None]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,1], None, adata_buf, acols_buf, annzs_buf, np.float32(1), np.uint32(ellwa), x_cp_buf, x_idx_buf, y_idx_buf)

resa = np.empty_like(adata)
resaidx = np.zeros(acols.shape).astype(np.uint32)
resannz = np.zeros(annz.shape).astype(np.uint32)

cl.enqueue_copy(queue, resa, adata_buf)
cl.enqueue_copy(queue, resaidx, acols_buf)
cl.enqueue_copy(queue, resannz, annzs_buf)

In [None]:
resa.shape, resaidx.shape, resannz.shape, ellwa, a.T.shape

In [None]:
adenseadd = to_dense(resa, resaidx, resannz, ellwa, a.shape)
adenseadd

In [None]:
adenseadd - adense

In [None]:
adenseadd == adense

In [None]:
ellwa

In [None]:
adata2 = adata.reshape(-1, ellwa)
adata2

In [None]:
resa = resa.reshape(-1, ellwa)
resa

In [None]:
resa - adata2

In [None]:
acols

In [None]:
resaidx

In [None]:
resannz

In [None]:
annz

### update vals2

In [None]:
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)

In [None]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint row = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint col = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [None]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,bs], None, adatat_buf, acolst_buf, annzst_buf, np.float32(1), np.uint32(ellwat), x_cp_buf, x_idx_buf, y_idx_buf)

resat = np.empty_like(adatat)
resaidxt = np.zeros(acolst.shape).astype(np.uint32)
resannzt = np.zeros(annzt.shape).astype(np.uint32)

cl.enqueue_copy(queue, resat, adatat_buf)
cl.enqueue_copy(queue, resaidxt, acolst_buf)
cl.enqueue_copy(queue, resannzt, annzst_buf)

In [None]:
ellwa

In [None]:
resat.shape, resaidxt.shape, resannzt.shape

In [None]:
adenseaddt = to_dense(resat, resaidxt, resannzt, ellwat, a.T.shape)
adenseaddt

In [None]:
adenseadd == adenseaddt.T

In [None]:
adata2t = adatat.reshape(-1, ellwat)
adata2t

In [None]:
resat = resat.reshape(-1, ellwat)
resat

In [None]:
resat - adata2t

In [None]:
acols

In [None]:
resaidx

In [None]:
resannz

In [None]:
annz

# OTHER

import numpy as np
import pyopencl as cl

mf = cl.mem_flags

dim = 16
topk = 4

x = np.random.rand(dim).astype(np.float32)
y = np.random.rand(dim).astype(np.float32)
x.shape,y.shape

dim1 = 4
dim2 = 8
dim3 = 1

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.2

a = np.zeros((dim1,dim2))
b = np.random.rand(dim2,dim3).flatten().astype(np.float32)

a.shape, b.shape

In [None]:
x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
val_out_buf = cl.Buffer(ctx, mf.READ_WRITE, 4*topk*topk)
x_idx_buf = cl.Buffer(ctx, mf.READ_WRITE, topk*4)
y_idx_buf = cl.Buffer(ctx, mf.READ_WRITE, topk*4)

prg = cl.Program(ctx, """
// Every global_id_0 works on a row
__kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                         __global  float* y,    // INPUT
                         __global  float* xout,    // INPUT
                         uint topk,
                         __global  uint* xoutidx,    // INPUT
                         __global  uint* youtidx    // INPUT
                        ) { // LOCAL SHARED BUFFER
  uint gid = get_global_id(0);
  uint n = get_global_size(0);
  
  xout[gid] = x[gid];
  xoutidx[gid] = gid;
  youtidx[gid] = gid;
  
  float valx = x[gid];
  float valy = y[gid];
  uint posx = 0;
  uint posy = 0;
  for (uint i = 0; i < n; i++) {
    float tempval = x[i];
    float tempval2 = y[i];
    bool larger = tempval > valx;
    bool larger2 = tempval2 > valy;
      
    posx += (larger)?1:0;
    posy += (larger2)?1:0;
  }
  //printf("posx:%i", posx);
  if (posx < topk) {
    xoutidx[posx] = gid;
  }
  if (posy < topk) {
    youtidx[posy] = gid;
  }
  if (gid < topk) {
    uint i = gid;
    for (uint j=0; j<topk; j++) {
      xout[gid*topk+j] = x[xoutidx[gid]] * y[youtidx[j]];
    }
  }
}""").build()

In [None]:
knl = prg.genwupdate2  # Use this Kernel object for repeated calls
event = knl(queue, [dim,], None, x_buf, y_buf, val_out_buf, np.uint32(topk), x_idx_buf, y_idx_buf)

#event.wait()
val_out = np.zeros(topk*topk).astype(np.float32)
resxidx = np.zeros(topk).astype(np.uint32)
resyidx = np.zeros(topk).astype(np.uint32)

cl.enqueue_copy(queue, val_out, val_out_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf, wait_for=[event])
cl.enqueue_copy(queue, resyidx, y_idx_buf)

In [None]:
val_out

In [None]:
resxidx

In [None]:
resyidx

In [None]:
asdf

In [None]:
from __future__ import division

KERNEL_CODE = """
// Thread block size
#define BLOCK_SIZE %(block_size)d
// Matrix dimensions
// (chosen as multiples of the thread block size for simplicity)
#define WA %(w_a)d // Matrix A width
#define HA %(h_a)d // Matrix A height
#define WB %(w_b)d // Matrix B width
#define HB WA  // Matrix B height
#define WC WB  // Matrix C width
#define HC HA  // Matrix C height
/*
 * Copyright 1993-2009 NVIDIA Corporation.  All rights reserved.
 *
 * NVIDIA Corporation and its licensors retain all intellectual property and
 * proprietary rights in and to this software and related documentation.
 * Any use, reproduction, disclosure, or distribution of this software
 * and related documentation without an express license agreement from
 * NVIDIA Corporation is strictly prohibited.
 *
 * Please refer to the applicable NVIDIA end user license agreement (EULA)
 * associated with this source code for terms and conditions that govern
 * your use of this NVIDIA software.
 *
 */
/* Matrix multiplication: C = A * B.
 * Device code.
 */
#define AS(j, i) As[i + j * BLOCK_SIZE]
#define BS(j, i) Bs[i + j * BLOCK_SIZE]
////////////////////////////////////////////////////////////////////////////////
//! Matrix multiplication on the device: C = A * B
//! WA is A's width and WB is B's width
////////////////////////////////////////////////////////////////////////////////
__kernel __attribute__((reqd_work_group_size(16,16,1))) 
void
matrixMul( __global float* C, __global float* A, __global float* B)
{
    __local float As[BLOCK_SIZE*BLOCK_SIZE];
    __local float Bs[BLOCK_SIZE*BLOCK_SIZE];
    // Block index
    int bx = get_group_id(0);
    int by = get_group_id(1);
    // Thread index
    int tx = get_local_id(0);
    int ty = get_local_id(1);
    // Index of the first sub-matrix of A processed by the block
    int aBegin = WA * BLOCK_SIZE * by;
    // Index of the last sub-matrix of A processed by the block
    int aEnd   = aBegin + WA - 1;
    // Step size used to iterate through the sub-matrices of A
    int aStep  = BLOCK_SIZE;
    // Index of the first sub-matrix of B processed by the block
    int bBegin = BLOCK_SIZE * bx;
    // Step size used to iterate through the sub-matrices of B
    int bStep  = BLOCK_SIZE * WB;
    // Csub is used to store the element of the block sub-matrix
    // that is computed by the thread
    float Csub = 0.0f;
    // Loop over all the sub-matrices of A and B
    // required to compute the block sub-matrix
    for (int a = aBegin, b = bBegin;
             a <= aEnd;
             a += aStep, b += bStep) {
        // Load the matrices from device memory
        // to shared memory; each thread loads
        // one element of each matrix
        AS(ty, tx) = A[a + WA * ty + tx];
        BS(ty, tx) = B[b + WB * ty + tx];
        // Synchronize to make sure the matrices are loaded
        barrier(CLK_LOCAL_MEM_FENCE);
        // Multiply the two matrices together;
        // each thread computes one element
        // of the block sub-matrix
        for (int k = 0; k < BLOCK_SIZE; ++k)
            Csub += AS(ty, k) * BS(k, tx);
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    // Write the block sub-matrix to device memory;
    // each thread writes one element
    C[get_global_id(1) * get_global_size(0) + get_global_id(0)] = Csub;
}
"""


In [314]:
a2 = np.random.rand(4,4)

In [315]:
a2

array([[0.37732439, 0.3744614 , 0.19944962, 0.88976546],
       [0.78897049, 0.22014409, 0.80173979, 0.12590939],
       [0.95648153, 0.70116229, 0.62388817, 0.76697481],
       [0.13811882, 0.434799  , 0.22015705, 0.31398815]])

In [316]:
a2.sum(axis=1)

array([1.84100087, 1.93676375, 3.04850681, 1.10706301])

In [317]:
b2 = np.random.rand(4,4)

In [318]:
b2

array([[0.31342742, 0.37951072, 0.37931032, 0.08043915],
       [0.41802246, 0.4204377 , 0.05660451, 0.74130855],
       [0.90380912, 0.83411823, 0.68734703, 0.8587971 ],
       [0.35209215, 0.94245315, 0.63776965, 0.94046972]])

In [319]:
b2.sum(axis=0)

array([1.98735114, 2.5765198 , 1.76103151, 2.62101452])

In [320]:
matmul = a2.dot(b2)
matmul

array([[0.7683409 , 1.30556316, 0.86887575, 1.31602731],
       [1.1082616 , 1.17938911, 0.94310044, 1.03360458],
       [1.42681075, 1.90102439, 1.32047321, 1.8538261 ],
       [0.53457868, 0.71477959, 0.42857788, 0.81779695]])