In [1]:
import numpy as np
import dace
from dace.frontend.common import op_repository as oprepo
from dace.transformation.interstate import GPUTransformSDFG

In [2]:
@oprepo.replaces('warpReduce_sum')
def warpReduce_sum(pv, sdfg: dace.SDFG, state: dace.SDFGState, x: str) -> str:
    desc = sdfg.arrays[x]
    newname, _ = sdfg.add_temp_transient(desc.shape, desc.dtype, desc.storage)
    ctype = desc.dtype.ctype

    t = state.add_tasklet(
       'warpReduce', {'__a'}, {'__out'}, f'''
       __out = dace::warpReduce<dace::ReductionType::Sum, {ctype}>::reduce(__a);
    ''', dace.Language.CPP)
    r = state.add_read(x)
    w = state.add_write(newname)
    state.add_edge(r, None, t, '__a', dace.Memlet(data=x))
    state.add_edge(t, '__out', w, None, dace.Memlet(data=newname))
    return newname



In [7]:
@oprepo.replaces('sync_threads')
def sync_threads(pv, sdfg: dace.SDFG, state: dace.SDFGState):
    state.add_tasklet(name='syncronize_threads',
                      inputs={},
                      outputs={},
                      code='__syncthreads();', 
                      language=dace.Language.CPP)

In [None]:
# 32x32 -> 32x1 using 1 block--32warps/block
@dace.program
def myprog(a: dace.float64[32, 32]):
    b = np.empty([32], dtype=dace.float64)
    for i in dace.map[0:1]:
        for j,k in dace.map[0:32,0:32]:
            reduced = warpReduce_sum(a[j, k])
            b[j] = reduced
    return b

In [None]:
# 64x32 -> 64x1 using 2 block--32warps/block (1 warp/row)
@dace.program
def myprog(a: dace.float64[64, 32]):
    b = np.empty([64], dtype=dace.float64)
    for i in dace.map[0:2]:
        for j,k in dace.map[0:32,0:32]:
            reduced = warpReduce_sum(a[i*32+j, k])
            b[i*32+j] = reduced
    return b

In [8]:
#32x64 -> 32x1 using 1 block--32warps/block (1 thread first add 2 values)
@dace.program
def myprog(a: dace.float64[32, 64]):
    tmp = np.empty([32,32], dtype=dace.float64)
    b = np.empty([32], dtype=dace.float64)
    for i in dace.map[0:1]:
        for j,k in dace.map[0:32,0:32]:
            tmp[j,k] = a[j,k]+a[j,k+32]
            sync_threads()
            reduced = warpReduce_sum(tmp[j, k])
            b[j] = reduced
    return b

In [62]:
#16x64 -> 16x1 using 1 block--32warps/block (2 warps/row)
@dace.program
def myprog(a: dace.float64[16, 64]):
    b = np.empty([16], dtype=dace.float64)
    tmp = np.empty([16,2], dtype=dace.float64)
    for i in dace.map[0:1]:
        for j,k in dace.map[0:32,0:32]:
            if j%2 == 0:
                tmp[j//2,0] = warpReduce_sum(a[j//2, k])
            else:
                tmp[j//2,1] = warpReduce_sum(a[j//2, k+32])
            sync_threads()
            if k==0 and j%2 == 0:
                b[j//2] = tmp[j//2,0] + tmp[j//2,1]
    return b

In [67]:
#32x64 -> 32x1 using 2 block--32warps/block (2 warps/row)
@dace.program
def myprog(a: dace.float64[32, 64]):
    b = np.empty([32], dtype=dace.float64)
    tmp = np.empty([32,2], dtype=dace.float64)
    for i in dace.map[0:2]:
        for j,k in dace.map[0:32,0:32]:
            if j%2 == 0:
                tmp[i*16+j//2,0] = warpReduce_sum(a[i*16+j//2, k])
            else:
                tmp[i*16+j//2,1] = warpReduce_sum(a[i*16+j//2, k+32])
            sync_threads()
            if k == 0 and j%2 == 0:
                b[i*16+j//2] = tmp[i*16+j//2,0] + tmp[i*16+j//2,1]
    return b

In [75]:
#8x2048 -> 8x1 using 8 block--32warps/block (1 block/row)
@dace.program
def myprog(a: dace.float64[8, 2048]):
    b = np.zeros([8], dtype=dace.float64)
    tmp = np.empty([8,1024], dtype=dace.float64)
    tmp1 = np.empty([8,32], dtype=dace.float64)
    for i in dace.map[0:8]:
        for j,k in dace.map[0:32,0:32]:
            tmp[i,32*j+k] = a[i,32*j+k]+a[i,32*j+k+1024]
            sync_threads()
            tmp1[i,j] = warpReduce_sum(tmp[i, 32*j+k])
            sync_threads()
            if k==0:
                b[i] += tmp1[i,j]
    return b

In [95]:
#8x2048 -> 8x1 using 16 block--32warps/block (2 block/row)
@dace.program
def myprog(a: dace.float64[8, 2048]):
    b = np.zeros([8], dtype=dace.float64)
    tmp = np.empty([8,64], dtype=dace.float64)
    for i in dace.map[0:16]:
        for j,k in dace.map[0:32,0:32]:
            if i % 2 == 0:    
                tmp[i//2,j] = warpReduce_sum(a[i//2, 32*j+k])
            else:
                tmp[i//2,j+32] = warpReduce_sum(a[i//2, 32*j+k+1024])
            sync_threads()
            if k==0 and i%2 == 0:
                b[i//2] += tmp[i//2,j]
                sync_threads()
                b[i//2] += tmp[i//2,j+32]
    return b

In [96]:
# Transform to GPU, keep thread-block map
sdfg = myprog.to_sdfg()
# sdfg.optimize()
sdfg.apply_transformations(GPUTransformSDFG, {'sequential_innermaps': False})

1

In [103]:
# Test
a = np.random.rand(8, 2048)
b = sdfg(a)
assert np.allclose(b, np.sum(a, axis=1))

In [98]:
np.sum(a, axis=1)

array([1052.6558795 , 1030.41451885, 1030.71169467, 1018.44169287,
       1018.77723684, 1039.3605196 , 1029.17974785, 1027.02621788])

In [99]:
b

array([1052.6558795 , 1030.41451885, 1030.71169467, 1018.44169287,
       1018.77723684, 1039.3605196 , 1029.17974785, 1027.02621788])