# Calculating Fluxes

In [7]:
import autoroot
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import equinox as eqx
import finitediffx as fdx
from jaxtyping import Array, Float
# from jaxsw._src.domain.base import Domain
from jaxsw._src.domain.base_v2 import Domain, init_domain_1d
import math
import torch
import torch.nn.functional as F
import einops
import seaborn as sns

jax.config.update("jax_enable_x64", True)
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Let's look at the typical material derivative

$$
\frac{Dq}{Dt} := \partial_t q + \vec{u}\cdot\nabla q
$$



where the velocity vector is given by

$$
\vec{u} = 
\left[ u, v \right]^\top
$$

In particular, we are interested in the advection term which is defined as:

$$
\begin{aligned}
\text{Advection}: &= \vec{u} \cdot \nabla q = u\partial_x q + v\partial_y q \\
&= \nabla \cdot (\vec{u}q) = \partial_x (uq) + \partial_y (vq) 
\end{aligned}
$$

where the second equation is the conservative form.

The problem comes into play when we are on a staggered grid. The velocities are on one grid and the tracer is on a difference grid. The velocity

$$
\begin{aligned}
\text{Variable}: \psi &\in\Omega_\psi \\
\text{Zonal Velocity}: u &\in\Omega_u \\
\text{Meridional Velocity}: v &\in\Omega_v \\
\text{Tracer}: q &\in\Omega_q \\
\end{aligned}
$$



$$
\begin{aligned}
\Omega_\psi &\in[N_x,N_y] && && [x_0,y_0] && && [x_1,y_1] \\ 
\Omega_u &\in[N_x,N_y-1] && && [x_0,y_0+\frac{1}{2}dy] && && [x_1,y_1-\frac{1}{2}dy]  \\ 
\Omega_v &\in[N_x-1,N_y] && && [x_0+\frac{1}{2}dx,y_0] && && [x_1-\frac{1}{2}dx,y_1]  \\ 
\Omega_q &\in[N_x-1,N_y-1] && && [x_0+\frac{1}{2}dx,y_0+\frac{1}{2}dy] && && [x_1-\frac{1}{2}dx,y_1-\frac{1}{2}dy] 
\end{aligned}
$$







So if $q$ is on a staggered domain.

In [16]:

# x = torch.arange(1, 36+1)
num = 16
num_sqrt = math.sqrt(num)
assert num_sqrt.is_integer(), num_sqrt

# x = torch.randn(size=(num,))
u = torch.randn(size=(num,))
x = torch.arange(0, num)
print(x.shape, num_sqrt)
x = einops.rearrange(x, "(Nx Ny) -> Nx Ny", Nx=int(num_sqrt), Ny=int(num_sqrt))
u = einops.rearrange(u, "(Nx Ny) -> Nx Ny", Nx=int(num_sqrt), Ny=int(num_sqrt))

dim = 0
u = u[1:]
N = x.shape[dim]
assert u.shape[dim] == N - 1

torch.Size([16]) 4.0


In [50]:
x

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [51]:
from jaxsw._src.domain.mask import Mask

In [57]:
mask = jnp.ones(x.shape)
mask = mask.at[0].set(0.0)
mask = mask.at[-1].set(0.0)
mask = mask.at[:, 0].set(0.0)
mask = mask.at[:, -1].set(0.0)

masks = Mask.init_mask(mask, variable="q")

In [58]:
def linear2(qm, qp):
    """
    2-points linear reconstruction:

    qm--x--qp

    """
    return 0.5 * (qm + qp)
    
def linear3_left(qm, q0, qp):
    """
    3-points linear left-biased stencil reconstruction:

    qm-----q0--x--qp

    """
    return -1./6.*qm + 5./6.*q0 + 1./3.*qp

def linear3_right(qm, q0, qp):
    """
    3-points linear left-biased stencil reconstruction:

    qp--x--q0-----qm

    """
    return linear3_left(qp, q0, qm)

def linear4(qmm, qm, qp, qpp):
    """
    4-points linear reconstruction:

    qmm-----qm--x--qp-----qpp

    """
    return -1./12.*qmm + 7./12.*qm + 7./12.*qp - 1./12.*qpp

def weno3z(qm, q0, qp):
    """
    3-points non-linear left-biased stencil reconstruction:

    qm-----q0--x--qp

    An improved weighted essentially non-oscillatory scheme for hyperbolic
    conservation laws, Borges et al, Journal of Computational Physics 227 (2008).
    """
    eps = 1e-14

    qi1 = -1./2.*qm + 3./2.*q0
    qi2 = 1./2.*(q0 + qp)

    beta1 = (q0-qm)**2
    beta2 = (qp-q0)**2
    tau = torch.abs(beta2-beta1)

    g1, g2 = 1./3., 2./3.
    w1 = g1 * (1. + tau / (beta1 + eps))
    w2 = g2 * (1. + tau / (beta2 + eps))

    qi_weno3 = (w1*qi1 + w2*qi2) / (w1 + w2)

    return qi_weno3


def flux_3pts(q, u, dim):
    """
    Flux computation for staggerded variables q and u, with solid boundaries.
    Upwind-biased stencil:
      - 3 points inside domain.
      - 1 point near boundaries.

    Args:
        q: tracer field to interpolate, torch.Tensor, shape[dim] = n
        u: transport velocity, torch.Tensor, shape[dim] = n-1
        dim: dimension along which computations are done

    Returns:
        flux: tracer flux computed on u points, torch.Tensor, shape[dim] = n-1
    """
    n = q.shape[dim]

    # q-interpolation: 3-points inside domain

    qm, q0, qp = q.narrow(dim, 0, n-2), q.narrow(dim, 1, n-2), q.narrow(dim, 2, n-2)

    qi_left_in = linear3_left(qm, q0, qp)
    qi_right_in = linear3_left(qp, q0, qm)

    # q-interpolation: 2-points on boundaries
    qi_0 = linear2(q.narrow(dim, 0, 1), q.narrow(dim, 1, 1))
    qi_m1 = linear2(q.narrow(dim, -2, 1), q.narrow(dim, -1, 1))

    qi_left = torch.cat([
        qi_0, qi_left_in.narrow(dim, 0, n-3), qi_m1], dim=dim)
    qi_right = torch.cat([
        qi_0, qi_right_in.narrow(dim, 1, n-3), qi_m1], dim=dim)

    # positive and negative parts of velocity
    u_pos = F.relu(u)
    u_neg = u - u_pos

    # upwind flux computation
    flux = u_pos * qi_left + u_neg * qi_right
    return qi_left, qi_right, flux


[0, 1, 2, 3, 4, 5, 6]
Interior
- [0, 1, 2, 3, 4]
- [1, 2, 3, 4, 5]
- [2, 3, 4, 5, 6]
--> [1 2 3 4 5]
Left Bd
- [0 1]
- [1 2]
--> [0.5 1.5]
right bd
- [5 6]
- [4 5]
--> [4.5, 5.5]



In [164]:


def flux_3pts_mask(q, u, dim, mask_u_d1, mask_u_d2plus):
    n = q.shape[dim]
    pad1 = () if dim == -1 else (0,0)
    pad2 = (0,0) if dim == -1 else ()
    print(pad1, pad2)
    
    qm, q0, qp = q.narrow(dim, 0, n-2), q.narrow(dim, 1, n-2), \
                 q.narrow(dim, 2, n-2)
    
    print("LEFT!")
    qi3_left = weno3z(qm, q0, qp)
    print(qi3_left.shape)
    print(pad1+(1,0)+pad2)
    qi3_left = F.pad(qi3_left, pad1+(1,0)+pad2)
    print(qi3_left)
    print(qi3_left.shape)
    # ===============
    print("RIGHT!")
    qi3_right = weno3z(qp, q0, qm)
    print(qi3_right.shape)
    print(pad1+(0,1)+pad2)
    qi3_right = F.pad(qi3_right, pad1+(0,1)+pad2)
    print(qi3_right.shape)
    # qi2 = linear2(q.narrow(dim, 0, n-1), q.narrow(dim, 1, n-1))
    
    

    u_pos = F.relu(u)
    u_neg = u - u_pos
    # NORMAL FLUX
    a = (u_pos*q.narrow(dim, 0, n-1) + u_neg*q.narrow(dim, 1, n-1))
    
    print(a)
          
    # FANCY FLUX
    b = (u_pos*qi3_left + u_neg*qi3_right)
    
    print(b)
    
    print("FLUXXX!")
    print(a.shape)
    print(b.shape)
    flux = mask_u_d1 * a + mask_u_d2plus * b

    return qi3_left, qi3_right, flux

In [165]:
x.shape, u.shape, mask.shape, masks.u_distbound1[1:-1].shape

(torch.Size([4, 4]), torch.Size([3, 4]), (4, 4), (3, 4))

In [172]:
# qi_left_torch, qi_right_torch, flux = flux_3pts(x, u, 0)
qi_left_torch, qi_right_torch, flux = flux_3pts_mask(
    x[None, None, ...], u[None, None, ...], -2, 
    torch.Tensor(np.asarray(masks.u_distbound1[1:-1][None, None, ...])),
    torch.Tensor(np.asarray(masks.u_distbound2plus[1:-1][None, None, ...]))
)

(0, 0) ()
LEFT!
torch.Size([1, 1, 2, 4])
(0, 0, 1, 0)
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 6.0000,  7.0000,  8.0000,  9.0000],
          [10.0000, 11.0000, 12.0000, 13.0000]]]])
torch.Size([1, 1, 3, 4])
RIGHT!
torch.Size([1, 1, 2, 4])
(0, 0, 0, 1)
torch.Size([1, 1, 3, 4])
tensor([[[[ -6.2200,  -0.8780,  -0.5934,   2.2774],
          [ -0.2971,   2.0050,   6.8547,  -3.2690],
          [-18.1282,  11.4857, -23.2584,  -5.7873]]]])
tensor([[[[-3.1100, -0.5268, -0.3956,  0.0000],
          [-0.2228,  2.8070,  9.1395, -2.6746],
          [ 0.0000, 14.0381,  0.0000,  0.0000]]]])
FLUXXX!
torch.Size([1, 1, 3, 4])
torch.Size([1, 1, 3, 4])


In [173]:
flux

tensor([[[[-0.0000, -0.0000, -0.0000, 0.0000],
          [-0.0000, 2.0050, 6.8547, -0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]]]])

In [122]:
x.shape, u.T.shape, mask.shape, masks.v_distbound1[:, 1:-1].shape

(torch.Size([4, 4]), torch.Size([4, 3]), (4, 4), (4, 3))

In [174]:
# qi_left_torch, qi_right_torch, flux = flux_3pts(x, u, 0)
qi_left_torch, qi_right_torch, flux = flux_3pts_mask(
    x[None, None, ...], u.T[None, None, ...], -1, 
    torch.Tensor(np.asarray(masks.v_distbound1[None, None, :, 1:-1])),
    torch.Tensor(np.asarray(masks.v_distbound2plus[None, None, :, 1:-1]))
)

() (0, 0)
LEFT!
torch.Size([1, 1, 4, 2])
(1, 0, 0, 0)
tensor([[[[ 0.0000,  1.5000,  2.5000],
          [ 0.0000,  5.5000,  6.5000],
          [ 0.0000,  9.5000, 10.5000],
          [ 0.0000, 13.5000, 14.5000]]]])
torch.Size([1, 1, 4, 3])
RIGHT!
torch.Size([1, 1, 4, 2])
(0, 1, 0, 0)
torch.Size([1, 1, 4, 3])
tensor([[[[ -1.5550,  -0.0743,  -4.5321],
          [ -0.8780,   2.0050,   7.6572],
          [ -0.8902,  10.2820, -18.2745],
          [  9.1097,  -4.1605,  -5.7873]]]])
tensor([[[[-0.7775, -0.0557,  0.0000],
          [-0.7902,  2.2055,  8.2953],
          [-0.8407, 10.8532,  0.0000],
          [ 0.0000, -4.0119,  0.0000]]]])
FLUXXX!
torch.Size([1, 1, 4, 3])
torch.Size([1, 1, 4, 3])


In [175]:
flux

tensor([[[[-0.0000, -0.0000,  0.0000],
          [-0.0000,  2.0050,  0.0000],
          [-0.0000, 10.2820,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]])

In [21]:
import jax.numpy as jnp
import kernex as kex
import jax
import numpy as np

In [141]:
from jaxsw._src.operators.functional.interp import weno as interp_weno
from jaxsw._src.operators.functional.interp import linear as interp_linear
from jaxsw._src.operators.functional.interp import flux as F_flux



In [45]:
x[0]

tensor([0, 1, 2, 3])

In [48]:
qi_left, qi_right = interp_3pt(x[0].numpy(), dim=0)

[0 1] [1 2] [2 3]
[1.5 2.5] [0.5 1.5]
[0.5] [2.5]


In [44]:
qi_left, qi_right

(Array([0.5, 1.5, 2.5], dtype=float64, weak_type=True),
 Array([0.5, 1.5, 2.5], dtype=float64, weak_type=True))

In [167]:
import typing as tp
import functools as ft

def interp_3pt_mask(q: Array, dim: int) -> tp.Union[Array]:
    """creates the stencils for the upwind scheme
    - 3 pts inside domain
    - 1 pt near boundaries
    Args:
        q (Array): 
            Size = [Nx,Ny]
        dim (int): ONLY 0 or 1! 
    """

    # get number of points
    num_pts = q.shape[dim]
    
    # get padding
    if dim == 0:
        pad_left = ((1,0),(0,0))
        pad_right = ((0,1),(0,0))
    elif dim == 1:
        pad_left = ((0,0),(1,0))
        pad_right = ((0,0),(0,1))
    else:
        msg = f"Dims should be between 0 and 1!"
        msg += f"\nDims: {dim}"
        raise ValueError(msg)

    # define slicers
    dyn_slicer = ft.partial(jax.lax.dynamic_slice_in_dim, axis=dim)

    # interior slices
    q0 = dyn_slicer(q, 0, num_pts-2)
    q1 = dyn_slicer(q, 1, num_pts-2)
    q2 = dyn_slicer(q, 2, num_pts-2)
    
    # DO WENO Interpolation
    qi_left_interior = interp_weno.weno_3pts_improved(q0, q1, q2)
    qi_right_interior = interp_weno.weno_3pts_improved(q2, q1, q0)
    
    
    qi_left_interior = jnp.pad(qi_left_interior, pad_width=pad_left)
    qi_right_interior = jnp.pad(qi_right_interior,pad_width=pad_right)
    
    return qi_left_interior, qi_right_interior


def tracer_flux_3pt_mask(
    q: Array, u: Array, dim: int, 
    u_mask1: Array, 
    u_mask2: Array
    
):
    
    # 1 point flux
    qi_left_i_1pt, qi_right_i_1pt = F_flux.interp_1pt(q=q, dim=dim)
    
    # 3 point flux
    qi_left_i_3pt, qi_right_i_3pt = interp_3pt_mask(q=q, dim=dim)
    
    # calculate +ve and -ve points
    u_pos, u_neg = F_flux.plusminus(u)
    
    # calculate upwind flux
    flux_1pt = (u_pos * qi_left_i_1pt + u_neg * qi_right_i_1pt) 
    flux_3pt = (u_pos * qi_left_i_3pt + u_neg * qi_right_i_3pt) 
    
    return flux_1pt * u_mask1 + flux_3pt * u_mask2

* interp_3pt_weno
* interp_1pt
* u plusmins
* flux + interp_1pt + mask
* flux + interp3pt
* sum fluxes

In [171]:
x_ = x.numpy()

q_flux_on_u = tracer_flux_3pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    u_mask1=masks.u_distbound1[1:-1],
    u_mask2=masks.u_distbound2plus[1:-1]
)
q_flux_on_u

Array([[-0.        , -0.        , -0.        ,  0.        ],
       [-0.        ,  2.00500751,  6.85465431, -0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],      dtype=float64)

In [176]:
q_flux_on_v = tracer_flux_3pt_mask(
    x_, 
    u=u.numpy().T,
    dim=1, 
    u_mask1=masks.v_distbound1[:, 1:-1],
    u_mask2=masks.v_distbound2plus[:, 1:-1]
)
q_flux_on_v

Array([[-0.        , -0.        ,  0.        ],
       [-0.        ,  2.00500751,  0.        ],
       [-0.        , 10.28198147,  0.        ],
       [ 0.        , -0.        ,  0.        ]], dtype=float64)