In [1]:
import projectpath

import collections
from importlib import resources
import time
from typing import Iterable, Union

import escher
import ipywidgets as widgets
import jax
import jax.numpy as jnp
import bokeh.models
import bokeh.plotting as plt
import numpy as np
import pandas as pd
import panel as pn
import scipy

import files.pw
from kb import kb
from sim import fba_gd
from model.core import KbEntry, Molecule, Reaction, Pathway
from model.reaction_network import ReactionNetwork

jax.config.update('jax_enable_x64', True)
prng = jax.random.PRNGKey(int(time.time() * 1000))  # Gets the No GPU warning out of the way

KB = kb.configure_kb()

plt.output_notebook() #hide_banner=True)

def labels(entries: Union[KbEntry, Iterable[KbEntry]]):
    if isinstance(entries, KbEntry):
        entries.shorthand or entries.id
    else:
        return [entry.shorthand or entry.id for entry in entries]



# Start with straight glycolysis

In [2]:
glycolysis = KB.find(KB.pathways, 'glycolysis')[0]
network = ReactionNetwork(glycolysis.steps)

pd.DataFrame(
    network.s_matrix,
    columns=labels(network.reactions()),
    index=labels(network.reactants()))

Unnamed: 0,PGI,PFK,FBP,FBA,TPI,GAPDH,PGK,GPMM,ENO,PYK,PPS,PDH
Glc.D.6P,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Fru.D.6P,1.0,-1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,-1.0,-1.0,0.0
ADP,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0
Fru.D.bis16,0.0,1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
H+,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,2.0,0.0
H2O,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,-1.0,0.0
Pi,0.0,0.0,1.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,1.0,0.0
GAP,0.0,0.0,0.0,1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0
DHAP,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Get some random reaction velocities, and corresponding rates of change

In [3]:
rando_v = jax.random.normal(prng, network.shape[1:])

rando_dmdt = network.s_matrix @ rando_v

print(rando_v, rando_dmdt)

[ 1.02038161  1.39429933 -0.3873127   0.47118126 -0.61552604 -1.48504818
  1.46096379 -0.63432217 -1.30791163  0.46672782 -0.08572719 -0.1023924 ] [-1.02038161 -0.76123042 -3.23626374  3.32199093  1.31043076  0.20452458
 -0.83487174  1.01200829  2.57175549 -0.14434478  1.58744058 -0.0240844
 -1.58744058 -2.09528596  1.9422338  -0.926911   -0.27860823 -0.08572719
  0.1023924  -0.1023924  -0.1023924 ]


### Define intermediates and boundaries, and corresponding Steady State objective component

In [4]:
intermediates = [
    KB.get(KB.compounds, mol_id)
    for mol_id in ('Fru.D.6P', 'Fru.D.bis16', 'dhap', 'gap', 'dpg', '3pg', '2pg', 'pep', 'pyr')
]
boundaries = [met for met in glycolysis.metabolites if met not in intermediates]

ss_obj = fba_gd.SteadyStateObjective(network, intermediates)
print(ss_obj.residual(rando_v, rando_dmdt))

[-0.76123042  1.31043076 -0.14434478  2.57175549 -0.0240844  -2.09528596
  1.9422338  -0.926911   -0.27860823]


### Super simple problem: target acCoA production

In [5]:
acCoA = KB.get(KB.compounds, 'accoa')
prod_obj = fba_gd.ProductionObjective(network, {acCoA: 3})
print(prod_obj.residual(rando_v, rando_dmdt, prod_obj.params()))

[-3.1023924]


In [6]:
def fn(v):
    dmdt = network.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=labels(network.reactions()))

PGI      1.500000
PFK      1.253493
FBP     -0.246507
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -1.223772
PPS     -1.776228
PDH      3.000000
dtype: float64

### Add an objective component for irreversible reactions

In [7]:
rbounds_obj = fba_gd.VelocityObjective(network, {rxn: (0.0, None) for rxn in glycolysis.steps if not rxn.reversible})

def fn(v):
    dmdt = network.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj, rbounds_obj])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=labels(network.reactions()))

PGI      1.500000
PFK      1.500049
FBP      0.000049
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -3.444193
PPS      0.444193
PDH      3.000000
dtype: float64

### Minimize unnecessary flux with a (low-weighted, L1) regularization term

In [8]:
ss_obj = fba_gd.SteadyStateObjective(network, intermediates, weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network, {rxn: (0.0, None) for rxn in glycolysis.steps if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network, {acCoA: 3}, weight=10)
sparsity = fba_gd.VelocityObjective(network, {rxn: 0. for rxn in glycolysis.steps}, aggfun=fba_gd.l1, weight=1e-8)

def fn(v):
    dmdt = network.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj, rbounds_obj, sparsity])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(np.round(soln.x, 3), index=labels(network.reactions()))

PGI      1.500
PFK      1.946
FBP      0.446
FBA      1.500
TPI     -1.500
GAPDH    3.000
PGK     -3.000
GPMM    -3.000
ENO      3.000
PYK     -3.005
PPS      0.005
PDH      3.000
dtype: float64

In [9]:
diag = escher.Builder(
    # map_json=resources.read_text(files.pw, 'glycolysis_ppp_ed.json'),
    map_json=resources.read_text(files.pw, 'central_carbon.json'),
    menu='zoom',
    enable_editing=False,
    never_ask_before_quit=True,
    reaction_styles = ['abs', 'color', 'size',],
    metabolite_styles = ['color', ],
    reaction_scale = [
        {'type': 'min', 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
    metabolite_scale = [
        {'type': 'min', 'color': '#b30019', 'size': 5},
        {'type': 'value', 'value': 0, 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
)
diag.reaction_data = pd.Series(soln.x, index=labels(network.reactions()))
diag.metabolite_data = pd.Series(network.s_matrix @ soln.x, index=labels(network.reactants()))

diag

Builder(enable_editing=False, menu='zoom', metabolite_data={'Glc.D.6P': -1.5000012719990328, 'Fru.D.6P': -2.37…

## Run glycolysis backwards (gluconeogensis)?

In [10]:
g6p = KB.get(KB.compounds, 'Glc.D.6P')
ss_obj = fba_gd.SteadyStateObjective(network, intermediates, weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network, {rxn: (0.0, None) for rxn in glycolysis.steps if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network, {g6p: 1.5}, weight=10)
sparsity = fba_gd.VelocityObjective(network, {rxn: 0. for rxn in glycolysis.steps}, aggfun=fba_gd.l1, weight=1e-8)

def fn(v):
    dmdt = network.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj, rbounds_obj, sparsity])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=labels(network.reactions()))


PGI     -0.000060
PFK      0.000193
FBP      0.000238
FBA     -0.000030
TPI      0.000022
GAPDH   -0.000045
PGK      0.000037
GPMM     0.000030
ENO     -0.000022
PYK     -0.046651
PPS      0.046666
PDH     -0.000007
dtype: float64

### Doesn't work... why?
- Break into upper and lower, and run each backward

### Upper

In [11]:
network_upper = ReactionNetwork(glycolysis.steps[:5])
pd.DataFrame(
    network_upper.s_matrix,
    columns=labels(network_upper.reactions()),
    index=labels(network_upper.reactants()))


Unnamed: 0,PGI,PFK,FBP,FBA,TPI
Glc.D.6P,-1.0,0.0,0.0,0.0,0.0
Fru.D.6P,1.0,-1.0,1.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,0.0
ADP,0.0,1.0,0.0,0.0,0.0
Fru.D.bis16,0.0,1.0,-1.0,-1.0,0.0
H+,0.0,1.0,0.0,0.0,0.0
H2O,0.0,0.0,-1.0,0.0,0.0
Pi,0.0,0.0,1.0,0.0,0.0
GAP,0.0,0.0,0.0,1.0,-1.0
DHAP,0.0,0.0,0.0,1.0,1.0


In [12]:
ss_obj = fba_gd.SteadyStateObjective(network_upper, [KB.get(KB.compounds, m_id) for m_id in ['Fru.D.6P', 'Fru.D.bis16', 'dhap']], weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network_upper, {rxn: (0.0, None) for rxn in network_upper.reactions() if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network_upper, {g6p: 1.5}, weight=10)
sparsity = fba_gd.VelocityObjective(network_upper, {rxn: 0. for rxn in network_upper.reactions()}, aggfun=fba_gd.l1, weight=1e-8)

def fn(v):
    dmdt = network_upper.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj, rbounds_obj, sparsity])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_upper.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=labels(network_upper.reactions()))


PGI   -1.500001
PFK    0.018722
FBP    1.518722
FBA   -1.500001
TPI    1.500001
dtype: float64

### Lower

In [13]:
network_lower = ReactionNetwork(glycolysis.steps[5:])
pd.DataFrame(
    network_lower.s_matrix,
    columns=labels(network_lower.reactions()),
    index=labels(network_lower.reactants()))


Unnamed: 0,GAPDH,PGK,GPMM,ENO,PYK,PPS,PDH
GAP,-1.0,0.0,0.0,0.0,0.0,0.0,0.0
NAD,-1.0,0.0,0.0,0.0,0.0,0.0,-1.0
Pi,-1.0,0.0,0.0,0.0,0.0,1.0,0.0
DPG,1.0,1.0,0.0,0.0,0.0,0.0,0.0
H+,1.0,0.0,0.0,0.0,1.0,2.0,0.0
NADH,1.0,0.0,0.0,0.0,0.0,0.0,1.0
3PG,0.0,-1.0,1.0,0.0,0.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,-1.0,-1.0,0.0
ADP,0.0,1.0,0.0,0.0,1.0,0.0,0.0
2PG,0.0,0.0,-1.0,-1.0,0.0,0.0,0.0


In [14]:
gap = KB.get(KB.compounds, 'gap')
# ss_obj = fba_gd.SteadyStateObjective(network_lower, [KB.get(KB.compounds, m_id) for m_id in ['dpg', '2pg', '3pg', 'pep', 'pyr']])
# prod_obj = fba_gd.TargetDmdtObjective(network_lower, [gap])
# rbounds_obj = fba_gd.VelocityBoundsObjective(network_lower, {rxn: (0.0, np.inf) for rxn in network_lower.reactions() if not rxn.reversible})

ss_obj = fba_gd.SteadyStateObjective(network_lower, [KB.get(KB.compounds, m_id) for m_id in ['dpg', '2pg', '3pg', 'pep', 'pyr']], weight=1e6)
rbounds_obj = fba_gd.VelocityObjective(network_lower, {rxn: (0.0, None) for rxn in network_lower.reactions() if not rxn.reversible}, weight=1e6)
prod_obj = fba_gd.ProductionObjective(network_lower, {gap: 3.}, weight=10)
sparsity = fba_gd.VelocityObjective(network_lower, {rxn: 0. for rxn in network_lower.reactions()}, aggfun=fba_gd.l1, weight=1e-8)

def fn(v):
    dmdt = network_lower.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj, rbounds_obj, sparsity])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_lower.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=labels(network_lower.reactions()))


GAPDH   -0.000180
PGK      0.000150
GPMM     0.000120
ENO     -0.000090
PYK     -0.133553
PPS      0.133613
PDH     -0.000030
dtype: float64

## PDH is irreveversible!
- Treat pyruvate as a boundary, not an intermediate

In [15]:
ss_obj = fba_gd.SteadyStateObjective(network_lower, [KB.get(KB.compounds, m_id) for m_id in ['dpg', '2pg', '3pg', 'pep']])

def fn(v):
    dmdt = network_lower.s_matrix @ v
    return sum(objective.loss(v, dmdt, objective.params())
               for objective in [ss_obj, prod_obj, rbounds_obj, sparsity])

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_lower.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=labels(network_lower.reactions()))


GAPDH   -2.999995
PGK      2.999956
GPMM     2.999953
ENO     -2.999989
PYK      2.869977
PPS      0.129975
PDH      0.053572
dtype: float64

- It works. We found the problem.
- Flux goes through PYK and not PPS. This is not biologically correct, but as constructed it is valid
    - PYK is reversible. We could make it irreversible but that is not actually correct (it is pyruvate _kinase_ after all).
    - _in vivo_, PPS would be prefered but that's for energetic reasons, which are not captured in FBA. Unless we can work that into the objective function...

## Higher-level analysis of $\vec{v}$ and $d{\vec{M}}/dt$

In [17]:
carbons = {
    'Glc.D.6P': 6,
    'acCoA': 2,
    'CO2': 1,
}

# c_yield = {metname: round(dmdt[metname] * carbons[metname], 3) for metname in carbons}
# c_yield

In [18]:
phosphates = {
    'ATP': 3,
    'ADP': 2,
    'AMP': 1,
}
# hep_yield = sum(dmdt[metname] * phosphates[metname] for metname in phosphates)
# hep_yield