# First lets consider case 1

# Getting ingredients to compute for matrix S

In [7]:
from kinetic_solver_v3 import load_data as load_data_3, process_data as process_data_3, plot_save, system_KE as system_KE_3
import argparse 
from scipy.integrate import solve_ivp
import time
import numpy as np
import pandas as pd

# Input
parser = argparse.ArgumentParser(
    description='Perform kinetic modelling given the free energy profile and mechanism detail')

parser.add_argument(
    "-i",
    help="input reaction profiles in csv"
    )

parser.add_argument("-a", 
                    help="manually add an input reaction profile in csv", 
                    action="append")

parser.add_argument(
    "-c",
    "--c",
    type=str,
    default="c0.txt",
    help="text file containing initial concentration of all species [[INTn], [Rn], [Pn]]")

parser.add_argument(
    "-r",
    "--r",
    type=str,
    default="Rp.txt",
    help="reactant position matrix")

parser.add_argument(
    "-p",
    "--p",
    type=str,
    default="Pp.txt",
    help="product position matrix")

parser.add_argument(
    "-rn",
    "--rn",
    type=str,
    default="rxn_network,csv",
    help="reaction network matrix")

parser.add_argument(
    "--Time",
    type=float,
    default=1e8,
    help="total reaction time (s)")

parser.add_argument(
    "-t",
    "--t",
    type=float,
    default=298.15,
    help="temperature (K)")

parser.add_argument(
    "-de",
    "--de",
    type=str,
    default="LSODA",
    help="Integration method to use (odesolver)")

dir = "test_cases/1/"

args = parser.parse_args(['-i', f"{dir}/reaction_data.csv", 
                          '-c', f"{dir}/c0.txt",
                          "-r", f"{dir}/Rp.txt",
                          "-rn", f"{dir}/rxn_network.csv",
                          "-p", f"{dir}/Pp.txt",
                          ])

initial_conc, Rp, Pp, t_span, temperature, method, energy_profile_all,\
    dgr_all, coeff_TS_all, rxn_network, n_INT_all = load_data_3(args)

k_forward_all, k_reverse_all = process_data_3(energy_profile_all, dgr_all, coeff_TS_all, temperature)

In [2]:
dydt = system_KE_3(
    k_forward_all,
    k_reverse_all,
    rxn_network,
    Rp,
    Pp,
    n_INT_all,
    initial_conc)

In [5]:
result_solve_ivp = solve_ivp(
    dydt,
    t_span,
    initial_conc,
    method="BDF",
    dense_output=True,
    # first_step=first_step,
    # max_step=100,
    rtol=1e-3,
    atol=1e-6,
    jac=dydt.jac,
    timeout = 2
)
result_solve_ivp

'Shiki'

In [8]:
from kinetic_solver_v4 import load_data as load_data_4, process_data as process_data_4, plot_save, system_KE as system_KE_4

initial_conc, Rp, Pp, t_span, temperature, method, energy_profile_all,\
    dgr_all, coeff_TS_all, rxn_network_all, n_INT_all = load_data_4(args)

k_forward_all, k_reverse_all = process_data_4(energy_profile_all, dgr_all, coeff_TS_all, temperature)

In [9]:
dydt = system_KE_4(
        k_forward_all,
        k_reverse_all,
        rxn_network_all,
        Rp,
        Pp,
        n_INT_all,
        initial_conc
        )

result_solve_ivp = solve_ivp(
    dydt,
    t_span,
    initial_conc,
    method="BDF",
    dense_output=True,
    # first_step=first_step,
    # max_step=max_step,
    rtol=1e-3,
    atol=1e-6,
    jac=dydt.jac,
)
plot_save(result_solve_ivp, rxn_network, Rp, Pp, dir)

[0.2 0.  0.  0.  0.6 0.5 0. ]
[ 2.00000000e-01 -6.29586890e-20 -1.72706491e-05 -1.96827928e-11
  5.99982729e-01  5.00000000e-01  0.00000000e+00]
[0.2 0.  0.  0.  0.6 0.5 0. ]
[ 2.00000000e-01 -4.59300429e-19 -1.25993991e-04 -1.43591223e-10
  5.99874006e-01  5.00000000e-01  0.00000000e+00]
[ 2.00000000e-01 -4.59300429e-19 -1.25971664e-04 -1.44416813e-10
  5.99874028e-01  5.00000000e-01 -2.78276415e-28]
[ 2.00000000e-01 -9.18600857e-19 -2.51943337e-04 -2.88833334e-10
  5.99748057e-01  5.00000000e-01 -5.59752778e-28]
[ 2.00000000e-01 -9.18600857e-19 -2.51917529e-04 -2.89787644e-10
  5.99748082e-01  5.00000000e-01 -8.83322383e-28]
[ 2.00000000e-01 -5.51160514e-18 -1.51137618e-03 -1.74349556e-09
  5.98488624e-01  5.00000000e-01 -6.93812607e-27]
[ 2.00000000e-01 -5.51160514e-18 -1.50910405e-03 -1.82751262e-09
  5.98490896e-01  5.00000000e-01 -3.56186819e-26]
[ 2.00000000e-01 -1.01046094e-17 -2.76629863e-03 -3.36494000e-09
  5.97233701e-01  5.00000000e-01 -7.36086537e-26]
[ 2.00000000e-01 -1.

ValueError: setting an array element with a sequence.

In [6]:
from kinetic_solver_v4 import dINT_dt, pad_network

n_INT_tot = np.sum(n_INT_all)
rxn_network = rxn_network_all[:n_INT_tot, :n_INT_tot]
Rp_, _ = pad_network(Rp, n_INT_all, rxn_network)
Pp_, _ = pad_network(Pp, n_INT_all, rxn_network)
k = rxn_network.shape[0] + Rp_[0].shape[1] + Pp_[0].shape[1]

dIdt_fn = dINT_dt(k_forward_all, k_reverse_all, rxn_network_all, Rp_, Pp_, n_INT_all)
dIdt_fn(np.array(initial_conc))

array([[-5.00094360e+04,  0.00000000e+00,  0.00000000e+00,
         1.35881500e-02,  0.00000000e+00,  8.49613440e+06,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [-4.59437521e+07,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [-2.71819645e+02,  1.16193191e+01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [-2.48445736e+08,  0.00000000e+00,  6.21243799e+12,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [-2.71819645e+02,  1.16193191e+01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.

In [6]:
def pad_network(X, n_INT_all, rxn_network):

    X_ = np.array_split(X, np.cumsum(n_INT_all)[:-1])
    # the first profile is assumed to be full, skipped
    insert_idx = []
    for i in range(1, len(n_INT_all)):  # n_profile - 1
        # pitfall
        if np.all(rxn_network[np.cumsum(n_INT_all)[i - 1]:np.cumsum(n_INT_all)[i], 0] == 0):
            cp_idx = np.where(rxn_network[np.cumsum(n_INT_all)[
                                i - 1]:np.cumsum(n_INT_all)[i], :][0] == -1)
            tmp_idx = cp_idx[0][0].copy()
            all_idx = [tmp_idx]
            while tmp_idx != 0:
                tmp_idx = np.where((rxn_network[tmp_idx, :] == -1))[0][0]
                all_idx.insert(0, tmp_idx)
            X_[i] = np.insert(X_[i], 0, X[all_idx], axis=0)
            insert_idx.append(all_idx)

        else:
            all_idx = []
            for j in range(rxn_network.shape[0]):
                if j >= np.cumsum(n_INT_all)[
                        i - 1] and j <= np.cumsum(n_INT_all)[i]:
                    continue
                elif np.any(rxn_network[np.cumsum(n_INT_all)[i - 1]:\
                    np.cumsum(n_INT_all)[i], j]):
                    X_[i] = np.insert(X_[i], j, X[j], axis=0)
                    all_idx.append(j)
            insert_idx.append(all_idx)
                    
    return X_, insert_idx


def add_rate(
        y,
        k_forward_all,
        k_reverse_all,
        rxn_network,
        Rp_,
        Pp_,
        a,
        cn,
        n_INT_all):
    """generate reaction rate at step a, cycle n

    Parameters
    ----------
    y : array-like
        concentration of all species under consideration (in kcal/mol)
    k_forward_all : array-like
        forward reaction rate constant
    k_forward_all : array-like
        reverse reaction rate constant
    rxn_network : array-like
        reaction network matrix
    Rp_: array-like
        reactant reaction coordinate matrix
    Pp_: array-like
        product reaction coordinate matrix
    a : int
        index of the elementary step (note that the last step is 0)
    cn : int
        index of the cycle (start at 0)
    n_INT_all : array-like
        number of state in each cycle

    Returns
    -------
    rate : float
        reaction rate at step a, cycle n
    """
        
    y_INT = []
    tmp = y[:np.sum(n_INT_all)]
    y_INT = np.array_split(tmp, np.cumsum(n_INT_all)[:-1])
    # Rp_ and Pp_ were initially designed to be all positive
    Rp_ = [np.abs(i) for i in Rp_]
    Pp_ = [np.abs(i) for i in Pp_]
    
    # the first profile is assumed to be full, skipped
    for i in range(1, len(k_forward_all)):  # n_profile - 1
        # pitfall
        if np.all(rxn_network[np.cumsum(n_INT_all)[i - 1]:np.cumsum(n_INT_all)[i], 0] == 0):
            cp_idx = np.where(rxn_network[np.cumsum(n_INT_all)[
                              i - 1]:np.cumsum(n_INT_all)[i], :][0] == -1)
            tmp_idx = cp_idx[0][0].copy()
            all_idx = [tmp_idx]
            while tmp_idx != 0:
                tmp_idx = np.where((rxn_network[tmp_idx, :] == -1))[0][0]
                all_idx.insert(0, tmp_idx)
            y_INT[i] = np.insert(y_INT[i], 0, tmp[all_idx])

        else:
            for j in range(rxn_network.shape[0]):
                if j >= np.cumsum(n_INT_all)[
                        i - 1] and j <= np.cumsum(n_INT_all)[i]:
                    continue
                else:
                    if np.any(rxn_network[np.cumsum(n_INT_all)[
                              i - 1]:np.cumsum(n_INT_all)[i], j]):
                        y_INT[i] = np.insert(y_INT[i], j, tmp[j])
                        
    y_R = np.array(y[np.sum(n_INT_all):np.sum(n_INT_all) + Rp_[0].shape[1]])
    y_P = np.array(y[np.sum(n_INT_all) + Rp_[0].shape[1]:])

    idx1 = np.where(Rp_[cn][a - 1] != 0)[0]
    if idx1.size == 0:
        sui = 1
    else:
        rate_tmp = np.where(Rp_[cn][a - 1] != 0, y_R ** Rp_[cn][a - 1], 0)
        zero_indices = np.where(rate_tmp == 0)[0]
        rate_tmp = np.delete(rate_tmp, zero_indices)
        if len(rate_tmp) == 0:
            sui = 0
        else:
            sui = np.prod(rate_tmp)
            
    # cn, a-1
    rate_1 = k_forward_all[cn][a - 1] * sui

    idx2 = np.where(Pp_[cn][a - 1] != 0)[0]
    if idx2.size == 0:
        sui = 1
    else:
        rate_tmp = np.where(Pp_[cn][a - 1] != 0, y_P ** Pp_[cn][a - 1], 0)
        zero_indices = np.where(rate_tmp == 0)[0]
        rate_tmp = np.delete(rate_tmp, zero_indices)
        if len(rate_tmp) == 0:
            sui = 0
        else:
            sui = np.prod(rate_tmp)
    
    # cn, a  
    rate_2 = k_reverse_all[cn][a - 1] * sui

    return rate_1, rate_2

# jaxifying add_rate, dINT, dX

In [13]:
# slicing index

import jax.numpy as jnp
from jax import jit
from jax.lax import dynamic_slice

# 1 indice
def slice_array(x, i):
    return dynamic_slice(x, (i, 0), (1, x.shape[1]))

jit_slice_array = jit(slice_array)

i = 5
x = jnp.ones((10, 10))
result = jit_slice_array(x, i)


# 2 indices
def slice_array(x, i, j):
    return dynamic_slice(x, (i, j), (1, 1))

jit_slice_array = jit(slice_array)

i = 5
j = 3
x = jnp.ones((10, 10))
result = jit_slice_array(x, i, j)
result

Array([[1.]], dtype=float64)

In [3]:
import jax.numpy as jnp
from jax import jit, jacfwd

@jit
def add_rate(
        y,
        k_forward_all,
        k_reverse_all,
        rxn_network,
        Rp_,
        Pp_,
        a,
        cn,
        n_INT_all):
        
    y_INT = []
    tmp = y[:jnp.sum(n_INT_all)] #*
    y_INT = jnp.array_split(tmp, jnp.cumsum(n_INT_all)[:-1]) 
    # Rp_ and Pp_ were initially designed to be all positive
    Rp_ = [jnp.abs(i) for i in Rp_]
    Pp_ = [jnp.abs(i) for i in Pp_]
    
    # the first profile is assumed to be full, skipped
    for i in range(1, len(k_forward_all)):  # n_profile - 1
        # pitfall
        if jnp.all(rxn_network[jnp.cumsum(n_INT_all)[i - 1]:jnp.cumsum(n_INT_all)[i], 0] == 0): #*
            cp_idx = jnp.where(rxn_network[jnp.cumsum(n_INT_all)[
                            i - 1]:jnp.cumsum(n_INT_all)[i], :][0] == -1) #*
            tmp_idx = cp_idx[0][0].copy()
            all_idx = [tmp_idx]
            while tmp_idx != 0:
                tmp_idx = jnp.where((rxn_network[tmp_idx, :] == -1))[0][0] #*
                all_idx.insert(0, tmp_idx)
            y_INT[i] = jnp.insert(y_INT[i], 0, tmp[all_idx])

        else:
            for j in range(rxn_network.shape[0]):
                if j >= jnp.cumsum(n_INT_all)[
                        i - 1] and j <= jnp.cumsum(n_INT_all)[i]: #*
                    continue
                else:
                    if jnp.any(rxn_network[jnp.cumsum(n_INT_all)[
                            i - 1]:jnp.cumsum(n_INT_all)[i], j]):
                        y_INT[i] = jnp.insert(y_INT[i], j, tmp[j]) #*

    y_R = jnp.array(y[jnp.sum(n_INT_all):jnp.sum(n_INT_all) + Rp_[0].shape[1]]) #*
    y_P = jnp.array(y[jnp.sum(n_INT_all) + Rp_[0].shape[1]:]) #*

    idx1 = jnp.where(Rp_[cn][a - 1] != 0)[0]
    if idx1.size == 0:
        sui = 1
    else:
        rate_tmp = jnp.where(Rp_[cn][a - 1] != 0, y_R ** Rp_[cn][a - 1], 0)
        zero_indices = jnp.where(rate_tmp == 0)[0]
        rate_tmp = jnp.delete(rate_tmp, zero_indices)
        if len(rate_tmp) == 0:
            sui = 0
        else:
            sui = jnp.prod(rate_tmp)

    # cn, a-1
    rate_1 = k_forward_all[cn][a - 1] * sui


    idx2 = jnp.where(Pp_[cn][a - 1] != 0)[0]
    if idx2.size == 0:
        sui = 1
    else:
        rate_tmp = jnp.where(Pp_[cn][a - 1] != 0, y_P ** Pp_[cn][a - 1], 0)
        zero_indices = jnp.where(rate_tmp == 0)[0]
        rate_tmp = jnp.delete(rate_tmp, zero_indices)
        if len(rate_tmp) == 0:
            sui = 0
        else:
            sui = jnp.prod(rate_tmp)
    
    # cn, a  
    rate_2 = k_reverse_all[cn][a - 1] * sui

    return rate_1, rate_2

In [10]:
@jit
def test(rxn_network, n_INT_all):
    return jnp.all(rxn_network[jnp.cumsum(n_INT_all)[2 - 1]:jnp.cumsum(n_INT_all)[2], 0] == 0)
test(rxn_network, n_INT_all)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [6]:
y = initial_conc.copy()
add_rate(
        y,
        k_forward_all,
        k_reverse_all,
        rxn_network,
        Rp,
        Pp,
        0,
        0,
        n_INT_all)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [48]:
df_network = pd.read_csv("test_cases/2A/rxn_network.csv")
df_network.fillna(0, inplace=True)

# process reaction network
rxn_network_all = df_network.to_numpy()[:, 1:]
rxn_network_all = rxn_network_all.astype(np.int32)
states = df_network.columns[1:].tolist()
nR = len([s for s in states if s.lower().startswith('r') and 'INT' not in s])
nP = len([s for s in states if s.lower().startswith('p') and 'INT' not in s])
    
n_INT_tot = rxn_network_all.shape[1] - nR - nP
rxn_network = rxn_network_all[:n_INT_tot, :n_INT_tot]

n_INT_all = []
x = 1
for i in range(1, rxn_network.shape[1]):
    if rxn_network[i, i - 1] == -1:
        x += 1
    elif rxn_network[i, i - 1] != -1:
        n_INT_all.append(x)
        x = 1
n_INT_all.append(x)
n_INT_all = np.array(n_INT_all)

Rp = rxn_network_all[:n_INT_tot, n_INT_tot:n_INT_tot+nR]
Pp = rxn_network_all[:n_INT_tot, n_INT_tot+nR:]

Rp_, _ = pad_network(Rp, n_INT_all, rxn_network)
Pp_, idx_insert = pad_network(Pp, n_INT_all, rxn_network)

y = np.array(initial_conc)

# INT rate
rate_INT = None
for i in range(n_INT_tot):

    # first lets just consider at INT0; this should return the array of size k
    rate_array = np.zeros(rxn_network_all.shape[1])
    self_rate = 0

    for i, j in enumerate(rxn_network_all[i, :]):
        if j == 0: continue
        else:        
            # the prod and react are already included implicitly
            if i > n_INT_tot-1:
                rate_array[i] = 0
            
            #TODO, assigned the elements correctly; taking cn, a into consideration
            else: 
                mori = np.cumsum(n_INT_all)
                cn = np.searchsorted(mori, i, side='right')
                incr = 0
                a = i
                if cn > 0:
                    incr = 0
                    if np.all(rxn_network[np.cumsum(n_INT_all)[
                                cn - 1]:np.cumsum(n_INT_all)[cn], 0] == 0):
                        cp_idx = np.where(rxn_network[np.cumsum(n_INT_all)[
                                            cn - 1]:np.cumsum(n_INT_all)[cn], :][0] == -1)
                        tmp_idx = cp_idx[0][0].copy()
                        incr += 1
                        while tmp_idx != 0:
                            tmp_idx = np.where((rxn_network[tmp_idx, :] == -1))[0][0]
                            incr += 1

                    else:
                        for j in range(rxn_network.shape[0]):
                            if j >= np.cumsum(n_INT_all)[
                                    cn - 1] and j <= np.cumsum(n_INT_all)[cn]:
                                continue
                            else:
                                if np.any(rxn_network[np.cumsum(n_INT_all)[
                                            cn - 1]:np.cumsum(n_INT_all)[cn], j]):
                                    incr += 1
                    a = i - mori[cn-1] + incr
                rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, a, cn, n_INT_all)
                
                if j < 0: 
                    self_rate-=rate_1
                    rate_array[i] = rate_2
                elif j > 0:
                    self_rate+=rate_1
                    rate_array[i] = -rate_2
    rate_array[0] = self_rate 
    if rate_INT is None: rate_INT = rate_array
    else: rate_INT = np.vstack([rate_INT, rate_array])

# TODo, R and P 

# the function

In [1]:
def dINT_dt(y,
        k_forward_all,
        k_reverse_all,
        rxn_network_all,
        Rp_,
        Pp_,
        n_INT_all):
    
    # INT rate
    rxn_network = rxn_network_all[:n_INT_tot, :n_INT_tot]
    rate_INT = None
    n_INT_tot = np.sum(n_INT_all)
    for n in range(n_INT_tot):

        # first lets just consider at INT0; this should return the array of size k
        rate_array = np.zeros(rxn_network_all.shape[1])
        self_rate = 0
        for i, j in enumerate(rxn_network_all[n, :]):
            if j == 0: continue
            else:        
                # the prod and react are already included implicitly
                if i > n_INT_tot-1:
                    rate_array[i] = 0
                
                #TODO, assigned the elements correctly; taking cn, a into consideration
                else: 
                    mori = np.cumsum(n_INT_all)
                    cn = np.searchsorted(mori, i, side='right')
                    incr = 0
                    a = i
                    if cn > 0:
                        incr = 0
                        if np.all(rxn_network[np.cumsum(n_INT_all)[
                                    cn - 1]:np.cumsum(n_INT_all)[cn], 0] == 0):
                            cp_idx = np.where(rxn_network[np.cumsum(n_INT_all)[
                                                cn - 1]:np.cumsum(n_INT_all)[cn], :][0] == -1)
                            tmp_idx = cp_idx[0][0].copy()
                            incr += 1
                            while tmp_idx != 0:
                                tmp_idx = np.where((rxn_network[tmp_idx, :] == -1))[0][0]
                                incr += 1

                        else:
                            for r in range(rxn_network.shape[0]):
                                if r >= np.cumsum(n_INT_all)[
                                        cn - 1] and r <= np.cumsum(n_INT_all)[cn]:
                                    continue
                                else:
                                    if np.any(rxn_network[np.cumsum(n_INT_all)[
                                                cn - 1]:np.cumsum(n_INT_all)[cn], r]):
                                        incr += 1
                        a = i - mori[cn-1] + incr
                    # if -1, already correct. If 1, a+=1
                    rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, a, cn, n_INT_all)             
                    if j < 0: 
                        rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, a, cn, n_INT_all)
                        self_rate-=rate_1
                        rate_array[i] = rate_2
                    elif j > 0:
                        try: 
                            rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, a+1, cn, n_INT_all)
                        except IndexError as e:
                            rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, 0, cn, n_INT_all)
                            self_rate-=rate_2
                            rate_array[i] = rate_1
        rate_array[0] = self_rate 
        if rate_INT is None: rate_INT = rate_array
        else: rate_INT = np.vstack([rate_INT, rate_array])
    return rate_INT

def dX_dt(y,
        k_forward_all,
        k_reverse_all,
        rxn_network_all,
        X,
        Rp_,
        Pp_,
        n_INT_all):

    rate_X = None
    n_INT_tot = np.sum(n_INT_all)
    rxn_network = rxn_network_all[:n_INT_tot, :n_INT_tot]
    rate_INT = None
    
    for n in range(Rp.shape[1]):
        # first lets just consider at INT0; this should return the array of size k
        rate_array = np.zeros(rxn_network_all.shape[1])  
        for i, j in enumerate(X[:,n]):
            if j == 0: continue
            else:        
                # the prod and react are already included implicitly
                if i > n_INT_tot-1:
                    rate_array[i] = 0
                
                #TODO, assigned the elements correctly; taking cn, a into consideration
                else: 
                    mori = np.cumsum(n_INT_all)
                    cn = np.searchsorted(mori, i, side='right')
                    incr = 0
                    a = i
                    if cn > 0:
                        incr = 0
                        if np.all(rxn_network[np.cumsum(n_INT_all)[
                                    cn - 1]:np.cumsum(n_INT_all)[cn], 0] == 0):
                            cp_idx = np.where(rxn_network[np.cumsum(n_INT_all)[
                                                cn - 1]:np.cumsum(n_INT_all)[cn], :][0] == -1)
                            tmp_idx = cp_idx[0][0].copy()
                            incr += 1
                            while tmp_idx != 0:
                                tmp_idx = np.where((rxn_network[tmp_idx, :] == -1))[0][0]
                                incr += 1

                        else:
                            for r in range(rxn_network.shape[0]):
                                if r >= np.cumsum(n_INT_all)[
                                        cn - 1] and r <= np.cumsum(n_INT_all)[cn]:
                                    continue
                                else:
                                    if np.any(rxn_network[np.cumsum(n_INT_all)[
                                                cn - 1]:np.cumsum(n_INT_all)[cn], r]):
                                        incr += 1
                        a = i - mori[cn-1] + incr 
                    a += 1
                    
                    try:
                        rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, a, cn, n_INT_all)    
                    except IndexError as e:
                        rate_1, rate_2 = add_rate(y, k_forward_all, k_reverse_all, rxn_network, Rp_, Pp_, 0, cn, n_INT_all)  
               
                    rate_array[i] = np.sign(j)*rate_1

                    if i + 1 >= n_INT_all[cn]: 

                        rate_array[0] = -np.sign(j)*rate_2
                    else: rate_array[i+1] = -np.sign(j)*rate_2
        if rate_X is None: rate_X = rate_array
        else: rate_X = np.vstack([rate_X, rate_array])

    return rate_X


def system_KE(
        k_forward_all,
        k_reverse_all,
        rxn_network,
        Rp_,
        Pp_,
        n_INT_all,
        initial_conc,
        jac_method="ag"):
    """"Forming the system of DE for kinetic modelling, inspried by get_dydt from overreact module

    Returns
    -------
    dydt : callable
        Reaction rate function. The actual reaction rate constants employed
        are stored in the attribute `k` of the returned function. If JAX is
        available, the attribute `jac` will hold the Jacobian function of
        `dydt`
    """
    k = rxn_network.shape[0] + Rp_[0].shape[1] + Pp_[0].shape[1]

    
    # to enforce boundary condition and the contraint
    #TODO when violated, assigning y and dydt could be better than this
    def bound_decorator(bounds):
        def decorator(func):
            def wrapper(t, y):

                dy_dt = func(t, y)

                for i in range(len(y)):
                    if y[i] < bounds[i][0]:
                        dy_dt[i] += (bounds[i][0] - y[i])/2
                        y[i] = bounds[i][0] 
                    elif y[i] > bounds[i][1]:
                        dy_dt[i] -= (y[i] - bounds[i][1])/2
                        y[i] = bounds[i][1] 
                        dy_dt[i] = 0
  
                return dy_dt
            return wrapper
        return decorator

    tolerance = 0.01
    boundary = []
    for i in range(k):
        if i == 0: boundary.append((0-tolerance, initial_conc[0]+tolerance))
        elif i >= rxn_network.shape[0] and i < rxn_network.shape[0] + Rp_[0].shape[1]:
            boundary.append((0-tolerance, initial_conc[i]+tolerance))
        else: boundary.append((0-tolerance, np.sum(initial_conc)+tolerance))

    S = jnp.asarray(S)
    @bound_decorator(boundary)
    def _dydt(t, y):
        return jnp.dot(S,y)

    def _jac(t,y):
        return jacfwd(lambda _y: _dydt(t, _y))(y)

    _dydt.jac = _jac
        
    return _dydt
    
