<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/DistillationBlockTridiagonal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget -N -q https://raw.githubusercontent.com/profteachkids/chetools/main/tools/che5.ipynb -O che5.ipynb
!pip install importnb

Collecting importnb
  Downloading importnb-2023.1.7-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.9/42.9 kB[0m [31m914.4 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: importnb
Successfully installed importnb-2023.1.7


In [235]:
import jax
jax.config.update('jax_enable_x64',True)
import jax.numpy as jnp
import numpy as np
from scipy.optimize import root, root_scalar, minimize, bracket, minimize_scalar
from functools import partial
from plotly.subplots import make_subplots
np.set_printoptions(precision=4, linewidth=200)
jnp.set_printoptions(precision=4, linewidth=200)

In [3]:
from importnb import Notebook
with Notebook():
    from che5 import Props

In [276]:
@jax.jit
def bubbleT_eq(T, x, P):
    return jnp.sum(x*p.NRTL_gamma(x,T)*p.Pvap(T)/P) - 1.

bubbleT_gradT=jax.jit(jax.grad(bubbleT_eq, 0))
bubbleT_gradT2=jax.jit(jax.grad(bubbleT_gradT, 0))

def bubbleT_calc(x,P):
    Tguess = jnp.sum(x*p.Tb(P))

    res=root_scalar(lambda T: bubbleT_eq(T,x,P), x0=Tguess,
                    fprime=lambda T: bubbleT_gradT(T,x,P),
                    fprime2=lambda T: bubbleT_gradT2(T,x,P), method='halley')
    if not(res.converged):
        print(res)
    bubbleT=res.root
    return bubbleT, x*p.NRTL_gamma(x,bubbleT)*p.Pvap(bubbleT)/P



@jax.jit
def dewT_eq(vec, y, P):
    T=vec[0]
    x=vec[1:]
    return jnp.r_[jnp.sum(y*P/p.NRTL_gamma(x,T)/p.Pvap(T))-1.,
                            jnp.sum(x)-1.]

dewT_jac=jax.jit(jax.jacobian(dewT_eq,0))

def dewT_calc(y,P):

    bubbleT = bubbleT_calc(y,P)[0]

    res=minimize(lambda x: 0., x0=np.r_[bubbleT*2., y], method='SLSQP',
                 bounds=np.r_[ np.c_[bubbleT, None], np.c_[jnp.zeros_like(y), jnp.ones_like(y)]],
                 constraints=dict(type='eq', fun=lambda vec: dewT_eq(vec, y, P), jac=lambda vec: dewT_jac(vec, y,P)),
                 tol=1e-12,
                 options=dict(maxiter=1000))
    if not(res.success):
        print(res)
    return res.x[0], res.x[1:]

@jax.jit
def flash_eq(vec, q, P):
    T=vec[0]
    x=vec[1:]
    K = p.NRTL_gamma(x,T)*p.Pvap(T)/P
    return jnp.r_[z/x - (q+K*(1-q)),
                    jnp.sum(x)-1.]

flash_eq_jac=jax.jit(jax.jacobian(flash_eq))

def flash(F, q, P):

    Ftot=jnp.sum(F)
    z=F/Ftot

    bubbleT=bubbleT_calc(z,P)[0]
    dewT=dewT_calc(z,P)[0]

    res=minimize(lambda x: 0., x0=np.r_[(bubbleT+dewT)/2,z], method='SLSQP',
                 bounds=np.r_[np.c_[bubbleT,dewT], np.c_[np.zeros_like(F), np.ones_like(F)]],
                 constraints=dict(type='eq', fun=lambda vec: flash_eq(vec,q,P), jac=lambda vec: flash_eq_jac(vec,q,P)),
                 options=dict(maxiter=10000))
    if not(res.success):
        print(res)
    return res.x[0], q*Ftot*res.x[1:]

In [277]:
@partial(jax.jit, static_argnums=(6,))
def eq(vec2, vec, vec1, F, FH=0, Q=0.,kind=3):
    T1 = vec1[0] #tray below
    T2 = vec2[0] #tray above
    T = vec[0]
    v1, l1 = jnp.split(vec1[1:],2)
    v2, l2 = jnp.split(vec2[1:],2)
    v, l = jnp.split(vec[1:],2)

    vtot, ltot = jnp.sum(v), jnp.sum(l)

    x, y = l/ltot, v/vtot

    EQUIL = x*p.NRTL_gamma(x,T)*p.Pvap(T)/P - y

    if kind==1: #reboiler
        ENTHALPY = ltot-B
        MB = (l2 - l - v)
    elif kind==2: #condenser
        ENTHALPY = ltot-R*D
        MB = (v1 - l - v - x*D)

    elif kind==3:  #tray
        ENTHALPY = (FH + Q + p.Hl(l2, T2) + p.Hv(v1, T1) - p.Hl(l, T) - p.Hv(v, T))/1e5
        MB = (F + v1 + l2 - l - v)

    return jnp.r_[jnp.atleast_1d(ENTHALPY), MB, EQUIL]

eq_reboiler=partial(eq, kind=1)
eq_condenser=partial(eq, kind=2)
eq_tray=partial(eq, kind=3)
eq_condenser_jac = jax.jacobian(eq_condenser,(0,1,2))
eq_reboiler_jac = jax.jacobian(eq_reboiler,(0,1,2))
eq_tray_jac = jax.jacobian(eq_tray, (0,1,2))

In [287]:
p=Props(['Ethanol','Water'])
P=101325.
NT=10
NF=2
F = jnp.array([1., 1.])
NC=F.size
Fzeros = jnp.zeros(NC)
Ftot = jnp.sum(F)
z=F/Ftot

q = 0.999
FT, FL=flash(F,q,P)
FV = F-FL
FH = p.Hl(FL,FT)+p.Hv(FV,FT)

D = Ftot/2
B = Ftot-D
R = 2

Vtot_rec = (R+1)*D
Ltot_rec = R*D

Vtot_strip = Vtot_rec - jnp.sum(FV)
Ltot_strip = Ltot_rec + jnp.sum(FL)

vec_zeros = jnp.zeros(NC*2+1)
Vtot = np.r_[np.repeat(Vtot_strip,NF),np.repeat(Vtot_rec,NT-NF+1),0.]
Ltot = np.r_[B, np.repeat(Ltot_strip,NF), np.repeat(Ltot_rec,NT-NF+1)]

dewT, dewx=dewT_calc(z,P)
bubbleT, bubbley=bubbleT_calc(z,P)

In [288]:
TFR = np.zeros(NT+2)
TFR=np.linspace(dewT,bubbleT,NT+2)
Lguess=np.zeros((NT+2,NC))

Lguess[:NF+1] = dewx*Ltot[:NF+1,None]
Lguess[NF+1:] = bubbley*Ltot[NF+1:,None]
Vguess=np.zeros_like(Lguess)

Vguess[:NF+1]= Lguess[:NF+1]-B*dewx
Vguess[NF+1:]=Lguess[NF+1:]+D*bubbley
Vguess[0]=Vguess[1]
Vguess[-1]=1e-6*Vguess[-2]

vecguess = jnp.c_[TFR, Vguess, Lguess]


Fs = np.zeros((NT+2,2*NC+1))
Cs = np.zeros((NT+2,2*NC+1,2*NC+1))
dvec = jnp.zeros((NT+2,2*NC+1))
mask = np.ones_like(vecguess)
mask[-1,1:NC+1]=0

In [289]:
@jax.jit
def evalF(vecguess):
    eqFs = jnp.zeros((NT+2,2*NC+1))
    eqFs=eqFs.at[NT+1].set(eq_condenser(vec_zeros,vecguess[NT+1],vecguess[NT], Fzeros))
    for i in range(NT,NF,-1):
        eqFs=eqFs.at[i].set(eq_tray(vecguess[i+1], vecguess[i], vecguess[i-1], F=Fzeros))
    eqFs=eqFs.at[NF].set(eq_tray(vecguess[NF+1], vecguess[NF], vecguess[NF-1], F=F, FH=FH))
    for i in range(NF-1,0,-1):
        eqFs=eqFs.at[i].set(eq_tray(vecguess[i+1], vecguess[i], vecguess[i-1], F=Fzeros))
    eqFs=eqFs.at[0].set(eq_reboiler(vecguess[1],vecguess[0], vec_zeros, F=Fzeros))

    return eqFs

@jax.jit
def norm_evalF_t(t, vecguess, dvec):
    return jnp.linalg.norm(evalF(vecguess + t*dvec))

In [290]:

for iter in range(25):
    eqFs=evalF(vecguess)
    print(np.linalg.norm(eqFs))
    _,eqB,eqC = eq_condenser_jac(vec_zeros,vecguess[NT+1],vecguess[NT], Fzeros)

    Binv = np.linalg.inv(eqB)
    Cs[NT+1]=Binv @ eqC
    Fs[NT+1]=Binv @ eqFs[NT+1]

    for i in range(NT,NF,-1):
        eqA,eqB,eqC=eq_tray_jac(vecguess[i+1], vecguess[i], vecguess[i-1], F=Fzeros)
        bacinv = np.linalg.inv(eqB-eqA @ Cs[i+1])
        Cs[i]=bacinv @ eqC
        Fs[i]=bacinv @ (eqFs[i] - eqA @ Fs[i+1])

    eqA,eqB,eqC=eq_tray_jac(vecguess[NF+1], vecguess[NF], vecguess[NF-1], F=F, FH=FH)
    bacinv = np.linalg.inv(eqB-eqA @ Cs[NF+1])
    Cs[NF]=bacinv @ eqC
    Fs[NF]=bacinv @ (eqFs[NF] - eqA @ Fs[NF+1])

    for i in range(NF-1,0,-1):
        eqA,eqB,eqC=eq_tray_jac(vecguess[i+1], vecguess[i], vecguess[i-1], F=Fzeros)
        bacinv = np.linalg.inv(eqB-eqA @ Cs[i+1])
        Cs[i]=bacinv @ eqC
        Fs[i]=bacinv @ (eqFs[i] - eqA @ Fs[i+1])

    eqA,eqB,_ = eq_reboiler_jac(vecguess[1],vecguess[0], vec_zeros, F=Fzeros)
    bacinv = np.linalg.inv(eqB-eqA @ Cs[1])
    Fs[0]=bacinv @ (eqFs[0] - eqA @ Fs[1])

    dvec=dvec.at[0].set(-Fs[0])
    for i in range(1,NT+2):
        Fs[i] = (Fs[i]-Cs[i]@Fs[i-1])
        dvec=dvec.at[i].set(-Fs[i])

    xa,xb,xc, *_=bracket(lambda t: norm_evalF_t(t, vecguess, dvec), 0., 1.)
    t=minimize_scalar(lambda t: norm_evalF_t(t, vecguess, dvec), bracket=(xa, xb, xc)).x

    vecguess = vecguess + t*dvec




2.5552691981827556
0.7934712120706939
0.2984306048009731
0.12821064978471874
0.050723252473517326
0.021954177120762662
0.008638338356272611
0.003726858192053467
0.0014611037060126466
0.0006294297813400047
0.0002464507345635697
0.0001061148387344082
4.1532391982846275e-05
1.787988317332505e-05
6.997007000878069e-06
3.012147871428368e-06
1.1784669680038422e-06
5.073862037159879e-07
1.985256449192085e-07
8.563827486392896e-08
3.3578744319082284e-08
1.4428563512814369e-08
5.750209912834871e-09
2.1096614495207e-09
9.855844856467231e-10


In [291]:
x=vecguess[:,-NC:]/np.sum(vecguess[:,-NC:],axis=1)[:,None]
y=vecguess[:,1:NC+1]/np.sum(vecguess[:,1:NC+1],axis=1)[:,None]
print(x)
print()
print(y)

[[0.1892 0.8108]
 [0.4429 0.5571]
 [0.5198 0.4802]
 [0.5864 0.4136]
 [0.6319 0.3681]
 [0.6666 0.3334]
 [0.6951 0.3049]
 [0.72   0.28  ]
 [0.7429 0.2571]
 [0.765  0.235 ]
 [0.7873 0.2127]
 [0.8108 0.1892]]

[[0.5296 0.4704]
 [0.6316 0.3684]
 [0.6621 0.3379]
 [0.6921 0.3079]
 [0.715  0.285 ]
 [0.7339 0.2661]
 [0.7504 0.2496]
 [0.7656 0.2344]
 [0.7803 0.2197]
 [0.7952 0.2048]
 [0.8108 0.1892]
 [0.828  0.172 ]]


In [292]:
McCabeXY=np.c_[np.repeat(x[:-1,0],2),np.repeat(y[:-1,0],2)]
McCabeXY[1:,1]=McCabeXY[:-1,1]
McCabeXY[0,1]=McCabeXY[0,0]
McCabeXY=np.r_[McCabeXY, np.atleast_2d(np.repeat(McCabeXY[-1,1],2))]

In [293]:
xplot=np.linspace(0,1,50)
yplot=[]
for x in xplot:
    T,y=bubbleT_calc( np.array([x,1-x]),P)
    yplot.append(y[0])

In [294]:
fig=make_subplots()
fig.add_scatter(x=McCabeXY[:,0],y=McCabeXY[:,1],mode='lines')
fig.add_scatter(x=xplot, y=yplot, mode='lines',line_color='gray')
fig.add_scatter(x=McCabeXY[::2,0],y=McCabeXY[::2,1],mode='lines',line_color='gray')
fig.update_xaxes(range=(0,1))
fig.update_yaxes(range=(0,1))
fig.update_layout(width=600,height=600,template='plotly_dark', showlegend=False)