# Charged particle in a magnetic field

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sympy
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
import jax.numpy as jnp
import matplotlib.pyplot as plt
import polars as pl
from matplotlib.figure import Figure
import gif

## Maths

### Lagrangian

$$ T = \frac{1}{2} \sum_i \dot{q}_i^2 $$

$$ V = - \frac{e}{c} \vec{A} \cdot \vec{v}$$

$$ L = T - V = \frac{1}{2} \sum_i \dot{q}_i^2 + \frac{e}{c} \vec{A} \cdot \vec{v} $$

In [None]:
t, e, c, b, m = sympy.symbols("t e c b m")
q1f, q2f, p1f, p2f = sympy.symbols("q1f q2f p1f p2f", cls=sympy.Function)
q1, q2, p1, p2, q1_dot, q2_dot = sympy.symbols("q1 q2 p1 p2 q1_dot q2_dot")

q1t = q1f(t)
q2t = q2f(t)
q1t_dot = q1t.diff(t)
q2t_dot = q2t.diff(t)

p1t = p1f(t)
p2t = p2f(t)
p1t_dot = p1t.diff(t)
p2t_dot = p2t.diff(t)

In [None]:
T = 0.5 * m * (q1t_dot**2 + q2t_dot**2)
T

$$ A = (0, bq_1, 0) $$

In [None]:
V = -e / c * b * q1t * q2t_dot
V

plugging it into the Lagrangian

$$ L = \frac{1}{2} \sum_i \dot{q}_i^2 + \frac{e}{c} bq_1 \cdot \dot{q}_2 $$

In [None]:
L = T - V
L

using 

$$ p = \partial_{\dot{q}} L $$

In [None]:
p1t_L = L.diff(q1t_dot)
p1t_L

In [None]:
p2t_L = L.diff(q2t_dot)
p2t_L

express $\dot{q}$ in terms of $p$

In [None]:
p_eq1 = p1t - p1t_L
p_eq2 = p2t - p2t_L

q_dot_qp = sympy.solve((p_eq1, p_eq2), (q1t_dot, q2t_dot))
q_dot_qp

$$ H = \sum_i p_i \dot{q}_i - L

In [None]:
H = p1t * q1t_dot + p2t * q2t_dot - L
H

replacing $\dot{q}$ with $p$, $q$

In [None]:
H_pq = H.subs(q1t_dot, q_dot_qp[q1t_dot]).subs(q2t_dot, q_dot_qp[q2t_dot]).simplify()
H_pq

computing Hamilton equations of motion

$$ \dot{q}_i = \partial_{p_i} H $$

$$ \dot{p}_i = -\partial_{q_i} H $$

In [None]:
q1t_dot_H = H_pq.diff(p1t)
q1t_dot_H

In [None]:
q2t_dot_H = H_pq.diff(p2t)
q2t_dot_H

In [None]:
p1t_dot_H = -H_pq.diff(q1t)
p1t_dot_H

In [None]:
p2t_dot_H = -H_pq.diff(q2t)
p2t_dot_H

replacing $q(t)$ with $q$, similar for derivatives and $p$, so we can lambdify

In [None]:
q1t_dot_H_subs = q1t_dot_H.subs(p1t, p1)
q1t_dot_H_subs

In [None]:
q2t_dot_H_subs = q2t_dot_H.subs(q1t, q1).subs(p2t, p2)
q2t_dot_H_subs

In [None]:
p1t_dot_H_subs = p1t_dot_H.subs(q1t, q1).subs(p2t, p2)
p1t_dot_H_subs

In [None]:
p2t_dot_H_subs = p2t_dot_H
p2t_dot_H_subs

creating jax functions from the $p$ and $q$ equations

In [None]:
q1t_dot_H_jax = sympy.lambdify((e, c, b, m, q1, q2, p1, p2), q1t_dot_H_subs, "jax")
q2t_dot_H_jax = sympy.lambdify((e, c, b, m, q1, q2, p1, p2), q2t_dot_H_subs, "jax")
p1t_dot_H_jax = sympy.lambdify((e, c, b, m, q1, q2, p1, p2), p1t_dot_H_subs, "jax")
p2t_dot_H_jax = sympy.lambdify((e, c, b, m, q1, q2, p1, p2), p2t_dot_H_subs, "jax")

creating jax functions to compute initial $p$ from initial $q$ and $\dot{q}$

In [None]:
p1_L = p1t_L.subs(q1t_dot, q1_dot)
p1_L

In [None]:
p2_L = p2t_L.subs(q2t_dot, q2_dot).subs(q1t, q1)
p2_L

In [None]:
p1_L_jax = sympy.lambdify((e, c, b, m, q1, q2, q1_dot, q2_dot), p1_L, "jax")
p2_L_jax = sympy.lambdify((e, c, b, m, q1, q2, q1_dot, q2_dot), p2_L, "jax")

## Numeric ODE solution

In [None]:
def ode_to_sole(t, z, args):
    e, c, b, m = args
    q_1, q_2, p_1, p_2 = z

    q_1_dot = q1t_dot_H_jax(e, c, b, m, q_1, q_2, p_1, p_2)
    q_2_dot = q2t_dot_H_jax(e, c, b, m, q_1, q_2, p_1, p_2)
    p_1_dot = p1t_dot_H_jax(e, c, b, m, q_1, q_2, p_1, p_2)
    p_2_dot = p2t_dot_H_jax(e, c, b, m, q_1, q_2, p_1, p_2)

    return q_1_dot, q_2_dot, p_1_dot, p_2_dot


e_val = 1.0
c_val = 1.0
b_val = 1.0
m_val = 1.0


q_1_0_val = 0.0
q_2_0_val = 0.0

q_1_dot_0_val = 1.0
q_2_dot_0_val = 0.0

p_1_0_val = p1_L_jax(
    e_val, c_val, b_val, m_val, q_1_0_val, q_2_0_val, q_1_dot_0_val, q_2_dot_0_val
)
p_2_0_val = p2_L_jax(
    e_val, c_val, b_val, m_val, q_1_0_val, q_2_0_val, q_1_dot_0_val, q_2_dot_0_val
)

params = (e_val, c_val, b_val, m_val)
initial_values = (q_1_0_val, q_2_0_val, p_1_0_val, p_2_0_val)
params, initial_values

In [None]:
ti, tf = 0, 10
t_span = [ti, tf]
t_eval = jnp.linspace(ti, tf, 101)
dt0 = t_eval[1].item() - t_eval[0].item()

term = ODETerm(ode_to_sole)
solver = Dopri5()
saveat = SaveAt(ts=t_eval)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)


sol = diffeqsolve(
    term,
    solver,
    t0=ti,
    t1=tf,
    dt0=dt0,
    y0=initial_values,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    args=params,
)

In [None]:
assert sol.ys is not None
assert sol.ts is not None
t_sol = sol.ts
q_1_sol = sol.ys[0]
q_2_sol = sol.ys[1]
p_1_sol = sol.ys[2]
p_2_sol = sol.ys[3]

fig, axs = plt.subplots(nrows=2, sharex=True)
ax = axs[0]
ax.plot(t_sol, q_1_sol, label="$q_1$")
ax.plot(t_sol, p_1_sol, label="$p_1$")
ax.legend()

ax = axs[1]
ax.plot(t_sol, q_2_sol, label=r"$q_2$")
ax.plot(t_sol, p_2_sol, label=r"$p_2$")
ax.set(xlabel="t")
ax.legend()
plt.tight_layout()

In [None]:
trajectory = pl.DataFrame(
    {
        "t": t_sol.tolist(),
        "x1": q_1_sol.tolist(),
        "y1": q_2_sol.tolist(),
    }
)

fig, ax = plt.subplots()
ax.scatter(data=trajectory, x="x1", y="y1")
ax.set(xlabel="x", ylabel="y")
plt.tight_layout()

In [None]:
def animate_trajectory(trajectory: pl.DataFrame, title: str, filename: str):
    x1_vals = trajectory["x1"].to_list()
    y1_vals = trajectory["y1"].to_list()

    x_min = min(x1_vals) - 1
    x_max = max(x1_vals) + 1

    y_min = min(y1_vals) - 1
    y_max = max(y1_vals) + 1

    frames = []

    @gif.frame
    def plot_frame(i: int) -> Figure:
        fig, ax = plt.subplots()

        ax.plot(x1_vals[: i + 1], y1_vals[: i + 1], color="lightblue", linewidth=2)
        ax.plot(x1_vals[i], y1_vals[i], "bo", markersize=8, label="1")

        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        ax.set_xlabel("x")
        ax.set_ylabel("y")

        ax.set_title(title)

        return fig

    for i in range(len(x1_vals)):
        frames.append(plot_frame(i))

    gif.save(frames, filename, duration=300)


animate_trajectory(
    trajectory,
    "Charged particle in magnetic field",
    "trajectory-chared-particle-magnetic-field.gif",
)