In [1]:
from tinygrad.densetensor import DenseTensor
from tinygrad.sparsetensor import SparseTensor
import numpy as np

%load_ext autoreload
%autoreload 2

DEVICE:GPU


In [2]:
x_init = np.random.randn(1,3).astype(np.float32)
x2_init = np.random.randn(3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
V_init = np.random.randn(3,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)

In [3]:
x = DenseTensor(x_init)
W = DenseTensor(W_init)
m = DenseTensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()

out.cpu().data, x

T/G: <DenseTensor <GPUBuffer with shape (1, 3)> with grad None> <GPUBuffer with shape (1, 3)>
T/G: <DenseTensor <GPUBuffer with shape (1, 3)> with grad None> <GPUBuffer with shape (1, 3)>
T/G: <DenseTensor <GPUBuffer with shape (1, 3)> with grad None> <GPUBuffer with shape (1, 3)>
T/G: <DenseTensor <GPUBuffer with shape (1, 3)> with grad None> <GPUBuffer with shape (1, 3)>
T/G: <DenseTensor <GPUBuffer with shape (1, 3)> with grad <GPUBuffer with shape (1, 3)>> <GPUBuffer with shape (1, 3)>
T/G: <DenseTensor <GPUBuffer with shape (1, 3)> with grad None> <GPUBuffer with shape (1, 3)>
T/G: <DenseTensor <GPUBuffer with shape (1, 1)> with grad None> <GPUBuffer with shape (1, 1)>
T/G: <DenseTensor <GPUBuffer with shape (1, 1)> with grad None> <GPUBuffer with shape (1, 1)>
T/G: <DenseTensor <GPUBuffer with shape (1, 1)> with grad None> <GPUBuffer with shape (1, 1)>
T/G: <DenseTensor <GPUBuffer with shape (1, 1)> with grad None> <GPUBuffer with shape (1, 1)>
T/G: <DenseTensor <GPUBuffer with s

(array([0.15753908], dtype=float32),
 <DenseTensor <GPUBuffer with shape (1, 3)> with grad <GPUBuffer with shape (1, 3)>>)

x2 = DenseTensor(x2_init)#.gpu()
W = SparseTensor(W_init)
out = W.dot(x2).relu().sum()

out.backward()

out.cpu().data, x

In [4]:
import numpy as np
import pyopencl as cl

mf = cl.mem_flags

In [5]:
dim1 = 8
dim2 = 8
dim3 = 8
bs = 4

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.2

a = np.zeros((dim1,dim2))
b = np.random.rand(4,dim3).astype(np.float32)

a.shape, b.shape

((8, 8), (4, 8))

In [6]:
def fill_sparse(mat, sparsity=0.1):
    indices = np.array(range(mat.shape[1]))
    nrows = int(mat.shape[1]*sparsity)
    for row in range(mat.shape[0]):
        lim = nrows #+ int(np.random.random()*3)
        mat[row][np.random.permutation(indices)[:lim]] = np.random.random(lim)
    return mat

a = fill_sparse(a, 0.9).astype(np.float32)
#b = fill_sparse(b, sparsity)

In [7]:
a

array([[0.        , 0.52328503, 0.85276353, 0.941733  , 0.7301068 ,
        0.8106221 , 0.0631992 , 0.51675314],
       [0.5375395 , 0.        , 0.68684566, 0.43293473, 0.5705436 ,
        0.25251487, 0.61363715, 0.47675598],
       [0.        , 0.32384282, 0.5876471 , 0.4220046 , 0.446311  ,
        0.01131043, 0.9438752 , 0.48222694],
       [0.8779282 , 0.        , 0.3875374 , 0.9887559 , 0.72527486,
        0.62135303, 0.4417264 , 0.32487348],
       [0.24411957, 0.04483144, 0.844277  , 0.14873913, 0.3800302 ,
        0.        , 0.7476068 , 0.9878039 ],
       [0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.        , 0.7926516 , 0.88013357],
       [0.02267831, 0.12424071, 0.683086  , 0.        , 0.64678925,
        0.6833284 , 0.64419883, 0.5820931 ],
       [0.50808465, 0.02975472, 0.12269319, 0.9034385 , 0.46145305,
        0.24297139, 0.        , 0.77481776]], dtype=float32)

In [8]:
b

array([[0.93212026, 0.8944856 , 0.7935633 , 0.87782544, 0.17367436,
        0.9549761 , 0.16357593, 0.999538  ],
       [0.8284154 , 0.28757167, 0.78423697, 0.07313304, 0.26050532,
        0.4531351 , 0.27547118, 0.17851022],
       [0.5093762 , 0.3891016 , 0.14989905, 0.84568554, 0.89133835,
        0.6859337 , 0.10950577, 0.72377515],
       [0.43113735, 0.9447981 , 0.50534743, 0.17087796, 0.19416347,
        0.5170911 , 0.5719422 , 0.46454972]], dtype=float32)

In [9]:
x2_init.T

array([-1.1239634 ,  0.5884606 , -0.21550849], dtype=float32)

In [10]:
mult = a.dot(b.T)
mult.shape

(8, 4)

In [11]:
mult.shape

(8, 4)

In [12]:
def to_data(mat):
    all_rows = []
    all_idxs = []
    all_nnzs = []
    for row in range(mat.shape[0]):
        rowdata = []
        colidxs = []
        all_nnzs.append(0)
        for col in range(mat.shape[1]):
            val = mat[row][col]
            if val != 0:
                rowdata.append(val)
                colidxs.append(col)
                all_nnzs[-1] += 1
        all_rows.append(rowdata)
        all_idxs.append(colidxs)
    
    ellwidth = int(np.sqrt(np.max(all_nnzs))+1)**2
    #all_rows = np.array(all_rows)#.astype(np.float32).flatten()           
    for row in range(mat.shape[0]):
        #print(row, all_rows)
        all_rows[row] = np.array(all_rows[row])
        all_rows[row].resize(ellwidth)
        all_idxs[row] = np.array(all_idxs[row])
        all_idxs[row].resize(ellwidth)
        print(all_idxs[row])
    all_rows = np.array(all_rows).astype(np.float32).flatten()
    print(all_idxs)
    all_idxs = np.array(all_idxs).astype(np.uint32).flatten()
    all_nnzs = np.array(all_nnzs).astype(np.uint32)
    
    return all_rows, all_idxs, all_nnzs, ellwidth

In [13]:
def to_dense(data, cols, nnzs, ellw, shape):
    out = np.zeros(shape)
    for row in range(shape[0]):
        for icol in range(nnzs[row]):
            out[row,cols[row*ellw+icol]] = data[row*ellw+icol]
    return out

In [14]:
adata, acols, annz, ellwa = to_data(a)
adata, acols, annz, ellwa

[1 2 3 4 5 6 7 0 0]
[0 2 3 4 5 6 7 0 0]
[1 2 3 4 5 6 7 0 0]
[0 2 3 4 5 6 7 0 0]
[0 1 2 3 4 6 7 0 0]
[0 1 2 3 4 6 7 0 0]
[0 1 2 4 5 6 7 0 0]
[0 1 2 3 4 5 7 0 0]
[array([1, 2, 3, 4, 5, 6, 7, 0, 0]), array([0, 2, 3, 4, 5, 6, 7, 0, 0]), array([1, 2, 3, 4, 5, 6, 7, 0, 0]), array([0, 2, 3, 4, 5, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 6, 7, 0, 0]), array([0, 1, 2, 4, 5, 6, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 7, 0, 0])]


(array([0.52328503, 0.85276353, 0.941733  , 0.7301068 , 0.8106221 ,
        0.0631992 , 0.51675314, 0.        , 0.        , 0.5375395 ,
        0.68684566, 0.43293473, 0.5705436 , 0.25251487, 0.61363715,
        0.47675598, 0.        , 0.        , 0.32384282, 0.5876471 ,
        0.4220046 , 0.446311  , 0.01131043, 0.9438752 , 0.48222694,
        0.        , 0.        , 0.8779282 , 0.3875374 , 0.9887559 ,
        0.72527486, 0.62135303, 0.4417264 , 0.32487348, 0.        ,
        0.        , 0.24411957, 0.04483144, 0.844277  , 0.14873913,
        0.3800302 , 0.7476068 , 0.9878039 , 0.        , 0.        ,
        0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.7926516 , 0.88013357, 0.        , 0.        , 0.02267831,
        0.12424071, 0.683086  , 0.64678925, 0.6833284 , 0.64419883,
        0.5820931 , 0.        , 0.        , 0.50808465, 0.02975472,
        0.12269319, 0.9034385 , 0.46145305, 0.24297139, 0.77481776,
        0.        , 0.        ], dtype=float32),

In [15]:
adatat, acolst, annzt, ellwat = to_data(a.T)
adatat, acolst, annzt, ellwat

[1 3 4 5 6 7 0 0 0]
[0 2 4 5 6 7 0 0 0]
[0 1 2 3 4 5 6 7 0]
[0 1 2 3 4 5 7 0 0]
[0 1 2 3 4 5 6 7 0]
[0 1 2 3 6 7 0 0 0]
[0 1 2 3 4 5 6 0 0]
[0 1 2 3 4 5 6 7 0]
[array([1, 3, 4, 5, 6, 7, 0, 0, 0]), array([0, 2, 4, 5, 6, 7, 0, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 7, 0]), array([0, 1, 2, 3, 4, 5, 7, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 7, 0]), array([0, 1, 2, 3, 6, 7, 0, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 0, 0]), array([0, 1, 2, 3, 4, 5, 6, 7, 0])]


(array([0.5375395 , 0.8779282 , 0.24411957, 0.91075796, 0.02267831,
        0.50808465, 0.        , 0.        , 0.        , 0.52328503,
        0.32384282, 0.04483144, 0.0012288 , 0.12424071, 0.02975472,
        0.        , 0.        , 0.        , 0.85276353, 0.68684566,
        0.5876471 , 0.3875374 , 0.844277  , 0.3646228 , 0.683086  ,
        0.12269319, 0.        , 0.941733  , 0.43293473, 0.4220046 ,
        0.9887559 , 0.14873913, 0.24154148, 0.9034385 , 0.        ,
        0.        , 0.7301068 , 0.5705436 , 0.446311  , 0.72527486,
        0.3800302 , 0.19601001, 0.64678925, 0.46145305, 0.        ,
        0.8106221 , 0.25251487, 0.01131043, 0.62135303, 0.6833284 ,
        0.24297139, 0.        , 0.        , 0.        , 0.0631992 ,
        0.61363715, 0.9438752 , 0.4417264 , 0.7476068 , 0.7926516 ,
        0.64419883, 0.        , 0.        , 0.51675314, 0.47675598,
        0.48222694, 0.32487348, 0.9878039 , 0.88013357, 0.5820931 ,
        0.77481776, 0.        ], dtype=float32),

In [16]:
adense = to_dense(adata, acols, annz, ellwa, a.shape)

In [17]:
adenset = to_dense(adatat, acolst, annzt, ellwat, a.T.shape)

In [18]:
adense

array([[0.        , 0.52328503, 0.85276353, 0.941733  , 0.73010677,
        0.8106221 , 0.0631992 , 0.51675314],
       [0.53753948, 0.        , 0.68684566, 0.43293473, 0.57054359,
        0.25251487, 0.61363715, 0.47675598],
       [0.        , 0.32384282, 0.58764708, 0.42200461, 0.446311  ,
        0.01131043, 0.94387519, 0.48222694],
       [0.8779282 , 0.        , 0.38753739, 0.98875588, 0.72527486,
        0.62135303, 0.44172639, 0.32487348],
       [0.24411957, 0.04483144, 0.84427702, 0.14873913, 0.38003021,
        0.        , 0.74760681, 0.98780388],
       [0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.        , 0.79265159, 0.88013357],
       [0.02267831, 0.12424071, 0.68308598, 0.        , 0.64678925,
        0.68332839, 0.64419883, 0.58209312],
       [0.50808465, 0.02975472, 0.12269319, 0.90343851, 0.46145305,
        0.24297139, 0.        , 0.77481776]])

In [19]:
adenset.T == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [20]:
a

array([[0.        , 0.52328503, 0.85276353, 0.941733  , 0.7301068 ,
        0.8106221 , 0.0631992 , 0.51675314],
       [0.5375395 , 0.        , 0.68684566, 0.43293473, 0.5705436 ,
        0.25251487, 0.61363715, 0.47675598],
       [0.        , 0.32384282, 0.5876471 , 0.4220046 , 0.446311  ,
        0.01131043, 0.9438752 , 0.48222694],
       [0.8779282 , 0.        , 0.3875374 , 0.9887559 , 0.72527486,
        0.62135303, 0.4417264 , 0.32487348],
       [0.24411957, 0.04483144, 0.844277  , 0.14873913, 0.3800302 ,
        0.        , 0.7476068 , 0.9878039 ],
       [0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.        , 0.7926516 , 0.88013357],
       [0.02267831, 0.12424071, 0.683086  , 0.        , 0.64678925,
        0.6833284 , 0.64419883, 0.5820931 ],
       [0.50808465, 0.02975472, 0.12269319, 0.9034385 , 0.46145305,
        0.24297139, 0.        , 0.77481776]], dtype=float32)

In [21]:
a == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [22]:
a.shape

(8, 8)

In [23]:
adata.shape, acols.shape, annz.shape, ellwa

((72,), (72,), (8,), 9)

In [24]:
#acols = acols.astype(np.uint32)
#annz = annz.astype(np.uint32)

In [25]:
adata, acols, annz, b

(array([0.52328503, 0.85276353, 0.941733  , 0.7301068 , 0.8106221 ,
        0.0631992 , 0.51675314, 0.        , 0.        , 0.5375395 ,
        0.68684566, 0.43293473, 0.5705436 , 0.25251487, 0.61363715,
        0.47675598, 0.        , 0.        , 0.32384282, 0.5876471 ,
        0.4220046 , 0.446311  , 0.01131043, 0.9438752 , 0.48222694,
        0.        , 0.        , 0.8779282 , 0.3875374 , 0.9887559 ,
        0.72527486, 0.62135303, 0.4417264 , 0.32487348, 0.        ,
        0.        , 0.24411957, 0.04483144, 0.844277  , 0.14873913,
        0.3800302 , 0.7476068 , 0.9878039 , 0.        , 0.        ,
        0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.7926516 , 0.88013357, 0.        , 0.        , 0.02267831,
        0.12424071, 0.683086  , 0.64678925, 0.6833284 , 0.64419883,
        0.5820931 , 0.        , 0.        , 0.50808465, 0.02975472,
        0.12269319, 0.9034385 , 0.46145305, 0.24297139, 0.77481776,
        0.        , 0.        ], dtype=float32),

In [26]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
b_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b.T)

prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void matmul(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            __global  float* vector_x,    // INPUT
                            __global  float* vector_y    // OUTPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);
      uint gid2 = get_global_id(1);

      uint nnz    = rowNnz[gid];
      uint baseidx = gid2*nrows;
      float sum = 0;
      for (uint i = 0; i < nnz; i++) {
        uint index   = (gid * ellwidth) + i;
        uint col     = colIdx[index];
        float aval  = matData[index];
        float xval  = vector_x[baseidx+col];
        //printf("aval, xval: %.2f,%.2f: (%i,%i) \\n", aval, xval, col, index);
        sum  += aval * xval;
      }
      printf("SUM/NNZ: %.2f %i \\n", sum, nnz);
      vector_y[baseidx+gid] = sum;
    }""").build()

In [27]:
a.shape, b.shape

((8, 8), (4, 8))

In [28]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [29]:
rows = a.shape[0]

In [30]:
mult = mult.astype(np.float32)

In [31]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.matmul  # Use this Kernel object for repeated calls
knl(queue, [rows,bs], None, adata_buf, acols_buf, annzs_buf, np.uint32(ellwa), b_buf, res_buf)

res_np = np.empty_like(b.T)
cl.enqueue_copy(queue, res_np, res_buf)

SUM/NNZ: 2.72 7 
SUM/NNZ: 1.84 7 
SUM/NNZ: 1.43 7 
SUM/NNZ: 2.70 7 
SUM/NNZ: 1.53 7 
SUM/NNZ: 1.62 7 
SUM/NNZ: 1.70 7 
SUM/NNZ: 2.19 7 
SUM/NNZ: 1.92 7 
SUM/NNZ: 1.47 7 
SUM/NNZ: 1.53 7 
SUM/NNZ: 1.61 7 
SUM/NNZ: 1.56 7 
SUM/NNZ: 1.52 7 
SUM/NNZ: 1.59 7 
SUM/NNZ: 1.04 7 
SUM/NNZ: 1.56 7 
SUM/NNZ: 1.53 7 
SUM/NNZ: 1.05 7 
SUM/NNZ: 1.75 7 
SUM/NNZ: 1.37 7 
SUM/NNZ: 1.48 7 
SUM/NNZ: 1.35 7 
SUM/NNZ: 0.96 7 
SUM/NNZ: 3.40 7 
SUM/NNZ: 2.34 7 
SUM/NNZ: 1.85 7 
SUM/NNZ: 3.11 7 
SUM/NNZ: 2.24 7 
SUM/NNZ: 2.39 7 
SUM/NNZ: 2.13 7 
SUM/NNZ: 2.48 7 


<pyopencl._cl.NannyEvent at 0x7faa88291950>

In [32]:
res_buf

<pyopencl._cl.Buffer at 0x7faa88276900>

In [33]:
res_np

array([[3.3992476 , 1.5552958 , 2.7155886 , 1.9233913 ],
       [2.3432944 , 1.5328158 , 1.8369108 , 1.4666219 ],
       [1.8511676 , 1.0523294 , 1.4289346 , 1.5314095 ],
       [3.1101434 , 1.753692  , 2.697643  , 1.6089851 ],
       [2.2438438 , 1.3693928 , 1.5296862 , 1.5599351 ],
       [2.3948452 , 1.4849818 , 1.6218513 , 1.5196328 ],
       [2.126435  , 1.3497163 , 1.6993622 , 1.5901372 ],
       [2.4772716 , 0.96037614, 2.1915672 , 1.038724  ]], dtype=float32)

In [34]:
mult

array([[3.3992476 , 1.5552958 , 2.7155886 , 1.9233913 ],
       [2.3432944 , 1.5328158 , 1.8369108 , 1.4666219 ],
       [1.8511676 , 1.0523294 , 1.4289346 , 1.5314095 ],
       [3.1101434 , 1.753692  , 2.697643  , 1.6089851 ],
       [2.2438438 , 1.3693928 , 1.5296862 , 1.5599351 ],
       [2.3948452 , 1.4849818 , 1.6218513 , 1.5196328 ],
       [2.126435  , 1.3497163 , 1.6993622 , 1.5901372 ],
       [2.4772716 , 0.96037614, 2.1915672 , 1.038724  ]], dtype=float32)

In [86]:
res_np==mult

array([[ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]])

In [35]:
res_np.shape

(8, 4)

In [36]:
mult.shape

(8, 4)

In [37]:
(res_np-mult).sum()

0.0

## Weight update kernel

In [38]:
bs = 4

In [39]:
dim = 8
topk = 2

x = np.random.rand(bs,dim).astype(np.float32)
y = np.random.rand(bs,dim).astype(np.float32)
x.shape,y.shape, topk

((4, 8), (4, 8), 2)

x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             __global  float* xout,    // INPUT
                             uint topk,
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      uint bs = get_global_size(1);
      uint gid2 = get_global_id(1);

      uint idx = n*gid2+gid;

      float valx = x[idx];
      float valy = y[idx];
      uint posx = 0;
      uint posy = 0;
      for (uint i = 0; i < n; i++) {
        uint idx2 = n*gid2+i;
        float tempval = x[idx2];
        float tempval2 = y[idx2];
        bool larger = tempval > valx;
        bool larger2 = tempval2 > valy;

        barrier(CLK_GLOBAL_MEM_FENCE);
        posx += (larger)?1:0;
        posy += (larger2)?1:0;
        barrier(CLK_GLOBAL_MEM_FENCE);
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
      //printf("posx:%i", posx);
      if (posx < topk) {
        xoutidx[posx+topk*gid2] = gid;
      }
      if (posy < topk) {
        youtidx[posy+topk*gid2] = gid;
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
      if (gid < topk) {
        for (uint j=0; j<topk; j++) {
          float res = x[xoutidx[gid+topk*gid2]+gid2*n] * y[youtidx[j+topk*gid2]+gid2*n];
          //printf("\\nJ:%i  gid:%i", j, gid);
          //printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
          barrier(CLK_GLOBAL_MEM_FENCE);
          xout[gid2*topk*topk+j*topk+gid] = res;
          barrier(CLK_GLOBAL_MEM_FENCE);
          
        }
      }
      barrier(CLK_GLOBAL_MEM_FENCE);
    }""").build()

In [40]:
x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
x_cp_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
x_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
y_idx_buf = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
#x_cp_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*topk*4)
#x_idx_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)
#y_idx_buft = cl.Buffer(ctx, mf.WRITE_ONLY, bs*topk*4)

prg = cl.Program(ctx, """
    // sorts x and y in ascending order and returns sorted indices
    __kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                             __global  float* y,    // INPUT
                             __global  float* xout,    // INPUT
                             uint topk,
                             uint bs,
                             __global  uint* xoutidx,    // INPUT
                             __global  uint* youtidx    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint n = get_global_size(0);
      //uint bs = get_global_size(1);
      //uint gid2 = get_global_id(1);

      for (uint gid2=0; gid2<bs; gid2++){
        uint idx = n*gid2+gid;

        float valx = x[idx];
        float valy = y[idx];
        uint posx = 0;
        uint posy = 0;
        for (uint i = 0; i < n; i++) {
          uint idx2 = n*gid2+i;
          float tempval = x[idx2];
          float tempval2 = y[idx2];
          bool larger = tempval > valx;
          bool larger2 = tempval2 > valy;

          posx += (larger)?1:0;
          posy += (larger2)?1:0;
        }
        //printf("posx:%i", posx);
        if (posx < topk) {
        xoutidx[posx+topk*gid2] = gid;
        }
        if (posy < topk) {
          youtidx[posy+topk*gid2] = gid;
        }
      }
      for (uint gid2=0; gid2<bs; gid2++){
        if (gid < topk) {
          for (uint j=0; j<topk; j++) {
            float res = x[xoutidx[gid+topk*gid2]+gid2*n] * y[youtidx[j+topk*gid2]+gid2*n];
            //printf("\\nJ:%i  gid:%i", j, gid);
            //printf("\\nRES:%.2f - %i - %i -  %.2f - %.2f",res, xoutidx[gid+topk*gid2], youtidx[j+topk*gid2], x[xoutidx[gid+topk*gid2]+gid2*n], y[youtidx[j+topk*gid2]+gid2*n]);
            //barrier(CLK_GLOBAL_MEM_FENCE);
            xout[gid2*topk*topk+j*topk+gid] = res;
          }
        }
      }
    }""").build()

In [41]:
knl = prg.genwupdate2  # Use this Kernel object for repeated calls
evt = knl(queue, [dim], None, x_buf, y_buf, x_cp_buf, np.uint32(topk), np.uint32(bs), x_idx_buf, y_idx_buf)

#evt.wait()
resx = np.zeros(bs*topk*topk).astype(np.float32)
resxidx = np.zeros(bs*topk).astype(np.uint32)
resyidx = np.zeros(bs*topk).astype(np.uint32)

cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

<pyopencl._cl.NannyEvent at 0x7faa8829ed10>

knl(queue, [dim], None, y_buf, x_buf, x_cp_buft, np.uint32(topk), np.uint32(bs), x_idx_buft, y_idx_buft)

#evt.wait()
resx = np.zeros(bs*topk*topk).astype(np.float32)
resxidx = np.zeros(bs*topk).astype(np.uint32)
resyidx = np.zeros(bs*topk).astype(np.uint32)

cl.enqueue_copy(queue, resx, x_cp_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf)
cl.enqueue_copy(queue, resyidx, y_idx_buf)

In [42]:
x

array([[0.7007082 , 0.78030187, 0.81460166, 0.01791268, 0.6809253 ,
        0.23861434, 0.36688572, 0.5888133 ],
       [0.13344142, 0.31516224, 0.10078434, 0.90952665, 0.12859394,
        0.30775154, 0.73491913, 0.6020403 ],
       [0.20044023, 0.2932102 , 0.81177634, 0.21738477, 0.6369448 ,
        0.692946  , 0.35219213, 0.02217519],
       [0.68943363, 0.8550104 , 0.8628079 , 0.19288574, 0.28852978,
        0.78029066, 0.5219705 , 0.8287392 ]], dtype=float32)

In [43]:
y

array([[0.96517766, 0.22530651, 0.89500284, 0.5050394 , 0.6814487 ,
        0.6945643 , 0.81712884, 0.02768122],
       [0.1881376 , 0.24455145, 0.7569934 , 0.05124188, 0.5786078 ,
        0.86848   , 0.9589475 , 0.55534047],
       [0.13901085, 0.09989852, 0.13279662, 0.5788994 , 0.8643667 ,
        0.2441829 , 0.7013198 , 0.90401286],
       [0.6082316 , 0.68015873, 0.20494615, 0.212279  , 0.904701  ,
        0.04224334, 0.53773874, 0.942847  ]], dtype=float32)

In [44]:
x.shape, y.shape

((4, 8), (4, 8))

In [45]:
resx

array([0.78623533, 0.7531299 , 0.7290708 , 0.69837236, 0.87218827,
       0.70474887, 0.7899057 , 0.63826257, 0.73385626, 0.6264321 ,
       0.70167243, 0.59895945, 0.8134959 , 0.806144  , 0.7805832 ,
       0.77352875], dtype=float32)

In [46]:
resx.reshape(bs,topk,topk)

array([[[0.78623533, 0.7531299 ],
        [0.7290708 , 0.69837236]],

       [[0.87218827, 0.70474887],
        [0.7899057 , 0.63826257]],

       [[0.73385626, 0.6264321 ],
        [0.70167243, 0.59895945]],

       [[0.8134959 , 0.806144  ],
        [0.7805832 , 0.77352875]]], dtype=float32)

In [47]:
resxidx

array([2, 1, 3, 6, 2, 5, 2, 1], dtype=uint32)

In [48]:
resyidx

array([0, 2, 6, 5, 7, 4, 7, 4], dtype=uint32)

In [49]:
idx = 1
xy0 = x[idx].reshape(dim,1)*y[idx]
xy0.shape

(8, 8)

In [50]:
xy0[3][7]

0.505097

### update vals add dense

In [95]:
matadd = np.random.randn(*a.shape)
matadd

array([[ 1.0807365 ,  1.20313213,  1.57455516, -2.22959989, -1.41064819,
        -0.18065373, -0.49028942, -0.91055222],
       [ 0.66621009,  0.69902972,  0.20639654,  0.12547115, -0.07135084,
         0.36121905, -0.58889072, -0.68000575],
       [ 1.27063674, -0.98910569,  1.62714557, -0.16510102, -0.02893584,
         0.5267675 ,  1.05340445, -0.27819108],
       [-0.45478888,  0.61715777, -0.57410355, -0.48966723,  1.42577486,
         1.20475918,  1.10852877, -1.94279439],
       [-1.14225069, -1.02980018, -1.99348521,  0.87215483, -0.37836665,
        -0.38715072,  0.1239996 ,  0.39889502],
       [-0.28734859,  0.32789885, -1.78896243, -0.57212434,  0.55981111,
         0.42809908, -0.00338386, -1.20385249],
       [ 0.57072756, -1.88208767,  1.74096982,  0.12021744, -0.66028105,
         0.02187457, -0.5204857 ,  0.76994923],
       [-1.92552138, -1.62916861,  0.07977981,  0.14730435,  0.65091849,
         1.11089586,  1.04600782, -0.33432753]])

In [159]:
a_added = a + matadd

In [190]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)
add_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=matadd)

prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void matmul(__global  float* matData,     // INPUT MATRIX DATA
                            __global  uint*  colIdx,
                            __global  uint*  rowNnz,
                            uint   ellwidth,
                            uint   awidth,
                            __global  float* vector_x    // INPUT
                            ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint nrows = get_global_size(0);

      uint nnz    = rowNnz[gid];
      uint baseidxs = gid*ellwidth;
      uint baseidxd = gid*awidth;
      
      for (uint i=0; i<awidth; i++) {
        float addval = vector_x[baseidxd+i];
        if (gid==0)
          printf("\\nADD VAL:%.2f idx:%i/%i  col:%i", addval, baseidxs+i, baseidxd+i, colIdx[i]);
        if (addval == 0) {
          continue;
        }
        if (i == colIdx[i]) {
          matData[baseidxs+i] += addval;
        } else {
          if (i > colIdx[i])
            break;
          for (uint j=nnz+1; j>i; j--) {
            colIdx[j] = colIdx[j-1];
            matData[j] = matData[j-1];
          }
          rowNnz[gid] += 1;
          nnz = rowNnz[gid];
          matData[baseidxs+i] = addval;
          if (nnz >= ellwidth)
            break;
        }
      }
    }""").build()

In [191]:
a.shape, b.shape

((8, 8), (4, 8))

In [192]:
res = np.zeros(a.shape[0]).astype(np.float32)
#res

In [193]:
rows = a.shape[0]

In [194]:
mult = mult.astype(np.float32)

In [195]:
res_buf = cl.Buffer(ctx, mf.WRITE_ONLY, b.nbytes)
knl = prg.matmul  # Use this Kernel object for repeated calls
knl(queue, [rows], None, adata_buf, acols_buf, annzs_buf, np.uint32(ellwa),np.uint32(a.shape[1]), b_buf)

<pyopencl._cl.Event at 0x7faa60039590>


ADD VAL:0.93 idx:0/0  col:1
ADD VAL:0.89 idx:1/1  col:1
ADD VAL:0.79 idx:2/2  col:2
ADD VAL:0.88 idx:3/3  col:3
ADD VAL:0.17 idx:4/4  col:4
ADD VAL:0.95 idx:5/5  col:5
ADD VAL:0.16 idx:6/6  col:6
ADD VAL:1.00 idx:7/7  col:7

In [186]:
matadd[0][0]

1.080736504937086

In [187]:
data_res = np.empty_like(adata)
cols_res = np.empty_like(acols)
nnzs_res = np.empty_like(annz)
cl.enqueue_copy(queue, data_res, adata_buf)
cl.enqueue_copy(queue, cols_res, acols_buf)
cl.enqueue_copy(queue, nnzs_res, annzs_buf)

<pyopencl._cl.NannyEvent at 0x7faa4768aea0>

In [188]:
adenseadd = to_dense(data_res, cols_res, nnzs_res, ellwa, a.shape)
adenseadd

array([[0.00000000e+00, 1.41777062e+00, 1.64632678e+00, 1.81955838e+00,
        9.03781116e-01, 1.76559818e+00, 2.26775140e-01, 1.51629114e+00],
       [1.78510219e-01, 0.00000000e+00, 9.74417329e-01, 1.21717167e+00,
        6.43676639e-01, 5.13020158e-01, 1.06677222e+00, 7.52227187e-01],
       [7.23775148e-01, 5.09376228e-01, 9.76748705e-01, 5.71903646e-01,
        1.29199648e+00, 9.02648807e-01, 1.62980890e+00, 5.91732681e-01],
       [4.64549720e-01, 0.00000000e+00, 1.33233547e+00, 1.49410331e+00,
        8.96152854e-01, 8.15516472e-01, 9.58817482e-01, 8.96815658e-01],
       [2.44119570e-01, 4.48314399e-02, 8.44277024e-01, 1.48739129e-01,
        3.80030215e-01, 0.00000000e+00, 7.47606814e-01, 9.87803876e-01],
       [9.10757959e-01, 1.22879690e-03, 3.64622802e-01, 2.41541475e-01,
        1.96010008e-01, 0.00000000e+00, 7.92651594e-01, 8.80133569e-01],
       [2.26783101e-02, 1.24240711e-01, 6.83085978e-01, 0.00000000e+00,
        6.46789253e-01, 6.83328390e-01, 6.44198835e-01, 5.

In [141]:
a

array([[0.        , 0.52328503, 0.85276353, 0.941733  , 0.7301068 ,
        0.8106221 , 0.0631992 , 0.51675314],
       [0.5375395 , 0.        , 0.68684566, 0.43293473, 0.5705436 ,
        0.25251487, 0.61363715, 0.47675598],
       [0.        , 0.32384282, 0.5876471 , 0.4220046 , 0.446311  ,
        0.01131043, 0.9438752 , 0.48222694],
       [0.8779282 , 0.        , 0.3875374 , 0.9887559 , 0.72527486,
        0.62135303, 0.4417264 , 0.32487348],
       [0.24411957, 0.04483144, 0.844277  , 0.14873913, 0.3800302 ,
        0.        , 0.7476068 , 0.9878039 ],
       [0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.        , 0.7926516 , 0.88013357],
       [0.02267831, 0.12424071, 0.683086  , 0.        , 0.64678925,
        0.6833284 , 0.64419883, 0.5820931 ],
       [0.50808465, 0.02975472, 0.12269319, 0.9034385 , 0.46145305,
        0.24297139, 0.        , 0.77481776]], dtype=float32)

In [143]:
matadd

array([[ 1.0807365 ,  1.20313213,  1.57455516, -2.22959989, -1.41064819,
        -0.18065373, -0.49028942, -0.91055222],
       [ 0.66621009,  0.69902972,  0.20639654,  0.12547115, -0.07135084,
         0.36121905, -0.58889072, -0.68000575],
       [ 1.27063674, -0.98910569,  1.62714557, -0.16510102, -0.02893584,
         0.5267675 ,  1.05340445, -0.27819108],
       [-0.45478888,  0.61715777, -0.57410355, -0.48966723,  1.42577486,
         1.20475918,  1.10852877, -1.94279439],
       [-1.14225069, -1.02980018, -1.99348521,  0.87215483, -0.37836665,
        -0.38715072,  0.1239996 ,  0.39889502],
       [-0.28734859,  0.32789885, -1.78896243, -0.57212434,  0.55981111,
         0.42809908, -0.00338386, -1.20385249],
       [ 0.57072756, -1.88208767,  1.74096982,  0.12021744, -0.66028105,
         0.02187457, -0.5204857 ,  0.76994923],
       [-1.92552138, -1.62916861,  0.07977981,  0.14730435,  0.65091849,
         1.11089586,  1.04600782, -0.33432753]])

In [144]:
a_added

array([[ 1.08073650e+00,  1.72641716e+00,  2.42731869e+00,
        -1.28786689e+00, -6.80541420e-01,  6.29968368e-01,
        -4.27090225e-01, -3.93799079e-01],
       [ 1.20374958e+00,  6.99029719e-01,  8.93242197e-01,
         5.58405877e-01,  4.99192750e-01,  6.13733917e-01,
         2.47464332e-02, -2.03249774e-01],
       [ 1.27063674e+00, -6.65262868e-01,  2.21479265e+00,
         2.56903588e-01,  4.17375154e-01,  5.38077934e-01,
         1.99727965e+00,  2.04035860e-01],
       [ 4.23139314e-01,  6.17157771e-01, -1.86566158e-01,
         4.99088655e-01,  2.15104972e+00,  1.82611221e+00,
         1.55025516e+00, -1.61792091e+00],
       [-8.98131118e-01, -9.84968741e-01, -1.14920819e+00,
         1.02089396e+00,  1.66356791e-03, -3.87150721e-01,
         8.71606410e-01,  1.38669890e+00],
       [ 6.23409373e-01,  3.29127643e-01, -1.42433963e+00,
        -3.30582868e-01,  7.55821118e-01,  4.28099077e-01,
         7.89267733e-01, -3.23718925e-01],
       [ 5.93405869e-01, -1.757846

In [189]:
adenseadd

array([[0.00000000e+00, 1.41777062e+00, 1.64632678e+00, 1.81955838e+00,
        9.03781116e-01, 1.76559818e+00, 2.26775140e-01, 1.51629114e+00],
       [1.78510219e-01, 0.00000000e+00, 9.74417329e-01, 1.21717167e+00,
        6.43676639e-01, 5.13020158e-01, 1.06677222e+00, 7.52227187e-01],
       [7.23775148e-01, 5.09376228e-01, 9.76748705e-01, 5.71903646e-01,
        1.29199648e+00, 9.02648807e-01, 1.62980890e+00, 5.91732681e-01],
       [4.64549720e-01, 0.00000000e+00, 1.33233547e+00, 1.49410331e+00,
        8.96152854e-01, 8.15516472e-01, 9.58817482e-01, 8.96815658e-01],
       [2.44119570e-01, 4.48314399e-02, 8.44277024e-01, 1.48739129e-01,
        3.80030215e-01, 0.00000000e+00, 7.47606814e-01, 9.87803876e-01],
       [9.10757959e-01, 1.22879690e-03, 3.64622802e-01, 2.41541475e-01,
        1.96010008e-01, 0.00000000e+00, 7.92651594e-01, 8.80133569e-01],
       [2.26783101e-02, 1.24240711e-01, 6.83085978e-01, 0.00000000e+00,
        6.46789253e-01, 6.83328390e-01, 6.44198835e-01, 5.

In [145]:
a_added == res_np

array([[False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False]])

### update vals

In [51]:
adata_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adata)
acols_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acols)
annzs_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annz)

In [52]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint col = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint row = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [53]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,bs], None, adata_buf, acols_buf, annzs_buf, np.float32(1), np.uint32(ellwa), x_cp_buf, x_idx_buf, y_idx_buf)

resa = np.empty_like(adata)
resaidx = np.zeros(acols.shape).astype(np.uint32)
resannz = np.zeros(annz.shape).astype(np.uint32)

cl.enqueue_copy(queue, resa, adata_buf)
cl.enqueue_copy(queue, resaidx, acols_buf)
cl.enqueue_copy(queue, resannz, annzs_buf)

<pyopencl._cl.NannyEvent at 0x7faa8829c400>


UPDATE[2,4]: 0.701672
UPDATE[2,4]: 0.780583
UPDATE[3,5]: 0.789906
INSERT[2,0]: 0.79
UPDATE[3,6]: 0.872188
UPDATE[2,7]: 0.733856
UPDATE[2,7]: 0.813496
UPDATE[2,2]: 0.729071
UPDATE[6,5]: 0.638263
UPDATE[5,4]: 0.598959
UPDATE[1,4]: 0.773529
UPDATE[1,0]: 0.753130
UPDATE[6,6]: 0.704749
UPDATE[5,7]: 0.626432
UPDATE[1,7]: 0.806144
UPDATE[1,2]: 0.698372

In [54]:
resa.shape, resaidx.shape, resannz.shape, ellwa, a.T.shape

((72,), (72,), (8,), 9, (8, 8))

In [55]:
adenseadd = to_dense(resa, resaidx, resannz, ellwa, a.T.shape)
adenseadd

array([[ 0.        ,  0.52328503,  0.85276353,  0.941733  ,  0.73010677,
         0.8106221 ,  0.0631992 ,  0.51675314],
       [-0.21559042,  0.        , -0.0115267 ,  0.43293473, -0.20298517,
         0.25251487,  0.61363715, -0.32938802],
       [-0.78623533,  0.32384282, -0.1414237 ,  0.42200461, -0.33427221,
         0.01131043,  0.94387519, -1.06512523],
       [ 0.8779282 ,  0.        ,  0.38753739,  0.98875588,  0.72527486,
        -0.1685527 , -0.43046188,  0.32487348],
       [ 0.24411957,  0.04483144,  0.84427702,  0.14873913,  0.38003021,
         0.        ,  0.74760681,  0.98780388],
       [ 0.91075796,  0.0012288 ,  0.3646228 ,  0.24154148, -0.40294945,
         0.        ,  0.79265159,  0.25370145],
       [ 0.02267831,  0.12424071,  0.68308598,  0.        ,  0.64678925,
         0.04506582, -0.06055003,  0.58209312],
       [ 0.50808465,  0.02975472,  0.12269319,  0.90343851,  0.46145305,
         0.24297139,  0.        ,  0.77481776]])

In [56]:
adenseadd - adense

array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [-0.7531299 ,  0.        , -0.69837236,  0.        , -0.77352875,
         0.        ,  0.        , -0.806144  ],
       [-0.78623533,  0.        , -0.72907078,  0.        , -0.7805832 ,
         0.        ,  0.        , -1.54735216],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.78990573, -0.87218827,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        , -0.59895946,
         0.        ,  0.        , -0.62643212],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.63826257, -0.70474887,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ]])

In [57]:
adenseadd == adense

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [False,  True, False,  True, False,  True,  True, False],
       [False,  True, False,  True, False,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True, False,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [58]:
ellwa

9

In [59]:
adata2 = adata.reshape(-1, ellwa)
adata2

array([[0.52328503, 0.85276353, 0.941733  , 0.7301068 , 0.8106221 ,
        0.0631992 , 0.51675314, 0.        , 0.        ],
       [0.5375395 , 0.68684566, 0.43293473, 0.5705436 , 0.25251487,
        0.61363715, 0.47675598, 0.        , 0.        ],
       [0.32384282, 0.5876471 , 0.4220046 , 0.446311  , 0.01131043,
        0.9438752 , 0.48222694, 0.        , 0.        ],
       [0.8779282 , 0.3875374 , 0.9887559 , 0.72527486, 0.62135303,
        0.4417264 , 0.32487348, 0.        , 0.        ],
       [0.24411957, 0.04483144, 0.844277  , 0.14873913, 0.3800302 ,
        0.7476068 , 0.9878039 , 0.        , 0.        ],
       [0.91075796, 0.0012288 , 0.3646228 , 0.24154148, 0.19601001,
        0.7926516 , 0.88013357, 0.        , 0.        ],
       [0.02267831, 0.12424071, 0.683086  , 0.64678925, 0.6833284 ,
        0.64419883, 0.5820931 , 0.        , 0.        ],
       [0.50808465, 0.02975472, 0.12269319, 0.9034385 , 0.46145305,
        0.24297139, 0.77481776, 0.        , 0.        ]],

In [60]:
resa = resa.reshape(-1, ellwa)
resa

array([[ 0.52328503,  0.85276353,  0.941733  ,  0.7301068 ,  0.8106221 ,
         0.0631992 ,  0.51675314,  0.        ,  0.        ],
       [-0.21559042, -0.0115267 ,  0.43293473, -0.20298517,  0.25251487,
         0.61363715, -0.32938802,  0.        ,  0.        ],
       [-0.78623533,  0.32384282, -0.1414237 ,  0.4220046 , -0.3342722 ,
         0.01131043,  0.9438752 , -1.0651252 ,  0.        ],
       [ 0.8779282 ,  0.3875374 ,  0.9887559 ,  0.72527486, -0.1685527 ,
        -0.43046188,  0.32487348,  0.        ,  0.        ],
       [ 0.24411957,  0.04483144,  0.844277  ,  0.14873913,  0.3800302 ,
         0.7476068 ,  0.9878039 ,  0.        ,  0.        ],
       [ 0.91075796,  0.0012288 ,  0.3646228 ,  0.24154148, -0.40294945,
         0.7926516 ,  0.25370145,  0.        ,  0.        ],
       [ 0.02267831,  0.12424071,  0.683086  ,  0.64678925,  0.04506582,
        -0.06055003,  0.5820931 ,  0.        ,  0.        ],
       [ 0.50808465,  0.02975472,  0.12269319,  0.9034385 ,  0

In [61]:
resa - adata2

array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [-0.7531299 , -0.69837236,  0.        , -0.77352875,  0.        ,
         0.        , -0.806144  ,  0.        ,  0.        ],
       [-1.1100781 , -0.26380426, -0.5634283 , -0.02430639, -0.34558263,
        -0.93256474,  0.46164826, -1.0651252 ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        , -0.7899057 ,
        -0.87218827,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        , -0.59895945,
         0.        , -0.6264321 ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        , -0.63826257,
        -0.70474887,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0

In [62]:
acols

array([1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4,
       5, 6, 7, 0, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3, 4, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 6, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 0, 1, 2,
       3, 4, 5, 7, 0, 0], dtype=uint32)

In [63]:
resaidx

array([1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3,
       4, 5, 6, 7, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3, 4, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 6, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 0, 1, 2,
       3, 4, 5, 7, 0, 0], dtype=uint32)

In [64]:
resannz

array([7, 7, 8, 7, 7, 7, 7, 7], dtype=uint32)

In [65]:
annz

array([7, 7, 7, 7, 7, 7, 7, 7], dtype=uint32)

### update vals2

In [66]:
adatat_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=adatat)
acolst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=acolst)
annzst_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=annzt)

In [67]:
prg = cl.Program(ctx, """
// Every global_id_0 works on a row
    __kernel void addvals(__global  float* matData,     // INPUT MATRIX DATA
                         __global  uint*  colIdx,
                         __global  uint*  rowNnz,
                         float lr,
                         uint   ellwidth,
                         __global  float* updatevals,    // INPUT
                         __global  uint* updatexidx,
                         __global  uint* updateyidx
                         ) { // LOCAL SHARED BUFFER
      uint gid = get_global_id(0);
      uint gid2 = get_global_id(1);
      uint topk = get_global_size(0);
      uint bs = get_global_size(1);
      uint baseupdateidx = topk*topk*gid2;
      uint baseidxidx = topk*gid2;
      uint row = updateyidx[baseidxidx+gid];

      for (uint i=0; i<topk; i++) {
        float val = updatevals[baseupdateidx+gid*topk+i];
        uint col = updatexidx[baseidxidx+i];
        for (uint i=0; i<rowNnz[row]; i++) {
          uint idx = row*ellwidth+i;
          if (colIdx[idx] >= col) {
            //printf("\\nFOUND:%i/%i  - idx:%i", colIdx[idx], col, idx);
            if (colIdx[idx] == col) {
              matData[idx] += -val*lr;
              printf("\\nUPDATE[%i,%i]: %f", row,col, val);
              break;
            } else {
              // insert new column
              printf("\\nINSERT[%i,%i]: %.2f", row,col, val);
              for (uint j=rowNnz[row]+1; j>i; j--) {
                uint idx2 = row*ellwidth+j;
                matData[idx2] = matData[idx2-1];
                colIdx[idx2] = colIdx[idx2-1];
              }
              matData[idx] = -val*lr;
              colIdx[idx] = col;
              rowNnz[row] += 1;
              break;
            }
          }
        }
        if (rowNnz[row] >= ellwidth) {
          break;
        }
      }
    }""").build()

In [68]:
knl = prg.addvals  # Use this Kernel object for repeated calls
knl(queue, [topk,bs], None, adatat_buf, acolst_buf, annzst_buf, np.float32(1), np.uint32(ellwat), x_cp_buf, x_idx_buf, y_idx_buf)

resat = np.empty_like(adatat)
resaidxt = np.zeros(acolst.shape).astype(np.uint32)
resannzt = np.zeros(annzt.shape).astype(np.uint32)

cl.enqueue_copy(queue, resat, adatat_buf)
cl.enqueue_copy(queue, resaidxt, acolst_buf)
cl.enqueue_copy(queue, resannzt, annzst_buf)


UPDATE[7,2]: 0.733856
UPDATE[4,2]: 0.701672
UPDATE[7,2]: 0.813496
UPDATE[4,2]: 0.780583
UPDATE[6,3]: 0.872188
UPDATE[5,3]: 0.789906
INSERT[0,2]: 0.79
UPDATE[5,6]: 0.638263
UPDATE[7,1]: 0.806144
UPDATE[4,1]: 0.773529
UPDATE[7,5]: 0.626432
UPDATE[4,5]: 0.598959
UPDATE[2,2]: 0.729071
UPDATE[6,6]: 0.704749
UPDATE[0,1]: 0.753130
UPDATE[2,1]: 0.698372

<pyopencl._cl.NannyEvent at 0x7faa882407c0>

In [69]:
ellwa

9

In [70]:
resat.shape, resaidxt.shape, resannzt.shape

((72,), (72,), (8,))

In [71]:
adenseaddt = to_dense(resat, resaidxt, resannzt, ellwat, a.T.shape)
adenseaddt

array([[ 0.        , -0.21559042, -0.78623533,  0.8779282 ,  0.24411957,
         0.91075796,  0.02267831,  0.50808465],
       [ 0.52328503,  0.        ,  0.32384282,  0.        ,  0.04483144,
         0.0012288 ,  0.12424071,  0.02975472],
       [ 0.85276353, -0.0115267 , -0.1414237 ,  0.38753739,  0.84427702,
         0.3646228 ,  0.68308598,  0.12269319],
       [ 0.941733  ,  0.43293473,  0.42200461,  0.98875588,  0.14873913,
         0.24154148,  0.        ,  0.90343851],
       [ 0.73010677, -0.20298517, -0.33427221,  0.72527486,  0.38003021,
        -0.40294945,  0.64678925,  0.46145305],
       [ 0.8106221 ,  0.25251487,  0.01131043, -0.1685527 ,  0.        ,
         0.        ,  0.04506582,  0.24297139],
       [ 0.0631992 ,  0.61363715,  0.94387519, -0.43046188,  0.74760681,
         0.79265159, -0.06055003,  0.        ],
       [ 0.51675314, -0.32938802, -0.33126894,  0.32487348,  0.98780388,
         0.25370145,  0.58209312,  0.77481776]])

In [72]:
adenseadd == adenseaddt.T

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [73]:
adata2t = adatat.reshape(-1, ellwat)
adata2t

array([[0.5375395 , 0.8779282 , 0.24411957, 0.91075796, 0.02267831,
        0.50808465, 0.        , 0.        , 0.        ],
       [0.52328503, 0.32384282, 0.04483144, 0.0012288 , 0.12424071,
        0.02975472, 0.        , 0.        , 0.        ],
       [0.85276353, 0.68684566, 0.5876471 , 0.3875374 , 0.844277  ,
        0.3646228 , 0.683086  , 0.12269319, 0.        ],
       [0.941733  , 0.43293473, 0.4220046 , 0.9887559 , 0.14873913,
        0.24154148, 0.9034385 , 0.        , 0.        ],
       [0.7301068 , 0.5705436 , 0.446311  , 0.72527486, 0.3800302 ,
        0.19601001, 0.64678925, 0.46145305, 0.        ],
       [0.8106221 , 0.25251487, 0.01131043, 0.62135303, 0.6833284 ,
        0.24297139, 0.        , 0.        , 0.        ],
       [0.0631992 , 0.61363715, 0.9438752 , 0.4417264 , 0.7476068 ,
        0.7926516 , 0.64419883, 0.        , 0.        ],
       [0.51675314, 0.47675598, 0.48222694, 0.32487348, 0.9878039 ,
        0.88013357, 0.5820931 , 0.77481776, 0.        ]],

In [74]:
resat = resat.reshape(-1, ellwat)
resat

array([[-0.21559042, -0.78623533,  0.8779282 ,  0.24411957,  0.91075796,
         0.02267831,  0.50808465,  0.        ,  0.        ],
       [ 0.52328503,  0.32384282,  0.04483144,  0.0012288 ,  0.12424071,
         0.02975472,  0.        ,  0.        ,  0.        ],
       [ 0.85276353, -0.0115267 , -0.1414237 ,  0.3875374 ,  0.844277  ,
         0.3646228 ,  0.683086  ,  0.12269319,  0.        ],
       [ 0.941733  ,  0.43293473,  0.4220046 ,  0.9887559 ,  0.14873913,
         0.24154148,  0.9034385 ,  0.        ,  0.        ],
       [ 0.7301068 , -0.20298517, -0.3342722 ,  0.72527486,  0.3800302 ,
        -0.40294945,  0.64678925,  0.46145305,  0.        ],
       [ 0.8106221 ,  0.25251487,  0.01131043, -0.1685527 ,  0.04506582,
         0.24297139,  0.        ,  0.        ,  0.        ],
       [ 0.0631992 ,  0.61363715,  0.9438752 , -0.43046188,  0.7476068 ,
         0.7926516 , -0.06055003,  0.        ,  0.        ],
       [ 0.51675314, -0.32938802, -0.33126894,  0.32487348,  0

In [75]:
resat - adata2t

array([[-0.7531299 , -1.6641636 ,  0.6338086 , -0.6666384 ,  0.88807964,
        -0.48540634,  0.50808465,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        , -0.69837236, -0.7290708 ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        , -0.77352875, -0.7805832 ,  0.        ,  0.        ,
        -0.59895945,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , -0.7899057 , -0.63826257,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , -0.87218827,  0.        ,
         0.        , -0.70474887,  0.        ,  0.        ],
       [ 0.        , -0.806144  , -0.8134959 ,  0.        ,  0

In [76]:
acols

array([1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4,
       5, 6, 7, 0, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3, 4, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 6, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 0, 1, 2,
       3, 4, 5, 7, 0, 0], dtype=uint32)

In [77]:
resaidx

array([1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3,
       4, 5, 6, 7, 0, 0, 2, 3, 4, 5, 6, 7, 0, 0, 0, 1, 2, 3, 4, 6, 7, 0,
       0, 0, 1, 2, 3, 4, 6, 7, 0, 0, 0, 1, 2, 4, 5, 6, 7, 0, 0, 0, 1, 2,
       3, 4, 5, 7, 0, 0], dtype=uint32)

In [78]:
resannz

array([7, 7, 8, 7, 7, 7, 7, 7], dtype=uint32)

In [79]:
annz

array([7, 7, 7, 7, 7, 7, 7, 7], dtype=uint32)

# OTHER

import numpy as np
import pyopencl as cl

mf = cl.mem_flags

dim = 16
topk = 4

x = np.random.rand(dim).astype(np.float32)
y = np.random.rand(dim).astype(np.float32)
x.shape,y.shape

dim1 = 4
dim2 = 8
dim3 = 1

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx,
        properties=cl.command_queue_properties.PROFILING_ENABLE)

sparsity = 0.2

a = np.zeros((dim1,dim2))
b = np.random.rand(dim2,dim3).flatten().astype(np.float32)

a.shape, b.shape

In [80]:
x_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=y)
val_out_buf = cl.Buffer(ctx, mf.READ_WRITE, 4*topk*topk)
x_idx_buf = cl.Buffer(ctx, mf.READ_WRITE, topk*4)
y_idx_buf = cl.Buffer(ctx, mf.READ_WRITE, topk*4)

prg = cl.Program(ctx, """
// Every global_id_0 works on a row
__kernel void genwupdate2(__global  float* x,     // INPUT MATRIX DATA
                         __global  float* y,    // INPUT
                         __global  float* xout,    // INPUT
                         uint topk,
                         __global  uint* xoutidx,    // INPUT
                         __global  uint* youtidx    // INPUT
                        ) { // LOCAL SHARED BUFFER
  uint gid = get_global_id(0);
  uint n = get_global_size(0);
  
  xout[gid] = x[gid];
  xoutidx[gid] = gid;
  youtidx[gid] = gid;
  
  float valx = x[gid];
  float valy = y[gid];
  uint posx = 0;
  uint posy = 0;
  for (uint i = 0; i < n; i++) {
    float tempval = x[i];
    float tempval2 = y[i];
    bool larger = tempval > valx;
    bool larger2 = tempval2 > valy;
      
    posx += (larger)?1:0;
    posy += (larger2)?1:0;
  }
  //printf("posx:%i", posx);
  if (posx < topk) {
    xoutidx[posx] = gid;
  }
  if (posy < topk) {
    youtidx[posy] = gid;
  }
  if (gid < topk) {
    uint i = gid;
    for (uint j=0; j<topk; j++) {
      xout[gid*topk+j] = x[xoutidx[gid]] * y[youtidx[j]];
    }
  }
}""").build()

In [81]:
knl = prg.genwupdate2  # Use this Kernel object for repeated calls
event = knl(queue, [dim,], None, x_buf, y_buf, val_out_buf, np.uint32(topk), x_idx_buf, y_idx_buf)

#event.wait()
val_out = np.zeros(topk*topk).astype(np.float32)
resxidx = np.zeros(topk).astype(np.uint32)
resyidx = np.zeros(topk).astype(np.uint32)

cl.enqueue_copy(queue, val_out, val_out_buf)
cl.enqueue_copy(queue, resxidx, x_idx_buf, wait_for=[event])
cl.enqueue_copy(queue, resyidx, y_idx_buf)

<pyopencl._cl.NannyEvent at 0x7faa88240900>

In [82]:
val_out

array([0.78623533, 0.7290708 , 0.7531299 , 0.69837236], dtype=float32)

In [83]:
resxidx

array([2, 1], dtype=uint32)

In [84]:
resyidx

array([0, 2], dtype=uint32)

In [85]:
asdf

NameError: name 'asdf' is not defined

In [None]:
from __future__ import division

KERNEL_CODE = """
// Thread block size
#define BLOCK_SIZE %(block_size)d
// Matrix dimensions
// (chosen as multiples of the thread block size for simplicity)
#define WA %(w_a)d // Matrix A width
#define HA %(h_a)d // Matrix A height
#define WB %(w_b)d // Matrix B width
#define HB WA  // Matrix B height
#define WC WB  // Matrix C width
#define HC HA  // Matrix C height
/*
 * Copyright 1993-2009 NVIDIA Corporation.  All rights reserved.
 *
 * NVIDIA Corporation and its licensors retain all intellectual property and
 * proprietary rights in and to this software and related documentation.
 * Any use, reproduction, disclosure, or distribution of this software
 * and related documentation without an express license agreement from
 * NVIDIA Corporation is strictly prohibited.
 *
 * Please refer to the applicable NVIDIA end user license agreement (EULA)
 * associated with this source code for terms and conditions that govern
 * your use of this NVIDIA software.
 *
 */
/* Matrix multiplication: C = A * B.
 * Device code.
 */
#define AS(j, i) As[i + j * BLOCK_SIZE]
#define BS(j, i) Bs[i + j * BLOCK_SIZE]
////////////////////////////////////////////////////////////////////////////////
//! Matrix multiplication on the device: C = A * B
//! WA is A's width and WB is B's width
////////////////////////////////////////////////////////////////////////////////
__kernel __attribute__((reqd_work_group_size(16,16,1))) 
void
matrixMul( __global float* C, __global float* A, __global float* B)
{
    __local float As[BLOCK_SIZE*BLOCK_SIZE];
    __local float Bs[BLOCK_SIZE*BLOCK_SIZE];
    // Block index
    int bx = get_group_id(0);
    int by = get_group_id(1);
    // Thread index
    int tx = get_local_id(0);
    int ty = get_local_id(1);
    // Index of the first sub-matrix of A processed by the block
    int aBegin = WA * BLOCK_SIZE * by;
    // Index of the last sub-matrix of A processed by the block
    int aEnd   = aBegin + WA - 1;
    // Step size used to iterate through the sub-matrices of A
    int aStep  = BLOCK_SIZE;
    // Index of the first sub-matrix of B processed by the block
    int bBegin = BLOCK_SIZE * bx;
    // Step size used to iterate through the sub-matrices of B
    int bStep  = BLOCK_SIZE * WB;
    // Csub is used to store the element of the block sub-matrix
    // that is computed by the thread
    float Csub = 0.0f;
    // Loop over all the sub-matrices of A and B
    // required to compute the block sub-matrix
    for (int a = aBegin, b = bBegin;
             a <= aEnd;
             a += aStep, b += bStep) {
        // Load the matrices from device memory
        // to shared memory; each thread loads
        // one element of each matrix
        AS(ty, tx) = A[a + WA * ty + tx];
        BS(ty, tx) = B[b + WB * ty + tx];
        // Synchronize to make sure the matrices are loaded
        barrier(CLK_LOCAL_MEM_FENCE);
        // Multiply the two matrices together;
        // each thread computes one element
        // of the block sub-matrix
        for (int k = 0; k < BLOCK_SIZE; ++k)
            Csub += AS(ty, k) * BS(k, tx);
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    // Write the block sub-matrix to device memory;
    // each thread writes one element
    C[get_global_id(1) * get_global_size(0) + get_global_id(0)] = Csub;
}
"""
