In [1]:
import dace
import numpy as np
from dace.transformation.interstate import GPUTransformSDFG
from dace.frontend.common import op_repository as oprepo

In [2]:
def find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry:
   """ Finds the first map entry node by the given parameter name. """
   return next(n for n, _ in sdfg.all_nodes_recursive()
               if isinstance(n, dace.nodes.MapEntry) and pname in n.params)

In [3]:
@oprepo.replaces('sync_threads')
def sync_threads(pv, sdfg: dace.SDFG, state: dace.SDFGState):
    state.add_tasklet(name='syncronize_threads',
                      inputs={},
                      outputs={},
                      code='__syncthreads();', 
                      language=dace.Language.CPP)

In [34]:
@dace.program
def myprog(a: dace.float64[512]):
    for i in dace.map[0:1]:
        for j, k in dace.map[0:16, 0:32]:
            tid = j * 32 + k
            a[tid] = a[tid] + a[tid + 512]
            sync_threads()
            a[tid] = a[tid] + a[tid + 256]
            sync_threads()
            a[tid] = a[tid] + a[tid + 128]
            sync_threads()
            a[tid] = a[tid] + a[tid + 64]
            sync_threads()
            a[tid] = a[tid] + a[tid + 32]
            a[tid] = a[tid] + a[tid + 16]
            a[tid] = a[tid] + a[tid + 8]
            a[tid] = a[tid] + a[tid + 4]
            a[tid] = a[tid] + a[tid + 2]
            a[tid] = a[tid] + a[tid + 1]
    return a[0]

In [38]:
@dace.program
def myprog(a: dace.float64[512]):
    for i in dace.map[0:1]:
        for j, k in dace.map[0:16, 0:32]:
            stride = 512
            tid = j * 32 + k
            while stride > 32:
                a[tid] = a[tid] + a[tid + stride]
                stride = stride / 2
                sync_threads()
            if j == 0:
                while stride > 0:
                    a[k] = a[k] + a[k + stride]
                    stride = stride / 2
    return a[0]

In [39]:
sdfg = myprog.to_sdfg()
block_map = find_map_by_param(sdfg, 'i')
block_map.schedule = dace.ScheduleType.GPU_Device
thread_map = find_map_by_param(sdfg, 'j')
thread_map.schedule = dace.ScheduleType.GPU_ThreadBlock
sdfg.apply_transformations(GPUTransformSDFG, {'sequential_innermaps': False})

1

In [40]:
# Test
a = np.random.rand(512)
# a = np.ones(512)
res = np.sum(a, axis=0)
b = sdfg(a)
# assert np.allclose(b, res)

In [41]:
a, b

(array([250.91934827, 250.26071447, 249.37264627, 248.64128125,
        248.06805125, 247.57614678, 247.15083171, 246.42891363,
        246.35138861, 245.8618359 , 245.29984103, 245.16182761,
        244.84782738, 244.30066875, 244.12194172, 243.28236622,
        242.32265563, 241.89621109, 241.86780399, 241.18839601,
        240.90865537, 240.71913986, 239.76379656, 239.54573116,
        238.84353453, 237.86230695, 237.8485916 , 237.65431272,
        237.09573572, 236.19073506, 236.18099672, 235.21777073,
        234.30951066, 233.95756362, 233.10572793, 232.32747776,
        231.50317283, 231.21760421, 230.61775783, 230.39730228,
        229.51561662, 229.14017532, 228.80801817, 228.45067707,
        228.02997877, 227.4887928 , 227.35008637, 226.37529628,
        225.88757762, 225.65839285, 225.62448257, 225.37779991,
        225.14396422, 224.62117918, 224.06190803, 223.54262619,
        223.08154236, 222.30469764, 221.4578489 , 220.51719888,
        219.83041947, 219.15715759, 218.

In [42]:
res

250.91934826783591

In [38]:
import collections
collections.Counter(a)

Counter({257.84836470826644: 1,
         129.4808740863368: 1,
         64.56853807856378: 1,
         65.87340319561612: 1,
         33.91914850777209: 1,
         31.423018208056995: 1,
         32.26991388304218: 1,
         34.298532411420624: 1,
         15.52414695241976: 1,
         14.857997365430524: 1,
         14.124229524049412: 1,
         16.50812537498463: 1,
         15.968900079953162: 1,
         15.189506074373028: 1,
         15.651466281128373: 1,
         18.119972231472495: 1,
         6.685490923185624: 1,
         9.122273193917096: 1,
         9.137353092249903: 1,
         7.027423399192143: 1,
         8.409830948071658: 1,
         8.406107905369169: 1,
         6.511916173397078: 1,
         7.608074486291687: 1,
         8.650985400645654: 1,
         7.655246956801416: 1,
         7.7479162082369655: 1,
         6.973888672228636: 1,
         7.449137893353553: 1,
         7.641333570724848: 1,
         8.284780385501433: 1,
         8.077478964777276: 1