<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/DiffraxOptimalControl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install diffrax

Collecting diffrax
  Downloading diffrax-0.4.0-py3-none-any.whl (159 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.8/159.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
Collecting equinox>=0.10.4 (from diffrax)
  Downloading equinox-0.10.10-py3-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.3/132.3 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxtyping>=0.2.20 (from equinox>=0.10.4->diffrax)
  Downloading jaxtyping-0.2.20-py3-none-any.whl (24 kB)
Collecting typeguard>=2.13.3 (from jaxtyping>=0.2.20->equinox>=0.10.4->diffrax)
  Downloading typeguard-4.0.0-py3-none-any.whl (33 kB)
Installing collected packages: typeguard, jaxtyping, equinox, diffrax
Successfully installed diffrax-0.4.0 equinox-0.10.10 jaxtyping-0.2.20 typeguard-4.0.0


In [2]:
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, Kvaerno5, PIDController, LinearInterpolation
import jax.numpy as jnp
import jax
from plotly.subplots import make_subplots
from scipy.optimize import minimize
jax.config.update("jax_enable_x64", True)
import numpy as np

- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.


In [12]:
a=0.01
A = 1.
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
loss0 = 0.
eint0 = 0.
hsp = 1.1*h20
tend=1000.
tcontrol = jnp.linspace(0.,tend,30)
ycontrol = jnp.full_like(tcontrol,0.01)

def dhdt(t, hvec, control_interp):
    h1,h2,loss = hvec
    qin=qin_initial_ss
    e = hsp - h2
    qcontrol= control_interp.evaluate(t)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, jnp.abs(e)])


dhdt_term = ODETerm(dhdt)
solver = Kvaerno5()
stepsize=PIDController(rtol=1e-5, atol=1e-8, jump_ts=tcontrol)
def loss(ycontrol):
    control_interp = LinearInterpolation(tcontrol,ycontrol)
    return diffeqsolve(dhdt_term, solver, t0=0, t1=tend, dt0=0.1, y0=jnp.array([h10, h20, loss0]), args=control_interp,
                       stepsize_controller=stepsize).ys[0][-1]

loss_jit=jax.jit(loss)
grad_jit=jax.jit(jax.grad(loss))

In [13]:
tplot=jnp.linspace(0,tend,500)
control_interp = LinearInterpolation(tcontrol,ycontrol)
sol=diffeqsolve(dhdt_term, solver, t0=0, t1=tend, dt0=0.1, y0=jnp.array([h10, h20, loss0]), args=control_interp, saveat=SaveAt(ts=tplot),
                stepsize_controller=stepsize)
fig=make_subplots()
fig.add_scatter(x=sol.ts,y=sol.ys[:,0],mode='lines')
fig.add_scatter(x=sol.ts,y=sol.ys[:,1],mode='lines')
fig.update_layout(width=600,height=600,template='plotly_dark')

In [14]:
loss_jit(ycontrol)

Array(427.09262725, dtype=float64, weak_type=True)

In [15]:
grad_jit(ycontrol)

Array([ 162.07736355,  886.62802035, 1764.32246062, 2639.00767499,
       3408.5685644 , 3653.33378201, 3836.89463362, 3793.08168867,
       3790.81011852, 3783.67653356, 3771.10545712, 3752.23711393,
       3725.87891223, 3690.45061382, 3643.92343096, 3583.75578935,
       3506.83198299, 3409.41516536, 3287.13406497, 3135.0348765 ,
       2947.74795883, 2719.84614278, 2446.5117523 , 2124.68881335,
       1754.98387927, 1344.70560699,  912.61702855,  496.24007144,
        162.32340236,   11.52215803], dtype=float64)

In [18]:
bounds=np.c_[np.full_like(ycontrol,0), np.full_like(ycontrol,0.03)]
sol=minimize(loss_jit, ycontrol, jac=grad_jit, bounds=bounds, method='SLSQP')
tplot=jnp.linspace(0,tend,500)
control_interp = LinearInterpolation(tcontrol,sol.x)
diffsol=diffeqsolve(dhdt_term, solver, t0=0, t1=tend, dt0=0.1, y0=jnp.array([h10, h20, loss0]), args=control_interp, saveat=SaveAt(ts=tplot),
                stepsize_controller=stepsize)
fig=make_subplots(rows=1,cols=2)
fig.add_scatter(x=diffsol.ts,y=diffsol.ys[:,0],mode='lines', row=1,col=1, name='Tank1')
fig.add_scatter(x=diffsol.ts,y=diffsol.ys[:,1],mode='lines', row=1, col=1, name='Tank2')
fig.add_scatter(x=diffsol.ts,y=control_interp.evaluate(diffsol.ts),mode='lines',row=1,col=2, name='ControlInterp')
fig.add_scatter(x=tcontrol,y=sol.x,mode='markers',row=1,col=2, name='ControlPoint')
fig.update_layout(width=1200,height=600,template='plotly_dark')