In [1]:
import dace
import numpy as np
from dace.transformation.interstate import GPUTransformSDFG
from dace.frontend.common import op_repository as oprepo

In [2]:
def find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry:
   """ Finds the first map entry node by the given parameter name. """
   return next(n for n, _ in sdfg.all_nodes_recursive()
               if isinstance(n, dace.nodes.MapEntry) and pname in n.params)

In [3]:
@oprepo.replaces('sync_threads')
def sync_threads(pv, sdfg: dace.SDFG, state: dace.SDFGState):
    state.add_tasklet(name='syncronize_threads',
                      inputs={},
                      outputs={},
                      code='__syncthreads();', 
                      language=dace.Language.CPP)

In [4]:
@dace.program
def myprog(a: dace.float64[512]):
    for i in dace.map[0:1]:
        for j, k in dace.map[0:16, 0:32]:
            stride = 512
            while stride > 0:
                t = j * 32 + k
                if t < stride:
                    a[t] = a[t] + a[t + stride]
                stride = stride / 2
                sync_threads()
    return a[0]

In [5]:
sdfg = myprog.to_sdfg()
block_map = find_map_by_param(sdfg, 'i')
block_map.schedule = dace.ScheduleType.GPU_Device
thread_map = find_map_by_param(sdfg, 'j')
thread_map.schedule = dace.ScheduleType.GPU_ThreadBlock
sdfg.apply_transformations(GPUTransformSDFG, {'sequential_innermaps': False})

1

In [6]:
# Test
a = np.random.rand(512)
# a = np.ones(512)
res = np.sum(a, axis=0)
b = sdfg(a)
# assert np.allclose(b, res)

In [7]:
a, b

(array([2.55076391e+02, 1.24986286e+02, 6.43627324e+01, 6.52783340e+01,
        3.17519197e+01, 3.09047683e+01, 3.20474821e+01, 3.14442685e+01,
        1.67330093e+01, 1.40259905e+01, 1.78978184e+01, 1.74374134e+01,
        1.88381379e+01, 1.59607372e+01, 1.35613025e+01, 1.60733240e+01,
        8.95152355e+00, 5.44723484e+00, 7.27496429e+00, 8.28965730e+00,
        6.72615547e+00, 7.54331123e+00, 9.31956979e+00, 7.59130595e+00,
        8.22263752e+00, 6.92053514e+00, 8.38729233e+00, 9.17221663e+00,
        9.85417551e+00, 7.52206869e+00, 5.82490401e+00, 8.67946533e+00,
        4.18638076e+00, 3.76904167e+00, 3.72827233e+00, 3.66346271e+00,
        3.37706668e+00, 3.58595629e+00, 4.89987038e+00, 4.20174363e+00,
        3.66559807e+00, 2.51708687e+00, 5.02863698e+00, 3.78081524e+00,
        5.37388855e+00, 3.64422907e+00, 3.01854394e+00, 3.23748225e+00,
        4.91872955e+00, 1.71776017e+00, 2.61948490e+00, 4.36319089e+00,
        3.44975175e+00, 3.59234543e+00, 4.32287324e+00, 3.900024

In [8]:
res

255.0763905717808

In [38]:
import collections
collections.Counter(a)

Counter({257.84836470826644: 1,
         129.4808740863368: 1,
         64.56853807856378: 1,
         65.87340319561612: 1,
         33.91914850777209: 1,
         31.423018208056995: 1,
         32.26991388304218: 1,
         34.298532411420624: 1,
         15.52414695241976: 1,
         14.857997365430524: 1,
         14.124229524049412: 1,
         16.50812537498463: 1,
         15.968900079953162: 1,
         15.189506074373028: 1,
         15.651466281128373: 1,
         18.119972231472495: 1,
         6.685490923185624: 1,
         9.122273193917096: 1,
         9.137353092249903: 1,
         7.027423399192143: 1,
         8.409830948071658: 1,
         8.406107905369169: 1,
         6.511916173397078: 1,
         7.608074486291687: 1,
         8.650985400645654: 1,
         7.655246956801416: 1,
         7.7479162082369655: 1,
         6.973888672228636: 1,
         7.449137893353553: 1,
         7.641333570724848: 1,
         8.284780385501433: 1,
         8.077478964777276: 1