# 1D Linear Convection

In [None]:
# !pip install --upgrade jax jaxlib
# !pip install jaxtyping diffrax xarray FiniteDiffX jaxdf

In [None]:
import typing as tp
import numpy as np
import xarray as xr
import jax
import jax.numpy as jnp
import diffrax as dfx
import finitediffx as fdx
import matplotlib.pyplot as plt
import seaborn as sns
from jaxtyping import Float, Array

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline

## Problem

Let's continue from the previous tutorial. Recall, we are working with a 1D Linear Convection scheme:

$$
\frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = 0
$$ (pde)

For the PDE {eq}`pde`, we are going to do a backwards difference discretization in space and a forwards in time.

## Geometry

In [None]:
from jaxdf.geometry import Domain

#### JaxDF - API 

In [None]:
from jaxdf.geometry import Domain

nx = 51
dx = 0.04

# initialize domain
domain = Domain(N=(nx,), dx=(dx,))

print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
print(f"dx: {domain.dx}")
print(f"Type: {type(domain)}")

## Initial Condition

In [None]:
def init_u0(domain):
    """Initial condition from grid"""
    u = jnp.ones_like(domain.grid, dtype=jnp.float64)

    u = u.at[int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1)].set(2.0)

    return u

In [None]:
u_init = init_u0(domain)

print(type(u_init))

## Equation of Motion

#### `jaxdf` - API

In [None]:
from jaxdf.discretization import FiniteDifferences
from jaxdf.operators import gradient


def equation_of_motion(t: Array, u: Array, args: tuple):
    c = args
    # initialize spatial discretization

    u = FiniteDifferences.from_grid(u, domain)

    u.accuracy = 2

    u_rhs = -c * gradient(u)

    return u_rhs.on_grid

In [None]:
c = 1.0

# initialize grid
u_init = init_u0(domain)

# RHS of equation of motion
out = equation_of_motion(0, u_init, c)

#### From Scratch

In [None]:
from jaxdf.discretization import FiniteDifferences
from jaxdf.operators import gradient


def equation_of_motion_scratch(t: Array, u: Array, args: tuple):
    c = args

    u_rhs = fdx.difference(
        u, axis=0, accuracy=1, method="backward", step_size=domain.dx[0]
    )

    return -c * u_rhs

In [None]:
# RHS of equation of motion
out_scratch = equation_of_motion_scratch(0, u_init, c)

In [None]:
fig, ax = plt.subplots()

ax.plot(domain.spatial_axis[0], u_init[..., 0], label="Initial Condition")
ax.plot(domain.spatial_axis[0], out[..., 0], label="JaxDF")
ax.plot(domain.spatial_axis[0], out_scratch[..., 0], label="Scratch")

plt.legend()
plt.show()

## Time Stepping

In [None]:
# temporal parameters
c = 1.0
sigma = 0.2


# CFL condition
def cfl_cond(dx, c, sigma):
    assert sigma <= 1.0
    return (sigma * dx) / c


dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)

t0 = 0.0
t1 = 0.5
ts = jnp.arange(t0, t1, dt)
saveat = dfx.SaveAt(ts=ts)

#### `JAXDF` API

In [None]:
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()


sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u_init,
    saveat=saveat,
    args=c,
    stepsize_controller=stepsize_controller,
)

#### From Scratch

In [None]:
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()


sol_scratch = dfx.diffeqsolve(
    terms=dfx.ODETerm(equation_of_motion_scratch),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u_init,
    saveat=saveat,
    args=c,
    stepsize_controller=stepsize_controller,
)

In [None]:
np.asarray(sol.ys).squeeze().shape

## Analysis

In [None]:
da_sol = xr.Dataset(
    {
        "jaxdf": (("time", "x"), np.asarray(sol.ys).squeeze()),
        "scratch": (("time", "x"), np.asarray(sol_scratch.ys).squeeze()),
    },
    coords={
        "x": (["x"], np.asarray(domain.spatial_axis[0])),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_sol

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(7, 5))

da_sol.jaxdf.T.plot.pcolormesh(ax=ax[0], cmap="gray_r")
da_sol.scratch.T.plot.pcolormesh(ax=ax[1], cmap="gray_r")

ax[0].set_title("JaxDF")
ax[1].set_title("Scratch")

plt.show()

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(7, 5))

for i in range(0, len(da_sol.time), 5):
    da_sol.jaxdf.isel(time=i).plot.line(ax=ax[0], color="gray")
    da_sol.scratch.isel(time=i).plot.line(ax=ax[1], color="gray")

plt.show()