# Demo: Interpolation



In this demo, we will look at interpolation and how we can use this easily move between different domains within the Arakawa C-grid. 
In general, the methods here can be used with Cartesian and Rectilinear-type grids.

**Note**: Curvilinear grids are currently outside of the scope for this project although it would be nice to have some methods that can handle this within the future.

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'FALSE'
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # first gpu


import jax
# again, this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')

In [1]:
import autoroot
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
from jax import config
config.update("jax_enable_x64", True)
from finitevolx._src.masks.masks import (
    MaskGrid
)
import matplotlib.pyplot as plt
import numpy as np
import functools as ft

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Domains

In [2]:
# define number of points
Nx, Ny = 128, 128

# define resolution
dx, dy = 5e3, 5e3

# define limits
xmin, ymin = 0.0, 0.0
Lx, Ly = (Nx - 1) * dx, (Ny - 1) * dy
xmax, ymax = Lx, Ly

In [3]:
from finitevolx._src.domain.domain import Domain
from finitevolx.functional import stagger_domain

# grid at center

h_domain = Domain(xmin=(xmin, ymin), xmax=(xmax, ymax), Lx=(Lx,Ly), Nx=(Nx,Ny), dx=(dx,dy))
# grid on x-axis
u_domain = stagger_domain(h_domain, direction=("outer", None), stagger=(True, False))

h_domain, u_domain

(Domain(
   xmin=(0.0, 0.0),
   xmax=(635000.0, 635000.0),
   dx=(5000.0, 5000.0),
   Nx=(128, 128),
   Lx=(635000.0, 635000.0),
   ndim=2
 ),
 Domain(
   xmin=(-2500.0, 0.0),
   xmax=(637500.0, 635000.0),
   dx=(5000.0, 5000.0),
   Nx=(129, 128),
   Lx=(640000.0, 635000.0),
   ndim=2
 ))

In [4]:
rng = np.random.RandomState(123)

# initialize h domain
u: Float[Array, "Nx Ny"] = rng.randn(*u_domain.Nx)


## From Scratch

Here, we have quite a few functions that are useful for interpolating between grid points.
One can do it from scratch while manually keeping track of where the grid points are.
This can be done with functions like `avg_pool` whereby the user can simply describe the shape of the kernel.

The backend of this is using `kernex` which is designed for *convolution-like* operations which run very fast on the CPU/GPU/TPU.

In [5]:
from finitevolx import avg_pool

In [6]:
kernel_size = (2,1)
stride = (1,1)
padding = "valid"
mean_fn = "arithmetic" # "geometric" # "harmonic" # "quadratic" #
u_on_h_scratch = avg_pool(
    u, kernel_size=kernel_size, stride=stride, padding=padding, mean_fn=mean_fn
)

There are some other goodies here like being able to choose the stride and also the mean function we wish to use. 
In all honesty, I have not seen a usecase for a mean function other than the arithmetic mean.

#### Synatic Sugar

There are some convenience functions to allow us to move between grids with a simpler syntax

* `x_avg_1D`
* `x_avg_2D`, `y_avg_2D`, `center_avg_2D`

In [7]:
from finitevolx import x_avg_2D

In [8]:
u_on_h_ss = x_avg_2D(u=u, mean_fn=mean_fn)

In [9]:
np.testing.assert_array_almost_equal(u_on_h_scratch, u_on_h_ss)

## Generalized Average

In [10]:
from finitevolx.functional import domain_interpolation_2D


# define interpolation params
method = "linear"
target_domain = h_domain
source_domain = u_domain

# using generalized function
u_on_h_gen = domain_interpolation_2D(
    u=u, source_domain=u_domain, target_domain=h_domain,
    method="linear")

In [11]:
np.testing.assert_array_almost_equal(u_on_h_scratch, u_on_h_gen)

### Cartesian Grid

It's much faster to use the Cartesian grid whenever we have a domain stepsize that is constant.

In [12]:
from finitevolx.functional import cartesian_interpolator_2D


# define interpolation params
method = "linear"
target_domain = h_domain
source_domain = u_domain


# using cartesian grid
u_on_h_cart = cartesian_interpolator_2D(
    u=u, source_domain=u_domain, target_domain=h_domain,
)


In [13]:
np.testing.assert_array_almost_equal(u_on_h_scratch, u_on_h_cart)

### Speed Test

#### Scratch

In [14]:

kernel_size = (2,1)
stride = (1,1)
padding = "valid"
mean_fn = "arithmetic" # "geometric" # "harmonic" # "quadratic" #
u_on_h_scratch = avg_pool(
    u, kernel_size=kernel_size, stride=stride, padding=padding, mean_fn=mean_fn
)

fn_scratch = ft.partial(
    avg_pool,
    kernel_size=kernel_size,
    stride=stride, 
    padding=padding, 
    mean_fn=mean_fn
)
fn_scratch_jitted = jax.jit(fn_scratch)

In [15]:
%timeit fn_scratch(u).block_until_ready()  # run
%timeit fn_scratch_jitted(u).block_until_ready()  # run

8.14 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
99.3 µs ± 4.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### Generalized

In [16]:

fn_gen = ft.partial(
    domain_interpolation_2D,
    source_domain=u_domain,
    target_domain=h_domain,
    method="linear"
)

fn_gen_jitted = jax.jit(fn_gen)

In [17]:
%timeit fn_gen(u).block_until_ready()  # run
%timeit fn_gen_jitted(u).block_until_ready()  # run

2.06 ms ± 32.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
359 µs ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Cartesian Grid

In [18]:
fn_cart = ft.partial(
    cartesian_interpolator_2D,
    source_domain=u_domain,
    target_domain=h_domain
)
fn_cart_jitted = jax.jit(fn_cart)

In [19]:
%timeit fn_cart(u).block_until_ready()  # run
%timeit fn_cart_jitted(u).block_until_ready()  # run

1.2 ms ± 55 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
83.6 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
