In [1]:
from typing import NamedTuple

import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

jax.config.update("jax_enable_x64", True)
np.set_printoptions(precision=3, suppress=True)

In [2]:
# Define the space grid
nx = 2 * 1024
x_min, x_max = -10.0, 10.0
x_grid = jnp.linspace(x_min, x_max, nx)


def at(field, x):
    return jnp.interp(x, x_grid, field)


def dx(field):
    delta = x_grid[1] - x_grid[0]

    # periodic boundary conditions
    right = jnp.roll(field, -1)
    left = jnp.roll(field, 1)

    left_diff = (field - left) / delta
    right_diff = (right - field) / delta
    return 0.5 * (right_diff + left_diff)


def dirac(x):
    delta = x_grid[1] - x_grid[0]
    dist = x_grid - x
    return jnp.where(jnp.abs(dist) < delta, 1 - jnp.abs(dist) / delta, 0.0) / delta

In [3]:
Ex0 = 0.0
Bx0 = 0.0


# Time dependent quantities
class State(NamedTuple):
    # rho(t, x, y, z) = rho(t, x)
    # J(t, x, y, z) = J(t, x)
    # E(t, x, y, z) = E(t, x)
    # B(t, x, y, z) = B(t, x)
    qx: jax.Array
    qy: jax.Array
    qz: jax.Array

    qvx: jax.Array
    qvy: jax.Array
    qvz: jax.Array
    # Ex = Ex0 + [0.5 if x > qx else -0.5]
    Ez: jax.Array
    By: jax.Array
    Ey: jax.Array
    Bz: jax.Array
    # Bx = Bx0


def equations(state: State) -> State:
    qx, qy, qz, qvx, qvy, qvz, Ez, By, Ey, Bz = state

    w = 5.0
    wx, wy, wz = w, w, w  # spring constant
    return State(
        qx=qvx,
        qy=qvy,
        qz=qvz,
        qvx=(Ex0 + qvy * at(Bz, qx) - qvz * at(By, qx)) - wx**2 * qx,
        qvy=(at(Ey, qx) - qvx * at(Bz, qx) + qvz * Bx0) - wy**2 * qy,
        Ey=-dx(Bz) - dirac(qx) * qvy,
        Bz=-dx(Ey),
        qvz=(at(Ez, qx) + qvx * at(By, qx) - qvy * Bx0) - wz**2 * qz,
        Ez=dx(By) - dirac(qx) * qvz,
        By=dx(Ez),
    )

In [4]:
def envelope(x_start, x_end):
    x = ((x_grid - x_start) / (x_end - x_start) - 0.5) * 2.0
    env = jnp.exp(-((x / 0.4) ** 2))
    return jnp.where((x_grid > x_start) * (x_grid < x_end), env, 0.0)


omega = 5.0
mask = envelope(0.0, 10.0)

initial_state = State(
    qx=jnp.array(0.0),
    qy=jnp.array(0.0),
    qz=jnp.array(0.0),
    qvx=jnp.array(0.0),
    qvy=jnp.array(0.0),
    qvz=jnp.array(0.0),
    Ez=jnp.zeros_like(x_grid),
    By=jnp.zeros_like(x_grid),
    Ey=mask * jnp.sin(omega * (x_grid - 5.0)),
    Bz=mask * jnp.sin(omega * (x_grid - 5.0)),
)

In [None]:
plt.title("Initial state")
plt.plot(x_grid, initial_state.Ey, label="Ey")
plt.plot(x_grid, initial_state.Bz, label="Bz")

In [None]:
def f(t, y, args):
    return equations(y)


t1 = 25.0

solution = diffrax.diffeqsolve(
    diffrax.ODETerm(f),
    diffrax.Tsit5(),
    t0=0.0,
    t1=t1,
    dt0=0.01,
    y0=initial_state,
    saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, t1, 1024)),
    stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-7),
    max_steps=8 * 4096,
)

ts: jax.Array = solution.ts
ys: State = solution.ys

solution.stats

In [None]:
fig = plt.figure(figsize=(5, 10))
ax = fig.add_subplot(111)

a = 1.0
ax.imshow(
    ys.Ey.T,
    cmap="bwr",
    origin="lower",
    vmin=-a,
    vmax=a,
    extent=[0, t1, x_min, x_max],
    aspect="auto",
)

ax.set_xlabel("t")
ax.set_ylabel("x")
# plt.xticks(np.linspace(0, nx, 5), np.linspace(x_min, x_max, 5))
# plt.yticks(np.linspace(0, len(ts), 5), np.linspace(0, t1, 5))

ax.plot(ts, ys.qx, "k-", label="qx")
ax.plot(ts, ts - 15.0, "k-")

ax.set_xlim(0, t1)
ax.set_ylim(x_min, x_max)
ax.set_title("How does the wave propagate?")

In [None]:
plt.title("How does the plane of changes move?")
plt.plot(ts, ys.qx, label="qx")
plt.plot(ts, ys.qy, label="qy")
plt.plot(ts, ys.qz, label="qz")
plt.plot(ts, ys.Ey[:, nx // 2] / 5, label="Ey")
plt.ylim(-0.2, 0.2)
plt.legend()