In [1]:
from tinygrad.densetensor import DenseTensor
from tinygrad.sparsetensor import SparseTensor
import numpy as np

%load_ext autoreload
%autoreload 2

DEVICE:GPU


In [2]:
x_init = np.random.randn(1,3).astype(np.float32)
x2_init = np.random.randn(3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
V_init = np.random.randn(3,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)

In [3]:
x = DenseTensor(x_init)
W = DenseTensor(W_init)
m = DenseTensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()

out.cpu().data, x

(array([-3.1604867], dtype=float32),
 <DenseTensor <GPUBuffer with shape (1, 3)> with grad <GPUBuffer with shape (1, 3)>>)

x2 = DenseTensor(x2_init)#.gpu()
W = SparseTensor(W_init)
out = W.dot(x2).relu().sum()

out.backward()

out.cpu().data, x

In [4]:
import numpy as np
import pyopencl as cl

mf = cl.mem_flags

In [5]:
dim1 = 8
dim2 = 8
dim3 = 8
bs = 4

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.2

a = np.zeros((dim1,dim2))
b = np.random.rand(4,dim3).astype(np.float32)

a.shape, b.shape

((8, 8), (4, 8))

In [6]:
def fill_sparse(mat, sparsity=0.1):
    indices = np.array(range(mat.shape[1]))
    nrows = int(mat.shape[1]*sparsity)
    for row in range(mat.shape[0]):
        lim = nrows #+ int(np.random.random()*3)
        mat[row][np.random.permutation(indices)[:lim]] = np.random.random(lim)
    return mat

a = fill_sparse(a, 0.9).astype(np.float32)
#b = fill_sparse(b, sparsity)

In [7]:
a

array([[0.16961966, 0.5964418 , 0.5171895 , 0.74336094, 0.48799506,
        0.32877994, 0.        , 0.7093185 ],
       [0.39712888, 0.510253  , 0.        , 0.64409375, 0.67125344,
        0.57573944, 0.46538112, 0.3581028 ],
       [0.13251007, 0.9755912 , 0.23229332, 0.32986483, 0.07646387,
        0.44998506, 0.        , 0.3866979 ],
       [0.0481182 , 0.08422235, 0.42488804, 0.        , 0.5962461 ,
        0.3434182 , 0.10857241, 0.74565107],
       [0.        , 0.5541187 , 0.24228822, 0.63623714, 0.4712802 ,
        0.5582224 , 0.23590986, 0.18378867],
       [0.08376146, 0.10682349, 0.46714184, 0.53649616, 0.17192142,
        0.5313124 , 0.        , 0.8377917 ],
       [0.77677315, 0.3933695 , 0.48234862, 0.31966466, 0.83224964,
        0.20107165, 0.29474154, 0.        ],
       [0.19884802, 0.        , 0.62628335, 0.02330758, 0.38636607,
        0.5560139 , 0.7838986 , 0.30145356]], dtype=float32)

In [8]:
b

array([[0.5758158 , 0.83606327, 0.8954073 , 0.89344084, 0.58846396,
        0.0793886 , 0.7087831 , 0.37386844],
       [0.18805529, 0.97089976, 0.01458754, 0.43385914, 0.71578693,
        0.9443288 , 0.45588118, 0.6926704 ],
       [0.94750524, 0.04013106, 0.63719   , 0.95751846, 0.6438051 ,
        0.6866626 , 0.32253084, 0.69727325],
       [0.44304255, 0.3809702 , 0.26089582, 0.30254954, 0.62685245,
        0.03873152, 0.86441535, 0.6877234 ]], dtype=float32)

In [9]:
x2_init.T

array([-1.9369417 , -0.40373948, -1.6533662 ], dtype=float32)

In [10]:
mult = a.dot(b.T)
mult.shape

(8, 4)

In [11]:
mult.shape

(8, 4)

In [12]:
def to_data(mat):
    all_rows = []
    all_idxs = []
    all_nnzs = []
    for row in range(mat.shape[0]):
        rowdata = []
        colidxs = []
        all_nnzs.append(0)
        for col in range(mat.shape[1]):
            val = mat[row][col]
            if val != 0:
                rowdata.append(val)
                colidxs.append(col)
                all_nnzs[-1] += 1
        all_rows.append(rowdata)
        all_idxs.append(colidxs)
    
    ellwidth = int(np.sqrt(np.max(all_nnzs))+1)**2
    #all_rows = np.array(all_rows)#.astype(np.float32).flatten()           
    for row in range(mat.shape[0]):
        #print(row, all_rows)
        all_rows[row] = np.array(all_rows[row])
        all_rows[row].resize(ellwidth)
        all_idxs[row] = np.array(all_idxs[row])
        all_idxs[row].resize(ellwidth)
        print(all_idxs[row])
    all_rows = np.array(all_rows).astype(np.float32).flatten()
    print(all_idxs)
    all_idxs = np.array(all_idxs).astype(np.uint32).flatten()
    all_nnzs = np.array(all_nnzs).astype(np.uint32)
    
    return all_rows, all_idxs, all_nnzs, ellwidth

In [13]:
def to_dense(data, cols, nnzs, ellw, shape):
    out = np.zeros(shape)
    for row in range(shape[0]):
        for icol in range(nnzs[row]):
            out[row,cols[row*ellw+icol]] = data[row*ellw+icol]
    return out

In [14]:
adata, acols, annz, ellwa = to_data(a)
adata, acols, annz, ellwa

[0 1 2 3 4 5 7 0 0]
[0 1 3 4 5 6 7 0 0]
[0 1 2 3 4 5 7 0 0]
[0 1 2 4 5 6 7 0 0]
[1 2 3 4 5 6 7 0 0]
[0 1 2 3 4 5 7 0 0]
[0 1 2 3 4 5 6 0 0]
[0 2 3 4 5 6 7 0 0]
[array([0, 1, 2, 3, 4, 5, 7, 0, 0]), array([0, 1, 3, 4, 5, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 7, 0, 0]), array([0, 1, 2, 4, 5, 6, 7, 0, 0]), array([1, 2, 3, 4, 5, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 0, 0]), array([0, 2, 3, 4, 5, 6, 7, 0, 0])]


(array([0.16961966, 0.5964418 , 0.5171895 , 0.74336094, 0.48799506,
        0.32877994, 0.7093185 , 0.        , 0.        , 0.39712888,
        0.510253  , 0.64409375, 0.67125344, 0.57573944, 0.46538112,
        0.3581028 , 0.        , 0.        , 0.13251007, 0.9755912 ,
        0.23229332, 0.32986483, 0.07646387, 0.44998506, 0.3866979 ,
        0.        , 0.        , 0.0481182 , 0.08422235, 0.42488804,
        0.5962461 , 0.3434182 , 0.10857241, 0.74565107, 0.        ,
        0.        , 0.5541187 , 0.24228822, 0.63623714, 0.4712802 ,
        0.5582224 , 0.23590986, 0.18378867, 0.        , 0.        ,
        0.08376146, 0.10682349, 0.46714184, 0.53649616, 0.17192142,
        0.5313124 , 0.8377917 , 0.        , 0.        , 0.77677315,
        0.3933695 , 0.48234862, 0.31966466, 0.83224964, 0.20107165,
        0.29474154, 0.        , 0.        , 0.19884802, 0.62628335,
        0.02330758, 0.38636607, 0.5560139 , 0.7838986 , 0.30145356,
        0.        , 0.        ], dtype=float32),

In [15]:
adatat, acolst, annzt, ellwat = to_data(a.T)
adatat, acolst, annzt, ellwat

[0 1 2 3 5 6 7 0 0]
[0 1 2 3 4 5 6 0 0]
[0 2 3 4 5 6 7 0 0]
[0 1 2 4 5 6 7 0 0]
[0 1 2 3 4 5 6 7 0]
[0 1 2 3 4 5 6 7 0]
[1 3 4 6 7 0 0 0 0]
[0 1 2 3 4 5 7 0 0]
[array([0, 1, 2, 3, 5, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 0, 0]), array([0, 2, 3, 4, 5, 6, 7, 0, 0]), array([0, 1, 2, 4, 5, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 7, 0]), array([0, 1, 2, 3, 4, 5, 6, 7, 0]), array([1, 3, 4, 6, 7, 0, 0, 0, 0]), array([0, 1, 2, 3, 4, 5, 7, 0, 0])]


(array([0.16961966, 0.39712888, 0.13251007, 0.0481182 , 0.08376146,
        0.77677315, 0.19884802, 0.        , 0.        , 0.5964418 ,
        0.510253  , 0.9755912 , 0.08422235, 0.5541187 , 0.10682349,
        0.3933695 , 0.        , 0.        , 0.5171895 , 0.23229332,
        0.42488804, 0.24228822, 0.46714184, 0.48234862, 0.62628335,
        0.        , 0.        , 0.74336094, 0.64409375, 0.32986483,
        0.63623714, 0.53649616, 0.31966466, 0.02330758, 0.        ,
        0.        , 0.48799506, 0.67125344, 0.07646387, 0.5962461 ,
        0.4712802 , 0.17192142, 0.83224964, 0.38636607, 0.        ,
        0.32877994, 0.57573944, 0.44998506, 0.3434182 , 0.5582224 ,
        0.5313124 , 0.20107165, 0.5560139 , 0.        , 0.46538112,
        0.10857241, 0.23590986, 0.29474154, 0.7838986 , 0.        ,
        0.        , 0.        , 0.        , 0.7093185 , 0.3581028 ,
        0.3866979 , 0.74565107, 0.18378867, 0.8377917 , 0.30145356,
        0.        , 0.        ], dtype=float32),

In [128]:
adense = to_dense(adata, acols, annz, ellwa, a.shape)

In [136]:
adenset = to_dense(adatat, acolst, annzt, ellwat, a.T.shape)

In [137]:
adense

array([[0.16961966, 0.59644181, 0.5171895 , 0.74336094, 0.48799506,
        0.32877994, 0.        , 0.70931852],
       [0.39712888, 0.51025301, 0.        , 0.64409375, 0.67125344,
        0.57573944, 0.46538112, 0.3581028 ],
       [0.13251007, 0.97559118, 0.23229332, 0.32986483, 0.07646387,
        0.44998506, 0.        , 0.38669789],
       [0.0481182 , 0.08422235, 0.42488804, 0.        , 0.59624612,
        0.34341821, 0.10857241, 0.74565107],
       [0.        , 0.55411869, 0.24228822, 0.63623714, 0.47128019,
        0.55822241, 0.23590986, 0.18378867],
       [0.08376146, 0.10682349, 0.46714184, 0.53649616, 0.17192142,
        0.53131241, 0.        , 0.83779168],
       [0.77677315, 0.3933695 , 0.48234862, 0.31966466, 0.83224964,
        0.20107165, 0.29474154, 0.        ],
       [0.19884802, 0.        , 0.62628335, 0.02330758, 0.38636607,
        0.55601388, 0.78389859, 0.30145356]])

In [138]:
adenset.T == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [20]:
a

array([[0.16961966, 0.5964418 , 0.5171895 , 0.74336094, 0.48799506,
        0.32877994, 0.        , 0.7093185 ],
       [0.39712888, 0.510253  , 0.        , 0.64409375, 0.67125344,
        0.57573944, 0.46538112, 0.3581028 ],
       [0.13251007, 0.9755912 , 0.23229332, 0.32986483, 0.07646387,
        0.44998506, 0.        , 0.3866979 ],
       [0.0481182 , 0.08422235, 0.42488804, 0.        , 0.5962461 ,
        0.3434182 , 0.10857241, 0.74565107],
       [0.        , 0.5541187 , 0.24228822, 0.63623714, 0.4712802 ,
        0.5582224 , 0.23590986, 0.18378867],
       [0.08376146, 0.10682349, 0.46714184, 0.53649616, 0.17192142,
        0.5313124 , 0.        , 0.8377917 ],
       [0.77677315, 0.3933695 , 0.48234862, 0.31966466, 0.83224964,
        0.20107165, 0.29474154, 0.        ],
       [0.19884802, 0.        , 0.62628335, 0.02330758, 0.38636607,
        0.5560139 , 0.7838986 , 0.30145356]], dtype=float32)

In [21]:
a == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [22]:
a.shape

(8, 8)

In [23]:
adata.shape, acols.shape, annz.shape, ellwa

((72,), (72,), (8,), 9)

In [24]:
#acols = acols.astype(np.uint32)
#annz = annz.astype(np.uint32)

In [25]:
adata, acols, annz, b

(array([0.16961966, 0.5964418 , 0.5171895 , 0.74336094, 0.48799506,
        0.32877994, 0.7093185 , 0.        , 0.        , 0.39712888,
        0.510253  , 0.64409375, 0.67125344, 0.57573944, 0.46538112,
        0.3581028 , 0.        , 0.        , 0.13251007, 0.9755912 ,
        0.23229332, 0.32986483, 0.07646387, 0.44998506, 0.3866979 ,
        0.        , 0.        , 0.0481182 , 0.08422235, 0.42488804,
        0.5962461 , 0.3434182 , 0.10857241, 0.74565107, 0.        ,
        0.        , 0.5541187 , 0.24228822, 0.63623714, 0.4712802 ,
        0.5582224 , 0.23590986, 0.18378867, 0.        , 0.        ,
        0.08376146, 0.10682349, 0.46714184, 0.53649616, 0.17192142,
        0.5313124 , 0.8377917 , 0.        , 0.        , 0.77677315,
        0.3933695 , 0.48234862, 0.31966466, 0.83224964, 0.20107165,
        0.29474154, 0.        , 0.        , 0.19884802, 0.62628335,
        0.02330758, 0.38636607, 0.5560139 , 0.7838986 , 0.30145356,
        0.        , 0.        ], dtype=float32),

In [26]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b.T)

prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void matmul(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      uint gid2 = get_global_id(1);

      uint nnz    = rowNnz[gid];
      uint baseidx = gid2*nrows;
      float sum = 0;
      for (uint i = 0; i < nnz; i++) {
        uint index   = (gid * ellwidth) + i;
        uint col     = colIdx[index];
        float aval  = matData[index];
        float xval  = vector_x[baseidx+col];
        //printf("aval, xval: %.2f,%.2f: (%i,%i) \\n", aval, xval, col, index);
        sum  += aval * xval;
      }
      printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
      vector_y[baseidx+gid] = sum;
    }""").build()

In [27]:
a.shape, b.shape

((8, 8), (4, 8))

In [28]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [29]:
rows = a.shape[0]

In [30]:
mult = mult.astype(np.float32)

In [31]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.matmul  # Use this Kernel object for repeated calls
knl(queue, [rows,bs], None, adata_buf, acols_buf, annzs_buf, np.uint32(ellwa), b_buf, res_buf)

res_np = np.empty_like(b.T)
cl.enqueue_copy(queue, res_np, res_buf)

<pyopencl._cl.NannyEvent at 0x7f2860330450>

SUM/NNZ: 2.30 7 
SUM/NNZ: 2.14 7 
SUM/NNZ: 1.62 7 
SUM/NNZ: 1.21 7 
SUM/NNZ: 1.81 7 
SUM/NNZ: 1.49 7 
SUM/NNZ: 2.21 7 
SUM/NNZ: 1.64 7 
SUM/NNZ: 2.09 7 
SUM/NNZ: 2.33 7 
SUM/NNZ: 1.87 7 
SUM/NNZ: 1.41 7 
SUM/NNZ: 1.92 7 
SUM/NNZ: 1.56 7 
SUM/NNZ: 1.59 7 
SUM/NNZ: 1.42 7 
SUM/NNZ: 2.26 7 
SUM/NNZ: 2.24 7 
SUM/NNZ: 1.26 7 
SUM/NNZ: 1.49 7 
SUM/NNZ: 1.68 7 
SUM/NNZ: 1.95 7 
SUM/NNZ: 2.13 7 
SUM/NNZ: 1.70 7 
SUM/NNZ: 1.47 7 
SUM/NNZ: 1.66 7 
SUM/NNZ: 0.92 7 
SUM/NNZ: 1.16 7 
SUM/NNZ: 1.11 7 
SUM/NNZ: 1.07 7 
SUM/NNZ: 1.50 7 
SUM/NNZ: 1.41 7 


In [32]:
res_buf

<pyopencl._cl.Buffer at 0x7f2860329400>

In [33]:
res_np

array([[2.3020377, 2.092142 , 2.2605045, 1.4686613],
       [2.1351898, 2.3339   , 2.2407806, 1.6568408],
       [1.6199633, 1.8661438, 1.256422 , 0.9220849],
       [1.2124329, 1.4140898, 1.4943258, 1.157969 ],
       [1.8062348, 1.9168991, 1.6767919, 1.1141719],
       [1.4917258, 1.564151 , 1.9547005, 1.0665153],
       [2.2082796, 1.5936826, 2.1341536, 1.5008274],
       [1.6359242, 1.4244308, 1.7033539, 1.407205 ]], dtype=float32)

In [34]:
mult

array([[2.3020377, 2.092142 , 2.2605045, 1.4686613],
       [2.1351898, 2.3339   , 2.2407806, 1.6568408],
       [1.6199633, 1.8661438, 1.256422 , 0.9220849],
       [1.2124329, 1.4140898, 1.4943258, 1.157969 ],
       [1.8062348, 1.9168991, 1.6767919, 1.1141719],
       [1.4917258, 1.564151 , 1.9547005, 1.0665153],
       [2.2082796, 1.5936826, 2.1341536, 1.5008274],
       [1.6359242, 1.4244308, 1.7033539, 1.407205 ]], dtype=float32)

In [35]:
res_np.shape

(8, 4)

In [36]:
mult.shape

(8, 4)

In [37]:
(res_np-mult).sum()

0.0

## Weight update kernel

In [38]:
bs = 4

In [39]:
dim = 8
topk = 2

x = np.random.rand(bs,dim).astype(np.float32)
y = np.random.rand(bs,dim).astype(np.float32)
x.shape,y.shape, topk

((4, 8), (4, 8), 2)

x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             __global  float* xout,    // INPUT
                             uint topk,
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      uint bs = get_global_size(1);
      uint gid2 = get_global_id(1);

      uint idx = n*gid2+gid;

      float valx = x[idx];
      float valy = y[idx];
      uint posx = 0;
      uint posy = 0;
      for (uint i = 0; i < n; i++) {
        uint idx2 = n*gid2+i;
        float tempval = x[idx2];
        float tempval2 = y[idx2];
        bool larger = tempval > valx;
        bool larger2 = tempval2 > valy;

        barrier(CLK_GLOBAL_MEM_FENCE);
        posx += (larger)?1:0;
        posy += (larger2)?1:0;
        barrier(CLK_GLOBAL_MEM_FENCE);
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
      //printf("posx:%i", posx);
      if (posx < topk) {
        xoutidx[posx+topk*gid2] = gid;
      }
      if (posy < topk) {
        youtidx[posy+topk*gid2] = gid;
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
      if (gid < topk) {
        for (uint j=0; j<topk; j++) {
          float res = x[xoutidx[gid+topk*gid2]+gid2*n] * y[youtidx[j+topk*gid2]+gid2*n];
          //printf("\\nJ:%i  gid:%i", j, gid);
          //printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
          barrier(CLK_GLOBAL_MEM_FENCE);
          xout[gid2*topk*topk+j*topk+gid] = res;
          barrier(CLK_GLOBAL_MEM_FENCE);
          
        }
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
    }""").build()

In [40]:
x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
#x_cp_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
#x_idx_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
#y_idx_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             __global  float* xout,    // INPUT
                             uint topk,
                             uint bs,
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      //uint bs = get_global_size(1);
      //uint gid2 = get_global_id(1);

      for (uint gid2=0; gid2<bs; gid2++){
        uint idx = n*gid2+gid;

        float valx = x[idx];
        float valy = y[idx];
        uint posx = 0;
        uint posy = 0;
        for (uint i = 0; i < n; i++) {
          uint idx2 = n*gid2+i;
          float tempval = x[idx2];
          float tempval2 = y[idx2];
          bool larger = tempval > valx;
          bool larger2 = tempval2 > valy;

          posx += (larger)?1:0;
          posy += (larger2)?1:0;
        }
        //printf("posx:%i", posx);
        if (posx < topk) {
        xoutidx[posx+topk*gid2] = gid;
        }
        if (posy < topk) {
          youtidx[posy+topk*gid2] = gid;
        }
      }
      for (uint gid2=0; gid2<bs; gid2++){
        if (gid < topk) {
          for (uint j=0; j<topk; j++) {
            float res = x[xoutidx[gid+topk*gid2]+gid2*n] * y[youtidx[j+topk*gid2]+gid2*n];
            //printf("\\nJ:%i  gid:%i", j, gid);
            //printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
            //barrier(CLK_GLOBAL_MEM_FENCE);
            xout[gid2*topk*topk+j*topk+gid] = res;
          }
        }
      }
    }""").build()

In [41]:
knl = prg.genwupdate2  # Use this Kernel object for repeated calls
evt = knl(queue, [dim], None, x_buf, y_buf, x_cp_buf, np.uint32(topk), np.uint32(bs), x_idx_buf, y_idx_buf)

#evt.wait()
resx = np.zeros(bs*topk*topk).astype(np.float32)
resxidx = np.zeros(bs*topk).astype(np.uint32)
resyidx = np.zeros(bs*topk).astype(np.uint32)

cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

<pyopencl._cl.NannyEvent at 0x7f2860344810>

knl(queue, [dim], None, y_buf, x_buf, x_cp_buft, np.uint32(topk), np.uint32(bs), x_idx_buft, y_idx_buft)

#evt.wait()
resx = np.zeros(bs*topk*topk).astype(np.float32)
resxidx = np.zeros(bs*topk).astype(np.uint32)
resyidx = np.zeros(bs*topk).astype(np.uint32)

cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

In [42]:
x

array([[0.61572224, 0.86994046, 0.1817479 , 0.13391666, 0.8094523 ,
        0.95450836, 0.63647723, 0.9892399 ],
       [0.188922  , 0.5824302 , 0.5276773 , 0.24744661, 0.49230814,
        0.64255136, 0.6712292 , 0.58414316],
       [0.9054736 , 0.702004  , 0.3757838 , 0.09323785, 0.43336618,
        0.07780908, 0.45926866, 0.5590805 ],
       [0.47826648, 0.67794526, 0.63452655, 0.48684815, 0.33555534,
        0.46694845, 0.7139826 , 0.5869166 ]], dtype=float32)

In [43]:
y

array([[0.9814869 , 0.26549807, 0.577078  , 0.82610697, 0.03952895,
        0.27001873, 0.80651575, 0.6598663 ],
       [0.8499231 , 0.02502976, 0.03331728, 0.02704707, 0.55400455,
        0.82277703, 0.22432825, 0.5286772 ],
       [0.87205654, 0.23948503, 0.29893118, 0.86796916, 0.88664794,
        0.8474162 , 0.378136  , 0.60928994],
       [0.34143072, 0.5386484 , 0.41787806, 0.91316026, 0.2801959 ,
        0.8082875 , 0.96954644, 0.16627187]], dtype=float32)

In [44]:
x.shape, y.shape

((4, 8), (4, 8))

In [45]:
resx

array([0.970926  , 0.9368375 , 0.81721795, 0.788526  , 0.57049316,
       0.5461192 , 0.55227196, 0.5286765 , 0.8028363 , 0.6224304 ,
       0.78962415, 0.6121872 , 0.6922393 , 0.6572994 , 0.6519805 ,
       0.6190727 ], dtype=float32)

In [46]:
resx.reshape(bs,topk,topk)

array([[[0.970926  , 0.9368375 ],
        [0.81721795, 0.788526  ]],

       [[0.57049316, 0.5461192 ],
        [0.55227196, 0.5286765 ]],

       [[0.8028363 , 0.6224304 ],
        [0.78962415, 0.6121872 ]],

       [[0.6922393 , 0.6572994 ],
        [0.6519805 , 0.6190727 ]]], dtype=float32)

In [47]:
resxidx

array([7, 5, 6, 5, 0, 1, 6, 1], dtype=uint32)

In [48]:
resyidx

array([0, 3, 0, 5, 4, 0, 6, 3], dtype=uint32)

In [49]:
idx = 1
xy0 = x[idx].reshape(dim,1)*y[idx]
xy0.shape

(8, 8)

In [50]:
xy0[3][7]

0.1308194

### update vals

In [51]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)

In [52]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint col = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint row = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [53]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,bs], None, adata_buf, acols_buf, annzs_buf, np.float32(1), np.uint32(ellwa), x_cp_buf, x_idx_buf, y_idx_buf)

resa = np.empty_like(adata)
resaidx = np.zeros(acols.shape).astype(np.uint32)
resannz = np.zeros(annz.shape).astype(np.uint32)

cl.enqueue_copy(queue, resa, adata_buf)
cl.enqueue_copy(queue, resaidx, acols_buf)
cl.enqueue_copy(queue, resannz, annzs_buf)


UPDATE[6,0]: 0.570493
UPDATE[7,0]: 0.970926
UPDATE[0,0]: 0.789624
UPDATE[6,3]: 0.651981
UPDATE[7,3]: 0.817218
UPDATE[6,5]: 0.552272
UPDATE[0,4]: 0.802836
UPDATE[6,6]: 0.692239
UPDATE[5,0]: 0.936837
UPDATE[5,0]: 0.546119
UPDATE[1,0]: 0.612187
UPDATE[1,3]: 0.619073
UPDATE[5,3]: 0.788526
UPDATE[5,5]: 0.528677
UPDATE[1,4]: 0.622430
UPDATE[1,6]: 0.657299

<pyopencl._cl.NannyEvent at 0x7f28602cb1d0>

In [54]:
resa.shape, resaidx.shape, resannz.shape, ellwa, a.T.shape

((72,), (72,), (8,), 9, (8, 8))

In [55]:
adenseadd = to_dense(resa, resaidx, resannz, ellwa, a.T.shape)
adenseadd

array([[-0.62000448,  0.59644181,  0.5171895 ,  0.74336094, -0.31484124,
         0.32877994,  0.        ,  0.70931852],
       [-0.21505833,  0.51025301,  0.        ,  0.02502108,  0.04882306,
         0.57573944, -0.19191828,  0.3581028 ],
       [ 0.13251007,  0.97559118,  0.23229332,  0.32986483,  0.07646387,
         0.44998506,  0.        ,  0.38669789],
       [ 0.0481182 ,  0.08422235,  0.42488804,  0.        ,  0.59624612,
         0.34341821,  0.10857241,  0.74565107],
       [ 0.        ,  0.55411869,  0.24228822,  0.63623714,  0.47128019,
         0.55822241,  0.23590986,  0.18378867],
       [-1.39919519,  0.10682349,  0.46714184, -0.25202984,  0.17192142,
         0.0026359 ,  0.        ,  0.83779168],
       [ 0.20627999,  0.3933695 ,  0.48234862, -0.33231586,  0.83224964,
        -0.35120031, -0.39749774,  0.        ],
       [-0.77207798,  0.        ,  0.62628335, -0.79391038,  0.38636607,
         0.55601388,  0.78389859,  0.30145356]])

In [56]:
adenseadd - adense

array([[-0.78962414,  0.        ,  0.        ,  0.        , -0.8028363 ,
         0.        ,  0.        ,  0.        ],
       [-0.61218721,  0.        ,  0.        , -0.61907268, -0.62243038,
         0.        , -0.6572994 ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [-1.48295666,  0.        ,  0.        , -0.788526  ,  0.        ,
        -0.52867651,  0.        ,  0.        ],
       [-0.57049316,  0.        ,  0.        , -0.65198052,  0.        ,
        -0.55227196, -0.69223928,  0.        ],
       [-0.970926  ,  0.        ,  0.        , -0.81721797,  0.        ,
         0.        ,  0.        ,  0.        ]])

In [57]:
adenseadd == adense

array([[False,  True,  True,  True, False,  True,  True,  True],
       [False,  True,  True, False, False,  True, False,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [False,  True,  True, False,  True, False,  True,  True],
       [False,  True,  True, False,  True, False, False,  True],
       [False,  True,  True, False,  True,  True,  True,  True]])

In [58]:
ellwa

9

In [59]:
adata2 = adata.reshape(-1, ellwa)
adata2

array([[0.16961966, 0.5964418 , 0.5171895 , 0.74336094, 0.48799506,
        0.32877994, 0.7093185 , 0.        , 0.        ],
       [0.39712888, 0.510253  , 0.64409375, 0.67125344, 0.57573944,
        0.46538112, 0.3581028 , 0.        , 0.        ],
       [0.13251007, 0.9755912 , 0.23229332, 0.32986483, 0.07646387,
        0.44998506, 0.3866979 , 0.        , 0.        ],
       [0.0481182 , 0.08422235, 0.42488804, 0.5962461 , 0.3434182 ,
        0.10857241, 0.74565107, 0.        , 0.        ],
       [0.5541187 , 0.24228822, 0.63623714, 0.4712802 , 0.5582224 ,
        0.23590986, 0.18378867, 0.        , 0.        ],
       [0.08376146, 0.10682349, 0.46714184, 0.53649616, 0.17192142,
        0.5313124 , 0.8377917 , 0.        , 0.        ],
       [0.77677315, 0.3933695 , 0.48234862, 0.31966466, 0.83224964,
        0.20107165, 0.29474154, 0.        , 0.        ],
       [0.19884802, 0.62628335, 0.02330758, 0.38636607, 0.5560139 ,
        0.7838986 , 0.30145356, 0.        , 0.        ]],

In [60]:
resa = resa.reshape(-1, ellwa)
resa

array([[-0.6200045 ,  0.5964418 ,  0.5171895 ,  0.74336094, -0.31484124,
         0.32877994,  0.7093185 ,  0.        ,  0.        ],
       [-0.21505833,  0.510253  ,  0.02502108,  0.04882306,  0.57573944,
        -0.19191828,  0.3581028 ,  0.        ,  0.        ],
       [ 0.13251007,  0.9755912 ,  0.23229332,  0.32986483,  0.07646387,
         0.44998506,  0.3866979 ,  0.        ,  0.        ],
       [ 0.0481182 ,  0.08422235,  0.42488804,  0.5962461 ,  0.3434182 ,
         0.10857241,  0.74565107,  0.        ,  0.        ],
       [ 0.5541187 ,  0.24228822,  0.63623714,  0.4712802 ,  0.5582224 ,
         0.23590986,  0.18378867,  0.        ,  0.        ],
       [-1.3991952 ,  0.10682349,  0.46714184, -0.25202984,  0.17192142,
         0.0026359 ,  0.8377917 ,  0.        ,  0.        ],
       [ 0.20628   ,  0.3933695 ,  0.48234862, -0.33231586,  0.83224964,
        -0.3512003 , -0.39749774,  0.        ,  0.        ],
       [-0.772078  ,  0.62628335, -0.7939104 ,  0.38636607,  0

In [61]:
resa - adata2

array([[-0.78962415,  0.        ,  0.        ,  0.        , -0.8028363 ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [-0.6121872 ,  0.        , -0.6190727 , -0.6224304 ,  0.        ,
        -0.6572994 ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [-1.4829566 ,  0.        ,  0.        , -0.788526  ,  0.        ,
        -0.5286765 ,  0.        ,  0.        ,  0.        ],
       [-0.57049316,  0.        ,  0.        , -0.6519805 ,  0.        ,
        -0.55227196, -0.6922393 ,  0.        ,  0.        ],
       [-0.970926  ,  0.        , -0.81721795,  0.        ,  0

In [62]:
acols

array([0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3,
       4, 5, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 2, 3,
       4, 5, 6, 7, 0, 0], dtype=uint32)

In [63]:
resaidx

array([0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3,
       4, 5, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 2, 3,
       4, 5, 6, 7, 0, 0], dtype=uint32)

In [64]:
resannz

array([7, 7, 7, 7, 7, 7, 7, 7], dtype=uint32)

In [65]:
annz

array([7, 7, 7, 7, 7, 7, 7, 7], dtype=uint32)

### update vals2

In [192]:
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)

In [193]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint row = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint col = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [194]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,bs], None, adatat_buf, acolst_buf, annzst_buf, np.float32(1), np.uint32(ellwat), x_cp_buf, x_idx_buf, y_idx_buf)

resat = np.empty_like(adatat)
resaidxt = np.zeros(acolst.shape).astype(np.uint32)
resannzt = np.zeros(annzt.shape).astype(np.uint32)

cl.enqueue_copy(queue, resat, adatat_buf)
cl.enqueue_copy(queue, resaidxt, acolst_buf)
cl.enqueue_copy(queue, resannzt, annzst_buf)

<pyopencl._cl.NannyEvent at 0x7f28286988b0>


UPDATE[0,6]: 0.570493
UPDATE[6,6]: 0.692239
UPDATE[0,7]: 0.970926
UPDATE[3,7]: 0.817218
UPDATE[4,0]: 0.802836
UPDATE[0,0]: 0.789624
UPDATE[5,6]: 0.552272
UPDATE[3,6]: 0.651981
UPDATE[0,5]: 0.936837
UPDATE[3,5]: 0.788526
UPDATE[4,1]: 0.622430
UPDATE[0,1]: 0.612187
UPDATE[0,5]: 0.546119
UPDATE[6,1]: 0.657299
UPDATE[5,5]: 0.528677
UPDATE[3,1]: 0.619073

In [195]:
ellwa

9

In [196]:
resat.shape, resaidxt.shape, resannzt.shape

((72,), (72,), (8,))

In [197]:
adenseaddt = to_dense(resat, resaidxt, resannzt, ellwat, a.T.shape)
adenseaddt

array([[-0.62000448, -0.21505833,  0.13251007,  0.0481182 ,  0.        ,
        -1.39919519,  0.20627999, -0.77207798],
       [ 0.59644181,  0.51025301,  0.97559118,  0.08422235,  0.55411869,
         0.10682349,  0.3933695 ,  0.        ],
       [ 0.5171895 ,  0.        ,  0.23229332,  0.42488804,  0.24228822,
         0.46714184,  0.48234862,  0.62628335],
       [ 0.74336094,  0.02502108,  0.32986483,  0.        ,  0.63623714,
        -0.25202984, -0.33231586, -0.79391038],
       [-0.31484124,  0.04882306,  0.07646387,  0.59624612,  0.47128019,
         0.17192142,  0.83224964,  0.38636607],
       [ 0.32877994,  0.57573944,  0.44998506,  0.34341821,  0.55822241,
         0.0026359 , -0.35120031,  0.55601388],
       [ 0.        , -0.19191828,  0.        ,  0.10857241,  0.23590986,
         0.        , -0.39749774,  0.78389859],
       [ 0.70931852,  0.3581028 ,  0.38669789,  0.74565107,  0.18378867,
         0.83779168,  0.        ,  0.30145356]])

In [199]:
adenseadd == adenseaddt.T

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [200]:
adata2t = adatat.reshape(-1, ellwat)
adata2t

array([[0.16961966, 0.39712888, 0.13251007, 0.0481182 , 0.08376146,
        0.77677315, 0.19884802, 0.        , 0.        ],
       [0.5964418 , 0.510253  , 0.9755912 , 0.08422235, 0.5541187 ,
        0.10682349, 0.3933695 , 0.        , 0.        ],
       [0.5171895 , 0.23229332, 0.42488804, 0.24228822, 0.46714184,
        0.48234862, 0.62628335, 0.        , 0.        ],
       [0.74336094, 0.64409375, 0.32986483, 0.63623714, 0.53649616,
        0.31966466, 0.02330758, 0.        , 0.        ],
       [0.48799506, 0.67125344, 0.07646387, 0.5962461 , 0.4712802 ,
        0.17192142, 0.83224964, 0.38636607, 0.        ],
       [0.32877994, 0.57573944, 0.44998506, 0.3434182 , 0.5582224 ,
        0.5313124 , 0.20107165, 0.5560139 , 0.        ],
       [0.46538112, 0.10857241, 0.23590986, 0.29474154, 0.7838986 ,
        0.        , 0.        , 0.        , 0.        ],
       [0.7093185 , 0.3581028 , 0.3866979 , 0.74565107, 0.18378867,
        0.8377917 , 0.30145356, 0.        , 0.        ]],

In [201]:
resat = resat.reshape(-1, ellwat)
resat

array([[-0.6200045 , -0.21505833,  0.13251007,  0.0481182 , -1.3991952 ,
         0.20628   , -0.772078  ,  0.        ,  0.        ],
       [ 0.5964418 ,  0.510253  ,  0.9755912 ,  0.08422235,  0.5541187 ,
         0.10682349,  0.3933695 ,  0.        ,  0.        ],
       [ 0.5171895 ,  0.23229332,  0.42488804,  0.24228822,  0.46714184,
         0.48234862,  0.62628335,  0.        ,  0.        ],
       [ 0.74336094,  0.02502108,  0.32986483,  0.63623714, -0.25202984,
        -0.33231586, -0.7939104 ,  0.        ,  0.        ],
       [-0.31484124,  0.04882306,  0.07646387,  0.5962461 ,  0.4712802 ,
         0.17192142,  0.83224964,  0.38636607,  0.        ],
       [ 0.32877994,  0.57573944,  0.44998506,  0.3434182 ,  0.5582224 ,
         0.0026359 , -0.3512003 ,  0.5560139 ,  0.        ],
       [-0.19191828,  0.10857241,  0.23590986, -0.39749774,  0.7838986 ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.7093185 ,  0.3581028 ,  0.3866979 ,  0.74565107,  0

In [202]:
resat - adata2t

array([[-0.78962415, -0.6121872 ,  0.        ,  0.        , -1.4829566 ,
        -0.57049316, -0.970926  ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        , -0.6190727 ,  0.        ,  0.        , -0.788526  ,
        -0.6519805 , -0.81721795,  0.        ,  0.        ],
       [-0.8028363 , -0.6224304 ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.5286765 , -0.55227196,  0.        ,  0.        ],
       [-0.6572994 ,  0.        ,  0.        , -0.6922393 ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0

In [203]:
acols

array([0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3,
       4, 5, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 2, 3,
       4, 5, 6, 7, 0, 0], dtype=uint32)

In [204]:
resaidx

array([0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3,
       4, 5, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 5, 7, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 2, 3,
       4, 5, 6, 7, 0, 0], dtype=uint32)

In [205]:
resannz

array([7, 7, 7, 7, 7, 7, 7, 7], dtype=uint32)

In [206]:
annz

array([7, 7, 7, 7, 7, 7, 7, 7], dtype=uint32)

# OTHER

import numpy as np
import pyopencl as cl

mf = cl.mem_flags

dim = 16
topk = 4

x = np.random.rand(dim).astype(np.float32)
y = np.random.rand(dim).astype(np.float32)
x.shape,y.shape

dim1 = 4
dim2 = 8
dim3 = 1

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.2

a = np.zeros((dim1,dim2))
b = np.random.rand(dim2,dim3).flatten().astype(np.float32)

a.shape, b.shape

In [None]:
x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
val_out_buf = cl.Buffer(ctx, mf.READ_WRITE, 4*topk*topk)
x_idx_buf = cl.Buffer(ctx, mf.READ_WRITE, topk*4)
y_idx_buf = cl.Buffer(ctx, mf.READ_WRITE, topk*4)

prg = cl.Program(ctx, """
// Every global_id_0 works on a row
__kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                         __global  float* y,    // INPUT
                         __global  float* xout,    // INPUT
                         uint topk,
                         __global  uint* xoutidx,    // INPUT
                         __global  uint* youtidx    // INPUT
                        ) { // LOCAL SHARED BUFFER
  uint gid = get_global_id(0);
  uint n = get_global_size(0);
  
  xout[gid] = x[gid];
  xoutidx[gid] = gid;
  youtidx[gid] = gid;
  
  float valx = x[gid];
  float valy = y[gid];
  uint posx = 0;
  uint posy = 0;
  for (uint i = 0; i < n; i++) {
    float tempval = x[i];
    float tempval2 = y[i];
    bool larger = tempval > valx;
    bool larger2 = tempval2 > valy;
      
    posx += (larger)?1:0;
    posy += (larger2)?1:0;
  }
  //printf("posx:%i", posx);
  if (posx < topk) {
    xoutidx[posx] = gid;
  }
  if (posy < topk) {
    youtidx[posy] = gid;
  }
  if (gid < topk) {
    uint i = gid;
    for (uint j=0; j<topk; j++) {
      xout[gid*topk+j] = x[xoutidx[gid]] * y[youtidx[j]];
    }
  }
}""").build()

In [None]:
knl = prg.genwupdate2  # Use this Kernel object for repeated calls
event = knl(queue, [dim,], None, x_buf, y_buf, val_out_buf, np.uint32(topk), x_idx_buf, y_idx_buf)

#event.wait()
val_out = np.zeros(topk*topk).astype(np.float32)
resxidx = np.zeros(topk).astype(np.uint32)
resyidx = np.zeros(topk).astype(np.uint32)

cl.enqueue_copy(queue, val_out, val_out_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf, wait_for=[event])
cl.enqueue_copy(queue, resyidx, y_idx_buf)

In [None]:
val_out

In [None]:
resxidx

In [None]:
resyidx

In [None]:
asdf

In [None]:
from __future__ import division

KERNEL_CODE = """
// Thread block size
#define BLOCK_SIZE %(block_size)d
// Matrix dimensions
// (chosen as multiples of the thread block size for simplicity)
#define WA %(w_a)d // Matrix A width
#define HA %(h_a)d // Matrix A height
#define WB %(w_b)d // Matrix B width
#define HB WA  // Matrix B height
#define WC WB  // Matrix C width
#define HC HA  // Matrix C height
/*
 * Copyright 1993-2009 NVIDIA Corporation.  All rights reserved.
 *
 * NVIDIA Corporation and its licensors retain all intellectual property and
 * proprietary rights in and to this software and related documentation.
 * Any use, reproduction, disclosure, or distribution of this software
 * and related documentation without an express license agreement from
 * NVIDIA Corporation is strictly prohibited.
 *
 * Please refer to the applicable NVIDIA end user license agreement (EULA)
 * associated with this source code for terms and conditions that govern
 * your use of this NVIDIA software.
 *
 */
/* Matrix multiplication: C = A * B.
 * Device code.
 */
#define AS(j, i) As[i + j * BLOCK_SIZE]
#define BS(j, i) Bs[i + j * BLOCK_SIZE]
////////////////////////////////////////////////////////////////////////////////
//! Matrix multiplication on the device: C = A * B
//! WA is A's width and WB is B's width
////////////////////////////////////////////////////////////////////////////////
__kernel __attribute__((reqd_work_group_size(16,16,1))) 
void
matrixMul( __global float* C, __global float* A, __global float* B)
{
    __local float As[BLOCK_SIZE*BLOCK_SIZE];
    __local float Bs[BLOCK_SIZE*BLOCK_SIZE];
    // Block index
    int bx = get_group_id(0);
    int by = get_group_id(1);
    // Thread index
    int tx = get_local_id(0);
    int ty = get_local_id(1);
    // Index of the first sub-matrix of A processed by the block
    int aBegin = WA * BLOCK_SIZE * by;
    // Index of the last sub-matrix of A processed by the block
    int aEnd   = aBegin + WA - 1;
    // Step size used to iterate through the sub-matrices of A
    int aStep  = BLOCK_SIZE;
    // Index of the first sub-matrix of B processed by the block
    int bBegin = BLOCK_SIZE * bx;
    // Step size used to iterate through the sub-matrices of B
    int bStep  = BLOCK_SIZE * WB;
    // Csub is used to store the element of the block sub-matrix
    // that is computed by the thread
    float Csub = 0.0f;
    // Loop over all the sub-matrices of A and B
    // required to compute the block sub-matrix
    for (int a = aBegin, b = bBegin;
             a <= aEnd;
             a += aStep, b += bStep) {
        // Load the matrices from device memory
        // to shared memory; each thread loads
        // one element of each matrix
        AS(ty, tx) = A[a + WA * ty + tx];
        BS(ty, tx) = B[b + WB * ty + tx];
        // Synchronize to make sure the matrices are loaded
        barrier(CLK_LOCAL_MEM_FENCE);
        // Multiply the two matrices together;
        // each thread computes one element
        // of the block sub-matrix
        for (int k = 0; k < BLOCK_SIZE; ++k)
            Csub += AS(ty, k) * BS(k, tx);
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    // Write the block sub-matrix to device memory;
    // each thread writes one element
    C[get_global_id(1) * get_global_size(0) + get_global_id(0)] = Csub;
}
"""
