# Convenience Kinetics with Jax

From _Bringing metabolic networks to life: convenience rate law and thermodynamic constraints_, Liebermeister and Klipp (2006)
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1781438/

In [1]:
import projectpath

from dataclasses import dataclass
from typing import Iterable, Mapping, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np

from mosmo.knowledge import kb
from mosmo.model import Molecule, Reaction, Pathway, ReactionNetwork

jax.config.update('jax_enable_x64', True)

KB = kb.configure_kb()
ArrayT = Union[np.ndarray, jnp.ndarray]

### General Rate Law, as defined in the paper:
$$
v(a,b) = E
\frac{
    k_{+}^{cat} \prod\limits_i \tilde{a}_i
    - k_{-}^{cat} \prod\limits_j \tilde{b}_j
}{
    \prod\limits_i (1 + \tilde{a}_i)
    + \prod\limits_j (1 + \tilde{b}_j)
    - 1
}
$$

With 'tilde' notation 
$$
\tilde{a}_i = \frac{a_i}{k^M_{a_i}}
$$
for any concentration vector $a$

In [2]:
glycolysis = KB.find(KB.pathways, 'glycolysis')[0]
network = ReactionNetwork(glycolysis.steps)

In [3]:
# Build up a list of substrate and product indices for each reaction.
width = 0
ragged_indices = []
for reaction in network.reactions:
    reaction_indices = [[], []]
    for reactant, count in reaction.stoichiometry.items():
        idx = network.reactants.index_of(reactant)
        if count < 0:
            reaction_indices[0].extend([idx] * -count)
        else:
            reaction_indices[1].extend([idx] * count)

    width = max(width, len(reaction_indices[0]), len(reaction_indices[1]))
    ragged_indices.append(reaction_indices)

ragged_indices

[[[0], [1]],
 [[2, 1], [3, 4, 5]],
 [[4, 6], [1, 7]],
 [[4], [8, 9]],
 [[8], [9]],
 [[8, 10, 7], [11, 5, 12]],
 [[13, 2], [11, 3]],
 [[14], [13]],
 [[14], [6, 15]],
 [[2, 16], [3, 5, 15]],
 [[2, 6, 16], [17, 7, 15, 5, 5]],
 [[18, 10, 16], [19, 20, 12]]]

In [4]:
# Build a regularized array of indices, padded with -1, and a corresponding mask padded with 0.
indices = -np.ones((len(ragged_indices), 2, width), dtype=int)
mask = np.zeros((len(ragged_indices), 2, width), dtype=int)
for i, reaction_indices in enumerate(ragged_indices):
    indices[i, 0, :len(reaction_indices[0])] = reaction_indices[0]
    indices[i, 1, :len(reaction_indices[1])] = reaction_indices[1]
    mask[i, 0, :len(reaction_indices[0])] = 1
    mask[i, 1, :len(reaction_indices[1])] = 1
print(indices)
print(mask)

[[[ 0 -1 -1 -1 -1]
  [ 1 -1 -1 -1 -1]]

 [[ 2  1 -1 -1 -1]
  [ 3  4  5 -1 -1]]

 [[ 4  6 -1 -1 -1]
  [ 1  7 -1 -1 -1]]

 [[ 4 -1 -1 -1 -1]
  [ 8  9 -1 -1 -1]]

 [[ 8 -1 -1 -1 -1]
  [ 9 -1 -1 -1 -1]]

 [[ 8 10  7 -1 -1]
  [11  5 12 -1 -1]]

 [[13  2 -1 -1 -1]
  [11  3 -1 -1 -1]]

 [[14 -1 -1 -1 -1]
  [13 -1 -1 -1 -1]]

 [[14 -1 -1 -1 -1]
  [ 6 15 -1 -1 -1]]

 [[ 2 16 -1 -1 -1]
  [ 3  5 15 -1 -1]]

 [[ 2  6 16 -1 -1]
  [17  7 15  5  5]]

 [[18 10 16 -1 -1]
  [19 20 12 -1 -1]]]
[[[1 0 0 0 0]
  [1 0 0 0 0]]

 [[1 1 0 0 0]
  [1 1 1 0 0]]

 [[1 1 0 0 0]
  [1 1 0 0 0]]

 [[1 0 0 0 0]
  [1 1 0 0 0]]

 [[1 0 0 0 0]
  [1 0 0 0 0]]

 [[1 1 1 0 0]
  [1 1 1 0 0]]

 [[1 1 0 0 0]
  [1 1 0 0 0]]

 [[1 0 0 0 0]
  [1 0 0 0 0]]

 [[1 0 0 0 0]
  [1 1 0 0 0]]

 [[1 1 0 0 0]
  [1 1 1 0 0]]

 [[1 1 1 0 0]
  [1 1 1 1 1]]

 [[1 1 1 0 0]
  [1 1 1 0 0]]]


In [5]:
# All taken from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4912430/, except where starred
# Maybe this should be in the KB somewhere...
POOLS = {KB(met_id): conc for met_id, conc in [
    ('2pg', 9.18e-02),
    ('3pg', 1.54),
    ('6pg', 3.77),
    ('6pgdl', 1.), # ************
    ('Ery.D.4P', 4.90e-02),
    ('Fru.D.6P', 2.52),
    ('Fru.D.bis16', 15.2),
    ('Glc.D.6P', 7.88),
    ('Rib.D.5P', 7.87e-01),
    ('Rul.D.5P', 1.12e-01),
    ('Sed.D.7P', 8.82e-01),
    ('Xul.D.5P', 1.81e-01),
    ('accoa', 6.06e-01),
    ('acon', 1.61e-02),
    ('adp', 5.55e-01),
    ('akg', 4.43e-01),
    ('amp', 2.81e-01),
    ('atp', 9.63),
    ('cit', 1.96),
    ('co2', 7.52e-02),
    ('coa', 1.37),
    ('dhap', 3.06),
    ('dpg', 1.65e-02),
    ('fum', 2.88e-01),
    ('gap', 2.71e-01),
    ('glx', 0.1),  # **********
    ('h+', 1e-7),
    ('h2o', 1.0),  # Activity of the solvent is defined as 1
    ('icit', 3.67e-02),
    ('kdpg', 0.01),  # **********
    ('mal.L', 1.68),
    ('nad.ox', 2.55),
    ('nad.red', 8.36e-02),
    ('nadp.ox', 2.08e-03),
    ('nadp.red', 1.21e-01),
    ('oaa', 4.87e-04),
    ('pep', 1.84e-01),
    ('pi', 23.9),
    ('pyr', 3.66),
    ('q.ox', 0.01),  # **********
    ('q.red', 0.01),  # **********
    ('succ', 5.69e-01),
    ('succcoa', 2.33e-01),
]}

state = network.reactants.pack(POOLS)
state

array([7.88e+00, 2.52e+00, 9.63e+00, 5.55e-01, 1.52e+01, 1.00e-07,
       1.00e+00, 2.39e+01, 2.71e-01, 3.06e+00, 2.55e+00, 1.65e-02,
       8.36e-02, 1.54e+00, 9.18e-02, 1.84e-01, 3.66e+00, 2.81e-01,
       1.37e+00, 6.06e-01, 7.52e-02])

In [6]:
# Generate some random kinetic data
prng = jax.random.PRNGKey(0)
prng, prng_ = jax.random.split(prng)
kcat_ = jax.random.uniform(prng_, indices.shape[:2])
prng, prng_ = jax.random.split(prng)
km_ = jax.random.uniform(prng_, indices.shape) * mask + 1
print(kcat_, km_)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[[0.24393938 0.98864717]
 [0.79771909 0.86229213]
 [0.30505736 0.35540929]
 [0.97067327 0.08206506]
 [0.71806872 0.15587295]
 [0.70148372 0.43307884]
 [0.37127029 0.33883736]
 [0.79413118 0.04328649]
 [0.91175046 0.72198283]
 [0.63780652 0.22756648]
 [0.44202183 0.3890605 ]
 [0.17855327 0.44526287]] [[[1.7330887  1.         1.         1.         1.        ]
  [1.84511803 1.         1.         1.         1.        ]]

 [[1.68000414 1.4106932  1.         1.         1.        ]
  [1.88453739 1.06395593 1.36094188 1.         1.        ]]

 [[1.86311185 1.41219498 1.         1.         1.        ]
  [1.08700861 1.81891399 1.         1.         1.        ]]

 [[1.83816509 1.         1.         1.         1.        ]
  [1.34808971 1.67732023 1.         1.         1.        ]]

 [[1.25781949 1.         1.         1.         1.        ]
  [1.06395669 1.         1.         1.         1.        ]]

 [[1.73577349 1.01549016 1.91770028 1.         1.        ]
  [1.84548954 1.99723361 1.61885689 1.  

In [7]:
kcats = {reaction: tuple(v.tolist()) for reaction, v in zip(network.reactions, kcat_)}
kms = {
    reaction: {
        network.reactants[idx]: km
        for idx, km in zip(reaction_indices.ravel(), reaction_kms.ravel().tolist())
        if idx >= 0
    }
    for reaction, reaction_indices, reaction_kms in zip(network.reactions, indices, km_)
}
for rxn, data in kms.items():
    print(f'{rxn.id}: [{", ".join(f"{mol.id}: {km}" for mol, km in data.items())}]')

pgi: [Glc.D.6P: 1.73308870290881, Fru.D.6P: 1.845118028680578]
pfk: [atp: 1.6800041426858345, Fru.D.6P: 1.4106931966597775, adp: 1.8845373875380222, Fru.D.bis16: 1.0639559349130323, h+: 1.360941883732501]
fbp: [Fru.D.bis16: 1.863111852460446, h2o: 1.4121949840515609, Fru.D.6P: 1.087008606641008, pi: 1.8189139870287792]
fba: [Fru.D.bis16: 1.83816508655738, gap: 1.3480897075806795, dhap: 1.6773202300618018]
tpi: [gap: 1.2578194870841028, dhap: 1.06395669472636]
gapdh: [gap: 1.735773490376081, nad.ox: 1.0154901613098066, pi: 1.9177002810527666, dpg: 1.8454895449793587, h+: 1.997233605782102, nad.red: 1.6188568909964942]
pgk: [3pg: 1.4495036494695925, atp: 1.8434500296356866, dpg: 1.310381086550176, adp: 1.248563312555589]
gpm.indep: [2pg: 1.2020587643039475, 3pg: 1.6468976747200537]
eno: [2pg: 1.1759124551051658, h2o: 1.989126170369294, pep: 1.6723482997507577]
pyk: [atp: 1.5936904245159553, pyr: 1.765441106928819, adp: 1.8637088638990469, h+: 1.2421153611883478, pep: 1.4752950199836787]


In [8]:
def kcat_array(kcats):
    return np.array([kcats.get(reaction, (0, 0)) for reaction in network.reactions])

def km_array(kms):
    # Result shape is (#rxns, 2, width)
    result = np.ones((network.shape[1], 2, width))
    # Use self.indices_ as the source of truth; do not rely on iterating over reaction.stoichiometry.items().
    for i, rxn_indices in enumerate(indices):
        reaction = network.reactions[i]
        if reaction in kms:
            for j, side in enumerate(rxn_indices):
                for k, reactant_idx in enumerate(side):
                    if reactant_idx >= 0:
                        result[i, j, k] = kms[reaction].get(network.reactants[reactant_idx], 1)
    return result


In [9]:
print(np.sum(np.square(kcat_ - kcat_array(kcats))))
print(np.sum(np.square(km_ - km_array(kms))))

0.0
0.11959516277754548


In [10]:
network.reactions[10]

[pps] ATP + H2O + pyr => AMP + Pi + PEP + 2 H+

In [11]:
km_[10]

Array([[1.39288675, 1.76409595, 1.85526158, 1.        , 1.        ],
       [1.62408382, 1.73038468, 1.72420759, 1.56365822, 1.21783288]],      dtype=float64)

In [12]:
km_array(kms)[10]

array([[1.39288675, 1.76409595, 1.85526158, 1.        , 1.        ],
       [1.62408382, 1.73038468, 1.72420759, 1.21783288, 1.21783288]])

- The Km array allows for multiple Km values, in this case for the 2 protons produced in the reaction
- But the most generalized form of convenience kinetics allows only for one Km per enzyme, substrate combination
- Microscopically, it actually makes sense we need separate binding sites for each copy of a molecule, and these binding sites can (and probably will) have different affinities
- Empirically it may be difficult to distinguish these separate affinities, so we will probably measure and report a single Km per substrate
- Leave this unresolved for now, but consider a `kms` data structure that allows for multiple affinities/Km's

In [13]:
# $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$, padded with ones as neccessary
state_norm = jnp.append(state, 1)[indices] / km_
state_norm

Array([[[4.54679555e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00],
        [1.36576629e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00]],

       [[5.73212872e+00, 1.78635582e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00],
        [2.94501984e-01, 1.42863059e+01, 7.34785234e-08, 1.00000000e+00,
         1.00000000e+00]],

       [[8.15839370e+00, 7.08117513e-01, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00],
        [2.31828891e+00, 1.31397087e+01, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00]],

       [[8.26911582e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00],
        [2.01025198e-01, 1.82433858e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00]],

       [[2.15452219e-01, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00],
        [2.87605691e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00]],



In [14]:
# $k_{+}^{cat} \prod_i{\tilde{a}_i} - k_{-}^{cat} \prod_j{\tilde{b}_j}$
numerator = jnp.sum(kcat_ * np.array([[1, -1]]) * jnp.prod(state_norm, axis=-1), axis=-1)
numerator

Array([-0.24111846,  8.16834126, -9.06400288,  7.99651328, -0.29358998,
        3.42748425,  2.05866944,  0.02017016,  0.03124248,  7.98985929,
        3.41749531,  0.90676452], dtype=float64)

In [15]:
# $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
denominator = jnp.sum(jnp.prod(state_norm + mask, axis=-1), axis=-1) - 1
denominator

Array([ 6.91256184, 37.54626087, 61.56325138, 11.66121762,  4.09150913,
       54.71047545, 13.29907594,  2.01146047,  1.74613848, 22.10245968,
       55.09025619, 20.9081069 ], dtype=float64)

In [16]:
prng, prng_ = jax.random.split(prng)
enzyme_conc = jax.random.uniform(prng_, indices.shape[:1])
rates = enzyme_conc * numerator / denominator
rates

Array([-0.03198926,  0.21069841, -0.06920069,  0.64562029, -0.0669045 ,
        0.0020999 ,  0.02801216,  0.00489933,  0.00277667,  0.16441616,
        0.05763836,  0.00126766], dtype=float64)

In [17]:
@dataclass
class ReactionKinetics:
    reaction: Reaction
    kcat_f: float
    kcat_b: float
    km: Mapping[Molecule, float]


class ConvenienceKinetics:
    def __init__(self, 
                 network: ReactionNetwork,
                 kinetics: Optional[Iterable[ReactionKinetics]] = None):
        self.network = network
        
        # Build up a list of substrate and product indices for each reaction.
        width = 0
        ragged_indices = []
        for reaction in network.reactions:
            reaction_indices = [[], []]
            for reactant, count in reaction.stoichiometry.items():
                idx = network.reactants.index_of(reactant)
                if count < 0:
                    reaction_indices[0].extend([idx] * -count)
                else:
                    reaction_indices[1].extend([idx] * count)

            width = max(width, len(reaction_indices[0]), len(reaction_indices[1]))
            ragged_indices.append(reaction_indices)
            
        # Build a regularized array of indices, padded with -1, and a corresponding mask padded with 0.
        indices = -np.ones((len(ragged_indices), 2, width), dtype=int)
        mask = np.zeros((len(ragged_indices), 2, width), dtype=int)
        for i, reaction_indices in enumerate(ragged_indices):
            indices[i, 0, :len(reaction_indices[0])] = reaction_indices[0]
            indices[i, 1, :len(reaction_indices[1])] = reaction_indices[1]
            mask[i, 0, :len(reaction_indices[0])] = 1
            mask[i, 1, :len(reaction_indices[1])] = 1

        self.width_ = width
        self.indices_ = indices
        self.mask_ = mask
        
        # Save kinetic parameters per reaction, if given
        if kinetics is not None:
            self.kcats_, self.kms_ = self.param_arrays(kinetics)
        else:
            self.kcats_ = None
            self.kms_ = None

    def param_arrays(self, kinetics: Iterable[ReactionKinetics]) -> Tuple[np.ndarray, np.ndarray]:
        """Processes ReactionKinetics structure into parameter arrays used for rate calculations."""
        kcats = np.zeros((self.network.shape[1], 2))
        kms = np.ones((self.network.shape[1], 2, self.width_))
        
        for reaction_kinetics in kinetics:
            i = self.network.reactions.index_of(reaction_kinetics.reaction)
            if i is not None:
                kcats[i, 0] = reaction_kinetics.kcat_f
                kcats[i, 1] = reaction_kinetics.kcat_b

                # Use self.indices_ as the source of truth for what value belongs where.
                for j, side in enumerate(self.indices_[i]):
                    for k, reactant_idx in enumerate(side):
                        if reactant_idx >= 0:
                            kms[i, j, k] = reaction_kinetics.km.get(self.network.reactants[reactant_idx], 1)
        return kcats, kms
        
    def reaction_rates(self,
                       state: ArrayT,
                       enzyme_conc: ArrayT,
                       kcats: Optional[ArrayT] = None,
                       kms: Optional[ArrayT] = None) -> ArrayT:
        """Calculates current reaction rates using the convenience kinetics formula."""
        # Use kinetic parameters as supplied, or fall back to configred intrinsic values
        if kcats is None:
            kcats = self.kcats_
        if kms is None:
            kms = self.kms_
        
        # $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$, padded with ones as neccessary
        # Appending [1] to the state vector means any index of -1 translates to unity, i.e. a no-op for multiplication.
        state_norm = jnp.append(state, 1)[self.indices_] / kms

        # $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$
        numerator = jnp.sum(kcats * jnp.array([[1, -1]]) * jnp.prod(state_norm, axis=-1), axis=-1)
        
        # $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
        # state_norm + mask means (1 + \tilde{a}_i) for all real values, and 1 (i.e. a no-op for multiplication) for all padded values.
        denominator = jnp.sum(jnp.prod(state_norm + self.mask_, axis=-1), axis=-1) - 1
        
        rates = enzyme_conc * numerator / denominator
        return rates

    def dstate_dt(self,
                  state: ArrayT,
                  enzyme_conc: ArrayT,
                  kcats: Optional[ArrayT] = None,
                  kms: Optional[ArrayT] = None) -> ArrayT:
        return network.s_matrix @ self.reaction_rates(state, enzyme_conc, kcats, kms)

In [18]:
kinetics = [ReactionKinetics(rxn, kcat_f, kcat_b, kms[rxn]) for rxn, (kcat_f, kcat_b) in kcats.items()]
ck = ConvenienceKinetics(network, kinetics)

In [19]:
ck.reaction_rates(state, enzyme_conc)

Array([-0.03198926,  0.21069841, -0.06920069,  0.64562029, -0.0669045 ,
        0.0020999 ,  0.02801216,  0.00489933,  0.00277667,  0.16441616,
        0.05763836,  0.00126766], dtype=float64)

In [20]:
ck.dstate_dt(state, enzyme_conc)

Array([ 0.03198926, -0.31188836, -0.46076509,  0.40312673, -0.36572119,
        0.49249119,  0.014339  , -0.01366223,  0.7104249 ,  0.57871579,
       -0.00336756,  0.03011206,  0.00336756, -0.02311283, -0.007676  ,
        0.22483119, -0.22332218,  0.05763836, -0.00126766,  0.00126766,
        0.00126766], dtype=float64)

## Alternative Approach to the Computation

- True data shape for $k_m$ and $\tilde{a}$ is (#rxns, 2, {ragged})
- First approach pads the ragged dimension with ones, then uses jnp.prod to collapse to (#rxns, 2)
- Alternative strings everything into a 1d array, then uses segment_prod and reshape to get the same (#rxn2, 2) result

In [21]:
class ConvenienceKinetics2:
    def __init__(self, 
                 network: ReactionNetwork,
                 kinetics: Optional[Mapping[Reaction, ReactionKinetics]] = None):
        self.network = network
        
        # Build up a list of substrate and product indices for each reaction.
        # Collapse into a 1d vector, with corresponding segment id vector to identify the original groupings.
        indices = []
        segment_ids = []
        for i, reaction in enumerate(network.reactions):
            reaction_indices = [[], []]
            for reactant, count in reaction.stoichiometry.items():
                idx = network.reactants.index_of(reactant)
                if count < 0:
                    reaction_indices[0].extend([idx] * -count)
                else:
                    reaction_indices[1].extend([idx] * count)

            indices.extend(reaction_indices[0])
            segment_ids.extend([2 * i] * len(reaction_indices[0]))
            indices.extend(reaction_indices[1])
            segment_ids.extend([2 * i + 1] * len(reaction_indices[1]))

        self.indices_ = np.array(indices)
        self.segment_ids_ = np.array(segment_ids)
        
        # Save kinetic parameters per reaction, if given
        if kinetics is not None:
            self.kcats_, self.kms_ = self.param_arrays(kinetics)
        else:
            self.kcats_ = None
            self.kms_ = None

    def param_arrays(self, kinetics: Mapping[Reaction, ReactionKinetics]) -> Tuple[np.ndarray, np.ndarray]:
        """Processes ReactionKinetics structure into parameter arrays used for rate calculations."""
        kcats = []
        for reaction in self.network.reactions:
            # Require that each reaction's kinetics are included, i.e. allow this to throw an error if not.
            reaction_kinetics = kinetics[reaction]
            kcats.append([reaction_kinetics.kcat_f, reaction_kinetics.kcat_b])
        
        # Easiest to build the Km array colinear with self.indices_, inferring reaction from segment_id.
        kms = []
        for reactant_index, segment_id in zip(self.indices_, self.segment_ids_):
            # Segment ids for reaction i are 2i (substrates) and 2i + 1 (products).
            i = int(segment_id / 2)
            reaction = self.network.reactions[i]
            # Require that each reaction's kinetics are included, i.e. allow this to throw an error if not.
            reaction_kinetics = kinetics[reaction]
            kms.append(reaction_kinetics.km[self.network.reactants[reactant_index]])

        return jnp.array(kcats, dtype=np.float64), jnp.array(kms, dtype=np.float64)

    def reaction_rates(self,
                       state: ArrayT,
                       enzyme_conc: ArrayT,
                       kcats: Optional[ArrayT] = None,
                       kms: Optional[ArrayT] = None) -> ArrayT:
        """Calculates current reaction rates using the convenience kinetics formula."""
        # Use kinetic parameters as supplied, or fall back to configred intrinsic values
        if kcats is None:
            kcats = self.kcats_
        if kms is None:
            kms = self.kms_
        
        # $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$.
        state_norm = state[self.indices_] / kms

        # $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$
        num_rxns = self.network.shape[1]
        numerator = jnp.sum(kcats * jnp.array([[1, -1]]) * jax.ops.segment_prod(state_norm, segment_ids=self.segment_ids_, num_segments=num_rxns * 2).reshape((num_rxns, 2)), axis=-1)
        
        # $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
        denominator = jnp.sum(jax.ops.segment_prod(state_norm + 1, segment_ids=self.segment_ids_, num_segments=num_rxns * 2).reshape((num_rxns, 2)), axis=-1) - 1
        
        rates = enzyme_conc * numerator / denominator
        return rates

    def dstate_dt(self,
                  state: ArrayT,
                  enzyme_conc: ArrayT,
                  kcats: Optional[ArrayT] = None,
                  kms: Optional[ArrayT] = None) -> ArrayT:
        return network.s_matrix @ self.reaction_rates(state, enzyme_conc, kcats, kms)

In [22]:
ck2 = ConvenienceKinetics2(network, {rk.reaction: rk for rk in kinetics})

In [23]:
ck.reaction_rates(state, enzyme_conc) - ck2.reaction_rates(state, enzyme_conc)

Array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  4.33680869e-19,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00, -6.93889390e-18,  0.00000000e+00],      dtype=float64)

Numerical instability(?) - all the inputs are the same, but the sum-of-products comes out different in the 14th decimal place, for one reaction

In [24]:
%timeit ck.reaction_rates(state, enzyme_conc)
%timeit ck2.reaction_rates(state, enzyme_conc)

1.28 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
3.56 ms ± 99 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
fn1 = jax.jit(ck.reaction_rates)
fn2 = jax.jit(ck2.reaction_rates)
%timeit fn1(state, enzyme_conc)
%timeit fn2(state, enzyme_conc)

5.9 µs ± 23.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
6.27 µs ± 23.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


All that and it's not faster even on a CPU. Stick with the initial implementation.

## Regulation

- Parameters ($K_I$ and $K_A$) can be stored semantically with ReactionKinetics
- Simple generalized equations in the Convenience Kinetics paper
  - Activation: $v' = \frac{d}{d + K_A}v$
    - Interpretation: unbound enzyme is purely inactive
  - Inhibition: $v' = \frac{K_I}{d + K_I}v$
    - Interpretation: bound enzyme is purely inactive
- Both can be expressed using the same convention of $\tilde{d} = \frac{d}{K}$, which then looks a lot like the other convenience kinetics terms:
  - Activation: $v' = \frac{\tilde{d}}{\tilde{d} + 1}v$
  - Inhibition: $v' = \frac{1}{\tilde{d} + 1}v$


In [26]:
@dataclass
class ReactionKinetics:
    reaction: Reaction
    kcat_f: float
    kcat_b: float
    km: Mapping[Molecule, float]
    ka: Mapping[Molecule, float]
    ki: Mapping[Molecule, float]

class Ligands:
    def __init__(self, network: ReactionNetwork, ligand_lists: Iterable[Iterable[Molecule]], constants: Iterable[Mapping[Molecule, float]]):
        self.network = network

        ragged_indices = []
        width = 0
        for ligand_list in ligand_lists:
            indices = [network.reactants.index_of(ligand) for ligand in ligand_list]
            ragged_indices.append(indices)
            width = max(width, len(indices))
        
        # -1 as a default index lets us deref an array by appending a default value
        padded_indices = np.full((len(ragged_indices), width), -1, dtype=int)
        mask = np.zeros((len(ragged_indices), width), dtype=int)
        for i, indices in enumerate(ragged_indices):
            padded_indices[i, :len(indices)] = indices
            mask[i, :len(indices)] = 1
            
        self.width = width
        self.indices = padded_indices
        self.mask = mask
        
        self.constants = self.constants_array(constants, default=1.0)

    def constants_array(self, constants: Iterable[Mapping[Molecule, float]], default: float = 0.0) -> np.ndarray:
        result = np.full(self.indices.shape, default)
        for i, (row_indices, row_constants) in enumerate(zip(self.indices, constants)):
            for j, ligand_index in enumerate(row_indices):
                if self.mask[i, j]:
                    result[i, j] = row_constants.get(self.network.reactants[ligand_index], default)
        return result
    
    def occupancy(self, state: ArrayT, constants: Optional[ArrayT] = None, default: float = 1.0) -> jnp.ndarray:
        if constants is None:
            constants = self.constants
        # Appending [default] to the state vector means any index of -1 derefs to the default value.
        return jnp.append(state, default)[self.indices] / constants

    
class ConvenienceKinetics:
    def __init__(self,
                 network: ReactionNetwork,
                 kinetics: Mapping[Reaction, ReactionKinetics]):
        self.network = network
        self.kcats = np.array([[kinetics[reaction].kcat_f, kinetics[reaction].kcat_b] for reaction in network.reactions])
        
        # Substrates and products represented by one Ligands set each, with Km's.
        substrates = []
        products = []
        constants = []
        for reaction in network.reactions:
            substrates.append([])
            products.append([])
            for reactant, count in reaction.stoichiometry.items():
                if count < 0:
                    substrates[-1].extend([reactant] * -count)
                else:
                    products[-1].extend([reactant] * count)

            constants.append(kinetics[reaction].km)

        self.substrates = Ligands(network, substrates, constants)
        self.products = Ligands(network, products, constants)
        
        # Activators and inhibitors represented by one Ligands set each, with Ka's or Ki's respectively.
        activators = []
        kas = []
        inhibitors = []
        kis = []
        for reaction in network.reactions:
            reaction_kinetics = kinetics[reaction]
            activators.append(reaction_kinetics.ka.keys())
            kas.append(reaction_kinetics.ka)
            inhibitors.append(reaction_kinetics.ki.keys())
            kis.append(reaction_kinetics.ki)
        
        self.activators = Ligands(network, activators, kas)
        self.inhibitors = Ligands(network, inhibitors, kis)

    def reaction_rates(self,
                       state: ArrayT,
                       enzyme_conc: ArrayT) -> ArrayT:
        kcats = self.kcats

        # $\tilde{a} = a_i / {km}^a_i for all i; \tilde{b} = b_j / {km}^b_j for all j$, padded with ones as necessary.
        occupancy_s = self.substrates.occupancy(state, default=1)
        occupancy_p = self.products.occupancy(state, default=1)

        # $k_{+}^{cat} \prod_i{\tilde{a}_i} + k_{-}^{cat} \prod_j{\tilde{b}_j}$.
        numerator = kcats[:, 0] * jnp.prod(occupancy_s, axis=-1) - kcats[:, 1] * jnp.prod(occupancy_p, axis=-1)

        # $\prod_i{(1 + \tilde{a}_i)} + \prod_j{(1 + \tilde{b}_j)} - 1$
        denominator = jnp.prod(occupancy_s * self.substrates.mask + 1, axis=-1) + jnp.prod(occupancy_p * self.products.mask + 1, axis=-1) - 1
        
        occupancy_a = self.activators.occupancy(state, default=1)
        occupancy_i = self.inhibitors.occupancy(state, default=0)
        activation = jnp.prod(occupancy_a / (occupancy_a * self.activators.mask + 1), axis=-1)
        inhibition = jnp.prod(1 / (occupancy_i + 1), axis=-1)

        return enzyme_conc * numerator / denominator * activation * inhibition

    def dstate_dt(self,
                  state: ArrayT,
                  enzyme_conc: ArrayT) -> ArrayT:
        return self.network.s_matrix @ self.reaction_rates(state, enzyme_conc)



In [27]:
k3 = {k.reaction: ReactionKinetics(
    reaction=k.reaction,
    kcat_f=k.kcat_f,
    kcat_b=k.kcat_b,
    km=k.km,
    ka={},
    ki={}
) for k in kinetics}
# Set ATP as an inhibitor of PFK
k3[glycolysis.steps[1]].ki[KB('atp')] = 5.
ck3 = ConvenienceKinetics(network, k3)

In [28]:
ck3.inhibitors.constants

array([[1.],
       [5.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]])

In [29]:
print(ck3.reaction_rates(state, enzyme_conc))
print(ck3.reaction_rates(state, enzyme_conc) - ck.reaction_rates(state, enzyme_conc))


[-0.03198926  0.07200903 -0.06920069  0.64562029 -0.0669045   0.0020999
  0.02801216  0.00489933  0.00277667  0.16441616  0.05763836  0.00126766]
[ 0.         -0.13868939  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.        ]


## Thermodynamically independent system parameters

Independent parameters
- $\Delta{G}_f$ per reactant
- velocity constant per reaction (drives kcat_f and kcat_b together)
- $K_M$ per (enzyme, reactant) pair

The paper focuses a lot on the first set, though this seems to me as something largely 'known', at least within some uncertainty. The latter two seem the most likely to be unknown, and therefore needing to be fit for some model. Ultimately it boils down to treating the forward and back kcats as dependent on everything else, instead of as 'free' parameters.

Key: Haldane relationship connects kinetic constants to $\Delta{G}_r$
$$
-\frac{\Delta{G}_r}{RT} = ln(K_{eq}) = ln(k_{cat}^{+}) - ln(k_{cat}^{-}) + \sum\limits_i{(n_i ln({K_M}_i))}
$$

Rearranging so dependent params are on the left:
$$
ln(k_{cat}^{+}) - ln(k_{cat}^{-}) = -\frac{\Delta{G}_r}{RT} - \sum\limits_i{(n_i ln({K_M}_i))}
$$

In other words, the ratio of kcats (logarithm difference) is dictated by ΔG, modified as a funciton of the various Km's. ΔG is fixed, and Km's are independent system parameters (each driven by ΔG of binding). The one additional parameter introduced by the authors is the _velocity constant_, $k_V$, defined as the geometric mean of the kcats (or alternatively, the average of their logarithms). Given all the other parameters plus $k_V$, we can calculate $k_{cat}^{+}$ and $k_{cat}^{-}$

We can treat $\Delta{G}_r$ as a configuration constant. Or, if we want to explore e.g. a Bayesian relationship between $\Delta{G}_r$ uncertainty and other parameters, we can make $\Delta{G}_r$ an explicit parameter for each reaction, or more rigorously make $\Delta{G}_f$ a parameter for every reactant in the system. The latter approach may take some calisthenics with equilibrator_api, which [highly discourages](https://equilibrator.readthedocs.io/en/latest/equilibrator_examples.html#Using-formation-energies-to-calculate-reaction-energies) using $\Delta{G}_f$ in general.

In [30]:
import equilibrator_api
from equilibrator_api.component_contribution import FARADAY, R, Q_

cc = equilibrator_api.ComponentContribution()
cc.p_h = Q_(7.3)
cc.p_mg = Q_(1.5)
cc.ionic_strength = Q_("0.25M")
cc.temperature = Q_("298.15K")
RT = cc.RT.m

def find_cc_met(met):
    for xref in (met.xrefs or []):
        if xref.db == 'KEGG':
            return cc.get_compound(f'KEGG:{xref.id}')
    return None

def delta_g(reaction):
    cc_rxn = equilibrator_api.Reaction({find_cc_met(met): count for met, count in reaction.stoichiometry.items()})
    return cc.physiological_dg_prime(cc_rxn).value.m


In [31]:
km_s = ck3.substrates.constants
km_p = ck3.products.constants
print(np.log(km_s) * ck3.substrates.mask)
print()
print(np.log(km_p) * ck3.products.mask)

[[0.54990519 0.         0.        ]
 [0.51879626 0.34408121 0.        ]
 [0.62224813 0.34514522 0.        ]
 [0.60876784 0.         0.        ]
 [0.22937966 0.         0.        ]
 [0.55145313 0.01537141 0.6511267 ]
 [0.37122119 0.61163883 0.        ]
 [0.18403572 0.         0.        ]
 [0.1620444  0.         0.        ]
 [0.46605235 0.56840058 0.        ]
 [0.3313784  0.56763835 0.6180257 ]
 [0.04697788 0.37583148 0.50017803]]

[[0.61254325 0.         0.         0.         0.        ]
 [0.63368237 0.06199398 0.30817702 0.         0.        ]
 [0.08342953 0.59823961 0.         0.         0.        ]
 [0.29868856 0.51719742 0.         0.         0.        ]
 [0.06199469 0.         0.         0.         0.        ]
 [0.61274458 0.69176303 0.48172028 0.         0.        ]
 [0.270318   0.22199354 0.         0.         0.        ]
 [0.49889332 0.         0.         0.         0.        ]
 [0.68769543 0.51422881 0.         0.         0.        ]
 [0.62256852 0.21681586 0.38885798 0.       

In [32]:
dg_r = np.array([delta_g(reaction) for reaction in network.reactions])
print(dg_r)

[ -0.04985926 -16.17793843 -31.06552991  11.00128599  -6.58047196
  18.89475523  18.14929293  -2.98573943  -4.2899532   27.23492107
 -17.93582223 -32.57199834]


In [33]:
diff = -dg_r / RT + np.sum(np.log(km_s), axis=-1) - np.sum(np.log(km_p), axis=-1)
print(diff)
print(diff.reshape(diff.shape + (1,)) * np.array([0.5, -0.5]) + 0.3)

[ -0.04251426   6.38863558  12.82414754  -4.64737013   2.82334298
  -8.19442821  -6.8347255    0.89022292   0.69159713 -11.18613274
   6.78395598  12.48629978]
[[ 0.27874287  0.32125713]
 [ 3.49431779 -2.89431779]
 [ 6.71207377 -6.11207377]
 [-2.02368506  2.62368506]
 [ 1.71167149 -1.11167149]
 [-3.79721411  4.39721411]
 [-3.11736275  3.71736275]
 [ 0.74511146 -0.14511146]
 [ 0.64579856 -0.04579856]
 [-5.29306637  5.89306637]
 [ 3.69197799 -3.09197799]
 [ 6.54314989 -5.94314989]]


In [34]:
def generate_kcats(kms_s: ArrayT, kms_p: ArrayT, kvs: ArrayT, dgrs: ArrayT) -> jnp.ndarray:
    """Generates thermodynamically consistent forward and back kcats, given ΔG, Km's and a velocity constant.

    Base on the Haldane relationship, -ΔG/RT = ln(K) = ln(kcat+) - ln(kcat-) + sum(n ln(Km)).

    Args:
        kms_s: array of substrate Km values (mM), with shape (#rxns, max(#substrates)), padded with ones.
        kms_p: array of product Km values (mM), with shape (#rxns, max(#products)), padded with ones.
        kvs: array of velocity constants, with shape (#rxns,)
        dgrs: reaction ΔGs (kilojoule / mole, mM standard), array of shape (#rxns,).

    Returns:
        An array of shape (#rxns, 2), with forward and back kcats per reaction.
    """
    # $ln(k_{cat}^{+}) - ln(k_{cat}^{-}) = -\frac{\Delta{G}_r}{RT} - \sum_i{(n_i ln({K_M}_i))}$
    diffs = -dgrs / RT + jnp.sum(jnp.log(kms_s), axis=-1) - jnp.sum(jnp.log(kms_p), axis=-1)
    # e^(kvs +/- diffs/2)
    return jnp.exp(diffs * jnp.array([[+0.5], [-0.5]]) + kvs).T
        

In [35]:
kcats = generate_kcats(ck3.substrates.constants, ck3.products.constants, np.zeros(network.shape[1]), dg_r)
kcats

Array([[9.78967213e-01, 1.02148467e+00],
       [2.43935265e+01, 4.09944827e-02],
       [6.09155621e+02, 1.64161663e-03],
       [9.79121077e-02, 1.02132415e+01],
       [4.10280748e+00, 2.43735541e-01],
       [1.66189095e-02, 6.01724199e+01],
       [3.27988197e-02, 3.04889020e+01],
       [1.56066414e+00, 6.40752852e-01],
       [1.41311793e+00, 7.07655021e-01],
       [3.72359243e-03, 2.68557856e+02],
       [2.97246893e+01, 3.36420674e-02],
       [5.14476506e+02, 1.94372336e-03]], dtype=float64)

In [36]:
RT * np.log((kcats[:,0] * np.prod(ck3.products.constants, axis=-1)) / (kcats[:,1] * np.prod(ck3.substrates.constants, axis=1))) + dg_r

array([ 5.55111512e-17, -3.55271368e-15,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  7.10542736e-15,  3.55271368e-15,  0.00000000e+00])