# Calculating Fluxes

In [1]:
import autoroot
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import equinox as eqx
import finitediffx as fdx
from jaxtyping import Array, Float
# from jaxsw._src.domain.base import Domain
from jaxsw._src.domain.base_v2 import Domain, init_domain_1d
import math
import torch
import torch.nn.functional as F
import einops
import seaborn as sns

jax.config.update("jax_enable_x64", True)
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

Let's look at the typical material derivative

$$
\frac{Dq}{Dt} := \partial_t q + \vec{u}\cdot\nabla q
$$



where the velocity vector is given by

$$
\vec{u} = 
\left[ u, v \right]^\top
$$

In particular, we are interested in the advection term which is defined as:

$$
\begin{aligned}
\text{Advection}: &= \vec{u} \cdot \nabla q = u\partial_x q + v\partial_y q \\
&= \nabla \cdot (\vec{u}q) = \partial_x (uq) + \partial_y (vq) 
\end{aligned}
$$

where the second equation is the conservative form.

The problem comes into play when we are on a staggered grid. The velocities are on one grid and the tracer is on a difference grid. The velocity

$$
\begin{aligned}
\text{Variable}: \psi &\in\Omega_\psi \\
\text{Zonal Velocity}: u &\in\Omega_u \\
\text{Meridional Velocity}: v &\in\Omega_v \\
\text{Tracer}: q &\in\Omega_q \\
\end{aligned}
$$



$$
\begin{aligned}
\Omega_\psi &\in[N_x,N_y] && && [x_0,y_0] && && [x_1,y_1] \\ 
\Omega_u &\in[N_x,N_y-1] && && [x_0,y_0+\frac{1}{2}dy] && && [x_1,y_1-\frac{1}{2}dy]  \\ 
\Omega_v &\in[N_x-1,N_y] && && [x_0+\frac{1}{2}dx,y_0] && && [x_1-\frac{1}{2}dx,y_1]  \\ 
\Omega_q &\in[N_x-1,N_y-1] && && [x_0+\frac{1}{2}dx,y_0+\frac{1}{2}dy] && && [x_1-\frac{1}{2}dx,y_1-\frac{1}{2}dy] 
\end{aligned}
$$







So if $q$ is on a staggered domain.

In [2]:

# x = torch.arange(1, 36+1)
num = 25
num_sqrt = math.sqrt(num)
assert num_sqrt.is_integer(), num_sqrt

# x = torch.randn(size=(num,))
u = torch.randn(size=(num,))
x = torch.arange(0, num)
print(x.shape, num_sqrt)
x = einops.rearrange(x, "(Nx Ny) -> Nx Ny", Nx=int(num_sqrt), Ny=int(num_sqrt))
u = einops.rearrange(u, "(Nx Ny) -> Nx Ny", Nx=int(num_sqrt), Ny=int(num_sqrt))

dim = 0
u = u[1:]
N = x.shape[dim]
assert u.shape[dim] == N - 1

torch.Size([25]) 5.0


In [3]:
x

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])

In [4]:
from jaxsw._src.masks import Mask

In [5]:
mask = jnp.ones(x.shape)
mask = mask.at[0].set(0.0)
mask = mask.at[-1].set(0.0)
mask = mask.at[:, 0].set(0.0)
mask = mask.at[:, -1].set(0.0)

masks = Mask.init_mask(mask, variable="q")

In [6]:
def linear2(qm, qp):
    """
    2-points linear reconstruction:

    qm--x--qp

    """
    return 0.5 * (qm + qp)
    
def linear3_left(qm, q0, qp):
    """
    3-points linear left-biased stencil reconstruction:

    qm-----q0--x--qp

    """
    return -1./6.*qm + 5./6.*q0 + 1./3.*qp

def linear3_right(qm, q0, qp):
    """
    3-points linear left-biased stencil reconstruction:

    qp--x--q0-----qm

    """
    return linear3_left(qp, q0, qm)

def linear4(qmm, qm, qp, qpp):
    """
    4-points linear reconstruction:

    qmm-----qm--x--qp-----qpp

    """
    return -1./12.*qmm + 7./12.*qm + 7./12.*qp - 1./12.*qpp

def weno3z(qm, q0, qp):
    """
    3-points non-linear left-biased stencil reconstruction:

    qm-----q0--x--qp

    An improved weighted essentially non-oscillatory scheme for hyperbolic
    conservation laws, Borges et al, Journal of Computational Physics 227 (2008).
    """
    eps = 1e-14

    qi1 = -1./2.*qm + 3./2.*q0
    qi2 = 1./2.*(q0 + qp)

    beta1 = (q0-qm)**2
    beta2 = (qp-q0)**2
    tau = torch.abs(beta2-beta1)

    g1, g2 = 1./3., 2./3.
    w1 = g1 * (1. + tau / (beta1 + eps))
    w2 = g2 * (1. + tau / (beta2 + eps))

    qi_weno3 = (w1*qi1 + w2*qi2) / (w1 + w2)

    return qi_weno3


def flux_3pts(q, u, dim):
    """
    Flux computation for staggerded variables q and u, with solid boundaries.
    Upwind-biased stencil:
      - 3 points inside domain.
      - 1 point near boundaries.

    Args:
        q: tracer field to interpolate, torch.Tensor, shape[dim] = n
        u: transport velocity, torch.Tensor, shape[dim] = n-1
        dim: dimension along which computations are done

    Returns:
        flux: tracer flux computed on u points, torch.Tensor, shape[dim] = n-1
    """
    n = q.shape[dim]

    # q-interpolation: 3-points inside domain

    qm, q0, qp = q.narrow(dim, 0, n-2), q.narrow(dim, 1, n-2), q.narrow(dim, 2, n-2)

    qi_left_in = linear3_left(qm, q0, qp)
    qi_right_in = linear3_left(qp, q0, qm)

    # q-interpolation: 2-points on boundaries
    qi_0 = linear2(q.narrow(dim, 0, 1), q.narrow(dim, 1, 1))
    qi_m1 = linear2(q.narrow(dim, -2, 1), q.narrow(dim, -1, 1))

    qi_left = torch.cat([
        qi_0, qi_left_in.narrow(dim, 0, n-3), qi_m1], dim=dim)
    qi_right = torch.cat([
        qi_0, qi_right_in.narrow(dim, 1, n-3), qi_m1], dim=dim)

    # positive and negative parts of velocity
    u_pos = F.relu(u)
    u_neg = u - u_pos

    # upwind flux computation
    flux = u_pos * qi_left + u_neg * qi_right
    return qi_left, qi_right, flux


[0, 1, 2, 3, 4, 5, 6]
Interior
- [0, 1, 2, 3, 4]
- [1, 2, 3, 4, 5]
- [2, 3, 4, 5, 6]
--> [1 2 3 4 5]
Left Bd
- [0 1]
- [1 2]
--> [0.5 1.5]
right bd
- [5 6]
- [4 5]
--> [4.5, 5.5]



In [7]:


def flux_3pts_mask(q, u, dim, mask_u_d1, mask_u_d2plus):
    n = q.shape[dim]
    pad1 = () if dim == -1 else (0,0)
    pad2 = (0,0) if dim == -1 else ()
    print(pad1, pad2)
    
    qm, q0, qp = q.narrow(dim, 0, n-2), q.narrow(dim, 1, n-2), \
                 q.narrow(dim, 2, n-2)
    
    print("LEFT!")
    qi3_left = weno3z(qm, q0, qp)
    print(qi3_left.shape)
    print(pad1+(1,0)+pad2)
    qi3_left = F.pad(qi3_left, pad1+(1,0)+pad2)
    print(qi3_left)
    print(qi3_left.shape)
    # ===============
    print("RIGHT!")
    qi3_right = weno3z(qp, q0, qm)
    print(qi3_right.shape)
    print(pad1+(0,1)+pad2)
    qi3_right = F.pad(qi3_right, pad1+(0,1)+pad2)
    print(qi3_right.shape)
    # qi2 = linear2(q.narrow(dim, 0, n-1), q.narrow(dim, 1, n-1))
    
    

    u_pos = F.relu(u)
    u_neg = u - u_pos
    # NORMAL FLUX
    a = (u_pos*q.narrow(dim, 0, n-1) + u_neg*q.narrow(dim, 1, n-1))
    
    print(a)
          
    # FANCY FLUX
    b = (u_pos*qi3_left + u_neg*qi3_right)
    
    print(b)
    
    print("FLUXXX!")
    print(a.shape)
    print(b.shape)
    flux = mask_u_d1 * a + mask_u_d2plus * b

    return qi3_left, qi3_right, flux

In [8]:
x.shape, u.shape, mask.shape, masks.u.distbound1[1:-1].shape

(torch.Size([5, 5]), torch.Size([4, 5]), (5, 5), (4, 5))

In [9]:
# qi_left_torch, qi_right_torch, flux = flux_3pts(x, u, 0)
qi_left_torch, qi_right_torch, flux = flux_3pts_mask(
    x[None, None, ...], u[None, None, ...], -2, 
    torch.Tensor(np.asarray(masks.u.distbound1[1:-1][None, None, ...])),
    torch.Tensor(np.asarray(masks.u.distbound2plus[1:-1][None, None, ...]))
)

(0, 0) ()
LEFT!
torch.Size([1, 1, 3, 5])
(0, 0, 1, 0)
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 7.5000,  8.5000,  9.5000, 10.5000, 11.5000],
          [12.5000, 13.5000, 14.5000, 15.5000, 16.5000],
          [17.5000, 18.5000, 19.5000, 20.5000, 21.5000]]]])
torch.Size([1, 1, 4, 5])
RIGHT!
torch.Size([1, 1, 3, 5])
(0, 0, 0, 1)
torch.Size([1, 1, 4, 5])
tensor([[[[ -0.9822,  -6.4643,   0.3690,  -1.1223,  -0.5799],
          [ -1.2434,   1.2616, -22.5542, -14.2985, -26.7411],
          [ -4.0552, -14.7635,   6.7217,  17.5426,   8.9709],
          [  8.8488,  -1.2903,  -7.1756, -10.1598,  -5.8775]]]])
tensor([[[[ -0.4911,  -3.7708,   0.0000,  -0.7716,  -0.4188],
          [ -0.9326,   1.7872, -17.8554, -11.5488, -21.9659],
          [ -3.3793, -12.4567,   8.1220,  20.9161,  10.5728],
          [ 10.3236,   0.0000,   0.0000,   0.0000,   0.0000]]]])
FLUXXX!
torch.Size([1, 1, 4, 5])
torch.Size([1, 1, 4, 5])


  torch.Tensor(np.asarray(masks.u.distbound1[1:-1][None, None, ...])),


In [10]:
flux

tensor([[[[ -0.0000,  -0.0000,   0.0000,  -0.0000,  -0.0000],
          [ -0.0000,   1.2616, -22.5542, -14.2985,  -0.0000],
          [ -0.0000, -14.7635,   6.7217,  17.5426,   0.0000],
          [  0.0000,   0.0000,   0.0000,   0.0000,   0.0000]]]])

In [11]:
x.shape, u.T.shape, mask.shape, masks.v.distbound1[:, 1:-1].shape

(torch.Size([5, 5]), torch.Size([5, 4]), (5, 5), (5, 4))

In [12]:
# qi_left_torch, qi_right_torch, flux = flux_3pts(x, u, 0)
qi_left_torch, qi_right_torch, flux = flux_3pts_mask(
    x[None, None, ...], u.T[None, None, ...], -1, 
    torch.Tensor(np.asarray(masks.v.distbound1[None, None, :, 1:-1])),
    torch.Tensor(np.asarray(masks.v.distbound2plus[None, None, :, 1:-1]))
)

() (0, 0)
LEFT!
torch.Size([1, 1, 5, 3])
(1, 0, 0, 0)
tensor([[[[ 0.0000,  1.5000,  2.5000,  3.5000],
          [ 0.0000,  6.5000,  7.5000,  8.5000],
          [ 0.0000, 11.5000, 12.5000, 13.5000],
          [ 0.0000, 16.5000, 17.5000, 18.5000],
          [ 0.0000, 21.5000, 22.5000, 23.5000]]]])
torch.Size([1, 1, 5, 4])
RIGHT!
torch.Size([1, 1, 5, 3])
(0, 1, 0, 0)
torch.Size([1, 1, 5, 4])
tensor([[[[ -0.1964,  -0.2487,  -0.8110,   1.7698],
          [ -6.4643,   1.2616,  -7.3817,  -0.5530],
          [  1.8452, -22.5542,   6.7217,  -4.5663],
          [ -2.2447, -18.6980,  22.9403,  -8.3929],
          [ -1.3532, -42.0217,  14.0971,  -5.8775]]]])
tensor([[[[ -0.0982,  -0.1865,  -0.6759,   2.0647],
          [ -5.9256,   1.3667,  -6.9204,   0.0000],
          [  0.0000, -21.6144,   7.0018,   0.0000],
          [ -2.1745, -18.1481,  23.6150,   0.0000],
          [ -1.3210, -41.0667,  14.4175,   0.0000]]]])
FLUXXX!
torch.Size([1, 1, 5, 4])
torch.Size([1, 1, 5, 4])


In [13]:
flux

tensor([[[[ -0.0000,  -0.0000,  -0.0000,   0.0000],
          [ -0.0000,   1.2616,  -7.3817,   0.0000],
          [  0.0000, -22.5542,   6.7217,   0.0000],
          [ -0.0000, -18.6980,  22.9403,   0.0000],
          [ -0.0000,  -0.0000,   0.0000,   0.0000]]]])

In [14]:
import jax.numpy as jnp
import kernex as kex
import jax
import numpy as np

In [15]:
from jaxsw._src.operators.functional.interp import weno as interp_weno
from jaxsw._src.operators.functional.interp import linear as interp_linear
from jaxsw._src.operators.functional.interp import flux as F_flux

* interp_3pt_weno
* interp_1pt
* u plusmins
* flux + interp_1pt + mask
* flux + interp3pt
* sum fluxes

---

**Methods**:

* `dim`: 0 (x-axis), 1 (y-axis)
* `num_pts`: 1, 3, 5
* `method`: linear, weno, wenoz
* `mask` (depends): `u_mask1`, `u_mask2`, `umask_2plus`, `umask_3plus`

### 1pt

In [34]:
x_ = x.numpy()

flux_methods = {}

flux_methods["1pt_linear"] = F_flux.tracer_flux_1pt(
    x_,
    u=u.numpy(),
    dim=0,
)

In [38]:
flux_methods["1pt_linear"]

Array([[ -0.9821613 ,  -6.4642696 ,   0.36903504,  -1.1223415 ,
         -0.57993454],
       [ -1.2434431 ,   1.2615503 , -22.554178  , -14.298466  ,
        -26.741076  ],
       [ -4.055218  , -14.763467  ,   6.721692  ,  17.542566  ,
          8.970864  ],
       [  8.848809  ,  -1.2902595 ,  -7.175615  , -10.159798  ,
         -5.877451  ]], dtype=float32)

### 3pts

In [36]:
x_ = x.numpy()



method = "linear"
flux_methods[f"3pt_{method}"] = F_flux.tracer_flux_3pt(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method
)

method = "weno"
flux_methods[f"3pt_{method}"] = F_flux.tracer_flux_3pt(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method
)

method = "wenoz"
flux_methods[f"3pt_{method}"] = F_flux.tracer_flux_3pt(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method
)

In [40]:
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_weno"])
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_wenoz"])

In [47]:
flux_methods["3pt_linear"]

Array([[ -0.49108064,  -3.770824  ,   0.8303288 ,  -0.7716098 ,
         -0.41884163],
       [ -0.9325823 ,   1.7871963 , -17.85539   , -11.548761  ,
        -21.965883  ],
       [ -3.3793488 , -12.456676  ,   8.122045  ,  20.916138  ,
         10.572804  ],
       [ 10.323611  ,  -1.1366571 ,  -6.3602037 ,  -9.055471  ,
         -5.2652164 ]], dtype=float32)

### 5pts

In [41]:
x_ = x.numpy()

flux_methods_5pt = {}

method = "linear"
flux_methods[f"5pt_{method}"] = F_flux.tracer_flux_5pt(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method
)

method = "weno"
flux_methods[f"5pt_{method}"] = F_flux.tracer_flux_5pt(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method
)

method = "wenoz"
flux_methods[f"5pt_{method}"] = F_flux.tracer_flux_5pt(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method
)

In [45]:
# check if all the same
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_weno"])
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_wenoz"])

In [46]:
# check if same as 5pt method
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["5pt_linear"])
np.testing.assert_array_almost_equal(flux_methods["3pt_weno"], flux_methods["5pt_weno"])
np.testing.assert_array_almost_equal(flux_methods["3pt_wenoz"], flux_methods["5pt_wenoz"])

### Masks

### 1pt

In [49]:
x_ = x.numpy()

flux_methods = {}

flux_methods["1pt_linear"] = F_flux.tracer_flux_1pt_mask(
    x_,
    u=u.numpy(),
    dim=0,
    u_mask1=masks.u.distbound1[1:-1]
)

In [50]:
flux_methods["1pt_linear"]

Array([[ -0.        ,  -0.        ,   0.        ,  -0.        ,
         -0.        ],
       [ -0.        ,   1.26155031, -22.55417824, -14.29846573,
         -0.        ],
       [ -0.        , -14.76346684,   6.72169209,  17.5425663 ,
          0.        ],
       [  0.        ,  -0.        ,  -0.        ,  -0.        ,
         -0.        ]], dtype=float64)

### 3pts

In [53]:
x_ = x.numpy()



method = "linear"
flux_methods[f"3pt_{method}"] = F_flux.tracer_flux_3pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method,
    u_mask1=masks.u.distbound1[1:-1],
    u_mask2plus=masks.u.distbound2plus[1:-1],
)

method = "weno"
flux_methods[f"3pt_{method}"] = F_flux.tracer_flux_3pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method,
    u_mask1=masks.u.distbound1[1:-1],
    u_mask2plus=masks.u.distbound2plus[1:-1],
)

method = "wenoz"
flux_methods[f"3pt_{method}"] = F_flux.tracer_flux_3pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method,
    u_mask1=masks.u.distbound1[1:-1],
    u_mask2plus=masks.u.distbound2plus[1:-1],
)

In [54]:
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_weno"])
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_wenoz"])

In [55]:
flux_methods["3pt_linear"]

Array([[ -0.        ,  -0.        ,   0.        ,  -0.        ,
         -0.        ],
       [ -0.        ,   1.26155031, -22.55417824, -14.29846573,
         -0.        ],
       [ -0.        , -14.76346684,   6.72169209,  17.5425663 ,
          0.        ],
       [  0.        ,   0.        ,   0.        ,   0.        ,
          0.        ]], dtype=float64)

### 5pts

In [56]:
x_ = x.numpy()

flux_methods_5pt = {}

method = "linear"
flux_methods[f"5pt_{method}"] = F_flux.tracer_flux_5pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method,
    u_mask1=masks.u.distbound1[1:-1],
    u_mask2=masks.u.distbound2[1:-1],
    u_mask3plus=masks.u.distbound3plus[1:-1],
)

method = "weno"
flux_methods[f"5pt_{method}"] = F_flux.tracer_flux_5pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method,
    u_mask1=masks.u.distbound1[1:-1],
    u_mask2=masks.u.distbound2[1:-1],
    u_mask3plus=masks.u.distbound3plus[1:-1],
)

method = "wenoz"
flux_methods[f"5pt_{method}"] = F_flux.tracer_flux_5pt_mask(
    x_, 
    u=u.numpy(),
    dim=0, 
    method=method,
    u_mask1=masks.u.distbound1[1:-1],
    u_mask2=masks.u.distbound2[1:-1],
    u_mask3plus=masks.u.distbound3plus[1:-1],
)

In [57]:
# check if all the same
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_weno"])
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["3pt_wenoz"])

In [58]:
# check if same as 5pt method
np.testing.assert_array_almost_equal(flux_methods["3pt_linear"], flux_methods["5pt_linear"])
np.testing.assert_array_almost_equal(flux_methods["3pt_weno"], flux_methods["5pt_weno"])
np.testing.assert_array_almost_equal(flux_methods["3pt_wenoz"], flux_methods["5pt_wenoz"])