# Finite Differences with MIRGE

In this code-along, we will put together a (very) simple finite-difference wave equation solver using the MIRGE machinery.

First, we need to import the ingredients:

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import arraycontext
import pyopencl as cl
from pytools.obj_array import make_obj_array

In [21]:
%load_ext gvmagic

In [2]:
cl_ctx = cl.create_some_context(interactive=True)
queue = cl.CommandQueue(cl_ctx)

To get started, we'll need a simple mesh infrastructure.

Observe how `actx` is being used:

In [4]:
class Mesh:
    def __init__(self, size_x, size_y, resolution):
        self.size_x = size_x
        self.size_y = size_y
        self.resolution = resolution
        self.nx, self.ny = int(size_x*resolution), int(size_y*resolution)

        self.x = actx.np.linspace(
                0, size_x, self.nx, endpoint=False).reshape(self.nx, 1)
        self.y = actx.np.linspace(
                0, size_y, self.ny, endpoint=False).reshape(1, self.ny)
        self.hx = actx.to_numpy(self.x[1, 0] - self.x[0, 0])
        self.hy = actx.to_numpy(self.y[0, 1] - self.y[0, 0])

    def plot(self, f, **kwargs):
        f = actx.to_numpy(f)
        return plt.imshow(f.T[::-1], extent=(0, self.size_x, 0, self.size_y),
                          **kwargs)

    def set_plot_data(self, img, f):
        f = actx.to_numpy(f)
        img.set_data(f.T[::-1])

    def zeros(self):
        return actx.zeros((self.nx, self.ny), dtype=np.float64)

    def norm(self, u):
        return actx.np.sqrt(actx.np.sum(abs(u)**2)) * (
                self.size_x * self.size_y / (self.nx - 1) / (self.ny -1))

mesh = Mesh(size_x=6, size_y=4, resolution=64)

In [5]:
f = (mesh.x - mesh.size_x/2)**2 + (mesh.y - mesh.size_y/2)**2
mesh.plot(f)

Next up, define and test a derivative operator along the $x$ axis:

We will now head towards time-dependent simulation. For this, we need a simple Runge-Kutta scheme:

In [7]:
def rk4_step(y, t, h, f):
    k1 = f(t, y)
    k2 = f(t+h/2, y + h/2*k1)
    k3 = f(t+h/2, y + h/2*k2)
    k4 = f(t+h, y + h*k3)
    return y + h/6*(k1 + 2*k2 + 2*k3 + k4)

And we will make use of a "bump" function as a source term:

In [8]:
bump = actx.np.exp(-80*((mesh.x - mesh.size_x/2)**2 + (mesh.y - mesh.size_y/2)**2))
mesh.plot(bump)

Next, let us move closer to an actual wave equation solver. To do so, define:

- A function `laplace(mesh, u, boundary_val)`
- The RHS: `rhs(t, s)`. Use a time-dependent source term $\sin(30t)\cdot \text{bump}$.
- A function realizing a whole time step: `tstep(t, s)`
- A time step `dt` obeying a CFL condition.

In [9]:
def laplace(mesh, u, boundary_val):
    padded_x = actx.np.concatenate((
        actx.np.full((1, mesh.ny), boundary_val),
        u,
        actx.np.full((1, mesh.ny), boundary_val),
        ), axis=0)
    padded_y = actx.np.concatenate((
        actx.np.full((mesh.nx, 1), boundary_val),
        u,
        actx.np.full((mesh.nx, 1), boundary_val),
        ), axis=1)
    return (
            (padded_x[2:] - 2*u + padded_x[:-2])/mesh.hx**2
            +
            (padded_y[:, 2:] - 2*u + padded_y[:, :-2])/mesh.hy**2
            )


# TODO: Remember to show that rhs is only called *once*
def rhs(t, s):
    u, du_dt = s

    return make_obj_array([
        du_dt + actx.np.sin(30*t) * bump,
        laplace(mesh, u, 0)
        ])

def tstep(t, s):
    return rk4_step(s, t, dt, rhs)

dt = min(mesh.hx, mesh.hy)

Here's our initial condition:

In [10]:
state = make_obj_array([mesh.zeros(), mesh.zeros()])
t = 0.

Next, take a few time steps (code up one and hit Ctrl-Enter):

What do you observe? Can you fix the issue? (Again, use a Ctrl-Enter loop):

In [12]:
t = 0.
state = make_obj_array([mesh.zeros(), mesh.zeros()])

What do you notice about step-to-step time? Can you fix this issue?

In [14]:
t = 0.
state = make_obj_array([mesh.zeros(), mesh.zeros()])

Here's an animation to convince you that a PDE got solved:

In [17]:
t = 0.
state = make_obj_array([mesh.zeros(), mesh.zeros()])

fig = plt.figure()
img = mesh.plot(state[0], vmin=-0.02, vmax=0.02)

def update(frame):
    global t, state
    for i in range(10):
        state = tstep_compiled(actx.from_numpy(t), state)
        t += dt
    mesh.set_plot_data(img, state[0])

import matplotlib.animation as animation
ani = animation.FuncAnimation(fig=fig, func=update, frames=40, interval=30)

from IPython.display import HTML
html = HTML(ani.to_jshtml())
plt.clf()
html

For the remainder of this code-along, we will investigate various questions relating to the our solver.

- First, how often did `rhs` get called above? (investigate this by modifying the code above)
    - What are the implications of that?
    - What if there are `if` statements in your code?
    
- Next, can we see the C code that was generated "under the hood"?

Can we see the data flow graph that was obtained?

In [34]:
from pytools.graphviz import show_dot
show_dot(dot_codes[-1])

In [33]:
%dotstr dot_codes[-1]

In [25]:
%dotstr "dpi=20;\n"+dot_codes[-1]

- Can we influence the code that gets generated? Perhaps change which results end up in temporary variables?

* Can we estimate cost?

- Can we transform the code for efficient execution on a GPU?

* How do we write robust transform code?

In [43]:
# exercise for the reader :)
# You can tag array axes in pytato, and those tags survive on loop variables ("inames") in loopy.
# You can also tag arrays.

* How can we express distributed-memory computation?

In [None]:
# exercise for the reader :)
# The necessary tools in pytato exist!