In [1]:
import random
import math
import timeit
import numpy as np
import matplotlib.pyplot as plt
import gt4py as gt
import gt4py.cartesian.gtscript as gtscript

In [6]:
from common import (
    initialize_fields,
    plot_field,
    array_to_gt_storage
)

In [8]:
NX = 128
NY = 128
NZ = 80
N_ITER = 50

In [11]:
in_field, out_field = initialize_fields(NX, NY, NZ, mode="square")

In [14]:
in_field.shape

(80, 128, 128)

In [13]:
in_field.swapaxes(0,2).shape

(128, 128, 80)

In [10]:
in_field.dtype

dtype('float64')

# Accelerate with GT4Py

It is very important to declare the backends of GT4Py in the first place.

In [None]:
def gt4py_1D_row_def(
    in_field: gtscript.Field[np.float64],
    out_field: gtscript.Field[np.float64],
):

    from __gtscript__ import PARALLEL, computation, interval
    
    with computation(PARALLEL), interval(...):
        # Apply 1D stencil in gt4py style
        out_field = 0.5 * (in_field[1, 0, 0] - in_field[0, 0, 0])

def gt4py_1D_row_apply(in_field, out_field, N_ITER=1):
    NX = in_field.shape[0] - 1
    # Define origin and domain for the stencil
    origin = (0, 0, 0)
    domain = (
        NX,
        in_field.shape[1],
        in_field.shape[2]
    )
    # Iteration
    for iter in range(N_ITER):
        # Update halo value
        in_field[NX, :, :] = in_field[0, :, :]
        
        gt4py_1D_row_stencil(
            in_field=in_field,
            out_field=out_field,
            origin=origin,
            domain=domain
        )
        
        if iter < N_ITER - 1:
            in_field, out_field = out_field, in_field
        else:
            in_field[NX, :, :] = in_field[0, :, :]


In [None]:
backend = "numpy"
gt4py_1D_row_stencil = gtscript.stencil(backend=backend, definition=gt4py_1D_row_def)

In [None]:
%%timeit
in_field, out_field = initialize_fields(NX+1, NY, NZ, mode="horizontal-bars")
array_to_gt_storage(in_field, out_field, backend=backend, index=(0, 0, 0))
gt4py_1D_row_apply(in_field, out_field, N_ITER=N_ITER)

## Copy Stencil

In [None]:
def gt4py_copy(in_field: gtscript.Field[]):