# Reaction Networks

In [None]:
import numpy as np
import scipy.integrate as spi
import matplotlib.pyplot as plt
from sympy import *
from functools import reduce 

%matplotlib inline

init_printing(use_latex='mathjax')

# # Deterministic Reaction Networks

# def StoichiometricMatix(rxns):
#     """Returns the stoichiometric matrix for the given reactions.
#        It is assumed that the reactions are in a list of reactions
#        each ofthe form
#           [ar, ap, rate_constant]
#        where ar is the reactant vector and ap is the product vector.
#     """
#     vectors = [(-Matrix(a[0]) + Matrix(a[1])) for a in rxns]
#     A = Matrix(vectors[0])
#     for i in range(1,len(vectors)):
#         A = A.col_insert(i,vectors[i])
#     return A

# def MAK(a,species):
#     """The mass action kinetics rate for the reaction a [ ar, ap, k ].
#        syms should be a list of sympy symbols such as returned by var("x y z")"""
#     rate = a[2] # k
#     for r,s in zip(a[0],species):
#         rate = rate * s**r
#     return rate

# def KineticsVector(rxns,species):
#     """Builds the kinetics vector for the given reactions"""
#     return Matrix([MAK(a,species) for a in rxns])

# def Dynamics(rxns,species):
#     """Builds the right hand side of the equation dv/dt = AK(v)"""
#     return StoichiometricMatix(rxns) * KineticsVector(rxns,species)

# def ODEs(dynamics,species):
#     """Convertns the dynamics into a function that can be used with scipy's odeint function"""
#     faux = lambdify([species],dynamics,'numpy')
#     def f(x,t):
#         return faux(x).flatten()
#     return f

# def vectorize(species,names):
#     a = [0 for s in species]
#     for name in names:
#         a[species.index(name)] += 1
#     return a

# def make_vectorizer(species):
#     def f(*args):
#         return vectorize(species, args)
#     return f

# Stochastic Networks

In [None]:
class ReactionNetwork:
    
    def __init__(self,names):
        self.names = names
        self.n = len(self.names)
        self.species = var(names)
        self.m = self.moment_vector()
        self.rxns = []
        
    def vectorize(self,parts):
        a = [0 for s in self.species]
        for part in parts:
            a[self.species.index(part)] += 1
        return a

    def make_vectorizer(self):
        def f(*args):
            return self.vectorize(args)
        return f        

    def moment_vector(self):
        m = [var("m_" + name) for name in self.names]
        for i in range(self.n):
            for j in range(i+1):
                m.append(var("m_" + self.names[i] + self.names[j]))
        return m

    def sublist(self):
        s = []
        for i in range(self.n):
            for j in range(i+1):
                s.append((
                    self.species[i]*self.species[j], 
                    var("m_" + self.names[i] + self.names[j]))
                )
        for i in range(self.n): # Note first moments come last so that second moments 
                                # are substituted first
            s.append((
                self.species[i],
                var("m_" + self.names[i]))
            )
        return s

    def first_power_vector(self,i):
        return [ (1 if i == j else 0) for j in range(self.n)]

    def second_power_vector(self,i,j):
        v = [0 for i in range(self.n)]
        v[i] += 1
        v[j] += 1
        return v

    def moment(self,powers):
        def f(species):
            terms = [s**p for s,p in zip(species,powers)]
            return reduce((lambda x, y: x * y), terms)
        return f

    def next(self, rxn, state):
        return Matrix(state) - Matrix(rxn[0]) + Matrix(rxn[1])
    
    def MAK(self,a,state=None):
        """The mass action kinetics rate for the reaction a [ ar, ap, k ].
           syms should be a list of sympy symbols such as returned by var("x y z")"""
        if state == None:
            state = self.species
        rate = a[2] # k
        for r,s in zip(a[0],state):
            rate = rate * s**r
        return rate    
    
    def extended_generator(self,f):
        state = self.species # symbolic representation of the state
        terms = [
            (f(self.next(rxn, state)) - f(Matrix(state))) * self.MAK(rxn) 
            for rxn in self.rxns
        ]
        return sum(terms).simplify()

    def first_moment_equations(self):
        return [ 
            self.extended_generator(
                self.moment(self.first_power_vector(i))
            ).subs(self.sublist()) 
            for i in range(self.n) 
        ]

    def second_moment_equations(self):
        eqns = []
        for i in range(self.n):
            for j in range(i+1):
                eqns.append(
                    self.extended_generator(
                        self.moment(self.second_power_vector(i,j))
                    ).subs(self.sublist())                 
                )
        return eqns

    def moment_equations(self):
        return Matrix(self.first_moment_equations() + self.second_moment_equations())

    def moment_odes(self,params):
        eqns = self.moment_equations()
        eqns = [eqn.subs(params) for eqn in eqns]
        dmdt = lambdify([self.m], eqns, 'numpy')
        def f(x,t):
            return dmdt(x)
        return f
    
    def next_states(self,state,params=[]):
        return [
            (self.next(rxn,state),self.MAK(rxn,state).subs(params)) 
            for rxn in self.rxns
            if self.MAK(rxn,state).subs(params) > 0
        ]
    
    def total_rate_out(self,state,params=[]):
        ns = self.next_states(state,params)
        return sum([k[1] for k in ns])
    
    def choose_next(self,state,params):
        ns = self.next_states(state,params)
        r = self.total_rate_out(state,params) * np.random.random()
        i = 0
        s = ns[0][1]
        while r > s:
            i = i + 1
            s = s + ns[i][1]
        return ns[i][0]
    
    def choose_time(self,state,params):
        K = self.total_rate_out(state,params)
        r = np.random.random()
        return (1/K) * log(1/(1-r))
    
    def ssa(self,x0,tmax,params):
        x = x0
        t = 0
        xdata = [x0]
        tdata = [t]
        while t < tmax and len(self.next_states(x,params)) > 0:
            xnew = self.choose_next(x,params)
            dt = self.choose_time(x,params)
            x = xnew
            t = t + dt
            xdata.append(x)
            tdata.append(t)
        return np.array(tdata),np.array(xdata) 
        

# Example

In [None]:
system = ReactionNetwork(["R", "P"])
params = var("k1 k2 k3 k4")
params = [ (k1,4), (k2,0.25), (k3,1), (k4,0.1) ]
v = system.make_vectorizer()

system.rxns = [   
    [v(), v(R), k1],           # 0 -> R
    [v(R), v(), k2],           # R -> 0
    [v(R), v(R,P), k3],        # R -> R + P
    [v(P), v(), k4]            # P -> 0 
]

eqns = system.moment_equations()
eqns

In [None]:
equilibrium = solve(eqns,system.m)
equilibrium

In [None]:
eq = Matrix(system.m).subs(equilibrium).subs(params)
eq

In [None]:
m_R.subs(equilibrium).simplify(), (m_RR - m_R**2).subs(equilibrium).simplify()

In [None]:
m_P.subs(equilibrium).simplify(), (m_PP - m_P**2).subs(equilibrium).simplify()

In [None]:
dmdt = system.moment_odes(params)
tmax = 30
t = np.linspace(0,tmax,100)
m0 = [0,0,0,0,0]
m = spi.odeint(dmdt,m0,t)

rstdev = np.sqrt(m[:,2]-m[:,0]**2)
pstdev = np.sqrt(m[:,4]-m[:,1]**2)

ssa_t,ssa_x = system.ssa([0,0],tmax,params)

fig,ax = plt.subplots(1,2,figsize=(15,5))

ax[0].plot(t,m[:,0],label="RNA")
ax[0].step(ssa_t,ssa_x[:,0],color="blue")
ax[0].fill_between(t,m[:,0] - rstdev,m[:,0] + rstdev,color="lightblue");
ax[0].set_xlabel("$t$")

ax[1].plot(t,m[:,1],label="Protein",color="orange")
ax[1].step(ssa_t,ssa_x[:,1],color="orange")
ax[1].fill_between(t,m[:,1] - pstdev,m[:,1] + pstdev,color="bisque");
ax[1].set_xlabel("$t$")

ax[0].legend();
ax[1].legend();

In [None]:
system.next_states([50,0],params)

In [None]:
system.total_rate_out([50,0],params)

In [None]:
[system.choose_next([50,0],params) for i in range(10)]

In [None]:
t,x = system.ssa([0,0],30,params)
plt.step(t,x);

# Enzyme Kinetics

In [None]:
system = ReactionNetwork(["E", "S", "X", "P"])
params = var("k1 k2 k3")
params = [ (k1,1), (k2,0.5), (k3,.1) ]
v = system.make_vectorizer()

system.rxns = [   
    [v(E,S), v(X), k1], 
    [v(X), v(E,S), k2],      
    [v(X), v(E,P), k3]
]

In [None]:
for i in range(40):
    ssa_t,ssa_x = system.ssa([2,10,0,0],140,params)
    plt.step(ssa_t, ssa_x[:,3]);

# Multiplication

<img width="75%" src="https://raw.githubusercontent.com/klavins/ECE424/master/images/register-machine.png">

In [None]:
multiplier = ReactionNetwork([
    "S0", "S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", 
    "R0", "R1", "R2", "R3"
])

params = var("k e")
params = [ (k,1), (e,0.01) ]

multiplier.species

In [None]:
v = multiplier.make_vectorizer()

multiplier.rxns = [
    [v(S0,R0), v(S1),    k],     # dec(0,0,1,8)
    [v(S0),    v(S8),    e],    
    [v(S1,R1), v(S2),    k],     # dec(1,1,2,4)
    [v(S1),    v(S4),    e], 
    [v(S2),    v(R2,S3), k],     # inc(2,2,3)
    [v(S3),    v(R3,S1), k],     # inc(3,3,4)
    [v(S4,R0), v(S5),    k],     # dec(4,0,5,8)
    [v(S4),    v(S8),    e],    
    [v(S5,R2), v(S6),    k],     # dec(5,2,6,0)
    [v(S5),    v(S0),    e],    
    [v(S6),    v(R1,S7), k],     # inc(6,1,7)
    [v(S7),    v(R3,S5), k],     # inc(7,3,5)  
]

In [None]:
x0 = [1,0,0,0,0,0,0,0,0,
      4,5,0,0]
t,x = multiplier.ssa(x0,1000,params)
plots = plt.step(t, x[:,9],t, x[:,10],t, x[:,11],t, x[:,12]);
plt.legend(plots, ["R0", "R1", "R2", "R3"])
plt.xlabel("$t$");

In [None]:
s = sum(np.array([x[:,i]*i for i in range(9)]))
plt.step(t,s);
plt.xlabel("$t$");
plt.ylabel("$state$");
plt.title("State versus time");