# StoichModel Class

**Michelle Ko & Martin Lysy, University of Waterloo**

**June 29th, 2022**

## Overview

`StoichModel()` is a class for stoichiometry models casted into SDEs. This class takes two matrices: In-Matrix and Out-Matrix, representing information about the reactants and products, respectively (hence in and out), of the reactions in the dynamic system. The columns corresponds to the reactions and the rows determine the species type. The Stoichiometry Matrix $S$, given by the net effect `Out-Matrix - In-Matrix`, is instrumental to calculating the drift and diffusion of the SDEs, along with the reaction propensity function $h(X, \theta)$.

In cases where $S$ is rank-deficient, the diffusion matrix $S^\mathsf{T} \mathrm{diag}(h(X,\theta)) S$ can no longer be positive-definite, and sampling trajectory from the multivariate normal becomes impossible. In terms of the stoichiometry model, this implies that population of one or more species can be completely explained by that of other species. The procedure that identify, remove, and restore such dependancies, outlined in Ingalls & Bembenek (2014), is implemented in `StoichModel()`.

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import pfjax as pf
import pfjax.mcmc as mcmc
from pfjax import sde as sde
from functools import partial

import pandas as pd
import seaborn as sns

  import pandas.util.testing as tm


In [16]:
class StoichModel(sde.SDEModel):

    def __init__(self, dt, n_res, InMatrix_Full, OutMatrix_Full, mask=None, epsilon=1e-6, bootstrap=True):

        super().__init__(dt, n_res, diff_diag=False) # Inherits SDEModel class
        self._bootstrap = bootstrap
        
        # Full stoichiometry matrix
        self._InMatrix_Full = InMatrix_Full
        self._OutMatrix_Full = OutMatrix_Full
        self._StoichMatrix_Full =  self._OutMatrix_Full - self._InMatrix_Full
        self._n_X_full, self._n_Rxn_full = self._StoichMatrix_Full.shape
        
        # Linearly independent stoichiometry matrix
        q,r = jnp.linalg.qr(jnp.transpose(self._StoichMatrix_Full))
        self._mask = jnp.array(jnp.abs(jnp.diag(r))>=epsilon) if mask is None else mask
        
        # Row-reduced stoichiometry matrix
        self._InMatrix = self._InMatrix_Full[self._mask,:]
        self._OutMatrix = self._OutMatrix_Full[self._mask,:]
        self._StoichMatrix = self._StoichMatrix_Full[self._mask,:]
        self._n_X, self._n_Rxn = self._StoichMatrix.shape
        
        self._n_state = (self._n_res, self._n_X)
        
        # The link matrix to restore dependent species
        self._L = jnp.matmul(self._StoichMatrix_Full, jnp.linalg.pinv(self._StoichMatrix))
        mask_matrix = jnp.transpose(jnp.broadcast_to(self._mask, (self._n_X, self._n_X_full))) # mask broadcasted to matrix
        self._L0 = jnp.where(jnp.invert(mask_matrix), self._L ,jnp.zeros((self._n_X_full,self._n_X))) 
        
    def _Hazard(self, x, param):
        
        # Propensity of i-th reaction determined by reaction type
        def h(x, param_i, i):

            Rxn = self._InMatrix_Full[:,i]

            n_mols = sum(Rxn)
            n_type = jnp.count_nonzero(Rxn)
            index = jnp.nonzero(jnp.array(Rxn),size=2)

            if n_mols == 0:
                ans = param_i
            elif n_mols == 1:
                ans = param_i * x[index[0][0]]
            elif n_mols == 2 and n_type == 1:
                ans = param_i * x[index[0][0]] * (x[index[0][0]] - 1) / 2
            elif n_mols == 2 and n_type == 2:
                ans = param_i * jnp.prod(x[index[0]])
            else:
                print('Not supported reaction')
                # Throw some error
            return ans
        
        Hazard = jnp.array([h(x, param[i], i) for i in range(self._n_Rxn)])
        
        return Hazard
    
    # Restore population of dependent species, given current (independent) and initial (full) population
    def _x_full(self, x, x_init):
        
        T_tilde = jnp.where(jnp.invert(self._mask), x_init, jnp.zeros(self._n_X_full)) - self._L0 @ x_init[self._mask]
        x_full = self._L @ x + T_tilde
        
        return x_full
    
    # Drift on the regular scale
    def _drift(self, x, theta):
        x_full = self._x_full(x, theta[(self._n_Rxn + self._n_X):])
        Hazard = self._Hazard(x_full, theta[:self._n_Rxn])
        mu = self._StoichMatrix @ Hazard
        
        return mu
    
    # Diffusion on the regular scale
    def _diff(self, x, theta):
        x_full = self._x_full(x, theta[(self._n_Rxn + self._n_X):])
        Hazard = self._Hazard(x_full, theta[:self._n_Rxn])
        Sigma = self._StoichMatrix @ jnp.diag(Hazard) @ jnp.transpose(self._StoichMatrix)

        return Sigma
    
    # Drift on the log scale
    def drift(self, x, theta):
        x = jnp.exp(x)
        mu = self._drift(x, theta)
        Sigma = self._diff(x, theta)

        f_p = 1/x
        f_pp = -1/x/x

        mu_trans = f_p * mu + 0.5 * f_pp * jnp.diag(Sigma)
        return mu_trans
    
    # Diffusion on the log scale
    def diff(self, x, theta):
        x = jnp.exp(x)
        Sigma = self._diff(x, theta)

        f_p = 1/x
        Sigma_trans = jnp.outer(f_p, f_p) * Sigma

        return Sigma_trans
    
    def meas_lpdf(self, y_curr, x_curr, theta):
        """
        Log-density of `p(y_curr | x_curr, theta)`.
        Args:
            y_curr: Measurement variable at current time `t`.
            x_curr: State variable at current time `t`.
            theta: Parameter value.
        Returns
            The log-density of `p(y_curr | x_curr, theta)`.
        """
        tau = theta[self._n_Rxn:(self._n_Rxn+self._n_X)]
        return jnp.sum(
            jsp.stats.norm.logpdf(y_curr, loc=jnp.exp(x_curr[-1]), scale=tau)
        )
        
    def meas_sample(self, key, x_curr, theta):
        """
        Sample from `p(y_curr | x_curr, theta)`.
        Args:
            key: PRNG key.
            x_curr: State variable at current time `t`.
            theta: Parameter value.
        Returns:
            Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
        """
        tau = theta[self._n_Rxn:(self._n_Rxn+self._n_X)]
        return jnp.exp(x_curr[-1]) + tau * random.normal(key, (self._n_state[1],))
    
    def pf_init(self, key, y_init, theta):
        """
        Particle filter calculation for `x_init`.
        Samples from an importance sampling proposal distribution
        ```
        x_init ~ q(x_init) = q(x_init | y_init, theta)
        ```
        and calculates the log weight
        ```
        logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
        ```
        **FIXME:** Explain what the proposal is and why it gives `logw = 0`.
        In fact, if you think about it hard enough then it's not actually a perfect proposal...
        Args:
            y_init: Measurement variable at initial time `t = 0`.
            theta: Parameter value.
            key: PRNG key.
        Returns:
            - x_init: A sample from the proposal distribution for `x_init`.
            - logw: The log-weight of `x_init`.
        """
        tau = theta[self._n_Rxn:(self._n_Rxn+self._n_X)]

        key, subkey = random.split(key)
        x_init = jnp.log(y_init + tau * random.truncated_normal(
            subkey,
            lower=-y_init/tau,
            upper=jnp.inf,
            shape=(self._n_state[1],)
        ))
        logw = jnp.sum(jsp.stats.norm.logcdf(y_init/tau))
        
        return \
            jnp.append(jnp.zeros((self._n_res-1,) + x_init.shape),
                       jnp.expand_dims(x_init, axis=0), axis=0), \
            logw
    
    def pf_step(self, key, x_prev, y_curr, theta):
        """
        Choose between bootstrap filter and bridge proposal.
        Args:
            x_prev: State variable at previous time `t-1`.
            y_curr: Measurement variable at current time `t`.
            theta: Parameter value.
            key: PRNG key.
        Returns:
            - x_curr: Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
            - logw: The log-weight of `x_curr`.
        """
        if self._bootstrap:
            x_curr, logw = super().pf_step(key, x_prev, y_curr, theta)
        else:
            omega = (theta[self._n_Rxn:(self._n_Rxn+self._n_X)] / y_curr)**2
            x_curr, logw = self.bridge_prop(
                key, x_prev, y_curr, theta, 
                jnp.log(y_curr), jnp.eye(4), jnp.diag(omega)
            )
        return x_curr, logw


In [24]:
# Auto-regulatory gene network from Golightly & Wilkinson (2006)

In = jnp.array([[1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
                [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.]])

Out = jnp.array([[0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
                 [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.]])

I,G,Ii,Ig,i,g,ri,rg = (8.5,29,2,3,8,7,18,9)
X_init = jnp.array([I,G,Ii,Ig,i,g,ri,rg])
K1 = Ii + i
K2 = Ig + g

tau = jnp.array([0.1,0.1,0.1,0.1,0.1,0.1])
param = jnp.array([0.08, 0.82, 0.09, 0.9, 0.25, 0.1, 0.35, 0.3, 0.1, 0.1, 0.12, 0.1])

mask = jnp.array([True,  True, False, False, True,  True, True,  True])

theta = jnp.concatenate((param, tau, X_init))
gnet = StoichModel(1,2,In,Out)
gnet_mask = StoichModel(1,2,In,Out,mask)


In [26]:
X = jnp.array([I,G,Ii,Ig,ri,rg])
X_mask = jnp.array([I,G,i,g,ri,rg])

gnet._drift(X, theta)
gnet_mask._drift(X_mask, theta)

DeviceArray([-5.675009  , -0.20000052, -3.8000045 , -2.6550033 ,
              0.20000064,  1.5500002 ], dtype=float32)

In [27]:
jnp.array([param[1]*(K1-i) + param[3]*(K2-g) + param[5]*ri - param[0]*I*i - param[2]*I*g - param[10]*I , 
           param[7]*rg - param[11]*G, 
           param[1]*(K1-i) - param[0]*I*i,
           param[3]*(K2-g) - param[2]*I*g,
           param[4]*i - param[8]*ri,
           param[6]*g - param[9]*rg
          ])

DeviceArray([-5.6750007 , -0.20000005, -3.8000002 , -2.6550007 ,
              0.19999993,  1.55      ], dtype=float32)