In [1]:
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

jax.config.update("jax_enable_x64", True)
np.set_printoptions(precision=3, suppress=True)

In [4]:
class State(eqx.Module):
    def dot(self, simulator: "Simulator1D") -> "State":
        raise NotImplementedError


class Simulator1D(eqx.Module):
    x_grid: jax.Array

    def __init__(self, nx: int, x_min: float, x_max: float):
        # Define the space grid
        self.x_grid = jnp.linspace(x_min, x_max, nx)

    @property
    def delta(self):
        return self.x_grid[1] - self.x_grid[0]

    def at(self, field, x):
        return jnp.interp(x, self.x_grid, field)

    def dx(self, field):
        # periodic boundary conditions
        right = jnp.roll(field, -1)
        left = jnp.roll(field, 1)

        # right = jnp.append(field[1:], 0.0)
        # left = jnp.append(0.0, field[:-1])

        # right = jnp.append(field[1:], field[-2])
        # left = jnp.append(field[1], field[:-1])

        return (right - left) / (2 * self.delta)

    def dirac(self, x):
        dist = self.x_grid - x
        return (
            jnp.where(jnp.abs(dist) < self.delta, 1 - jnp.abs(dist) / self.delta, 0.0)
            / self.delta
        )

    @eqx.filter_jit
    def evolve(self, state: State, t_end: float, nt: int):
        def f(t, y: State, args) -> State:
            return y.dot(self)

        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(f),
            diffrax.Tsit5(),
            t0=0.0,
            t1=t_end,
            dt0=t_end / nt,
            y0=state,
            saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, t_end, nt)),
            stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-6),
            max_steps=32 * 1024,
        )
        return solution

In [5]:
# I tried to separate the physics (state and equation of motion) from the simulation (grid and time stepping).


class Charge(eqx.Module):
    # rho(t, x, y, z) = rho(t, x) = q dirac(x - pos[0])
    # J(t, x, y, z) = J(t, x) = q vel dirac(x - pos[0])
    pos: jax.Array
    vel: jax.Array
    # Electro-Magnetic properties
    m: jax.Array  # mass - relating the EM force to the acceleration
    q: (
        jax.Array
    )  # charge - relating (i) the EM field to the EM force and (ii) the current
    # The (YZ plane of) charge is hold by a spring
    anchor: jax.Array  # anchor position of the spring
    omega: jax.Array  # angular frequency of the spring constant


class World(eqx.Module):
    # E(t, x, y, z) = E(t, x)
    # B(t, x, y, z) = B(t, x)
    charges: Charge  # vmapped over the charges

    # Ex only depends on the charges and a global shift (set to 0)
    Ey: jax.Array
    Ez: jax.Array
    # Bx is a constant (set to 0)
    By: jax.Array
    Bz: jax.Array

    def E_at(self, x: float, sim: "Simulator1D") -> jax.Array:
        Ex = (
            0.5
            * jax.vmap(lambda ch: ch.q * jnp.sign(x - ch.pos[0]))(self.charges).sum()
        )
        return jnp.array([Ex, sim.at(self.Ey, x), sim.at(self.Ez, x)])

    def B_at(self, x: float, sim: "Simulator1D") -> jax.Array:
        return jnp.array([0.0, sim.at(self.By, x), sim.at(self.Bz, x)])

    def J(self, sim: "Simulator1D") -> jax.Array:
        return jax.vmap(lambda ch: ch.q * ch.vel * sim.dirac(ch.pos[0])[:, None])(
            self.charges
        ).sum(axis=0)

    def dot(self, sim: "Simulator1D") -> "World":
        def dot_charge(charge: Charge) -> Charge:
            return Charge(
                m=0.0,
                q=0.0,
                pos=charge.vel,
                vel=charge.q
                * (
                    self.E_at(charge.pos[0], sim)
                    + jnp.cross(charge.vel, self.B_at(charge.pos[0], sim))
                )
                / charge.m
                - charge.omega**2 * (charge.pos - charge.anchor),
                anchor=jnp.zeros_like(charge.anchor),
                omega=jnp.zeros_like(charge.omega),
            )

        return World(
            charges=jax.vmap(dot_charge)(self.charges),
            Ey=-sim.dx(self.Bz) - self.J(sim)[:, 1],
            Ez=sim.dx(self.By) - self.J(sim)[:, 2],
            By=sim.dx(self.Ez),
            Bz=-sim.dx(self.Ey),
        )

In [6]:
sim = Simulator1D(2 * 2048, x_min=0.0, x_max=40.0)

In [None]:
def envelope(x_start, x_end):
    x = ((sim.x_grid - x_start) / (x_end - x_start) - 0.5) * 2.0
    env = jnp.exp(-((x / 0.4) ** 2))
    return jnp.where((sim.x_grid > x_start) * (sim.x_grid < x_end), env, 0.0)


wave = 0.0
wave += 1.0 * envelope(0.0, 10.0) * jnp.sin(3.0 * (sim.x_grid - 5.0))

plt.plot(sim.x_grid, wave)

In [None]:
def init_charge(q: float, x0: float) -> Charge:
    return Charge(
        m=0.9,
        q=q,
        pos=jnp.array([x0, 0.0, 0.0]),
        vel=jnp.array([0.0, 0.0, 0.0]),
        omega=jnp.array(5.0),
        anchor=jnp.array([x0, 0.0, 0.0]),
    )


nc = 60
initial_state = World(
    charges=jax.vmap(init_charge)(
        (-1.0) ** jnp.arange(nc),
        jnp.repeat(jnp.linspace(10.0, 15.0, nc // 2), 2),
    ),
    Ey=wave,
    Ez=jnp.zeros_like(sim.x_grid),
    By=jnp.zeros_like(sim.x_grid),
    Bz=wave,
)
sol = sim.evolve(initial_state, 30.0, 2 * 2048)
sol.stats

In [None]:
fig = plt.figure(figsize=(9, 9), dpi=150)
ax = fig.add_subplot(111)

a = 0.5
ax.imshow(
    sol.ys.Ey,
    cmap="bwr",
    origin="lower",
    vmin=-a,
    vmax=a,
    extent=[sim.x_grid[0], sim.x_grid[-1], 0, sol.ts[-1]],
)
ax.plot(sol.ys.charges.pos[:, :, 0], sol.ts, "k-", linewidth=0.2, alpha=0.5)
ax.plot([5, 35], [0, 30], "g-")

ax.set_xlabel("x")
ax.set_xlim(sim.x_grid[0], sim.x_grid[-1])
ax.set_ylabel("t")
ax.set_ylim(0, sol.ts[-1])