In [1]:
import projectpath

import collections
from importlib import resources
import time
from typing import Iterable, Union

import escher
import ipywidgets as widgets
import jax
import jax.numpy as jnp
import bokeh.models
import bokeh.plotting as plt
import numpy as np
import pandas as pd
import panel as pn
import scipy

from mosmo.model import Molecule, Reaction, Pathway, ReactionNetwork
from mosmo.knowledge import kb
from mosmo.sim import fba_gd
import mosmo.preso.escher.pw as pw_files

jax.config.update('jax_enable_x64', True)
prng = jax.random.PRNGKey(int(time.time() * 1000))  # Gets the No GPU warning out of the way

KB = kb.configure_kb()

plt.output_notebook() #hide_banner=True)



# Start with straight glycolysis

In [2]:
glycolysis = KB.find(KB.pathways, 'glycolysis')[0]
network = ReactionNetwork(glycolysis.steps)

pd.DataFrame(
    network.s_matrix,
    columns=network.reactions.labels(),
    index=network.reactants.labels())

Unnamed: 0,PGI,PFK,FBP,FBA,TPI,GAPDH,PGK,GPMM,ENO,PYK,PPS,PDH
G6P,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
F6P,1.0,-1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,-1.0,-1.0,0.0
ADP,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0
F16bP,0.0,1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
H+,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,2.0,0.0
H2O,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,-1.0,0.0
Pi,0.0,0.0,1.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,1.0,0.0
GAP,0.0,0.0,0.0,1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0
DHAP,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Get some random reaction velocities, and corresponding rates of change

In [3]:
rando_v = jax.random.normal(prng, network.shape[1:])

rando_dmdt = network.s_matrix @ rando_v

print(rando_v, rando_dmdt)

[-0.44655719  0.13081097 -0.89833388  0.61989818 -0.16647852  1.86623394
  1.0053918   0.80259854  0.46022088 -0.59591905 -0.50278548 -0.99134371] [ 0.44655719 -1.47570205 -0.03749824  0.54028372  0.40924667  0.39555489
  1.86134024 -3.2673533  -1.07985723  0.45341966 -0.87489023  2.87162573
  0.87489023 -0.20279326 -1.26281942 -0.63848365  2.09004824 -0.50278548
  0.99134371 -0.99134371 -0.99134371]


### Define intermediates and boundaries, and corresponding Steady State objective component

In [4]:
intermediates = [
    KB(mol_id)
    for mol_id in ('Fru.D.6P', 'Fru.D.bis16', 'dhap', 'gap', 'dpg', '3pg', '2pg', 'pep', 'pyr')
]
boundaries = [met for met in glycolysis.metabolites if met not in intermediates]

ss_obj = fba_gd.SteadyStateObjective(network, intermediates)
print(ss_obj.residual(rando_v, rando_dmdt))

[-1.47570205  0.40924667  0.45341966 -1.07985723  2.87162573 -0.20279326
 -1.26281942 -0.63848365  2.09004824]


### Super simple problem: target acCoA production

In [5]:
acCoA = KB('accoa')
prod_obj = fba_gd.ProductionObjective(network, {acCoA: 3})
print(prod_obj.residual(rando_v, rando_dmdt, prod_obj.params()))

[-3.99134371]


In [6]:
def fn(v):
    dmdt = network.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, prod_obj])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=network.reactions.labels())

PGI      1.500000
PFK      0.366239
FBP     -1.133761
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -1.546567
PPS     -1.453433
PDH      3.000000
dtype: float64

### Add an objective component for irreversible reactions

In [7]:
rbounds_obj = fba_gd.VelocityObjective(network, {rxn: (0.0, None) for rxn in glycolysis.steps if not rxn.reversible})

def fn(v):
    dmdt = network.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, prod_obj, rbounds_obj])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=network.reactions.labels())

PGI      1.500000
PFK      1.501626
FBP      0.001626
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -3.933560
PPS      0.933560
PDH      3.000000
dtype: float64

### Minimize unnecessary flux with ~~a (low-weighted, L1) regularization term~~ exclusion objectives

In [8]:
ss_obj = fba_gd.SteadyStateObjective(network, intermediates, weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network, {rxn: (0.0, None) for rxn in glycolysis.steps if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network, {acCoA: 3}, weight=10)
# sparsity = fba_gd.VelocityObjective(network, {rxn: 0. for rxn in glycolysis.steps}, aggfun=fba_gd.l1, weight=1e-8)
exclusion1 = fba_gd.ExclusionObjective(network, [KB('pfk'), KB('fbp')])
exclusion2 = fba_gd.ExclusionObjective(network, [KB('pyk'), KB('pps')])

def fn(v):
    dmdt = network.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, rbounds_obj, prod_obj, exclusion1, exclusion2])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(np.round(soln.x, 3), index=network.reactions.labels())

PGI      1.5
PFK      1.5
FBP      0.0
FBA      1.5
TPI     -1.5
GAPDH    3.0
PGK     -3.0
GPMM    -3.0
ENO      3.0
PYK     -3.0
PPS      0.0
PDH      3.0
dtype: float64

In [9]:
diag = escher.Builder(
    # map_json=resources.read_text(files.pw, 'glycolysis_ppp_ed.json'),
    map_json=resources.read_text(pw_files, 'central_carbon.json'),
    menu='zoom',
    enable_editing=False,
    never_ask_before_quit=True,
    reaction_styles = ['abs', 'color', 'size',],
    metabolite_styles = ['color', ],
    reaction_scale = [
        {'type': 'min', 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
    metabolite_scale = [
        {'type': 'min', 'color': '#b30019', 'size': 5},
        {'type': 'value', 'value': 0, 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
)
diag.reaction_data = pd.Series(soln.x, index=network.reactions.labels())
diag.metabolite_data = pd.Series(network.s_matrix @ soln.x, index=network.reactants.labels())

diag

Builder(enable_editing=False, menu='zoom', metabolite_data={'G6P': -1.5000000000000338, 'F6P': 1.4344081478157…

## Run glycolysis backwards (gluconeogensis)?

In [10]:
g6p = KB('Glc.D.6P')
ss_obj = fba_gd.SteadyStateObjective(network, intermediates, weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network, {rxn: (0.0, None) for rxn in glycolysis.steps if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network, {g6p: 1.5}, weight=10)
# sparsity = fba_gd.VelocityObjective(network, {rxn: 0. for rxn in glycolysis.steps}, aggfun=fba_gd.l1, weight=1e-8)
exclusion1 = fba_gd.ExclusionObjective(network, [KB('pfk'), KB('fbp')])
exclusion2 = fba_gd.ExclusionObjective(network, [KB('pyk'), KB('pps')])

def fn(v):
    dmdt = network.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, rbounds_obj, prod_obj, exclusion1, exclusion2])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=network.reactions.labels())


PGI     -1.200000e+00
PFK      2.896874e-12
FBP      9.000000e-01
FBA     -6.000000e-01
TPI      4.500000e-01
GAPDH   -9.000000e-01
PGK      7.500000e-01
GPMM     6.000000e-01
ENO     -4.500000e-01
PYK     -8.180944e-11
PPS      3.000000e-01
PDH     -1.500000e-01
dtype: float64

### Doesn't work... why?
- Break into upper and lower, and run each backward

### Upper

In [11]:
network_upper = ReactionNetwork(glycolysis.steps[:5])
pd.DataFrame(
    network_upper.s_matrix,
    columns=network_upper.reactions.labels(),
    index=network_upper.reactants.labels())


Unnamed: 0,PGI,PFK,FBP,FBA,TPI
G6P,-1.0,0.0,0.0,0.0,0.0
F6P,1.0,-1.0,1.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,0.0
ADP,0.0,1.0,0.0,0.0,0.0
F16bP,0.0,1.0,-1.0,-1.0,0.0
H+,0.0,1.0,0.0,0.0,0.0
H2O,0.0,0.0,-1.0,0.0,0.0
Pi,0.0,0.0,1.0,0.0,0.0
GAP,0.0,0.0,0.0,1.0,-1.0
DHAP,0.0,0.0,0.0,1.0,1.0


In [12]:
ss_obj = fba_gd.SteadyStateObjective(network_upper, [KB(m_id) for m_id in ['Fru.D.6P', 'Fru.D.bis16', 'dhap']], weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network_upper, {rxn: (0.0, None) for rxn in network_upper.reactions if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network_upper, {g6p: 1.5}, weight=10)
# sparsity = fba_gd.VelocityObjective(network, {rxn: 0. for rxn in glycolysis.steps}, aggfun=fba_gd.l1, weight=1e-8)
exclusion1 = fba_gd.ExclusionObjective(network_upper, [KB('pfk'), KB('fbp')])

def fn(v):
    dmdt = network_upper.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, rbounds_obj, prod_obj, exclusion1])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_upper.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=network_upper.reactions.labels())


PGI   -1.500000e+00
PFK    9.470036e-14
FBP    1.500000e+00
FBA   -1.500000e+00
TPI    1.500000e+00
dtype: float64

### Lower

In [13]:
network_lower = ReactionNetwork(glycolysis.steps[5:])
pd.DataFrame(
    network_lower.s_matrix,
    columns=network_lower.reactions.labels(),
    index=network_lower.reactants.labels())


Unnamed: 0,GAPDH,PGK,GPMM,ENO,PYK,PPS,PDH
GAP,-1.0,0.0,0.0,0.0,0.0,0.0,0.0
NAD+,-1.0,0.0,0.0,0.0,0.0,0.0,-1.0
Pi,-1.0,0.0,0.0,0.0,0.0,1.0,0.0
DPG,1.0,1.0,0.0,0.0,0.0,0.0,0.0
H+,1.0,0.0,0.0,0.0,1.0,2.0,0.0
NADH,1.0,0.0,0.0,0.0,0.0,0.0,1.0
3PG,0.0,-1.0,1.0,0.0,0.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,-1.0,-1.0,0.0
ADP,0.0,1.0,0.0,0.0,1.0,0.0,0.0
2PG,0.0,0.0,-1.0,-1.0,0.0,0.0,0.0


In [14]:
gap = KB('gap')
ss_obj = fba_gd.SteadyStateObjective(network_lower, [KB(m_id) for m_id in ['dpg', '2pg', '3pg', 'pep', 'pyr']], weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network_lower, {rxn: (0.0, None) for rxn in network_lower.reactions if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network_lower, {gap: 3.}, weight=10)
exclusion2 = fba_gd.ExclusionObjective(network_lower, [KB('pyk'), KB('pps')])


def fn(v):
    dmdt = network_lower.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, rbounds_obj, prod_obj, exclusion2])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_lower.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=network_lower.reactions.labels())


GAPDH   -2.571429e+00
PGK      2.142857e+00
GPMM     1.714286e+00
ENO     -1.285714e+00
PYK      1.388494e-10
PPS      8.571429e-01
PDH     -4.285714e-01
dtype: float64

## PDH is irreveversible!
- Treat pyruvate as a boundary, not an intermediate

In [15]:
ss_obj = fba_gd.SteadyStateObjective(network_lower, [KB(m_id) for m_id in ['dpg', '2pg', '3pg', 'pep']])

def fn(v):
    dmdt = network_lower.s_matrix @ v
    return sum(jnp.sum(jnp.square(objective.residual(v, dmdt, objective.params()))) for objective in [ss_obj, rbounds_obj, prod_obj, exclusion2])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_lower.shape[1:]), jac=jacfn, tol=1e-12)
# pd.Series(soln.x, index=network_lower.reactions.labels())
pd.Series(network_lower.reactions.unpack(soln.x))


[gapdh] GAP + NAD+ + Pi <=> DPG + H+ + NADH      -3.000000e+00
[pgk] 3PG + ATP <=> DPG + ADP                     3.000000e+00
[gpm.indep] 2PG <=> 3PG                           3.000000e+00
[eno] 2PG <=> H2O + PEP                          -3.000000e+00
[pyk] ATP + pyr <=> ADP + H+ + PEP                1.044177e-14
[pps] ATP + H2O + pyr => AMP + Pi + PEP + 2 H+    3.000000e+00
[pdh] CoA + NAD+ + pyr => acCoA + CO2 + NADH      1.069016e+00
dtype: float64

- It works. We found the problem.
- Flux goes through PYK and not PPS. This is not biologically correct, but as constructed it is valid
    - PYK is reversible. We could make it irreversible but that is not actually correct (it is pyruvate _kinase_ after all).
    - _in vivo_, PPS would be prefered but that's for energetic reasons, which are not captured in FBA. Unless we can work that into the objective function...