# Calculating Fluxes

In [1]:
import autoroot
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import equinox as eqx
import finitediffx as fdx
from jaxtyping import Array, Float
# from jaxsw._src.domain.base import Domain
from jaxsw._src.domain.base_v2 import Domain, init_domain_1d
import math
import torch
import torch.nn.functional as F
import einops
import seaborn as sns

jax.config.update("jax_enable_x64", True)
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

In [2]:
from jaxsw._src.operators.functional.interp.flux import tracer_flux

Let's look at the typical material derivative

$$
\frac{Dq}{Dt} := \partial_t q + \vec{u}\cdot\nabla q
$$



where the velocity vector is given by

$$
\vec{u} = 
\left[ u, v \right]^\top
$$

In particular, we are interested in the advection term which is defined as:

$$
\begin{aligned}
\text{Advection}: &= \vec{u} \cdot \nabla q = u\partial_x q + v\partial_y q \\
&= \nabla \cdot (\vec{u}q) = \partial_x (uq) + \partial_y (vq) 
\end{aligned}
$$

where the second equation is the conservative form.

The problem comes into play when we are on a staggered grid. The velocities are on one grid and the tracer is on a difference grid. The velocity

$$
\begin{aligned}
\text{Variable}: \psi &\in\Omega_\psi \\
\text{Zonal Velocity}: u &\in\Omega_u \\
\text{Meridional Velocity}: v &\in\Omega_v \\
\text{Tracer}: q &\in\Omega_q \\
\end{aligned}
$$



$$
\begin{aligned}
\Omega_\psi &\in[N_x,N_y] && && [x_0,y_0] && && [x_1,y_1] \\ 
\Omega_u &\in[N_x,N_y-1] && && [x_0,y_0+\frac{1}{2}dy] && && [x_1,y_1-\frac{1}{2}dy]  \\ 
\Omega_v &\in[N_x-1,N_y] && && [x_0+\frac{1}{2}dx,y_0] && && [x_1-\frac{1}{2}dx,y_1]  \\ 
\Omega_q &\in[N_x-1,N_y-1] && && [x_0+\frac{1}{2}dx,y_0+\frac{1}{2}dy] && && [x_1-\frac{1}{2}dx,y_1-\frac{1}{2}dy] 
\end{aligned}
$$







So if $q$ is on a staggered domain.

In [4]:

# x = torch.arange(1, 36+1)
num = 25
num_sqrt = math.sqrt(num)
assert num_sqrt.is_integer(), num_sqrt

# x = torch.randn(size=(num,))
u = torch.randn(size=(num,))
x = torch.arange(0, num)
print(x.shape, num_sqrt)
x = einops.rearrange(x, "(Nx Ny) -> Nx Ny", Nx=int(num_sqrt), Ny=int(num_sqrt))
u = einops.rearrange(u, "(Nx Ny) -> Nx Ny", Nx=int(num_sqrt), Ny=int(num_sqrt))

dim = 0
u = u[1:]
N = x.shape[dim]
assert u.shape[dim] == N - 1

torch.Size([25]) 5.0


In [5]:
x, u

(tensor([[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19],
         [20, 21, 22, 23, 24]]),
 tensor([[ 0.6527,  1.8054, -0.0150,  0.1841, -1.6945],
         [-0.9101, -0.0365,  0.4091, -1.2957,  1.5384],
         [-1.8070, -1.5308,  0.5493,  0.7109, -0.5754],
         [-1.0096, -0.8631,  0.9273, -0.4084,  0.9295]]))

In [8]:
def linear2(qm, qp):
    """
    2-points linear reconstruction:

    qm--x--qp

    """
    return 0.5 * (qm + qp)
    
def linear3_left(qm, q0, qp):
    """
    3-points linear left-biased stencil reconstruction:

    qm-----q0--x--qp

    """
    return -1./6.*qm + 5./6.*q0 + 1./3.*qp

def linear3_right(qm, q0, qp):
    """
    3-points linear left-biased stencil reconstruction:

    qp--x--q0-----qm

    """
    return linear3_left(qp, q0, qm)

def linear4(qmm, qm, qp, qpp):
    """
    4-points linear reconstruction:

    qmm-----qm--x--qp-----qpp

    """
    return -1./12.*qmm + 7./12.*qm + 7./12.*qp - 1./12.*qpp

[0, 1, 2, 3, 4, 5, 6]
Interior
- [0, 1, 2, 3, 4]
- [1, 2, 3, 4, 5]
- [2, 3, 4, 5, 6]
--> [1 2 3 4 5]
Left Bd
- [0 1]
- [1 2]
--> [0.5 1.5]
right bd
- [5 6]
- [4 5]
--> [4.5, 5.5]



In [9]:
def flux_3pts(x, u, dim: int=0):

    N = x.shape[dim]

    print(x.shape)
    print(x)
    print("-"*40)
    qm = x.narrow(dim, 0, N-2)
    q0 = x.narrow(dim, 1, N-2)
    qp = x.narrow(dim, 2, N-2)

    qi_left_interior = linear3_left(qm, q0, qp)
    qi_right_interior = linear3_left(qp, q0, qm)

    print(qi_left_interior)
    print(qi_right_interior)

    # left boundaries
    qm = x.narrow(dim, 0, 1)
    qp = x.narrow(dim, 1, 1)
    qi_left_bd = linear2(qm, qp)
    print(qi_left_bd)

    # right boundaries
    qm = x.narrow(dim, -2, 1)
    qp = x.narrow(dim, -1, 1)
    qi_right_bd = linear2(qm, qp)

    print(qi_right_bd)

    print("-"*40)
    print(qi_left_interior.narrow(dim,0,N-3).shape)
    print(qi_left_interior.narrow(dim,0,N-3))
    print(qi_left_bd.shape, qi_right_bd.shape)
    qi_left = torch.cat([
        qi_left_bd, qi_left_interior.narrow(dim, 0, N-3), qi_right_bd], dim=dim)

    qi_right = torch.cat([
        qi_left_bd, qi_right_interior.narrow(dim, 1, N-3), qi_right_bd], dim=dim)
    
    # positive and negative parts of velocity
    u_pos = F.relu(u)
    u_neg = u - u_pos

    # upwind flux computation
    flux = u_pos * qi_left + u_neg * qi_right
    return qi_left, qi_right, flux

In [10]:
qi_left_torch, qi_right_torch, flux = flux_3pts(x, u, 0)

torch.Size([5, 5])
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])
----------------------------------------
tensor([[ 7.5000,  8.5000,  9.5000, 10.5000, 11.5000],
        [12.5000, 13.5000, 14.5000, 15.5000, 16.5000],
        [17.5000, 18.5000, 19.5000, 20.5000, 21.5000]])
tensor([[ 2.5000,  3.5000,  4.5000,  5.5000,  6.5000],
        [ 7.5000,  8.5000,  9.5000, 10.5000, 11.5000],
        [12.5000, 13.5000, 14.5000, 15.5000, 16.5000]])
tensor([[2.5000, 3.5000, 4.5000, 5.5000, 6.5000]])
tensor([[17.5000, 18.5000, 19.5000, 20.5000, 21.5000]])
----------------------------------------
torch.Size([2, 5])
tensor([[ 7.5000,  8.5000,  9.5000, 10.5000, 11.5000],
        [12.5000, 13.5000, 14.5000, 15.5000, 16.5000]])
torch.Size([1, 5]) torch.Size([1, 5])


In [11]:
import jax.numpy as jnp
import kernex as kex
import jax
import numpy as np

In [12]:
x_ = x.numpy()
u_ = u.numpy()

In [13]:



@kex.kmap(kernel_size=(N-2,N),padding='valid')
def get_3x3_patches(x):
    # returns 5x5x3x3 array
    return x

In [14]:
@kex.kmap(kernel_size=(2,1),padding='valid')
def x_linear_interp_2pt(x):
    # returns 5x5x3x3 array
    return 0.5 * (x[0] + x[1])

@kex.kmap(kernel_size=(3,1),padding='valid')
def get_linear_3_stencil_forward(x):
    # returns 5x5x3x3 array
    return -1./6.*x[0] + 5./6.*x[1] + 1./3.*x[2]


@kex.kmap(kernel_size=(3,1),padding='valid')
def get_linear_3_stencil_backward(x):
    # returns 5x5x3x3 array
    return -1./6.*x[2] + 5./6.*x[1] + 1./3.*x[0]

In [15]:
import typing as tp
from jaxtyping import Array
import functools as ft
from jax.nn import relu

def plusminus(u: Array, way: int=1) -> tp.Union[Array]:
    u_pos = relu(float(way) * u)
    u_neg = u - u_pos
    return u_pos, u_neg

def interp_1pt(q: Array, dim: int) -> tp.Union[Array]:
    """creates the stencils for the upwind scheme
    - 1 pts inside domain & boundary
    """
    # get number of points
    num_pts = q.shape[dim]

    # define slicers
    dyn_slicer = ft.partial(jax.lax.dynamic_slice_in_dim, axis=dim)

    qi_left = dyn_slicer(q, 0, num_pts-1)
    qi_right = dyn_slicer(q, 1, num_pts-1)
    

    return qi_left, qi_right

def interp_3pt(q: Array, dim: int) -> tp.Union[Array]:
    """creates the stencils for the upwind scheme
    - 3 pts inside domain
    - 1 pt near boundaries
    """

    # get number of points
    num_pts = q.shape[dim]

    # define slicers
    dyn_slicer = ft.partial(jax.lax.dynamic_slice_in_dim, axis=dim)

    # interior slices
    q0 = dyn_slicer(q, 0, num_pts-2)
    q1 = dyn_slicer(q, 1, num_pts-2)
    q2 = dyn_slicer(q, 2, num_pts-2)

    qi_left_interior = linear3_left(q0, q1, q2)
    qi_right_interior = linear3_right(q0, q1, q2)

    # left boundary slices
    q0 = dyn_slicer(q, 0, 1)
    q1 = dyn_slicer(q, 1, 1)
    qi_left_bd = linear2(q0, q1)

    # right boundary slices
    q0 = dyn_slicer(q, -1, 1)
    q1 = dyn_slicer(q, -2, 1)
    qi_right_bd = linear2(q0, q1)

    # concatenate each
    qi_left = jnp.concatenate([
        qi_left_bd,
        dyn_slicer(qi_left_interior, 0, num_pts-3),
        qi_right_bd
    ])

    qi_right = jnp.concatenate([
        qi_left_bd,
        dyn_slicer(qi_right_interior, 1, num_pts-3),
        qi_right_bd
    ])

    return qi_left, qi_right


def tracer_flux(u: Array, a: Array, dim: int, num_pts: int=1) -> Array:
    """Flux computation for staggered variables q and u with
    solid boundaries. Typically used for calculating the flux
    Advection Scheme:
        ∇ ⋅ (uq)        
    
    Args:
        q (Array): tracer field to interpolate
            shape[dim] = N
        u (Array): transport velocity
            shape[dim] = N-1
        dim (int): dimension along which computations are done
        num_pts (int): the number of points for the flux computation
            options = (1, 3, 5)
    
    Returns:
        flux (Array): tracer flux computed on u points
            shape[dim] = N -1
            
    """
    
    # calculate flux
    if num_pts == 1:
        ui_left, ui_right = interp_1pt(u, dim=dim)
    elif num_pts == 3:
        ui_left, ui_right = interp_3pt(u, dim=dim)
    elif num_pts == 5:
        msg = "5pt method is not implemented yet"
        raise NotImplementedError(msg)
    else:
        msg = "Unrecognized method: {num_pts}"
        msg +="\nMust be 1, 3, or 5"
        raise ValueError(msg)
    
    # calculate +ve and -ve points
    a_pos, a_neg = plusminus(a)
    
    # calculate upwind flux
    flux = a_pos * ui_left + a_neg * ui_right
    
    return flux

In [17]:
qi_left_jax, qi_right_jax = interp_3pt(x_, 0)
flux_ = tracer_flux(x_, u_, 0)
# x_.shape, qi_left.shape, qi_right.shape

In [19]:
out = tracer_flux(x_, u_, 0, num_pts=1)
out_ = tracer_flux(x_, u_, 0, num_pts=3)
# tracer_flux(x_, u_, 0, num_pts=5)

In [20]:
out

Array([[  0.        ,   1.805351  ,  -0.10503372,   0.552448  ,
        -15.2502165 ],
       [ -9.10059   ,  -0.40109357,   2.8634002 , -16.844027  ,
         13.845664  ],
       [-27.104927  , -24.492468  ,   6.591754  ,   9.242194  ,
        -10.933088  ],
       [-20.192703  , -18.125383  ,  15.764215  ,  -9.392781  ,
         17.66041   ]], dtype=float32)

In [82]:
out_

Array([[ -1.6798853 ,  -2.191656  ,  -0.62947637,  -3.63754   ,
         -2.3527443 ],
       [ -9.007757  ,  -7.7067084 ,   4.294481  , -14.842717  ,
          7.8252954 ],
       [ -8.781178  , -16.217318  ,  -2.7002711 , -12.441315  ,
         27.955723  ],
       [-10.226277  ,  12.631437  ,  31.171684  , -23.959835  ,
         -9.450634  ]], dtype=float32)

In [72]:
np.testing.assert_array_almost_equal(qi_left_jax, qi_left_torch.numpy())
np.testing.assert_array_almost_equal(qi_right_jax, qi_right_torch.numpy())
np.testing.assert_array_almost_equal(flux_, flux.numpy())

In [169]:
def flux_3pt_jax(q, dim: int):
    
    qi_left_interior = get_linear_3_stencil_forward(q)[...,0]
    qi_right_interior = get_linear_3_stencil_backward(q)[...,0]

    print(qi_left_interior)
    print(qi_right_interior)
    
    front = jax.lax.slice_in_dim(x_, None, 2, axis=dim)
    back = jax.lax.slice_in_dim(x_, -2, None, axis=dim)
    qi_left_bd = x_linear_interp_2pt(front)
    qi_right_bd = x_linear_interp_2pt(back)

    print(qi_left_bd)
    print(qi_right_bd)

    qi_left = jnp.concatenate([
        qi_left_bd[...,0], 
        jax.lax.slice_in_dim(qi_left_interior, None, -1, axis=dim),
        qi_right_bd[...,0]
    ])

    qi_right = jnp.concatenate([
        qi_left_bd[...,0],
        jax.lax.slice_in_dim(qi_left_interior, 1, None, axis=dim),
        qi_right_bd[...,0]
    ])

    print(qi_left)
    print(qi_right)

    return None

In [170]:
flux_3pt_jax(x_, 0)

[[10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21.]
 [22. 23. 24. 25. 26. 27.]
 [28. 29. 30. 31. 32. 33.]]
[[ 3.9999995  4.9999995  6.         6.9999995  7.999999   9.       ]
 [10.        10.999999  12.        13.        13.999999  15.       ]
 [16.        17.        18.        18.999998  20.        21.       ]
 [21.999998  23.        24.        24.999998  26.        27.       ]]
[[[4.]
  [5.]
  [6.]
  [7.]
  [8.]
  [9.]]]
[[[28.]
  [29.]
  [30.]
  [31.]
  [32.]
  [33.]]]
[[ 4.  5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21.]
 [22. 23. 24. 25. 26. 27.]
 [28. 29. 30. 31. 32. 33.]]
[[ 4.  5.  6.  7.  8.  9.]
 [16. 17. 18. 19. 20. 21.]
 [22. 23. 24. 25. 26. 27.]
 [28. 29. 30. 31. 32. 33.]
 [28. 29. 30. 31. 32. 33.]]


In [30]:
get_linear_3_stencil(x_).shape

(4, 6, 1)

In [24]:
qm, q0, qp = get_3x3_patches(x_).squeeze()

In [25]:
linear3_left(qm, q0, qp)

Array([[10., 11., 12., 13., 14., 15.],
       [16., 17., 18., 19., 20., 21.],
       [22., 23., 24., 25., 26., 27.],
       [28., 29., 30., 31., 32., 33.]], dtype=float32, weak_type=True)

In [46]:
get_3x3_patches(x_)[0,0], get_3x3_patches(x_)[1,0], get_3x3_patches(x_)[2,0]

(Array([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]], dtype=int32),
 Array([[ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20]], dtype=int32),
 Array([[11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25]], dtype=int32))

In [13]:
x.narrow(0, 0, 3-1).shape, x.narrow(0, 1, 3-1).shape, x.narrow(0, 2, 3-1).shape

(torch.Size([2, 5]), torch.Size([2, 5]), torch.Size([2, 5]))