<a href="https://colab.research.google.com/github/profteachkids/CHE4071_Fall2023/blob/main/adjoint_optim_LevelConcTemp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [361]:
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64',True)
from plotly.subplots import make_subplots
from scipy.optimize import root, minimize
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss

In [479]:
s1T=90.  #C
s2T=50.
s3T=70.

s1C=2. #mol/L
s2C=1.
s3C=0.

rho=1 #kg/L
Cp=4190 #J/kg

hCT0 = jnp.array([1., 1.2, 75.])
hCTsp = jnp.array([1.1, 1.3, 72])

scale = jnp.abs(hCTsp-hCT0)
A = 1.  #m2
a = 0.05  #m2
tend = 50

In [480]:
sq2g = (2*9.81)**0.5
def rhs(t, hCT, qs):
    h,C,T = hCT
    q1, q2, q3 = qs
    qout = sq2g*(h**0.5)*a
    dh=(q1+q2+q3 - qout)/A
    dC = (q1*s1C + q2*s2C + q3*s3C - qout*C - C*A*dh)/(h*A)          # dCV = V dC + C dV = Ah dC + CA dh = q1*c1 + q2*c2 + q3*c3 - qout*C
    # d(rho Cp V T) = rho*Cp*(q1*s1T + q2*s2T + q3*s3T) - rho*Cp*qout*T
    # V dT + T dV = (q1*s1T + q2*s2T + q3*s3T) - qout*T
    dT = ((q1*s1T + q2*s2T + q3*s3T) - qout*T - T*A*dh)/(h*A)
    return jnp.array([dh, dC, dT])

rhs_jac=jax.jacobian(rhs,1)

In [481]:
qs0=root(lambda qs: rhs(0., hCT0, qs), jnp.ones(3)).x

In [482]:
qs = root(lambda qs: rhs(0., hCTsp, qs), jnp.ones(3)).x

In [483]:
res=solve_ivp(rhs, (0,tend), hCT0, args=(qs,), dense_output=True, method='LSODA', jac=jax.jacobian(rhs,1))

In [484]:
tplot=jnp.linspace(0,tend,250)
h,C,T = res.sol(tplot)
fig=make_subplots(rows=1,cols=3)
fig.add_scatter(x=tplot,y=h,mode='lines',row=1,col=1,name='h')
fig.add_scatter(x=tplot,y=C,mode='lines',row=1,col=2,name='C')
fig.add_scatter(x=tplot,y=T,mode='lines',row=1,col=3,name='T')
fig.update_layout(width=1200, height=400, template='plotly_dark')

In [485]:
Np=15
qt=jnp.linspace(0,tend,Np)
qst0=jnp.tile(qs.reshape(-1,1),(1,Np))
interp_vec = jax.vmap(jnp.interp, (None, None, 0))

def rhs_control(t, hCT, qst_flat):
    h,C,T = hCT
    q1, q2, q3 = interp_vec(t, qt, jnp.reshape(qst_flat, (3,Np)))
    qout = sq2g*(h**0.5)*a
    dh=(q1+q2+q3 - qout)/A
    dC = (q1*s1C + q2*s2C + q3*s3C - qout*C - C*A*dh)/(h*A)
    dT = ((q1*s1T + q2*s2T + q3*s3T) - qout*T - T*A*dh)/(h*A)
    return jnp.array([dh, dC, dT])

rhs_control_jac=jax.jit(jax.jacobian(rhs_control,1))

In [486]:
res_control=solve_ivp(rhs_control, (0,tend), hCT0, args=(qst0.ravel(),), dense_output=True, method='LSODA', jac=rhs_control_jac)

In [487]:
tplot=jnp.linspace(0,tend,250)
h,C,T = res_control.sol(tplot)
fig=make_subplots(rows=1,cols=3)
fig.add_scatter(x=tplot,y=h,mode='lines',row=1,col=1,name='h')
fig.add_scatter(x=tplot,y=C,mode='lines',row=1,col=2,name='C')
fig.add_scatter(x=tplot,y=T,mode='lines',row=1,col=3,name='T')
fig.update_layout(width=1200, height=400, template='plotly_dark')

In [514]:
#x: hCT
#p: qst_flat
gt, gw = leggauss(100)
gt=tend/2*gt + tend/2


@jax.jit
def jax_abs(x):
    return jnp.abs(x)
    # return jnp.sqrt(x**2+1e-10)

def f(t, x, p):
    # return jnp.sum(jax_abs(x-hCTsp[:,None])/scale[:,None])
    return jnp.sum(jax_abs(x[2]-hCTsp[2])/scale[2] + jax_abs(x[0]-hCTsp[0])/scale[0] + jax_abs(x[1]-hCTsp[1])/scale[1])

adj_f = jax.jit(f)   #Nt
adj_h = jax.jit(lambda t, x,xp, p: xp-rhs_control(t,x,p))

In [515]:
adj_h_gradp = jax.jit(jax.jacobian(adj_h,3))   #Nx, Nt, Np
adj_h_gradx = jax.jit(jax.jacobian(adj_h, 1))
adj_h_gradxp = 1.
adj_f_gradx = jax.jit(jax.jacobian(f,1))
adj_f_gradp = jax.jit(jax.jacobian(f,2))   #Nt, Np

In [516]:
rhs_control=jax.jit(rhs_control)

def obj(p):
    hCT=solve_ivp(rhs_control, (0,tend), hCT0, method='LSODA', dense_output=True, jac=rhs_control_jac, args=(p,)).sol
    interror=tend/2*jnp.sum(gw*adj_f(gt,hCT(gt),p))
    print(interror)

    def adj_ode(t, L, p):
        x=hCT(t)
        xp = rhs_control(t,x, p)
        return (adj_f_gradx(t, x, p) + adj_h_gradx(t, x, xp, p).T @ L)

    adj_ode_jac=jax.jacobian(adj_ode,1)

    L0=jnp.zeros(3)
    Lsol=solve_ivp(adj_ode, (tend,0), L0, method='LSODA', dense_output=True,args=(p,)).sol

    def dpF_integrand(t, p):
        x=hCT(t)
        xp = rhs_control(t,x, p)
        L = Lsol(t)
        return adj_f_gradp(t, x,p)+jnp.einsum('xt, xtp -> tp', L, adj_h_gradp(t,x,xp,p))

    dpF=tend/2*jnp.einsum('t, tp -> p', gw,dpF_integrand(gt, p))

    return interror, dpF

In [517]:
res=minimize(obj,qst0.ravel(), jac=True, bounds=[(0,1.)]*qst0.size,
         options=dict(maxiter=100))
res

3018.8870283118404
2174831.8051018
3311.7948483783316
3016.662533836306
3014.470963285189
3091.3871271944904
3012.402079148776
3012.398253363167
3014.76782729938
2948.755342958281
2961.9424490695474
2925.436881406047
2821.402146272961
2783.85839681079
2707.367025708992
2650.333101003856
2622.8181208798605
2594.5878865531777
2529.5417087344986
2496.9470784854943
2486.7764069504965
2445.473671973552
2427.866847722516
2491.902781384383
2425.251067491812
2425.247746141753
2393.7627755699928
2369.5348695090765
2326.7532921360967
2292.397954539043
2328.671328494207
2292.252023633652
2292.251774785249
2297.84551199548
2292.0215373408573
2292.021151479317
2290.6102468207723
2288.8885943761593
2267.204409469437
2248.41290777693
2237.764266738108
2234.342294318673
2217.971416509486
2209.835140651946
2195.866796624297
2189.7070303462533
2177.0496142503794
2195.86993103177
2176.4904759944316
2176.4896869919435
2170.82255009469
2146.551926069436
2146.520301305598
2134.1384184074905
2267.32131283911

  message: ABNORMAL_TERMINATION_IN_LNSRCH
  success: False
   status: 2
      fun: 1920.507080527915
        x: [ 1.191e-01  1.152e-01 ...  3.899e-02  3.858e-02]
      nit: 75
      jac: [-2.732e+01  3.072e+01 ...  3.352e+01 -7.615e-01]
     nfev: 293
     njev: 293
 hess_inv: <45x45 LbfgsInvHessProduct with dtype=float64>

In [518]:
res_optim=solve_ivp(rhs_control, (0,tend), hCT0, args=(res.x,), dense_output=True, method='LSODA', jac=rhs_control_jac)
tplot=jnp.linspace(0,tend,250)
h,C,T = res_optim.sol(tplot)
fig=make_subplots(rows=1,cols=4)
fig.add_scatter(x=tplot,y=h,mode='lines',row=1,col=1,name='h')
fig.add_scatter(x=tplot,y=C,mode='lines',row=1,col=2,name='C')
fig.add_scatter(x=tplot,y=T,mode='lines',row=1,col=3,name='T')
fig.add_scatter(x=tplot,y=interp_vec(tplot,qt,res.x.reshape(3,Np))[0,:], row=1, col=4, name='Q1')
fig.add_scatter(x=tplot,y=interp_vec(tplot,qt,res.x.reshape(3,Np))[1,:], row=1, col=4, name='Q2')
fig.add_scatter(x=tplot,y=interp_vec(tplot,qt,res.x.reshape(3,Np))[2,:], row=1, col=4, name='Q3')

fig.update_layout(width=1200, height=400, template='plotly_dark')