In [1]:
import dace
import gt4py
import numpy as np

from gt4py import gtscript
from gt4py.gtscript import Field

In [2]:
@gtscript.as_sdfg
def diffusion_step(c_in: Field[np.float32], c_out: Field[np.float32], delc: Field[np.float32], coeff: np.float32, dt: np.float32):
    with computation(PARALLEL), interval(1, -1):
        delc = (
            -6.0 * c_in[0, 0, 0]
            + c_in[1, 0, 0]
            + c_in[0, 1, 0]
            + c_in[0, 0, 1]
            + c_in[-1, 0, 0]
            + c_in[0, -1, 0]
            + c_in[0, 0, -1]
        )
        c_out = c_in + dt * coeff * delc
        delc = delc if delc>=0 else -delc

In [3]:

dt = np.float32(0.1)
coeff = np.float32(0.5)
domain = (10, 10, 10)

I = domain[0]
J = domain[1]
K = domain[2]

@dace.program
def swap(x, y):
    return y, x
    
@dace.program
def diffusion_periodic(c: dace.float32[I, J, K], threshold: dace.float32, maxiter: dace.int64):
    c_haloed = np.empty((I + 2, J + 2, K + 2), dtype=np.float32)
    c_tmp = np.empty_like(c_haloed)
    delc = np.empty((I, J, K+2), dtype=np.float32)
    delc[...] = threshold+1.0
    c_haloed[1:I + 1, 1:J + 1, 1:K + 1] = c

    it = 0
    tmp_c_tmp = np.empty_like(c_tmp[1:I + 1, 1:J + 1, 0:K + 2])
    tmp_delc = np.empty((I, J, K-2), dtype=np.float32)
    tmp_delc[...] = delc[:,:,1:K-1]
    
    while tmp_delc.max() > threshold and it<maxiter:
        # set periodic BC:
        c_haloed[0, 1:J + 1, 1:K + 1] = c_haloed[I, 1:J + 1, 1:K + 1]
        c_haloed[I + 1, 1:J + 1, 1:K + 1] = c_haloed[1, 1:J + 1, 1:K + 1]
        c_haloed[1:I + 1, 0, 1:K + 1] = c_haloed[1:I + 1, J, 1:K + 1]
        c_haloed[1:I + 1, J + 1, 1:K + 1] = c_haloed[1:I + 1, 1, 1:K + 1]
        c_haloed[1:I + 1, 1:J + 1, 0] = c_haloed[1:I + 1, 1:J + 1, K]
        c_haloed[1:I + 1, 1:J + 1, K + 1] = c_haloed[1:I + 1, 1:J + 1, 1]

        # preparing output slice
        tmp_c_tmp[...] = c_tmp[1:I + 1, 1:J + 1, 0:K + 2]
        
        # stencil call
        diffusion_step(c_in=c_haloed, c_out=tmp_c_tmp, delc=delc, coeff=coeff, dt=dt)
        
        # writing back results
        c_tmp[1:I + 1, 1:J + 1, 0:K+2] = tmp_c_tmp
        tmp_delc[...] = delc[:,:,1:K-1]

        # swap
        c_tmp[:], c_haloed[:] = swap(c_tmp, c_haloed)
        it += 1
    c[...] = c_haloed[1:I + 1, 1:J + 1, 1:K + 1]


In [5]:
c = np.zeros(domain, dtype='float32')
c[0:2, 0:2, 0:2] = 1.0

# diffusion_callable = diffusion_periodic.compile()

res = diffusion_periodic(c, threshold=1e-8, maxiter=99999999)
print(c.min(), c.max())
diffusion_periodic.to_sdfg().save('tmp.sdfg')

0.007999964 0.008000025


'ac8c225df1689053edb7076813fbfbd1f205f3c2dc204765298695ac69d45bfa'