# Particle in a rotation symmetric potential

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from sympy import dsolve, symbols, Function, Eq
import sympy
import polars as pl
import torch
import numpy as np
import gif
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import math
from scipy.integrate import solve_ivp
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
import jax.numpy as jnp
from jaxtyping import Array

## Lagrangian

$$ L = T - V = \frac{1}{2} m \left( \dot{x}^2 + \dot{y}^2 \right) - V(r) $$

$$ \frac{dL}{dx} = \frac{d}{dt} \frac{dL}{d\dot{x}} $$

$$ \frac{dL}{dy} = \frac{d}{dt} \frac{dL}{d\dot{y}} $$

$$ \frac{dV}{dr} \frac{dr}{dx} = m \ddot{x} $$

$$ \frac{dV}{dr} \frac{dr}{dy} = m \ddot{y} $$

$$ r = \sqrt{x^2 + y^2} $$

In [None]:
x, y = symbols("x y")

sympy.sqrt(x**2 + y**2).diff(x)

$$ \frac{dr}{dx} = \frac{x}{\sqrt{x^2 + y^2}} $$

ugh this is getting nasty, better switch from $L(x,y,t)$ to $L(r,\phi,t)$

In [None]:
sympy.sin(x).diff(x)  # type: ignore

$$ x = r \cos \phi $$

$$ y = r \sin \phi $$

$$ \dot{x} = \dot{r} \cos \phi - r \dot{\phi} \sin \phi $$

$$ \dot{y} = \dot{r} \sin \phi + r \dot{\phi} \cos \phi $$

$$ L = T - V = \frac{1}{2} m \left( \dot{x}^2 + \dot{y}^2 \right) - V(r) $$

$$ L = T - V = \frac{1}{2} m \left( \left( \dot{r} \cos \phi - r \dot{\phi} \sin \phi \right)^2 + \left( \dot{r} \sin \phi + r \dot{\phi} \cos \phi \right)^2 \right) - V(r) $$

In [None]:
t = symbols("t")
q_r, phi = symbols("r phi", cls=Function)

sympy.simplify(
    (q_r(t).diff(t) * sympy.cos(phi(t)) - q_r(t) * phi(t).diff(t) * sympy.sin(phi(t)))
    ** 2
    + (q_r(t).diff(t) * sympy.sin(phi(t)) + q_r(t) * phi(t).diff(t) * sympy.cos(phi(t)))
    ** 2
)

$$ L = T - V = \frac{1}{2} m \left( r^2 \dot\phi^2 + \dot{r}^2 \right) - V(r) $$

$$ \frac{dL}{dr} = \frac{d}{dt} \frac{dL}{d\dot{r}} $$

$$ \frac{dL}{d\phi} = \frac{d}{dt} \frac{dL}{d\dot{\phi}} $$

$$ mr\dot{\phi}^2 - \frac{dV}{dr} = m\ddot{r} $$

$$ 0 = d_t (mr^2\dot{\phi}) $$

$$ mr^2\dot{\phi} = l \rightarrow \dot{\phi} = \frac{l}{mr^2} $$

where $l$ is the angular momentum

so 

$$ mr\dot{\phi}^2 - \frac{dV}{dr} = m\ddot{r} $$

becomes

$$ mr \left( \frac{l}{mr^2} \right)^2 - \frac{dV}{dr} = m\ddot{r} $$

$$ \frac{l^2}{mr^3} - \frac{dV}{dr} = m\ddot{r} $$

In [None]:
l, m = symbols("l m")
dVdr = symbols("dVdr", cls=Function)

In [None]:
diffeq = Eq(m * q_r(t).diff(t, t) - l**2 / m / q_r(t) ** 3, 0)
diffeq

In [None]:
res = dsolve(diffeq)
res

well this sucks. let's re-write the radial equations of motion

$$ mr\dot{\phi}^2 - \frac{dV}{dr} = m\ddot{r} $$

$$ 0 = d_t (mr^2\dot{\phi}) \rightarrow \dot{\phi} = \frac{L_z}{mr^2} $$

$$ \ddot{r} = \frac{L_z^2}{m^2r^3} - \frac{1}{m} \frac{dV}{dr} $$

https://medium.com/@benjaminkhelyer/a-zoo-of-differential-equation-solvers-in-python-mostly-odes-bcb071a33450

In [None]:
def get_Lz(m: float, r: float, phi_dot: float) -> float:
    return m * r**2 * phi_dot


def get_dVdr_quadratic(r: float) -> float:
    return 2 * r


def particle_central_potential_ode(
    t: float, z: tuple[float, float, float], m: float, Lz: float
) -> tuple[float, float, float]:
    r, r_dot, phi = z
    dVdr = get_dVdr_quadratic(r)
    r_ddot = Lz**2 / m**2 / r**3 - 1 / m * dVdr
    phi_dot = Lz / m / r**2
    return r_dot, r_ddot, phi_dot

In [None]:
ti, tf = 0, 10
t_span = [ti, tf]
t_eval = torch.linspace(ti, tf, 101)

r0 = 1.0
r_dot0 = 2.0
phi0 = 0.0
phi_dot0 = 1.0

m = 1.0
Lz = get_Lz(m, r0, phi_dot0)
params = (m, Lz)
sol = solve_ivp(
    particle_central_potential_ode,
    t_span,
    [r0, r_dot0, phi0],
    args=params,
    t_eval=t_eval,
)

In [None]:
t_sol = sol.t
r_sol = sol.y[0, :]
r_dot_sol = sol.y[1, :]
phi_sol = sol.y[2, :]

fig, axs = plt.subplots(nrows=2, sharex=True)
ax = axs[0]
ax.plot(t_sol, r_sol, label="r")
ax.plot(t_sol, r_dot_sol, label="dr/dt")
ax.legend()

ax = axs[1]
ax.plot(t_sol, phi_sol, label="phi")
ax.set(xlabel="t", ylabel=r"$\phi$")

plt.tight_layout()

$$ x = r \cos \phi $$

$$ y = r \sin \phi $$

In [None]:
trajectory = pl.DataFrame(
    {
        "t": t_sol,
        "x": r_sol * np.cos(phi_sol),
        "y": r_sol * np.sin(phi_sol),
    }
)

fig, ax = plt.subplots()
ax.scatter(data=trajectory, x="x", y="y")
ax.set(xlabel="x", ylabel="y")
plt.tight_layout()

In [None]:
def animate_trajectory(trajectory: pl.DataFrame, title: str, filename: str):
    x_vals = trajectory["x"].to_list()
    y_vals = trajectory["y"].to_list()

    x_min = min(x_vals) - 1
    x_max = max(x_vals) + 1

    y_min = min(y_vals) - 1
    y_max = max(y_vals) + 1

    frames = []

    @gif.frame
    def plot_frame(i: int) -> Figure:
        fig, ax = plt.subplots()

        ax.plot(x_vals[: i + 1], y_vals[: i + 1], color="lightblue", linewidth=2)
        ax.plot(x_vals[i], y_vals[i], "bo", markersize=8)

        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        ax.set_xlabel("x")
        ax.set_ylabel("y")

        ax.set_title(title)

        return fig

    for i in range(len(x_vals)):
        frames.append(plot_frame(i))

    gif.save(frames, filename, duration=300)


animate_trajectory(
    trajectory,
    "Particle in central potential - Lagrangian",
    "trajectory-central-potential-lagrangian.gif",
)

## Hamiltonian

$$ L = T - V = \frac{1}{2} m \left( r^2 \dot\phi^2 + \dot{r}^2 \right) - V(r) $$

$$ L = \frac{1}{2} m \left( q_r^2 \dot{q}_\phi^2 + \dot{q}_r^2 \right) - V(q_r) $$

using 

$$ p = \partial_{\dot{q}} L $$

we get

$$ p_r = \partial_{\dot{q}_r} L =  m \dot{q}_r $$

$$ p_\phi = \partial_{\dot{q}_\phi} L = m q_r^2 \dot{q}_\phi $$

so

$$ \dot{q}_r = \frac{p_r}{m} $$

and 

$$ \dot{q}_\phi = \frac{p_\phi}{m q_r^2} $$

$$ H = \sum_i p_i \dot{q}_i - L $$

$$ H = p_r \dot{q}_r + p_\phi \dot{q}_\phi - \frac{1}{2} m \left( q_r^2 \dot{q}_\phi^2 + \dot{q}_r^2 \right) + V(q_r) $$

inserting the canonical momenta for the velocities found above

$$ H = \frac{p_r^2}{m} + \frac{p_\phi^2}{m q_r^2} - \frac{p_\phi^2}{2 m q_r^2} - \frac{p_r^2}{2m} + V(q_r) $$

$$ H = \frac{p_r^2}{2 m} + \frac{p_\phi^2}{2 m q_r^2} + V(q_r) $$

$$ \dot{q}_i = \partial_{p_i} H $$

$$ -\dot{p}_i = \partial_{q_i} H $$

$\dot{q}$ relationships

$$ \dot{q}_r = \partial_{p_r} H = \frac{p_r}{m} $$

$$ \dot{q}_\phi = \partial_{p_\phi} H = \frac{p_\phi}{m q_r^2} $$

$q$ derivatives

$$ -\dot{p}_r = \partial_{q_r} H =  - \frac{p^2_\phi}{m q_r^3} + \partial_{q_r} V(q_r) $$

$$ -\dot{p}_\phi = \partial_{q_\phi} H = 0 $$

In [None]:
def get_dVdr_quadratic(r: float) -> float:
    return 2 * r


def particle_central_potential_ode_hamiltonian(
    t: float, z: tuple[float, float, float, float], m: float
) -> tuple[float, float, float, float]:
    q_r, q_phi, p_r, p_phi = z
    dVdqr = get_dVdr_quadratic(q_r)

    q_r_dot = p_r / m
    q_phi_dot = p_phi / m / q_r**2
    p_r_dot = p_phi**2 / m / q_r**3 - dVdqr
    p_phi_dot = 0

    return q_r_dot, q_phi_dot, p_r_dot, p_phi_dot

In [None]:
ti, tf = 0, 100
t_span = [ti, tf]
t_eval = torch.linspace(ti, tf, 1001)

q_r_0 = 1.0
q_phi_0 = 0.0

q_r_dot_0 = 1.0
q_phi_dot_0 = 0.1

p_r_0 = q_r_0 * m
p_phi_0 = m * q_r_0**2 * q_phi_dot_0

m = 1.0
params = (m,)
initial_values = (q_r_0, q_phi_0, p_r_0, p_phi_0)
sol_ham = solve_ivp(
    particle_central_potential_ode_hamiltonian,
    t_span,
    initial_values,
    args=params,
    t_eval=t_eval,
)

In [None]:
# %%timeit
# sol_ham = solve_ivp(particle_central_potential_ode_hamiltonian, t_span, initial_values, args=params, t_eval=t_eval)

In [None]:
t_sol = sol_ham.t
q_r_sol = sol_ham.y[0, :]
q_phi_sol = sol_ham.y[1, :]
p_r_sol = sol_ham.y[2, :]
p_phi_sol = sol_ham.y[3, :]

fig, axs = plt.subplots(nrows=2, sharex=True)
ax = axs[0]
ax.plot(t_sol, q_r_sol, label="$q_r$")
ax.plot(t_sol, p_r_sol, label="$p_r$")
ax.legend()

ax = axs[1]
ax.plot(t_sol, q_phi_sol, label=r"$q_\phi$")
ax.set(xlabel="t", ylabel=r"$\phi$")

plt.tight_layout()

In [None]:
trajectory = pl.DataFrame(
    {
        "t": t_sol,
        "x": q_r_sol * np.cos(q_phi_sol),
        "y": q_r_sol * np.sin(q_phi_sol),
    }
)

fig, ax = plt.subplots()
ax.scatter(data=trajectory, x="x", y="y", s=0.2)
ax.set(xlabel="x", ylabel="y")
plt.tight_layout()

In [None]:
# animate_trajectory(trajectory, 'Particle in central potential - Hamiltonian', "trajectory-central-potential-hamiltonian.gif")

same as above but with `jax` / `diffrax`

In [None]:
def particle_central_potential_ode_hamiltonian_diffrax(t, z, args):
    (m,) = args
    q_r, q_phi, p_r, p_phi = z
    dVdqr = get_dVdr_quadratic(q_r)

    q_r_dot = p_r / m
    q_phi_dot = p_phi / m / q_r**2
    p_r_dot = p_phi**2 / m / q_r**3 - dVdqr
    p_phi_dot = 0

    return q_r_dot, q_phi_dot, p_r_dot, p_phi_dot


term = ODETerm(particle_central_potential_ode_hamiltonian_diffrax)
solver = Dopri5()
saveat = SaveAt(ts=t_eval.detach().numpy().tolist())
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
dt0 = t_eval[1].item() - t_eval[0].item()
sol_jax = diffeqsolve(
    term,
    solver,
    t0=ti,
    t1=tf,
    dt0=dt0,
    y0=initial_values,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    args=(m,),
)

In [None]:
# %%timeit
# sol_jax = diffeqsolve(term,
#                   solver,
#                   t0=ti,
#                   t1=tf,
#                   dt0=dt0,
#                   y0=initial_values,
#                   saveat=saveat,
#                   stepsize_controller=stepsize_controller,
#                   args=(m,))

In [None]:
sol_jax.ts

In [None]:
sol_jax.ys

In [None]:
assert sol_jax.ys is not None

t_sol = sol_jax.ts
q_r_sol = sol_jax.ys[0]
q_phi_sol = sol_jax.ys[1]
p_r_sol = sol_jax.ys[2]
p_phi_sol = sol_jax.ys[3]

fig, axs = plt.subplots(nrows=2, sharex=True)
ax = axs[0]
ax.plot(t_sol, q_r_sol, label="$q_r$")
ax.plot(t_sol, p_r_sol, label="$p_r$")
ax.legend()

ax = axs[1]
ax.plot(t_sol, q_phi_sol, label=r"$q_\phi$")
ax.set(xlabel="t", ylabel=r"$\phi$")

plt.tight_layout()

In [None]:
trajectory = pl.DataFrame(
    {
        "t": t_sol.tolist(),
        "x": (q_r_sol * np.cos(q_phi_sol)).tolist(),
        "y": (q_r_sol * np.sin(q_phi_sol)).tolist(),
    }
)

fig, ax = plt.subplots()
ax.scatter(data=trajectory, x="x", y="y", s=0.2)
ax.set(xlabel="x", ylabel="y")
plt.tight_layout()

## With gravity

$V(q_r) \rightarrow V(q_r,q_\phi) $ 

$ h = q_r \sin(q_\phi) $

$$ V(q_r,q_\phi) = q_r^2 + g m q_r \sin(q_\phi) $$

$$ H = \frac{p_r^2}{2 m} + \frac{p_\phi^2}{2 m q_r^2} + V(q_r, q_\phi) $$

unchanged: $p_r$, $p_\phi$, $\dot{q}_i$

but $\dot{p}_i$

$$ \dot{p}_r = - \partial_{q_r} H =  \frac{p^2_\phi}{m q_r^3} - \partial_{q_r} V(q_r, q_\phi) $$

with 

$$ \partial_{q_r} V(q_r, q_\phi) = 2 q_r + g m \sin(q_\phi) $$

and 

$$ \dot{p}_\phi = - \partial_{q_\phi} H = - \partial_{q_\phi} V(q_r, q_\phi) $$

with

$$ \partial_{q_\phi} V(q_r, q_\phi) = g m q_r \cos(q_\phi) $$

In [None]:
def get_dVdqr_quadratic_with_gravity(
    q_r: float, q_phi: float, m: float, g: float
) -> Array:
    return 2 * q_r + g * m * jnp.sin(q_phi)


def get_dVdphi_with_gravity(q_r: float, q_phi: float, m: float, g: float) -> Array:
    return g * m * q_r * jnp.cos(q_phi)


def particle_central_potential_plus_gravity_ode_hamiltonian_diffrax(t, z, args):
    m, g = args
    q_r, q_phi, p_r, p_phi = z
    dVdq_r = get_dVdqr_quadratic_with_gravity(q_r, q_phi, m, g)
    dVdq_phi = get_dVdphi_with_gravity(q_r, q_phi, m, g)

    q_r_dot = p_r / m
    q_phi_dot = p_phi / m / q_r**2
    p_r_dot = p_phi**2 / m / q_r**3 - dVdq_r
    p_phi_dot = -dVdq_phi

    return q_r_dot, q_phi_dot, p_r_dot, p_phi_dot

In [None]:
g = 10.0

ti, tf = 0, 10
t_span = [ti, tf]
t_eval = torch.linspace(ti, tf, 401)

q_r_0 = 2.0
q_phi_0 = 2 / 4 * 2 * math.pi

q_r_dot_0 = 1.0
q_phi_dot_0 = 2.0

p_r_0 = q_r_dot_0 * m
p_phi_0 = m * q_r_0**2 * q_phi_dot_0

initial_values = (q_r_0, q_phi_0, p_r_0, p_phi_0)

In [None]:
x_0, y_0 = q_r_0 * math.cos(q_phi_0), q_r_0 * math.sin(q_phi_0)

In [None]:
term = ODETerm(particle_central_potential_plus_gravity_ode_hamiltonian_diffrax)
solver = Dopri5()
saveat = SaveAt(ts=t_eval.detach().numpy().tolist())
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
dt0 = t_eval[1].item() - t_eval[0].item()
sol_jax = diffeqsolve(
    term,
    solver,
    t0=ti,
    t1=tf,
    dt0=dt0,
    y0=initial_values,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    args=(m, g),
)

In [None]:
assert sol_jax.ys is not None

t_sol = sol_jax.ts
q_r_sol = sol_jax.ys[0]
q_phi_sol = sol_jax.ys[1]
p_r_sol = sol_jax.ys[2]
p_phi_sol = sol_jax.ys[3]

fig, axs = plt.subplots(nrows=2, sharex=True)
ax = axs[0]
ax.plot(t_sol, q_r_sol, label="$q$")
ax.plot(t_sol, p_r_sol, label="$p$")
ax.legend()
ax.set(ylabel=r"$r$")

ax = axs[1]
ax.plot(t_sol, q_phi_sol, label=r"$q$")
ax.plot(t_sol, p_phi_sol, label=r"$p$")
ax.legend(title="variable")
ax.set(xlabel="t", ylabel=r"$\phi$")

plt.tight_layout()

In [None]:
assert t_sol is not None

trajectory = pl.DataFrame(
    {
        "t": t_sol.tolist(),
        "x": (q_r_sol * jnp.cos(q_phi_sol)).tolist(),
        "y": (q_r_sol * jnp.sin(q_phi_sol)).tolist(),
    }
)

x_min = trajectory["x"].min() - 1
x_max = trajectory["x"].max() + 1

y_min = trajectory["y"].min() - 1
y_max = trajectory["y"].max() + 1

fig, ax = plt.subplots()
ax.scatter(data=trajectory, x="x", y="y", s=0.2)
ax.scatter(x=[x_0], y=[y_0], marker="x")
ax.set(xlabel="x", ylabel="y", xlim=(x_min, x_max), ylim=(y_min, y_max))
plt.tight_layout()

In [None]:
len(trajectory)

In [None]:
animate_trajectory(
    trajectory,
    "Particle in central potential with gravity - Hamiltonian",
    "trajectory-central-potential-plus-gravity-hamiltonian.gif",
)