<a href="https://colab.research.google.com/github/chetools/CHE4071_Spring2026/blob/main/TubularReactorBandBroadening2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget -N -q https://raw.githubusercontent.com/chetools/chetools/main/tools/che5.ipynb -O che5.ipynb
%run che5.ipynb

In [2]:
jnp.set_printoptions(precision=5, linewidth=240)

In [7]:
N = 200
q=100. #L/min
Cain = 1. #mol/L
Tin = 350. #K
totV = 100. #L
V = totV/N
rho = 1e3 #g/L
C = 0.239 #J/(g K)
negHr = 5e4 #J/mol
ER = 8750. #K
k0 = 7.2e10 #1/min
UAtot = 5e4 #J/(min K)
UA = UAtot/N

Tc = 300. #K
Ca0 = 0.5 #mol/L
T0 = 350. #K

cin = 1. #mol/L

c0 = np.full(N,0.5)
T0 = np.full(N, 350.)
tend=2  #min

In [8]:
def rhs(t, vec):
    c, T = np.split(vec,2)
    dc = np.zeros(N)
    dT = np.zeros(N)

    k = k0*np.exp(-ER/T)
    dc[0] = q*(cin - c[0])/V - k[0]*c[0]
    dc[1:] = q*(c[:-1] - c[1:])/V - k[1:]*c[1:]

    dT[0]=q*(Tin - T[0])/V - UA*(T[0]-Tc)/(rho*V*C) + negHr*k[0]*c[0]/(rho*C)
    dT[1:]=q*(T[:-1] - T[1:])/V - UA*(T[1:]-Tc)/(rho*V*C) + negHr*k[1:]*c[1:]/(rho*C)

    return np.r_[dc,dT]

In [9]:
def rhs_jax(t, vec):
    c, T = jnp.split(vec,2)
    dc = jnp.zeros(N)
    dT = jnp.zeros(N)

    k = k0*jnp.exp(-ER/T)
    dc = dc.at[0].set(q*(cin - c[0])/V - k[0]*c[0])
    dc= dc.at[1:].set(q*(c[:-1] - c[1:])/V - k[1:]*c[1:])

    dT=dT.at[0].set(q*(Tin - T[0])/V - UA*(T[0]-Tc)/(rho*V*C) + negHr*k[0]*c[0]/(rho*C))
    dT=dT.at[1:].set(q*(T[:-1] - T[1:])/V - UA*(T[1:]-Tc)/(rho*V*C) + negHr*k[1:]*c[1:]/(rho*C))

    return jnp.r_[dc,dT]
rhs_jax = jax.jit(rhs_jax)
rhs_jax_jac = jax.jit(jax.jacobian(rhs_jax, 1))

In [10]:
N_times = 100
dt = tend/N_times
tplot = np.linspace(0,tend,N_times)
res=sp.integrate.solve_ivp(rhs_jax, (0,tend), np.r_[c0, T0], method='Radau', dense_output=True,
                           jac = rhs_jax_jac)
c_profiles = res.sol(tplot)

In [11]:
cplot, Tplot = np.split(res.sol(tplot), 2, axis=0)

In [12]:
fig = make_subplots(rows=1,cols=2)
for t_index in range(N_times):
    fig.add_scatter(x = np.arange(N), y = cplot[:, t_index],row=1,col=1,name=f't={dt*t_index:.2f}', legendgroup=t_index)
    fig.add_scatter(x = np.arange(N), y = Tplot[:, t_index],row=1,col=2,  legendgroup=t_index, showlegend=False)
fig