<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/cubic_hermite_ode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
import numpy as np
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots
from scipy.optimize import root

In [None]:
def rhs(t,y):
    y1,y2 = y
    return jnp.array([y2, -5*y1])

In [None]:
tend=5.
solution=solve_ivp(rhs, (0,tend), [0., 1.], method='Radau', jac=jax.jacobian(rhs,1), dense_output=True, atol=1e-8,rtol=1e-8)
sol=solution.sol

In [None]:
t=np.linspace(0,tend,100)
y1,y2=sol(t)
fig=make_subplots()
fig.add_scatter(x=t,y=y1,mode='lines')
fig.add_scatter(x=t,y=y2,mode='lines')
fig.update_layout(width=800,height=600,template='plotly_dark')

In [None]:
def pmid(p1,t1,p0,t0,rhs):
    dt=(t1-t0)
    m0=rhs(t0,p0)*dt
    m1=rhs(t1,p1)*dt
    a=(2*p0 + m0 - 2*p1 + m1)
    b=(-3*p0 + 3*p1 - 2*m0 - m1)
    t=0.5
    pmid =a*t**3+ b*t**2+ m0 * t + p0
    pmid_prime = (3*a*t**2 + b*2*t + m0)/dt
    return pmid_prime - rhs(t*dt,pmid)

def q75(p1,t1,p0,t0,rhs):
    dt=(t1-t0)
    m0=rhs(t0,p0)*dt
    m1=rhs(t1,p1)*dt
    a=(2*p0 + m0 - 2*p1 + m1)
    b=(-3*p0 + 3*p1 - 2*m0 - m1)
    t=0.75
    p =a*t**3+ b*t**2+ m0 * t + p0
    p_prime = (3*a*t**2 + b*2*t + m0)/dt
    return p_prime - rhs(t*dt,p)

In [None]:
p0=np.array([0.,1.])
p1=root(pmid, p0, args=(0.02, p0, 0, rhs)).x



In [None]:
q75(p1, 0.02, p0, 0, rhs)

Array([1.5622396e-08, 1.5619792e-06], dtype=float64)

In [None]:
sol(0.02)-p1

array([ 1.06510377e-10, -3.43610140e-10])