# Des ODE avec Jax ?

In [None]:
%matplotlib ipympl
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import jit, vmap
from dataclasses import dataclass
from diffrax import ODETerm, Dopri5, SaveAt, PIDController, diffeqsolve

## Problème

On veut intégrer l'ODE suivante:


$$
\left \lbrace 
\begin{split}
M_1  \ddot x_1 + D_1  \dot x_1 + K_1  x_1 + K_{12} (x_1 - x_2) = F_{d1} \sin (w_d t)  \\
M_2  \ddot x_2 + D_2  \dot x_2 + K_2  x_1 + K_{12} (x_2 - x_1) = F_{d2} \sin (w_d t) 
\end{split}
\right .
$$

:::{note}
Il est possible de simplifier grandement cette équation en l'adimensionnant mais on fait le choix de ne pas le faire ici.
:::

On la traduit comme suit en Python:

1. On crée une classe pour stocker les paramètres du problème:

In [None]:
@dataclass
class CoupleLinearResonatorParams:
    """
    Parameters for the coupled linear resonator ODE system.
    """

    M1: float = 1.0  # Mass 1
    M2: float = 1.0  # Mass 2
    K1: float = 1.0  # Spring constant 1
    K2: float = 1.0  # Spring constant 2
    K12: float = 0.5  # Coupling spring constant 2
    D1: float = 0.2  # Damping coefficient 1
    D2: float = 0.2  # Damping coefficient 2
    wd: float = 0.5  # Driving frequency
    Fd1: float = 0.5  # Driving force amplitude 1
    Fd2: float = 0.5  # Driving force amplitude 2


ode_params = CoupleLinearResonatorParams(Fd2=0.0)
ode_params

In [None]:
def coupled_linear_resonator_ode(t, X, params: CoupleLinearResonatorParams):
    x1, v1, x2, v2 = X
    dx1dt = v1
    dv1dt = (
        -params.K1 * x1
        - params.K12 * (x1 - x2)
        - params.D1 * v1
        + params.Fd1 * jnp.sin(params.wd * t)
    ) / params.M1
    dx2dt = v2
    dv2dt = (
        -params.K2 * x2
        - params.K12 * (x2 - x1)
        - params.D2 * v2
        + params.Fd2 * jnp.sin(params.wd * t)
    ) / params.M2
    return jnp.array([dx1dt, dv1dt, dx2dt, dv2dt])


ode = coupled_linear_resonator_ode
X = jnp.array([0.0, 0.0, 0.0, 0.0])  # Initial state: [x1, v1, x2, v2]
t = 0.2  # Initial time
ode(t, X, ode_params)

Ok, ça marche !

In [None]:
term = ODETerm(ode)
solver = Dopri5()
saveat = SaveAt(ts=[0.0, 1.0, 2.0, 3.0])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

# sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
#                    stepsize_controller=stepsize_controller)

# print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
# print(sol.ys)  # DeviceArray([1.   , 0.368, 0.135, 0.0498])