# Boundary Conditions

In [41]:
import autoroot
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from jaxtyping import Array
import einops
import finitediffx as fdx
from jaxsw._src.operators.functional import grid as F_grid
from jaxsw._src.boundaries import functional as F_bc


sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 1-Dimensional

In [42]:
u = jnp.arange(1, 11)
u

Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int64)

In [43]:
F_grid.x_average_1D(u)

Array([1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5], dtype=float64)

### Padding

**Equal Padding** (ones on each size).

In [44]:
F_grid.x_average_1D(u, padding="same")

Array([1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ], dtype=float64)

**Custom Padding**: Each side

In [5]:
F_grid.x_average_1D(u, padding=((1, 1),))

Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ], dtype=float64)

**Custom**: Left Hand Side Only

In [6]:
F_grid.x_average_1D(u, padding=((1, 0),))

Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5], dtype=float64)

**Custom Custom**: Fancy...

Inevitably, there are many ways we may want to pad an array. Some examples include:

* Symmetric Boundaries
* Wrap for periodic conditions

In [7]:
mode = "constant"  # "linear_ramp" # "reflect" #  "wrap" #"symmetric" #
constant_values = (100, 100)
u_pad = jnp.pad(u, pad_width=((1, 1)), mode=mode, constant_values=constant_values)
u_pad

Array([100,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10, 100], dtype=int64)

**Periodic Boundary Conditions**

In [18]:
mode = "wrap"  # "linear_ramp" # "reflect" #  "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1)), mode=mode)
u_periodic

Array([10,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1], dtype=int64)

In [19]:
u_periodic = F_bc.apply_periodic_pad_1D(u)
u_periodic

Array([10,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1], dtype=int64)

In [25]:
fdx.difference(u_periodic)

Array([-9., -4.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -4., -9.],      dtype=float64)

**Neumann Boundaries**

In [26]:
u_neumann = jnp.pad(u, pad_width=((1, 1)), mode="constant")
u_neumann = u_neumann.at[0].set(u_neumann[1])
u_neumann = u_neumann.at[-1].set(u_neumann[-2])
u_neumann

Array([ 1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 10], dtype=int64)

In [28]:
u_neumann = F_bc.apply_neumann_pad_1D(u)
u_neumann

Array([ 1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 10], dtype=int64)

In [29]:
fdx.difference(u_neumann)

Array([0. , 0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.5, 0. ],      dtype=float64)

#### Dirichlet Boundaries

In [38]:
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")

# modify values manually
u_dirichlet = u_dirichlet.at[0].set(-u_dirichlet[1])
u_dirichlet = u_dirichlet.at[-1].set(-u_dirichlet[-2])

u_dirichlet

Array([ -1,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10, -10], dtype=int64)

In [31]:
u_dirichlet = F_bc.apply_dirichlet_pad_1D(u)
u_dirichlet

Array([ -1,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10, -10], dtype=int64)

In [36]:
jnp.gradient(u_dirichlet)

Array([  2. ,   1.5,   1. ,   1. ,   1. ,   1. ,   1. ,   1. ,   1. ,
         1. ,  -9.5, -20. ], dtype=float64)

## Two-Dimensional

In [40]:
u = jnp.arange(1, 11)
u = einops.repeat(u, "Nx -> Nx Ny", Ny=15)

u.T

Array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]], dtype=int64)

In [64]:
grid.x_average_2D(u).T

Array([[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]], dtype=float64)

In [65]:
grid.x_average_2D(u, padding=(0, 0)).T

Array([[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]], dtype=float64)

In [66]:
grid.x_average_2D(u, padding=(1, 0)).T

Array([[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 5. ]],      dtype=float64)

In [70]:
grid.x_average_2D(u, padding=((1, 0), (0, 0))).T

Array([[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5],
       [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]], dtype=float64)