In [1]:
import jax
from jax import numpy as jnp
import numpy as np
from scipy import constants
from jax.experimental.ode import odeint
from matplotlib import pyplot as plt

In [2]:

class Lagrangian(object):
    """
    Lagrangian class for defining and solving the equations of motion in a Lagrangian mechanics system.

    Attributes:
        potentials (list[tuple[callable, dict]]): A list of potential energy functions and their parameters.

    Methods:
        __init__(self, potentials: list[tuple[callable, dict]] = []) -> None:
            Constructor for the Lagrangian class.

        __call__(self, q: jnp.ndarray, q_t: jnp.ndarray, mass) -> float:
            Calculates the Lagrangian of the system given the generalized coordinates, their velocities, and the mass.

        eom(self, q, q_t, mass):
            Solves the equations of motion for the system.

        integrate_eom(self, q0, q_t0, mass, time_span):
            Integrates the equations of motion over a specified time span.

    """

    def __init__(self, potentials: list[tuple[callable, dict]] = []) -> None:
        """
        Constructor for the Lagrangian class.

        Parameters:
            potentials (list[tuple[callable, dict]]): A list of potential energy functions and their parameters.
        """
        self.potentials = potentials

    def __call__(self, q: jnp.ndarray, q_t: jnp.ndarray, mass) -> float:
        """
        Calculates the Lagrangian of the system given the generalized coordinates, their velocities and the mass.

        Parameters:
            q (jnp.ndarray): Generalized coordinates.
            q_t (jnp.ndarray): Generalized velocities.
            mass: Mass of the system.

        Returns:
            float: The Lagrangian of the system.
        """
        @jax.jit
        def kinetic_energy(q_t: jnp.ndarray, mass: jnp.ndarray):
            return 0.5 * mass * jnp.linalg.norm(q_t)
        
        T = jnp.sum(jax.jit(kinetic_energy)(q_t, mass))
        V = jnp.sum(jnp.array([pot_fn(q, q_t, mass, **pot_params) for pot_fn, pot_params in self.potentials]))
        return T - V    

    def eom(self, q: jnp.ndarray, q_t: jnp.ndarray, mass) -> tuple[jnp.ndarray, jnp.ndarray]:
        """
        Solves the equations of motion for the system.

        Parameters:
            q (jnp.ndarray): Generalized coordinates.
            q_t (jnp.ndarray): Generalized velocities.
            mass: Mass of the system.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing updated velocities and accelerations.
        """
        dLdq = jax.jacfwd(self.__call__, 0)(q, q_t, mass)
        dLdq_t_dq = jax.jacfwd(jax.jacfwd(self.__call__, 1), 0)(q, q_t, mass)
        H = jax.hessian(self.__call__, 1)(q, q_t, mass)

        # print("dLdq", dLdq.shape, dLdq) # correct 
        # print("dLdq_t_dq", dLdq_t_dq.shape, dLdq_t_dq) # correct
        # print("H", H.shape, H) # correct

        dot1 = jnp.tensordot(dLdq_t_dq, q_t, axes=((2, 3), (0, 1)))
        # print("dot1", dot1.shape, dot1)

        q_tt = jnp.tensordot(jnp.linalg.pinv(H), dLdq - dot1, axes=((2, 3), (1, 0)))
        # print("q_tt", q_tt)

        # dot1 = jnp.tensordot(dLdq_t_dq, q_t, axes=((2, 3), (0, 1)))
        # q_tt = jnp.tensordot(jnp.linalg.pinv(H), dLdq - dot1, axes=((2, 3), (1, 0)))

        return q_t, q_tt

    def eom_int(self, q0: jnp.ndarray, q_t0: jnp.ndarray, mass, t_span) -> tuple[jnp.ndarray, jnp.ndarray]:
        """
        Integrates the equations of motion over a specified time span.

        Parameters:
            q0 (jnp.ndarray): Initial generalized coordinates.
            q_t0 (jnp.ndarray): Initial velocities corresponding to the generalized coordinates.
            mass: Mass of the system.
            t_span: Time span for integration.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing integrated generalized coordinates and velocities.
        """
        # Flatten the initial conditions for odeint
        initial_conditions_flat = jnp.concatenate([q0, q_t0])

        # Define a function for odeint
        def dynamics(y, t):
            q, q_t = jnp.split(y, 2)
            q_t, q_tt = self.eom(q, q_t, mass)
            result = jnp.concatenate([q_t, q_tt])  # Flatten the result
            return result

        # Use odeint to integrate the equations of motion
        result = odeint(dynamics, initial_conditions_flat, t_span)

        # Reshape the result to get q and q_t separately
        q, q_t = jnp.split(result, 2, axis=1)

        return q, q_t


    def draw_trajectory(self, q0: jnp.ndarray, q_t0: jnp.ndarray, mass, t_span) -> None:
        """
        Draws the trajectories of the system without constraints.

        Parameters:
            q0 [jnp.ndarray]: List of initial generalized coordinates for each trajectory.
            q_t0 [jnp.ndarray]: List of initial velocities for each trajectory.
            mass (List): List of masses for each trajectory.
            t_span: Time span for integration.

        Returns:
            None
        """

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.set_title("Trajectories of the motion")
        ax.set_xlabel("X-axis")
        ax.set_ylabel("Y-axis")
        ax.set_zlabel("Z-axis")

        positions, _ = self.eom_int(q0, q_t0, mass, t_span)

        for body in range(len(positions[0])):  # Iterate over n_bodies
            ax.scatter(positions[:, body, 0], positions[:, body, 1], positions[:, body, 2])

        plt.show()

In [3]:
@jax.jit
def gravity(x: jnp.ndarray, x_t: jnp.ndarray, mass, g) -> jnp.ndarray:
    return g * mass * x[:, 2]

In [4]:
# Example usage
nbodies = 2
ndim = 3

# setting the ic
m1 = 1.0
m2 = 2.0
x0 = 3 * jnp.ones(3)
x1 = jnp.array([2.0, -1.0, 5.0])
x_t0 = jnp.ones(3)
x_t1 = jnp.array([0.0, 0.0, 1.0])

q = jnp.array([x0, x1])
q_t = jnp.array([x_t0, x_t1])
mass = jnp.array([m1, m2])

print("q", q)
print("q_t", q_t)
print("mass", mass)

q [[ 3.  3.  3.]
 [ 2. -1.  5.]]
q_t [[1. 1. 1.]
 [0. 0. 1.]]
mass [1. 2.]


In [5]:
# Create the lagrangian
L = Lagrangian(potentials=[(gravity, {'g': 9.81})])
print(L(q, q_t, mass))

-124.53001


In [6]:
print(L.eom(q, q_t, mass))

(Array([[1., 1., 1.],
       [0., 0., 1.]], dtype=float32), Array([[104.64   , 104.63999, 104.63998],
       [  0.     ,   0.     , -34.88   ]], dtype=float32))


In [None]:
L.draw_trajectory(q, q_t, mass, jnp.linspace(0., 2., 100))