In [1]:
import projectpath

import collections
from importlib import resources

import escher
import ipywidgets as widgets
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy

import files.pw
from kb import kb
from sim import fba_gd
from model.core import KbEntry, Molecule, Reaction, Pathway
from model.reaction_network import ReactionNetwork

jax.config.update("jax_enable_x64", True)
prng = jax.random.PRNGKey(0)  # Gets the No GPU warning out of the way

KB = kb.configure_kb()

def to_label(entries):
    if isinstance(entries, KbEntry):
        entries = [entries]
    return [entry.shorthand or entry.id for entry in entries]




# Start with straight glycolysis

In [2]:
glycolysis = KB.find(KB.pathways, 'glycolysis')[0]
network = ReactionNetwork(glycolysis.steps)

mets = {met.id: met for met in glycolysis.metabolites}
rxns = {rxn.id: rxn for rxn in glycolysis.steps}
enzs = {enz.id: enz for enz in glycolysis.enzymes}

pd.DataFrame(
    network.s_matrix,
    columns=to_label(network.reactions()),
    index=to_label(network.reactants()))

Unnamed: 0,PGI,PFK,FBP,FBA,TPI,GAPDH,PGK,GPMM,ENO,PYK,PPS,PDH
Glc.D.6P,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Fru.D.6P,1.0,-1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,0.0,0.0,-1.0,0.0,0.0,-1.0,-1.0,0.0
ADP,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0
Fru.D.bis16,0.0,1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
H+,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,2.0,0.0
H2O,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,-1.0,0.0
Pi,0.0,0.0,1.0,0.0,0.0,-1.0,0.0,0.0,0.0,0.0,1.0,0.0
GAP,0.0,0.0,0.0,1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,0.0,0.0
DHAP,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Get some random reaction velocities, and corresponding rates of change

In [3]:
rando_v = jax.random.normal(prng, network.shape[1:])

rando_dmdt = network.s_matrix @ rando_v

print(rando_v, rando_dmdt)

[ 0.36753958 -0.90820405 -2.00643986  0.16056284  0.13233447 -1.30543452
 -0.40556757 -1.79353506 -1.35665474  0.80958459 -0.37977961  0.0844285 ] [-0.36753958 -0.73069623  0.88396663 -0.50418703  0.93767298 -2.1636132
  1.02956473 -1.08078495  1.33366288  0.29289731  1.22100602 -1.71100209
 -1.22100602 -1.38796749  3.1501898  -0.92684976 -0.51423348 -0.37977961
 -0.0844285   0.0844285   0.0844285 ]


### Define intermediates and boundaries, and corresponding Steady State objective component

In [4]:
intermediates = [mets[mol_id] for mol_id in ('Fru.D.6P', 'Fru.D.bis16', 'dhap', 'gap', 'dpg', '3pg', '2pg', 'pep', 'pyr')]
boundaries = [met for met in glycolysis.metabolites if met not in intermediates]

ss_obj = fba_gd.SteadyStateObjective(network, intermediates)
print(ss_obj.residual(rando_v, rando_dmdt))

[-0.73069623  0.93767298  0.29289731  1.33366288 -1.71100209 -1.38796749
  3.1501898  -0.92684976 -0.51423348]


### Super simple problem: target acCoA production

In [5]:
prod_obj = fba_gd.TargetDmdtObjective(network, [mets['accoa']])
print(prod_obj.residual(rando_v, rando_dmdt, prod_obj.prepare_targets({mets['accoa']: 3})))

[-2.9155715]


In [6]:
def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))),
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))),
    ])

targets = {'prod': prod_obj.prepare_targets({mets['accoa']: 3})}
def fn(v):
    dmdt = network.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network.reactions()))

PGI      1.500000
PFK     -0.707322
FBP     -2.207322
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -0.905318
PPS     -2.094682
PDH      3.000000
dtype: float64

### Add an objective component for irreversible reactions

In [7]:
rbounds_obj = fba_gd.VelocityBoundsObjective(network, {rxn: (0.0, np.inf) for rxn in glycolysis.steps if not rxn.reversible})

def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))),
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))),
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))),
    ])

targets = {'prod': prod_obj.prepare_targets({mets['accoa']: 3}), 'rbounds': rbounds_obj.prepare_targets()}
def fn(v):
    dmdt = network.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network.reactions()))

PGI      1.500000
PFK      1.620765
FBP      0.120765
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -3.155047
PPS      0.155047
PDH      3.000000
dtype: float64

### Minimize unnecessary flux with an L1 regularization term

In [8]:
def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))) * 1e6,
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))) * 1e6,
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))) * 1e1,
        jnp.sum(jnp.abs(v)) * 1e-8,  # L1 regularization, but don't dominate the loss
    ])

targets = {
    'prod': prod_obj.prepare_targets({mets['accoa']: 3}),
    'rbounds': rbounds_obj.prepare_targets()
}
def fn(v):
    dmdt = network.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network.reactions()))

PGI      1.500000
PFK      1.500002
FBP      0.000002
FBA      1.500000
TPI     -1.500000
GAPDH    3.000000
PGK     -3.000000
GPMM    -3.000000
ENO      3.000000
PYK     -3.018893
PPS      0.018893
PDH      3.000000
dtype: float64

In [9]:
diag = escher.Builder(
    # map_json=resources.read_text(files.pw, 'glycolysis_ppp_ed.json'),
    map_json=resources.read_text(files.pw, 'central_carbon.json'),
    menu='zoom',
    enable_editing=False,
    never_ask_before_quit=True,
    reaction_styles = ['abs', 'color', 'size',],
    metabolite_styles = ['color', ],
    reaction_scale = [
        {'type': 'min', 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
    metabolite_scale = [
        {'type': 'min', 'color': '#b30019', 'size': 5},
        {'type': 'value', 'value': 0, 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
)
diag.reaction_data = pd.Series(soln.x, index=to_label(network.reactions()))
diag.metabolite_data = pd.Series(network.s_matrix @ soln.x, index=to_label(network.reactants()))

diag

Builder(enable_editing=False, menu='zoom', metabolite_data={'Glc.D.6P': -1.499999984355431, 'Fru.D.6P': -1.695…

## Run glycolysis backwards (gluconeogensis)?

In [10]:
ss_obj = fba_gd.SteadyStateObjective(network, intermediates)
prod_obj = fba_gd.TargetDmdtObjective(network, [mets['Glc.D.6P']])
rbounds_obj = fba_gd.VelocityBoundsObjective(network, {rxn: (0.0, np.inf) for rxn in glycolysis.steps if not rxn.reversible})

def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))) * 1e6,
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))) * 1e6,
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))) * 1e1,
        jnp.sum(jnp.abs(v)) * 1e-8,  # L1 regularization, but don't dominate the loss
    ])

targets = {
    'prod': prod_obj.prepare_targets({mets['Glc.D.6P']: 1.5}),
    'rbounds': rbounds_obj.prepare_targets()
}
def fn(v):
    dmdt = network.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network.reactions()))


PGI     -0.000060
PFK      0.031224
FBP      0.031269
FBA     -0.000030
TPI      0.000022
GAPDH   -0.000045
PGK      0.000037
GPMM     0.000030
ENO     -0.000022
PYK     -0.003049
PPS      0.003064
PDH     -0.000007
dtype: float64

### Doesn't work... why?
- Break into upper and lower, and run each backward

### Upper

In [11]:
network_upper = ReactionNetwork(glycolysis.steps[:5])
pd.DataFrame(
    network_upper.s_matrix,
    columns=to_label(network_upper.reactions()),
    index=to_label(network_upper.reactants()))


Unnamed: 0,PGI,PFK,FBP,FBA,TPI
Glc.D.6P,-1.0,0.0,0.0,0.0,0.0
Fru.D.6P,1.0,-1.0,1.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,0.0
ADP,0.0,1.0,0.0,0.0,0.0
Fru.D.bis16,0.0,1.0,-1.0,-1.0,0.0
H+,0.0,1.0,0.0,0.0,0.0
H2O,0.0,0.0,-1.0,0.0,0.0
Pi,0.0,0.0,1.0,0.0,0.0
GAP,0.0,0.0,0.0,1.0,-1.0
DHAP,0.0,0.0,0.0,1.0,1.0


In [12]:
ss_obj = fba_gd.SteadyStateObjective(network_upper, [mets[m_id] for m_id in ['Fru.D.6P', 'Fru.D.bis16', 'dhap']])
prod_obj = fba_gd.TargetDmdtObjective(network_upper, [mets['Glc.D.6P']])
rbounds_obj = fba_gd.VelocityBoundsObjective(network_upper, {rxn: (0.0, np.inf) for rxn in network_upper.reactions() if not rxn.reversible})

def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))) * 1e6,
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))) * 1e8,
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))) * 1e1,
        jnp.sum(jnp.abs(v)) * 1e-8,  # L1 regularization, but don't dominate the loss
    ])

targets = {
    'prod': prod_obj.prepare_targets({mets['Glc.D.6P']: 1.5}),
    'rbounds': rbounds_obj.prepare_targets()
}
def fn(v):
    dmdt = network_upper.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_upper.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network_upper.reactions()))


PGI   -1.500000e+00
PFK    1.680164e-07
FBP    1.500000e+00
FBA   -1.500000e+00
TPI    1.500000e+00
dtype: float64

### Lower

In [13]:
network_lower = ReactionNetwork(glycolysis.steps[5:])
pd.DataFrame(
    network_lower.s_matrix,
    columns=to_label(network_lower.reactions()),
    index=to_label(network_lower.reactants()))


Unnamed: 0,GAPDH,PGK,GPMM,ENO,PYK,PPS,PDH
GAP,-1.0,0.0,0.0,0.0,0.0,0.0,0.0
NAD,-1.0,0.0,0.0,0.0,0.0,0.0,-1.0
Pi,-1.0,0.0,0.0,0.0,0.0,1.0,0.0
DPG,1.0,1.0,0.0,0.0,0.0,0.0,0.0
H+,1.0,0.0,0.0,0.0,1.0,2.0,0.0
NADH,1.0,0.0,0.0,0.0,0.0,0.0,1.0
3PG,0.0,-1.0,1.0,0.0,0.0,0.0,0.0
ATP,0.0,-1.0,0.0,0.0,-1.0,-1.0,0.0
ADP,0.0,1.0,0.0,0.0,1.0,0.0,0.0
2PG,0.0,0.0,-1.0,-1.0,0.0,0.0,0.0


In [14]:
ss_obj = fba_gd.SteadyStateObjective(network_lower, [mets[m_id] for m_id in ['dpg', '2pg', '3pg', 'pep', 'pyr']])
prod_obj = fba_gd.TargetDmdtObjective(network_lower, [mets['gap']])
rbounds_obj = fba_gd.VelocityBoundsObjective(network_lower, {rxn: (0.0, np.inf) for rxn in network_lower.reactions() if not rxn.reversible})

def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))) * 1e6,
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))) * 1e8,
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))) * 1e1,
        jnp.sum(jnp.abs(v)) * 1e-8,  # L1 regularization, but don't dominate the loss
    ])

targets = {
    'prod': prod_obj.prepare_targets({mets['gap']: 3}),
    'rbounds': rbounds_obj.prepare_targets()
}
def fn(v):
    dmdt = network_lower.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_lower.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network_lower.reactions()))


GAPDH   -1.502925e-04
PGK      1.202940e-04
GPMM     9.029548e-05
ENO     -6.029698e-05
PYK      3.021373e-05
PPS      8.475176e-08
PDH     -2.999850e-07
dtype: float64

## PDH is irreveversible!
- Treat pyruvate as a boundary, not an intermediate

In [15]:
ss_obj = fba_gd.SteadyStateObjective(network_lower, [mets[m_id] for m_id in ['dpg', '2pg', '3pg', 'pep']])


def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))) * 1e6,
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))) * 1e8,
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))) * 1e1,
        jnp.sum(jnp.abs(v)) * 1e-8,  # L1 regularization, but don't dominate the loss
    ])

targets = {
    'prod': prod_obj.prepare_targets({mets['gap']: 3}),
    'rbounds': rbounds_obj.prepare_targets()
}
def fn(v):
    dmdt = network_lower.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, jax.random.normal(prng, network_lower.shape[1:]), jac=jacfn, tol=1e-12)
pd.Series(soln.x, index=to_label(network_lower.reactions()))


GAPDH   -3.000000
PGK      3.000000
GPMM     3.000000
ENO     -3.000000
PYK      2.999933
PPS      0.000067
PDH      0.003823
dtype: float64

- It works. We found the problem.
- Flux goes through PYK and not PPS. This is not biologically correct, but as constructed it is valid
    - PYK is reversible. We could make it irreversible but that is not actually correct (it is pyruvate _kinase_ after all).
    - _in vivo_, PPS would be prefered but that's for energetic reasons, which are not captured in FBA. Unless we can work that into the objective function...

## Next: some analysis

### Boundary fluxes

In [16]:
# Back to the full pathway, forward flux
network = ReactionNetwork(glycolysis.steps)
intermediates = [mets[mol_id] for mol_id in ('Fru.D.6P', 'Fru.D.bis16', 'dhap', 'gap', 'dpg', '3pg', '2pg', 'pep', 'pyr')]
boundaries = [met for met in glycolysis.metabolites if met not in intermediates]

ss_obj = fba_gd.SteadyStateObjective(network, intermediates)
rbounds_obj = fba_gd.VelocityBoundsObjective(network, {rxn: (0.0, np.inf) for rxn in glycolysis.steps if not rxn.reversible})
prod_obj = fba_gd.TargetDmdtObjective(network, [mets['accoa']])

def loss(v, dmdt, targets):
    return sum([
        jnp.sum(jnp.square(ss_obj.residual(v, dmdt))) * 1e3,
        jnp.sum(jnp.square(rbounds_obj.residual(v, dmdt, targets['rbounds']))) * 1e3,
        jnp.sum(jnp.square(prod_obj.residual(v, dmdt, targets['prod']))),
        jnp.sum(jnp.abs(v)) * 1e-6,  # L1 regularization, but don't dominate the loss
    ])

targets = {
    'prod': prod_obj.prepare_targets({mets['accoa']: 3}),
    'rbounds': rbounds_obj.prepare_targets()
}
def fn(v):
    dmdt = network.s_matrix @ v
    return loss(v, dmdt, targets)

jfn = jax.jit(fn)
jacfn = jax.jit(jax.jacfwd(fn))

soln = scipy.optimize.minimize(jfn, rando_v, jac=jacfn, tol=1e-12)
flux = pd.Series(soln.x, index=to_label(network.reactions()))
dmdt = pd.Series(network.s_matrix @ soln.x, index=to_label(network.reactants()))

In [17]:
diag = escher.Builder(
    map_json=resources.read_text(files.pw, 'glycolysis_ppp_ed.json'),
    # map_json=resources.read_text(files.pw, 'central_carbon.json'),
    menu='zoom',
    enable_editing=False,
    never_ask_before_quit=True,
    reaction_data=flux,
    metabolite_data=dmdt,
    reaction_styles = ['abs', 'color', 'size',],
    metabolite_styles = ['color', ],
    reaction_scale = [
        {'type': 'min', 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
    metabolite_scale = [
        {'type': 'min', 'color': '#b30019', 'size': 5},
        {'type': 'value', 'value': 0, 'color': '#eeeeee', 'size': 5},
        {'type': 'max', 'color': '#1f77b4', 'size': 15},
    ],
)

diag

Builder(enable_editing=False, menu='zoom', metabolite_data={'Glc.D.6P': -1.4999979914978734, 'Fru.D.6P': -4.99…

In [23]:
from bokeh.plotting import figure, output_notebook, show

x = to_label(reversed(boundaries))
plot = figure(width=400, height=300, tools=[], y_range=x)
plot.hbar(
    y = x,
    right = dmdt[x],
    height=0.5,
    color='#1f77b4',
)
output_notebook()
show(plot)


In [20]:
carbons = {
    'CO2': 1,
    'acCoA': 2,
    'Glc.D.6P': 6,
}

c_yield = {metname: round(dmdt[metname] * carbons[metname], 3) for metname in carbons}
c_yield

{'CO2': 3.0, 'acCoA': 6.0, 'Glc.D.6P': -9.0}