# Convenience Kinetics with Jax

From _Bringing metabolic networks to life: convenience rate law and thermodynamic constraints_, Liebermeister and Klipp (2006)
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1781438/

In [1]:
import projectpath

from dataclasses import dataclass
from typing import Iterable, Mapping, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np

from kb import kb
from model.reaction_network import ReactionNetwork
from model.core import Molecule, Reaction, Pathway

ArrayT = Union[np.ndarray, jnp.ndarray]

KB = kb.configure_kb()

jax.config.update('jax_enable_x64', True)

### General Rate Law, as defined in the paper:
$$
v(a,b) = E
\frac{
    k_{+}^{cat} \prod\limits_i \tilde{a}_i
    + k_{-}^{cat} \prod\limits_j \tilde{b}_j
}{
    \prod\limits_i (1 + \tilde{a}_i)
    + \prod\limits_j (1 + \tilde{b}_j)
    - 1
}
$$

In [2]:
glycolysis = KB.find(KB.pathways, 'glycolysis')[0]
network = ReactionNetwork(glycolysis.steps)

In [3]:
# Build up a list of substrate and product indices for each reaction.
width = 0
ragged_indices = []
for reaction in network.reactions():
    reaction_indices = [[], []]
    for reactant, count in reaction.stoichiometry.items():
        idx = network.reactant_index(reactant)
        if count < 0:
            reaction_indices[0].extend([idx] * -count)
        else:
            reaction_indices[1].extend([idx] * count)

    width = max(width, len(reaction_indices[0]), len(reaction_indices[1]))
    ragged_indices.append(reaction_indices)

ragged_indices

[[[0], [1]],
 [[2, 1], [3, 4, 5]],
 [[4, 6], [1, 7]],
 [[4], [8, 9]],
 [[8], [9]],
 [[8, 10, 7], [11, 5, 12]],
 [[13, 2], [11, 3]],
 [[14], [13]],
 [[14], [6, 15]],
 [[2, 16], [3, 5, 15]],
 [[2, 6, 16], [17, 7, 15, 5, 5]],
 [[18, 10, 16], [19, 20, 12]]]

In [4]:
# Build a regularized array of indices, padded with -1, and a corresponding mask padded with 0.
indices = -np.ones((len(ragged_indices), 2, width), dtype=int)
mask = np.zeros((len(ragged_indices), 2, width), dtype=int)
for i, reaction_indices in enumerate(ragged_indices):
    indices[i, 0, :len(reaction_indices[0])] = reaction_indices[0]
    indices[i, 1, :len(reaction_indices[1])] = reaction_indices[1]
    mask[i, 0, :len(reaction_indices[0])] = 1
    mask[i, 1, :len(reaction_indices[1])] = 1
print(indices)
print(mask)

[[[ 0 -1 -1 -1 -1]
  [ 1 -1 -1 -1 -1]]

 [[ 2  1 -1 -1 -1]
  [ 3  4  5 -1 -1]]

 [[ 4  6 -1 -1 -1]
  [ 1  7 -1 -1 -1]]

 [[ 4 -1 -1 -1 -1]
  [ 8  9 -1 -1 -1]]

 [[ 8 -1 -1 -1 -1]
  [ 9 -1 -1 -1 -1]]

 [[ 8 10  7 -1 -1]
  [11  5 12 -1 -1]]

 [[13  2 -1 -1 -1]
  [11  3 -1 -1 -1]]

 [[14 -1 -1 -1 -1]
  [13 -1 -1 -1 -1]]

 [[14 -1 -1 -1 -1]
  [ 6 15 -1 -1 -1]]

 [[ 2 16 -1 -1 -1]
  [ 3  5 15 -1 -1]]

 [[ 2  6 16 -1 -1]
  [17  7 15  5  5]]

 [[18 10 16 -1 -1]
  [19 20 12 -1 -1]]]
[[[1 0 0 0 0]
  [1 0 0 0 0]]

 [[1 1 0 0 0]
  [1 1 1 0 0]]

 [[1 1 0 0 0]
  [1 1 0 0 0]]

 [[1 0 0 0 0]
  [1 1 0 0 0]]

 [[1 0 0 0 0]
  [1 0 0 0 0]]

 [[1 1 1 0 0]
  [1 1 1 0 0]]

 [[1 1 0 0 0]
  [1 1 0 0 0]]

 [[1 0 0 0 0]
  [1 0 0 0 0]]

 [[1 0 0 0 0]
  [1 1 0 0 0]]

 [[1 1 0 0 0]
  [1 1 1 0 0]]

 [[1 1 1 0 0]
  [1 1 1 1 1]]

 [[1 1 1 0 0]
  [1 1 1 0 0]]]


In [5]:
# All taken from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4912430/, except where starred
# Maybe this should be in the KB somewhere...
POOLS = {KB.get(KB.compounds, met_id): conc for met_id, conc in [
    ('2pg', 9.18e-02),
    ('3pg', 1.54),
    ('6pg', 3.77),
    ('6pgdl', 1.), # ************
    ('Ery.D.4P', 4.90e-02),
    ('Fru.D.6P', 2.52),
    ('Fru.D.bis16', 15.2),
    ('Glc.D.6P', 7.88),
    ('Rib.D.5P', 7.87e-01),
    ('Rul.D.5P', 1.12e-01),
    ('Sed.D.7P', 8.82e-01),
    ('Xul.D.5P', 1.81e-01),
    ('accoa', 6.06e-01),
    ('acon', 1.61e-02),
    ('adp', 5.55e-01),
    ('akg', 4.43e-01),
    ('amp', 2.81e-01),
    ('atp', 9.63),
    ('cit', 1.96),
    ('co2', 7.52e-02),
    ('coa', 1.37),
    ('dhap', 3.06),
    ('dpg', 1.65e-02),
    ('fum', 2.88e-01),
    ('gap', 2.71e-01),
    ('glx', 0.1),  # **********
    ('h+', 1e-7),
    ('h2o', 1.0),  # Activity of the solvent is defined as 1
    ('icit', 3.67e-02),
    ('kdpg', 0.01),  # **********
    ('mal.L', 1.68),
    ('nad.ox', 2.55),
    ('nad.red', 8.36e-02),
    ('nadp.ox', 2.08e-03),
    ('nadp.red', 1.21e-01),
    ('oaa', 4.87e-04),
    ('pep', 1.84e-01),
    ('pi', 23.9),
    ('pyr', 3.66),
    ('q.ox', 0.01),  # **********
    ('q.red', 0.01),  # **********
    ('succ', 5.69e-01),
    ('succcoa', 2.33e-01),
]}

state = network.reactant_vector(POOLS)
state

array([7.88e+00, 2.52e+00, 9.63e+00, 5.55e-01, 1.52e+01, 1.00e-07,
       1.00e+00, 2.39e+01, 2.71e-01, 3.06e+00, 2.55e+00, 1.65e-02,
       8.36e-02, 1.54e+00, 9.18e-02, 1.84e-01, 3.66e+00, 2.81e-01,
       1.37e+00, 6.06e-01, 7.52e-02])

In [6]:
# Generate some random kinetic data
prng = jax.random.PRNGKey(0)
prng, prng_ = jax.random.split(prng)
kcat_ = jax.random.uniform(prng_, indices.shape[:2])
prng, prng_ = jax.random.split(prng)
km_ = jax.random.uniform(prng_, indices.shape) * mask + 1
print(kcat_, km_)



[[0.24393938 0.98864717]
 [0.79771909 0.86229213]
 [0.30505736 0.35540929]
 [0.97067327 0.08206506]
 [0.71806872 0.15587295]
 [0.70148372 0.43307884]
 [0.37127029 0.33883736]
 [0.79413118 0.04328649]
 [0.91175046 0.72198283]
 [0.63780652 0.22756648]
 [0.44202183 0.3890605 ]
 [0.17855327 0.44526287]] [[[1.7330887  1.         1.         1.         1.        ]
  [1.84511803 1.         1.         1.         1.        ]]

 [[1.68000414 1.4106932  1.         1.         1.        ]
  [1.88453739 1.06395593 1.36094188 1.         1.        ]]

 [[1.86311185 1.41219498 1.         1.         1.        ]
  [1.08700861 1.81891399 1.         1.         1.        ]]

 [[1.83816509 1.         1.         1.         1.        ]
  [1.34808971 1.67732023 1.         1.         1.        ]]

 [[1.25781949 1.         1.         1.         1.        ]
  [1.06395669 1.         1.         1.         1.        ]]

 [[1.73577349 1.01549016 1.91770028 1.         1.        ]
  [1.84548954 1.99723361 1.61885689 1.  

In [7]:
kcats = {reaction: tuple(v.tolist()) for reaction, v in zip(network.reactions(), kcat_)}
kms = {
    reaction: {
        network.reactant(idx): km
        for idx, km in zip(reaction_indices.ravel(), reaction_kms.ravel().tolist())
        if idx >= 0
    }
    for reaction, reaction_indices, reaction_kms in zip(network.reactions(), indices, km_)
}
for rxn, data in kms.items():
    print(f'{rxn.id}: [{", ".join(f"{mol.id}: {km}" for mol, km in data.items())}]')

pgi: [Glc.D.6P: 1.73308870290881, Fru.D.6P: 1.845118028680578]
pfk: [atp: 1.6800041426858345, Fru.D.6P: 1.4106931966597775, adp: 1.8845373875380222, Fru.D.bis16: 1.0639559349130323, h+: 1.360941883732501]
fbp: [Fru.D.bis16: 1.863111852460446, h2o: 1.4121949840515609, Fru.D.6P: 1.087008606641008, pi: 1.8189139870287792]
fba: [Fru.D.bis16: 1.83816508655738, gap: 1.3480897075806795, dhap: 1.6773202300618018]
tpi: [gap: 1.2578194870841028, dhap: 1.06395669472636]
gapdh: [gap: 1.735773490376081, nad.ox: 1.0154901613098066, pi: 1.9177002810527666, dpg: 1.8454895449793587, h+: 1.997233605782102, nad.red: 1.6188568909964942]
pgk: [3pg: 1.4495036494695925, atp: 1.8434500296356866, dpg: 1.310381086550176, adp: 1.248563312555589]
gpm.indep: [2pg: 1.2020587643039475, 3pg: 1.6468976747200537]
eno: [2pg: 1.1759124551051658, h2o: 1.989126170369294, pep: 1.6723482997507577]
pyk: [atp: 1.5936904245159553, pyr: 1.765441106928819, adp: 1.8637088638990469, h+: 1.2421153611883478, pep: 1.4752950199836787]


In [8]:
def kcat_array(kcats):
    return np.array([kcats.get(reaction, (0, 0)) for reaction in network.reactions()])

def km_array(kms):
    # Result shape is (#rxns, 2, width)
    result = np.ones((network.shape[1], 2, width))
    # Use self.indices_ as the source of truth; do not rely on iterating over reaction.stoichiometry.items().
    for i, rxn_indices in enumerate(indices):
        reaction = network.reaction(i)
        if reaction in kms:
            for j, side in enumerate(rxn_indices):
                for k, reactant_idx in enumerate(side):
                    if reactant_idx >= 0:
                        result[i, j, k] = kms[reaction].get(network.reactant(reactant_idx), 1)
    return result


In [9]:
print(np.sum(np.square(kcat_ - kcat_array(kcats))))
print(np.sum(np.square(km_ - km_array(kms))))

0.0
0.11959516277754548


In [10]:
network.reaction(10)

Reaction(_id='pps', name='Phosphoenolpyruvate Synthetase', shorthand='PPS', description=None, aka=None, xrefs={KEGG:R00199, EC:2.7.9.2, ECOCYC:PEPSYNTH-RXN, GO:0008986, RHEA:11364, METACYC:PEPSYNTH-RXN}, stoichiometry={Molecule [atp] adenosine 5'-triphosphate: -1, Molecule [h2o] water: -1, Molecule [pyr] pyruvate: -1, Molecule [amp] adenosine 5'-monophosphate: 1, Molecule [pi] phosphate: 1, Molecule [pep] phosphoenolpyruvate: 1, Molecule [h+] proton: 2}, catalyst=Molecule [PpsA] , reversible=False)

In [11]:
km_[10]

DeviceArray([[1.39288675, 1.76409595, 1.85526158, 1.        , 1.        ],
             [1.62408382, 1.73038468, 1.72420759, 1.56365822, 1.21783288]],            dtype=float64)

In [12]:
km_array(kms)[10]

array([[1.39288675, 1.76409595, 1.85526158, 1.        , 1.        ],
       [1.62408382, 1.73038468, 1.72420759, 1.21783288, 1.21783288]])

In [13]:
# $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$, padded with ones as neccessary
state_norm = jnp.append(state, 1)[indices] / km_
state_norm

DeviceArray([[[4.54679555e+00, 1.00000000e+00, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00],
              [1.36576629e+00, 1.00000000e+00, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00]],

             [[5.73212872e+00, 1.78635582e+00, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00],
              [2.94501984e-01, 1.42863059e+01, 7.34785234e-08,
               1.00000000e+00, 1.00000000e+00]],

             [[8.15839370e+00, 7.08117513e-01, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00],
              [2.31828891e+00, 1.31397087e+01, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00]],

             [[8.26911582e+00, 1.00000000e+00, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00],
              [2.01025198e-01, 1.82433858e+00, 1.00000000e+00,
               1.00000000e+00, 1.00000000e+00]],

             [[2.15452219e-01, 1.00000000e+00, 1.00000000e+00,
               1.00000000e+00, 1.00000000

In [14]:
# $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$
numerator = jnp.sum(kcat_ * jnp.prod(state_norm, axis=-1), axis=-1)
numerator

DeviceArray([ 2.45940347,  8.1683418 , 12.58869747,  8.05670604,
              0.60300898,  3.42748425,  2.0624625 ,  0.10112382,
              0.11111283,  7.9898593 ,  3.41749531,  0.90746114],            dtype=float64)

In [15]:
# $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
denominator = jnp.sum(jnp.prod(state_norm + mask, axis=-1), axis=-1) - 1
denominator

DeviceArray([ 6.91256184, 37.54626087, 61.56325138, 11.66121762,
              4.09150913, 54.71047545, 13.29907594,  2.01146047,
              1.74613848, 22.10245968, 55.09025619, 20.9081069 ],            dtype=float64)

In [16]:
prng, prng_ = jax.random.split(prng)
enzyme_conc = jax.random.uniform(prng_, indices.shape[:1])
rates = enzyme_conc * numerator / denominator
rates

DeviceArray([0.32628979, 0.21069843, 0.09611057, 0.65048012, 0.13741619,
             0.0020999 , 0.02806377, 0.02456297, 0.00987514, 0.16441616,
             0.05763836, 0.00126864], dtype=float64)

In [17]:
@dataclass
class ReactionKinetics:
    reaction: Reaction
    kcat_f: float
    kcat_b: float
    km: Mapping[Molecule, float]


class ConvenienceKinetics:
    def __init__(self, 
                 network: ReactionNetwork,
                 kinetics: Optional[Iterable[ReactionKinetics]] = None):
        self.network = network
        
        # Build up a list of substrate and product indices for each reaction.
        width = 0
        ragged_indices = []
        for reaction in network.reactions():
            reaction_indices = [[], []]
            for reactant, count in reaction.stoichiometry.items():
                idx = network.reactant_index(reactant)
                if count < 0:
                    reaction_indices[0].extend([idx] * -count)
                else:
                    reaction_indices[1].extend([idx] * count)

            width = max(width, len(reaction_indices[0]), len(reaction_indices[1]))
            ragged_indices.append(reaction_indices)
            
        # Build a regularized array of indices, padded with -1, and a corresponding mask padded with 0.
        indices = -np.ones((len(ragged_indices), 2, width), dtype=int)
        mask = np.zeros((len(ragged_indices), 2, width), dtype=int)
        for i, reaction_indices in enumerate(ragged_indices):
            indices[i, 0, :len(reaction_indices[0])] = reaction_indices[0]
            indices[i, 1, :len(reaction_indices[1])] = reaction_indices[1]
            mask[i, 0, :len(reaction_indices[0])] = 1
            mask[i, 1, :len(reaction_indices[1])] = 1

        self.width_ = width
        self.indices_ = indices
        self.mask_ = mask
        
        # Save kinetic parameters per reaction, if given
        if kinetics is not None:
            self.kcats_, self.kms_ = self.param_arrays(kinetics)
        else:
            self.kcats_ = None
            self.kms_ = None

    def param_arrays(self, kinetics: Iterable[ReactionKinetics]) -> Tuple[np.ndarray, np.ndarray]:
        """Processes ReactionKinetics structure into parameter arrays used for rate calculations."""
        kcats = np.zeros((self.network.shape[1], 2))
        kms = np.ones((self.network.shape[1], 2, self.width_))
        
        for reaction_kinetics in kinetics:
            i = self.network.reaction_index(reaction_kinetics.reaction)
            if i is not None:
                kcats[i, 0] = reaction_kinetics.kcat_f
                kcats[i, 1] = reaction_kinetics.kcat_b

                # Use self.indices_ as the source of truth for what value belongs where.
                for j, side in enumerate(self.indices_[i]):
                    for k, reactant_idx in enumerate(side):
                        if reactant_idx >= 0:
                            kms[i, j, k] = reaction_kinetics.km.get(self.network.reactant(reactant_idx), 1)
        return kcats, kms
        
    def reaction_rates(self,
                       state: ArrayT,
                       enzyme_conc: ArrayT,
                       kcats: Optional[ArrayT] = None,
                       kms: Optional[ArrayT] = None) -> ArrayT:
        """Calculates current reaction rates using the convenience kinetics formula."""
        # Use kinetic parameters as supplied, or fall back to configred intrinsic values
        if kcats is None:
            kcats = self.kcats_
        if kms is None:
            kms = self.kms_
        
        # $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$, padded with ones as neccessary
        # Appending [1] to the state vector means any index of -1 translates to unity, i.e. a no-op for multiplication.
        state_norm = jnp.append(state, 1)[self.indices_] / kms

        # $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$
        numerator = jnp.sum(kcats * jnp.prod(state_norm, axis=-1), axis=-1)
        
        # $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
        # state_norm + mask means (1 + \tilde{a}_i) for all real values, and 1 (i.e. a no-op for multiplication) for all padded values.
        denominator = jnp.sum(jnp.prod(state_norm + self.mask_, axis=-1), axis=-1) - 1
        
        rates = enzyme_conc * numerator / denominator
        return rates

    def dstate_dt(self,
                  state: ArrayT,
                  enzyme_conc: ArrayT,
                  kcats: Optional[ArrayT] = None,
                  kms: Optional[ArrayT] = None) -> ArrayT:
        return network.s_matrix @ self.reaction_rates(state, enzyme_conc, kcats, kms)

In [18]:
kinetics = [ReactionKinetics(rxn, kcat_f, kcat_b, kms[rxn]) for rxn, (kcat_f, kcat_b) in kcats.items()]
ck = ConvenienceKinetics(network, kinetics)

In [19]:
ck.reaction_rates(state, enzyme_conc)

DeviceArray([0.32628979, 0.21069843, 0.09611057, 0.65048012, 0.13741619,
             0.0020999 , 0.02806377, 0.02456297, 0.00987514, 0.16441616,
             0.05763836, 0.00126864], dtype=float64)

In [20]:
ck.dstate_dt(state, enzyme_conc)

DeviceArray([-0.32628979,  0.21170194, -0.46081672,  0.40317835,
             -0.53589227,  0.49249121, -0.1438738 ,  0.15164904,
              0.51096403,  0.78789631, -0.00336854,  0.03016367,
              0.00336854, -0.0035008 , -0.03443811,  0.23192966,
             -0.22332315,  0.05763836, -0.00126864,  0.00126864,
              0.00126864], dtype=float64)

## Alternative Approach to the Computation

- True data shape for $k_m$ and $\tilde{a}$ is (#rxns, 2, {ragged})
- First approach pads the ragged dimension with ones, then uses jnp.prod to collapse to (#rxns, 2)
- Alternative strings everything into a 1d array, then uses segment_prod and reshape to get the same (#rxn2, 2) result

In [21]:
class ConvenienceKinetics2:
    def __init__(self, 
                 network: ReactionNetwork,
                 kinetics: Optional[Mapping[Reaction, ReactionKinetics]] = None):
        self.network = network
        
        # Build up a list of substrate and product indices for each reaction.
        # Collapse into a 1d vector, with corresponding segment id vector to identify the original groupings.
        indices = []
        segment_ids = []
        for i, reaction in enumerate(network.reactions()):
            reaction_indices = [[], []]
            for reactant, count in reaction.stoichiometry.items():
                idx = network.reactant_index(reactant)
                if count < 0:
                    reaction_indices[0].extend([idx] * -count)
                else:
                    reaction_indices[1].extend([idx] * count)

            indices.extend(reaction_indices[0])
            segment_ids.extend([2 * i] * len(reaction_indices[0]))
            indices.extend(reaction_indices[1])
            segment_ids.extend([2 * i + 1] * len(reaction_indices[1]))

        self.indices_ = np.array(indices)
        self.segment_ids_ = np.array(segment_ids)
        
        # Save kinetic parameters per reaction, if given
        if kinetics is not None:
            self.kcats_, self.kms_ = self.param_arrays(kinetics)
        else:
            self.kcats_ = None
            self.kms_ = None

    def param_arrays(self, kinetics: Mapping[Reaction, ReactionKinetics]) -> Tuple[np.ndarray, np.ndarray]:
        """Processes ReactionKinetics structure into parameter arrays used for rate calculations."""
        kcats = []
        for reaction in self.network.reactions():
            # Require that each reaction's kinetics are included, i.e. allow this to throw an error if not.
            reaction_kinetics = kinetics[reaction]
            kcats.append([reaction_kinetics.kcat_f, reaction_kinetics.kcat_b])
        
        # Easiest to build the Km array colinear with self.indices_, inferring reaction from segment_id.
        kms = []
        for reactant_index, segment_id in zip(self.indices_, self.segment_ids_):
            # Segment ids for reaction i are 2i (substrates) and 2i + 1 (products).
            i = int(segment_id / 2)
            reaction = self.network.reaction(i)
            # Require that each reaction's kinetics are included, i.e. allow this to throw an error if not.
            reaction_kinetics = kinetics[reaction]
            kms.append(reaction_kinetics.km[self.network.reactant(reactant_index)])

        return jnp.array(kcats, dtype=np.float64), jnp.array(kms, dtype=np.float64)

    def reaction_rates(self,
                       state: ArrayT,
                       enzyme_conc: ArrayT,
                       kcats: Optional[ArrayT] = None,
                       kms: Optional[ArrayT] = None) -> ArrayT:
        """Calculates current reaction rates using the convenience kinetics formula."""
        # Use kinetic parameters as supplied, or fall back to configred intrinsic values
        if kcats is None:
            kcats = self.kcats_
        if kms is None:
            kms = self.kms_
        
        # $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$.
        state_norm = state[self.indices_] / kms

        # $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$
        num_rxns = self.network.shape[1]
        numerator = jnp.sum(kcats * jax.ops.segment_prod(state_norm, segment_ids=self.segment_ids_, num_segments=num_rxns * 2).reshape((num_rxns, 2)), axis=-1)
        
        # $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
        denominator = jnp.sum(jax.ops.segment_prod(state_norm + 1, segment_ids=self.segment_ids_, num_segments=num_rxns * 2).reshape((num_rxns, 2)), axis=-1) - 1
        
        rates = enzyme_conc * numerator / denominator
        return rates

    def dstate_dt(self,
                  state: ArrayT,
                  enzyme_conc: ArrayT,
                  kcats: Optional[ArrayT] = None,
                  kms: Optional[ArrayT] = None) -> ArrayT:
        return network.s_matrix @ self.reaction_rates(state, enzyme_conc, kcats, kms)

In [22]:
ck2 = ConvenienceKinetics2(network, {rk.reaction: rk for rk in kinetics})

In [23]:
ck.reaction_rates(state, enzyme_conc) - ck2.reaction_rates(state, enzyme_conc)

DeviceArray([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
              0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
              0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
              0.0000000e+00, -6.9388939e-18,  0.0000000e+00],            dtype=float64)

Numerical instability(?) - all the inputs are the same, but the sum-of-products comes out different in the 14th decimal place, for one reaction

In [24]:
%timeit ck.reaction_rates(state, enzyme_conc)
%timeit ck2.reaction_rates(state, enzyme_conc)

894 µs ± 6.16 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.11 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
fn1 = jax.jit(ck.reaction_rates)
fn2 = jax.jit(ck2.reaction_rates)
%timeit fn1(state, enzyme_conc)
%timeit fn2(state, enzyme_conc)

4.61 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
5.83 µs ± 481 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


All that and it's clearly not faster even on a CPU. Stick with the initial implementation.

## Regulation

- Parameters ($K_I$ and $K_A$) can be stored semantically with ReactionKinetics
- Simple generalized equations in the Convenience Kinetics paper
  - Activation: $v' = \frac{d}{d + K_A}v$
    - Interpretation: unbound enzyme is purely inactive
  - Inhibition: $v' = \frac{K_I}{d + K_I}v$
    - Interpretation: bound enzyme is purely inactive
- Both can be expressed using the same convention of $\tilde{d} = \frac{d}{K}$, which then looks a lot like the other convenience kinetics terms:
  - Activation: $v' = \frac{\tilde{d}}{\tilde{d} + 1}v$
  - Inhibition: $v' = \frac{1}{\tilde{d} + 1}v$


In [79]:
@dataclass
class ReactionKinetics:
    reaction: Reaction
    kcat_f: float
    kcat_b: float
    km: Mapping[Molecule, float]
    ka: Mapping[Molecule, float]
    ki: Mapping[Molecule, float]

class Ligands:
    def __init__(self, network: ReactionNetwork, ligand_lists: Iterable[Iterable[Molecule]], constants: Iterable[Mapping[Molecule, float]]):
        self.network = network

        ragged_indices = []
        width = 0
        for ligand_list in ligand_lists:
            indices = [network.reactant_index(ligand) for ligand in ligand_list]
            ragged_indices.append(indices)
            width = max(width, len(indices))
        
        # -1 as a default index lets us deref an array by appending a default value
        padded_indices = np.full((len(ragged_indices), width), -1, dtype=int)
        mask = np.zeros((len(ragged_indices), width), dtype=int)
        for i, indices in enumerate(ragged_indices):
            padded_indices[i, :len(indices)] = indices
            mask[i, :len(indices)] = 1
            
        self.width = width
        self.indices = padded_indices
        self.mask = mask
        
        self.constants = self.constants_array(constants, default=1.0)

    def constants_array(self, constants: Iterable[Mapping[Molecule, float]], default: float = 0.0) -> np.ndarray:
        result = np.full(self.indices.shape, default)
        for i, (row_indices, row_constants) in enumerate(zip(self.indices, constants)):
            for j, ligand_index in enumerate(row_indices):
                if self.mask[i, j]:
                    result[i, j] = row_constants.get(self.network.reactant(ligand_index), default)
        return result
    
    def occupancy(self, state: ArrayT, constants: Optional[ArrayT] = None, default: float = 1.0) -> jnp.ndarray:
        if constants is None:
            constants = self.constants
        # Appending [default] to the state vector means any index of -1 derefs to the default value.
        return jnp.append(state, default)[self.indices] / constants

    
class ConvenienceKinetics:
    def __init__(self,
                 network: ReactionNetwork,
                 kinetics: Mapping[Reaction, ReactionKinetics]):
        self.network = network
        self.kcats = np.array([[kinetics[reaction].kcat_f, kinetics[reaction].kcat_b] for reaction in network.reactions()])
        
        # Substrates and products represented by one Ligands set each, with Km's.
        substrates = []
        products = []
        constants = []
        for reaction in network.reactions():
            substrates.append([])
            products.append([])
            for reactant, count in reaction.stoichiometry.items():
                if count < 0:
                    substrates[-1].extend([reactant] * -count)
                else:
                    products[-1].extend([reactant] * count)

            constants.append(kinetics[reaction].km)

        self.substrates = Ligands(network, substrates, constants)
        self.products = Ligands(network, products, constants)
        
        # Activators and inhibitors represented by one Ligands set each, with Ka's or Ki's respectively.
        activators = []
        kas = []
        inhibitors = []
        kis = []
        for reaction in network.reactions():
            reaction_kinetics = kinetics[reaction]
            activators.append(reaction_kinetics.ka.keys())
            kas.append(reaction_kinetics.ka)
            inhibitors.append(reaction_kinetics.ki.keys())
            kis.append(reaction_kinetics.ki)
        
        self.activators = Ligands(network, activators, kas)
        self.inhibitors = Ligands(network, inhibitors, kis)

    def reaction_rates(self,
                       state: ArrayT,
                       enzyme_conc: ArrayT) -> ArrayT:
        kcats = self.kcats

        # $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$, padded with ones as necessary.
        occupancy_s = self.substrates.occupancy(state, default=1)
        occupancy_p = self.products.occupancy(state, default=1)

        # $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$.
        numerator = kcats[:, 0] * jnp.prod(occupancy_s, axis=-1) + kcats[:, 1] * jnp.prod(occupancy_p, axis=-1)

        # $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
        denominator = jnp.prod(occupancy_s * self.substrates.mask + 1, axis=-1) + jnp.prod(occupancy_p * self.products.mask + 1, axis=-1) - 1
        
        occupancy_a = self.activators.occupancy(state, default=1)
        occupancy_i = self.inhibitors.occupancy(state, default=0)
        activation = jnp.prod(occupancy_a / (occupancy_a * self.activators.mask + 1), axis=-1)
        inhibition = jnp.prod(1 / (occupancy_i + 1), axis=-1)

        return enzyme_conc * numerator / denominator * activation * inhibition

    def dstate_dt(self,
                  state: ArrayT,
                  enzyme_conc: ArrayT) -> ArrayT:
        return self.network.s_matrix @ self.reaction_rates(state, enzyme_conc)



In [80]:
k3 = {k.reaction: ReactionKinetics(
    reaction=k.reaction,
    kcat_f=k.kcat_f,
    kcat_b=k.kcat_b,
    km=k.km,
    ka={},
    ki={}
) for k in kinetics}
# Set ATP as an inhibitor of PFK
k3[glycolysis.steps[1]].ki[KB.get(KB.compounds, 'atp')] = 5.
ck3 = ConvenienceKinetics(network, k3)

In [81]:
ck3.inhibitors.constants

array([[1.],
       [5.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]])

In [82]:
ck3.reaction_rates(state, enzyme_conc)

DeviceArray([0.32628979, 0.07200903, 0.09611057, 0.65048012, 0.13741619,
             0.0020999 , 0.02806377, 0.02456297, 0.00987514, 0.16441616,
             0.05763836, 0.00126864], dtype=float64)

## Thermodynamically independent system parameters

Independent parameters
- $\Delta{G}_f$ per reactant
- velocity constant per reaction (drives kcat_f and kcat_b together)
- $K_M$ per (reaction, reactant) pair

The paper focuses a lot on the first set, though this seems to me as something largely 'known', at least within some uncertainty. The latter two seem the most likely to be unknown, and therefore needing to be fit for some model.

Ultimately it boils down to treating the forward and back kcats as dependent on everything else, instead of as 'free' parameters. This makes sense, it's a question of what you do with it.

Ok, defer this for now, see how far we get with supplied kinetic paramaters