In [14]:

from __future__ import division
from numba import cuda, float32
import numpy
import math
import time

# Controls threads per block and shared memory usage.
# The computation will be done on blocks of TPBxTPB elements.
TPB = 32

@cuda.jit
def calc(ntmax, Nr, Npre, stdp, dw):
    """
    Perform calc
    """
    
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    sstdp = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sdw = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)
    
    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    
    if x >= dw.shape[0] and y >= dw.shape[1]:
        # Quit if (x, y) is outside of valid C boundary
        return

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = 0.
    for i in range(0, ntmax):
        for j in range(0, int(Nr / TPB)):
            for k in range(0, int(Npre / TPB)):
                # Preload data into shared memory
                if (stdp[j,k] > 0): dw[j,k]=1
                if (stdp[j,k] < 0): dw[j,k]=-1    
                
        # Wait until all threads finish preloading
        cuda.syncthreads()

start = time.time()

ntmax = 200
Nr = 200
Npre = 200

# Initialize the data arrays
stdp = numpy.zeros([Nr,Npre])
sw = numpy.zeros([Nr,Npre])
dw = numpy.zeros([Nr,Npre])
stdp = numpy.random.normal(0,1,[Nr,Npre])            

# Copy the arrays to the device
ntmax_global_mem = cuda.to_device(ntmax)
Nr_global_mem = cuda.to_device(Nr)
Npre_global_mem = cuda.to_device(Npre)
stdp_global_mem = cuda.to_device(stdp)
dw_global_mem = cuda.to_device(dw)

# Configure the blocks
threadsperblock = (TPB, TPB)
blockspergrid_x = int(math.ceil(stdp.shape[0] / threadsperblock[1]))
blockspergrid_y = int(math.ceil(dw.shape[1] / threadsperblock[0]))
blockspergrid = (blockspergrid_x, blockspergrid_y)

# end = time.time()

# Start the kernel 
result = calc[blockspergrid, threadsperblock](ntmax, Nr, Npre, stdp, dw)
res = dw_global_mem.copy_to_host()

end = time.time()

run_time = end - start

print("Total time = {}".format(run_time))

Total time = 2.002309560775757
