In [None]:
import numpy as np
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, Tsit5
from scipy.integrate import odeint, solve_ivp
import diffrax

In [None]:

def get_dcdts(c_first=False):

    def dcdt_func(c, t, *args):
        # print(c, t, ks, k_kinetics)
        # print()
        if not c_first:
            _x = c
            c = t
            t = _x

        ks, k_kinetics = args
        c_xNH3, c_xNO3, c_xNO2, c_xNOrg, c_xN2, c_ANH3, c_ANO3, c_ANO2, c_ANOrg, c_AN2 = c
        
        r1 = ks[0] * c_xN2 if k_kinetics[0] == 1 else ks[0]
        r2 = ks[1] * c_xNH3 if k_kinetics[1] == 1 else ks[1]
        r3 = ks[2] * c_xNO2 if k_kinetics[2] == 1 else ks[2]
        r4 = ks[3] * c_xNO3 if k_kinetics[3] == 1 else ks[3]
        r5 = ks[4] * c_xNO2 if k_kinetics[4] == 1 else ks[4]
        r6 = ks[5] * c_xNO2 * c_xNO3 if k_kinetics[5] == 1 else ks[5]
        r7 = ks[6] * c_xNO3 if k_kinetics[6] == 1 else ks[6]
        r8 = ks[7] * c_xNO3 if k_kinetics[7] == 1 else ks[7]
        r9 = ks[8] * c_xNH3 if k_kinetics[8] == 1 else ks[8]
        r10 = ks[9] * c_xNOrg if k_kinetics[9] == 1 else ks[9]
        r11 = ks[10] * c_xNOrg if k_kinetics[10] == 1 else ks[10]
        
        dc_xNH3 = 2 * r1 + r7 + r10 - r2 - r6 - r9
        dc_xNO3 = r3 - r7 - r4 - r8 + r11
        dc_xNO2 = r2 + r4 - r3 - r6 - 2 * r5
        dc_xNOrg = r8 + r9 - r10 - r11
        dc_xN2 = r5 + r6 - r1
        dc_ANH3 = (2 * r1 * (c_AN2 - c_ANH3) + (c_ANO3 - c_ANH3) * r7 + (c_ANOrg - c_ANH3) * r10) / c_xNH3
        dc_ANO3 = ((c_ANO2 - c_ANO3) * r2 + (c_ANOrg - c_ANO3) * r11) / c_xNO3
        dc_ANO2 = ((c_ANH3 - c_ANO2) * r2 + (c_ANO3 - c_ANO2) * r4) / c_xNO2
        dc_ANOrg = ((c_ANO3 - c_ANOrg) * r8 + (c_ANH3 - c_ANOrg) * r9) / c_xNOrg
        dc_AN2 = ((c_ANO2 - c_AN2) * r5 + (c_ANO2 * c_ANH3 - c_AN2) * r6) / c_xN2
        
        dcdts = [dc_xNH3, dc_xNO3, dc_xNO2, dc_xNOrg, dc_xN2, dc_ANH3, dc_ANO3, dc_ANO2, dc_ANOrg, dc_AN2]
        return dcdts

    return dcdt_func

In [None]:
import core
db_csv_path = "dataset/data.csv"
idata_save_path = "odes-exp04-idata-4-number-1core-c0number-halfnormks-from-core.py-success.dt"

dataset_ori = core.MyDataset(db_csv_path)
df_ori = dataset_ori.get_df()
cct_names, rates_names, error_names = dataset_ori.get_var_col_names()
c0 = df_ori[cct_names].iloc[0].values

# 假设都是一级动力学
k_kinetics = np.repeat(1, 11).astype(np.uint8) 
# k_kinetics = np.array([0,0,0,0,1,1,0,0,1,1,0]).astype(np.uint8) 
ks = np.array([0.00071942, 0.00269696, 0.00498945, 0.00444931, 0.00571299, 0.00801272, 0.00131931, 0.00319959, 0.00415571, 0.00228432, 0.00177611])
#  =======================================================

# t_eval = np.linspace(0.5, 150, 8)
t_eval = np.array([0.5, 48, 96, 144])


dataset = core.MyDataset(db_csv_path)
df = dataset.get_df()
cct_names, rates_names, error_names = dataset.get_var_col_names()
c0 = df[cct_names].iloc[0].values
dataset.set_as_sim_dataset(core.dcdt_func_for_odeint, t_eval, c0, args=(ks, k_kinetics))
df = dataset.get_df()
y = odeint(get_dcdts(c_first=True), y0=c0, t=t_eval, args=(ks, k_kinetics))
y_s = solve_ivp(get_dcdts(), t_span=(0.5, 144), y0=c0, t_eval=t_eval, args=(ks, k_kinetics))

print(y)
print(y_s.y.transpose(1,0) - y)



In [None]:
term = ODETerm(get_dcdts(c_first=True))
solver = Dopri5()
saveat = SaveAt(ts=t_eval)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
solver = diffrax.Heun()

# sol = diffeqsolve(term, solver, t0=t_eval[0], t1=t_eval[-1], dt0=0.1, y0=list(c0), args=([12,3,4,4], list(k_kinetics),), saveat=saveat, stepsize_controller=stepsize_controller)
# sol = diffeqsolve(term, solver, t0=t_eval.min(), t1=t_eval.max(), dt0=0.1, y0=c0, saveat=saveat, args=(23, 233), stepsize_controller=stepsize_controller)
sol = diffeqsolve(term, solver, t0=t_eval[0], t1=t_eval[-1], dt0=0.1, y0=list(c0), saveat=saveat, args=[ks.tolist(), k_kinetics.tolist()], stepsize_controller=stepsize_controller)


print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
y2 = np.array(sol.ys).transpose(1,0)

print( np.sum(y2- y))

In [None]:
import time
n = 1000

s_time = time.time()
ks = np.random.random(11)/100
for x in range(n):
    ks = np.random.random(11)/100
    # sol = diffeqsolve(term, diffrax.Tsit5(), t0=t_eval[0], t1=t_eval[-1], dt0=48, y0=list(c0), saveat=saveat, args=[ks.tolist(), k_kinetics.tolist()])
    y_s = solve_ivp(get_dcdts(),t_span=(0.5, 144), y0=c0, t_eval=t_eval, args=(ks, k_kinetics), method='RK45')

s_time1 = time.time()
print(s_time1-s_time)
for x in range(n):
    ks = np.random.random(11)/100
    y = odeint(get_dcdts(c_first=True), y0=c0, t=t_eval, args=(ks, k_kinetics))

e_time = time.time()
print(s_time1-s_time, e_time-s_time1)

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5


def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = d_prey, d_predator
    return d_y


ly_term = ODETerm(vector_field)
solver = Tsit5()
t0 = 0
t1 = 140
dt0 = 0.1
y0 = (10.0, 10.0)
args = (0.1, 0.02, 0.4, 0.03)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffeqsolve(ly_term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)

In [None]:
plt.plot(sol.ts, sol.ys[0], label="Prey")
plt.plot(sol.ts, sol.ys[1], label="Predator")
plt.legend()
plt.show()
