<a href="https://colab.research.google.com/github/profteachkids/CHE4071_Fall2023/blob/main/adjoint_optim_LevelConcTemp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64',True)
from plotly.subplots import make_subplots
from scipy.optimize import root, minimize
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss

In [167]:
s1T=90.  #C
s2T=50.
s3T=70.

s1C=2. #mol/L
s2C=1.
s3C=0.

rho=1 #kg/L
Cp=4190 #J/kg

hCT0 = jnp.array([1., 1.2, 75.])
hCTsp = jnp.array([1.1, 1.3, 72])

scale = jnp.abs(hCTsp-hCT0)
A = 1.  #m2
a = 0.05  #m2
tend = 50.

In [168]:
sq2g = (2*9.81)**0.5
def rhs(t, hCT, qs):
    h,C,T = hCT
    q1, q2, q3 = qs
    qout = sq2g*(h**0.5)*a
    dh=(q1+q2+q3 - qout)/A
    dC = (q1*s1C + q2*s2C + q3*s3C - qout*C - C*A*dh)/(h*A)          # dCV = V dC + C dV = Ah dC + CA dh = q1*c1 + q2*c2 + q3*c3 - qout*C
    # d(rho Cp V T) = rho*Cp*(q1*s1T + q2*s2T + q3*s3T) - rho*Cp*qout*T
    # V dT + T dV = (q1*s1T + q2*s2T + q3*s3T) - qout*T
    dT = ((q1*s1T + q2*s2T + q3*s3T) - qout*T - T*A*dh)/(h*A)
    return jnp.array([dh, dC, dT])

rhs_jac=jax.jacobian(rhs,1)

In [169]:
qs0=root(lambda qs: rhs(0., hCT0, qs), jnp.ones(3)).x

In [170]:
qs = root(lambda qs: rhs(0., hCTsp, qs), jnp.ones(3)).x

In [171]:
res=solve_ivp(rhs, (0,tend), hCT0, args=(qs,), dense_output=True, method='LSODA', jac=jax.jacobian(rhs,1))

In [172]:
tplot=jnp.linspace(0,tend,250)
h,C,T = res.sol(tplot)
fig=make_subplots(rows=1,cols=3)
fig.add_scatter(x=tplot,y=h,mode='lines',row=1,col=1,name='h')
fig.add_scatter(x=tplot,y=C,mode='lines',row=1,col=2,name='C')
fig.add_scatter(x=tplot,y=T,mode='lines',row=1,col=3,name='T')
fig.update_layout(width=1200, height=400, template='plotly_dark')

In [264]:
Np=20
qt=jnp.linspace(0,tend,Np)
qst0=jnp.tile(qs.reshape(-1,1),(1,Np))
interp_vec = jax.vmap(jnp.interp, (None, None, 0))

def rhs_control(t, hCT, qst_flat):
    h,C,T = hCT
    q1, q2, q3 = interp_vec(t, qt, jnp.reshape(qst_flat, (3,Np)))
    qout = sq2g*(h**0.5)*a
    dh=(q1+q2+q3 - qout)/A
    dC = (q1*s1C + q2*s2C + q3*s3C - qout*C - C*A*dh)/(h*A)
    dT = ((q1*s1T + q2*s2T + q3*s3T) - qout*T - T*A*dh)/(h*A)
    return jnp.array([dh, dC, dT])

rhs_control_jac=jax.jit(jax.jacobian(rhs_control,1))

In [265]:
res_control=solve_ivp(rhs_control, (0,tend), hCT0, args=(qst0.ravel(),), dense_output=True, method='LSODA', jac=rhs_control_jac)

In [266]:
tplot=jnp.linspace(0,tend,250)
h,C,T = res_control.sol(tplot)
fig=make_subplots(rows=1,cols=3)
fig.add_scatter(x=tplot,y=h,mode='lines',row=1,col=1,name='h')
fig.add_scatter(x=tplot,y=C,mode='lines',row=1,col=2,name='C')
fig.add_scatter(x=tplot,y=T,mode='lines',row=1,col=3,name='T')
fig.update_layout(width=1200, height=400, template='plotly_dark')

In [267]:
#x: hCT
#p: qst_flat
gt, gw = leggauss(100)
gt=tend/2*gt + tend/2


@jax.jit
def jax_abs(x):
    return jnp.sqrt(x**2+1e-10)


def f(t, x, p):
    return jnp.sum((jax_abs(x[2]-hCTsp[2])/scale[2] + jax_abs(x[0]-hCTsp[0])/scale[0] + jax_abs(x[1]-hCTsp[1])/scale[1]))

adj_f = jax.jit(f)   #Nt
adj_h = jax.jit(lambda t, x,xp, p: xp-rhs_control(t,x,p))

In [268]:
adj_h_gradp = jax.jit(jax.jacobian(adj_h,3))   #Nx, Nt, Np
adj_h_gradx = jax.jit(jax.jacobian(adj_h, 1))
adj_h_gradxp = 1.
adj_f_gradx = jax.jit(jax.jacobian(f,1))
adj_f_gradp = jax.jit(jax.jacobian(f,2))   #Nt, Np

In [269]:
rhs_control=jax.jit(rhs_control)

def obj(p):
    hCT=solve_ivp(rhs_control, (0,tend), hCT0, method='LSODA', dense_output=True, jac=rhs_control_jac, args=(p,)).sol
    interror=tend/2*jnp.sum(gw*adj_f(gt,hCT(gt),p))
    print(interror)

    def adj_ode(t, L, p):
        x=hCT(t)
        xp = rhs_control(t,x, p)
        return (adj_f_gradx(t, x, p) + adj_h_gradx(t, x, xp, p).T @ L)

    adj_ode_jac=jax.jacobian(adj_ode,1)

    L0=jnp.zeros(3)
    Lsol=solve_ivp(adj_ode, (tend,0), L0, method='LSODA', dense_output=True,args=(p,)).sol

    def dpF_integrand(t, p):
        x=hCT(t)
        xp = rhs_control(t,x, p)
        L = Lsol(t)
        return adj_f_gradp(t, x,p)+jnp.einsum('xt, xtp -> tp', L, adj_h_gradp(t,x,xp,p))

    dpF=tend/2*jnp.einsum('t, tp -> p', gw,dpF_integrand(gt, p))

    return interror, dpF

In [270]:
res=minimize(obj,qst0.ravel(), jac=True, bounds=[(0,1.)]*qst0.size,
         options=dict(maxiter=100))
res

3018.9371936225584
2184020.172437113
3325.0110298876457
3016.760319456243
3016.756335981612
3098.3327578878225
3014.7591527414947
3014.755484792393
3018.080006013553
2952.70667459307
2987.5980293197626
2938.447098481636
2831.058429986414
2773.3632879561783
2706.917006087696
2641.6388412389592
2576.4876394114262
2541.564369511458
2588.254926793483
2541.321421298933
2541.3201137948563
2517.5309177202316
2505.3072924963053
2435.448235260682
2393.6080644985987
2359.1940867407816
2339.11221849579
2335.018166546574
2282.900423303109
2257.4160412016777
2232.674888426542
2226.492262725045
2205.6456100938653
2323.0117209882715
2204.965854173981
2204.964751500525
2228.448279162386
2204.129155238096
2204.1278039719637
2197.331187064939
2182.984629365886
2190.7003948633537
2182.6573114827293
2182.6566846887304
2175.4928214506394
2172.5541359127947
2176.0713806634426
2171.082463975324
2171.0797948233203
2165.6145368500615
2157.045779961604
2140.4790623391937
2136.4673506860345
2111.4912431053212
21

  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: 1839.756611251546
        x: [ 1.204e-01  1.196e-01 ...  3.884e-02  3.850e-02]
      nit: 63
      jac: [-2.188e+01  3.373e+00 ... -9.043e+01 -2.317e+01]
     nfev: 139
     njev: 139
 hess_inv: <60x60 LbfgsInvHessProduct with dtype=float64>

In [272]:
res_optim=solve_ivp(rhs_control, (0,tend), hCT0, args=(res.x,), dense_output=True, method='LSODA', jac=rhs_control_jac)
tplot=jnp.linspace(0,tend,250)
h,C,T = res_optim.sol(tplot)
fig=make_subplots(rows=1,cols=4)
fig.add_scatter(x=tplot,y=h,mode='lines',row=1,col=1,name='h')
fig.add_scatter(x=tplot,y=C,mode='lines',row=1,col=2,name='C')
fig.add_scatter(x=tplot,y=T,mode='lines',row=1,col=3,name='T')
fig.add_scatter(x=tplot,y=interp_vec(tplot,qt,res.x.reshape(3,Np))[0,:], row=1, col=4, name='Q1')
fig.add_scatter(x=tplot,y=interp_vec(tplot,qt,res.x.reshape(3,Np))[1,:], row=1, col=4, name='Q2')
fig.add_scatter(x=tplot,y=interp_vec(tplot,qt,res.x.reshape(3,Np))[2,:], row=1, col=4, name='Q3')

fig.update_layout(width=1200, height=400, template='plotly_dark')