<a href="https://colab.research.google.com/github/profteachkids/CHE2064_Spring2022/blob/main/FlattenWrapJax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [46]:
!wget -N -q https://raw.githubusercontent.com/profteachkids/chetools/main/tools/che.py

In [47]:
import che
from collections import namedtuple
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import expit,logit
from scipy.optimize import root
from functools import partial

In [48]:
class dot_dict(dict):
    __getattr__= dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [49]:
class Unk():
    pass

class Range(Unk):
    def __init__(self,value,lo,hi):
        self.lo=lo
        self.hi=hi
        self.diff = hi-lo
        self.x=value
        self.shape=(1,)
        self.size = 1

    def flatten(self):
        return jnp.ravel(logit((self.x-self.lo)/self.diff))

    def unflatten(self,xx):
        return expit(xx.reshape(self.shape))*self.diff +self.lo


class RangeArray(Range):
    def __init__(self,value,lo,hi):
        self.lo=np.atleast_1d(lo)
        self.hi=np.atleast_1d(hi)
        self.diff = self.hi-self.lo
        value = np.atleast_1d(value)
        self.x = value
        self.shape = value.shape
        self.size = value.size

class Comp(Unk):
    def __init__(self,value):
        self.x=np.asarray(value).reshape(-1)
        self.size=self.x.size - 1

    def __repr__(self):
        return f'{self.x}'

    def flatten(self):
        return jnp.log(self.x[:-1]) + np.log(1.+ (1. - self.x[-1])/self.x[-1])


    def unflatten(self,xx):
        xm1 = jnp.exp(xx)/(1+np.sum(jnp.exp(xx)))
        return jnp.concatenate((xm1, jnp.atleast_1d(1.-jnp.sum(xm1))))



In [50]:
Unk_Tuple = namedtuple('Unk_Tuple', ['keys', 'start', 'end', 'unk'])

def dtox(d):
    d2=dot_dict()
    size=0
    unks={}

    for k,v in d.items():
        idv = id(v)
        if idv in unks:
            unks[idv].keys.append(k)
            continue
        if isinstance(v,Unk):
            unks[idv]=Unk_Tuple([k], size, size+v.size, v)
            size+=v.size
        else:
            d2[k]=v
    x = np.zeros(size)
    for k,v in unks.items():
        x[v.start:v.end]= v.unk.flatten()
            
    def xtod(x,d2):
        for k,v in unks.items():
            unflattened = v.unk.unflatten(x[v.start:v.end])
            for key in v.keys:
                d2[key]=unflattened
        return d2
        

    def wrap(f):
        def wrapped(x,d2):
            d2=xtod(x,d2)
            return f(d2)

        return partial(wrapped,d2=d2)

    d2=xtod(x,d2)
    return wrap, jnp.asarray(x), d2, xtod

In [51]:
p = che.Props(['Ethanol','Isopropanol', 'Water'])

In [183]:
# Static parameters (Total feed, feed mole fradtions, feed temperature and )
d=dot_dict()
d.Ftot=10 # Total Feed moles
d.Fz = np.array([1/3, 1/3, 1/3]) # Equimolar feed composition
d.FT = 450.

d.flashP= 101325 # Flash drum pressure

d.Vy = Comp(d.Fz) # Guess vapor/liquid composition equal to feed
d.Lx = Comp(d.Fz) # Comp - constrains mole fradtions to behave like mole fradtions!
d.flashT = Range(400., 300., d.FT)
d.Vtot = Range(3.97163878, 0., d.Ftot)  # Guess half of feed in vapor
d.Ltot = Range(6.02836122, 0., d.Ftot)

In [184]:
wrap, x, d2, xtod  = dtox(d)

In [185]:
def eqs(c):
    V = c.Vy * c.Vtot # Moles of each component = mole fractions * total moles
    L = c.Lx * c.Ltot
    F = c.Fz * c.Ftot
    mass_balance = F - V - L # Mass balance for each component (vectors!)

    # Hmix calculates the enthalpy given the temperature and moles of each
    # component in the vapor and liquid phases
    FH = p.Hl(nL=F, T=c.FT)
    VH = p.Hv(nV=V, T=c.flashT)
    LH = p.Hl(nL=L, T=c.flashT)
    energy_balance = (FH - VH - LH)

    # Raoults with NRTL activity coefficient correction.  One-liner!
    fugL = c.Lx  * p.NRTL_gamma(c.Lx,c.flashT)* p.Pvap(c.flashT)
    fugV = c.Vy * c.flashP
    VLE = (fugL - fugV)
    return jnp.concatenate([mass_balance, jnp.atleast_1d(energy_balance), VLE])


In [186]:
eqs_jax = wrap(eqs)

In [187]:
eqs_jax(x)

DeviceArray([-4.44089210e-16, -4.44089210e-16, -8.88178420e-16,
             -5.54296099e+04,  1.45994899e+05,  1.39338184e+05,
              1.17611093e+05], dtype=float64)

In [188]:
jac = jax.jacobian(eqs_jax)

In [189]:
sol=root(eqs_jax, x, jac=jac)
print(sol)
solx=sol.x

    fjac: array([[-3.07308281e-05,  1.55516705e-05,  1.51791576e-05,
        -6.73012378e-01, -6.07023792e-01,  3.27320726e-01,
         2.67278124e-01],
       [-1.16835633e-05, -2.62452565e-05,  3.79288198e-05,
         5.03327011e-01, -3.93181039e-01, -2.79619021e-01,
         7.16856884e-01],
       [-3.59668264e-05,  5.92325038e-06,  3.00435760e-05,
        -5.37956557e-01,  4.55014013e-01, -5.87395323e-01,
         3.98160425e-01],
       [-4.53897849e-04, -2.24735096e-04,  6.78632946e-04,
         6.57190933e-02,  5.19512129e-01,  6.85303419e-01,
         5.06109331e-01],
       [-5.44718635e-01, -2.54386206e-01,  7.99104840e-01,
        -3.80533639e-05, -4.26864653e-04, -5.51565066e-04,
        -4.83024831e-04],
       [-6.39595653e-01,  7.42325225e-01, -1.99676383e-01,
         4.11182797e-05,  5.90684542e-06,  6.11448883e-06,
         4.07263351e-06],
       [-5.42400959e-01, -6.19871653e-01, -5.67062900e-01,
        -2.30509700e-06, -3.31138649e-07, -3.42779173e-07,
        

In [190]:
xtod(solx,d)

{'FT': 450.0,
 'Ftot': 10,
 'Fz': array([0.33333333, 0.33333333, 0.33333333]),
 'Ltot': DeviceArray([6.02836122], dtype=float64),
 'Lx': DeviceArray([0.32122647, 0.32919072, 0.34958281], dtype=float64),
 'Vtot': DeviceArray([3.97163878], dtype=float64),
 'Vy': DeviceArray([0.35170977, 0.33962121, 0.30866903], dtype=float64),
 'flashP': 101325,
 'flashT': DeviceArray([352.85497499], dtype=float64)}

In [158]:
solx

array([-1.64109939e+06,  1.96647580e-01,  1.10024157e-01,  5.58387643e-03,
       -1.65399224e-02, -1.55001875e+00,  1.55060261e+00])