# Lotka-Volterra Model with Multiplicative Noise

## Model Definition

The SDE is given by $\XX_t = (U_t, V_t)$

$$
\begin{bmatrix}
\ud U_t \\
\ud V_t
\end{bmatrix}
= 
\begin{bmatrix} 
\alpha U_t - \beta U_t V_t \\ 
\beta U_t V_t - \gamma V_t 
\end{bmatrix} \ud t + 
\begin{bmatrix}
\alpha U_t + \beta U_t V_t & - \beta U_t V_t \\
-\beta U_t V_t & \gamma V_t + \beta U_t V_t
\end{bmatrix}^{1/2} \ud \BB_t.
$$

However, the SDE is actually approximated on the log scale, $\ZZ_t = \log \XX_t$.  The noisy observations are given by 

$$
\yy_n \ind \N(\exp(\ZZ_n), \diag(\tta^2)).
$$

## Inference Settings

- In each case, please try to estimate $\pph = (\alpha, \beta, \gamma)$, assuming that $\tta$ is known.  

- Please do this with stochastic optimization, and make sure you provide a hessian estimate for the variance at the end.  

- You should use the "Bayesian normal approximation", which plots normal distributions with mean MLE and variance taken from the diagonal of the inverse Hessian.  Then put a vertical line for the true parameter value, as in the Ryder paper Figure 3.

- Make sure you do both the optimization/hessian/etc on $\log \pph$.  You are welcome to rewrite the `LVMultModel` below with $\tth = (\log \pph, \log \tta)$ as opposed to $\tth = (\pph, \tta)$.

### Setting 1

The following settings are modified from Ryder et al (2018), Section 5.1:


- $\pph = (\alpha, \beta, \gamma) = (0.5, 0.0025, 0.3)$.

- $\tta = (1, 1)$.

- $\XX_0 = (71, 79)$.

- `dt = 1`.

- `n_res = 10`.

- `n_obs = 50`.

### Setting 2

**Update:** Ignore this for now.

- Same as above, but $V_t$ is unobserved.  The easiest way to do this is to do the optimization/hessian/etc over $\pph$ while fixing $\tau = (1, 100)$.  So, the first $\tau$ is its true value, but the second $\tau$ is made so large that the corresponding observations provide no information about $\pph$.

- So, you simulate data with $\pph = \pph_{\text{true}}$ and $\tta = \tta_{\text{true}}$.  But for inference, you only optimize/hessian over $\pph$, as in setting 1, but now you change the fixed parameters to $\tta = (1, 100)$.



In [25]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from jax import lax
import pfjax as pf
from pfjax import sde as sde

key = random.PRNGKey(0)


class LVMultModel(sde.SDEModel):
    def __init__(self, dt, n_res, bootstrap=False):
        self._observed_predator = observed_predator
        self._dt = dt
        self._n_res = n_res
        self._n_state = (self._n_res, 2)
        self._bootstrap = bootstrap
        super().__init__(dt, n_res, diff_diag=False)

    def _get_params(self, theta):
        return theta[0], theta[1], theta[2], theta[3:4]

    def _get_data(self, x):
        return x[0], x[1]

    def _drift(self, x, theta):
        """Drift on the original scale."""
        alpha, beta, gamma, tau = self._get_params(theta)
        U, V = self._get_data(x)
        aU = alpha * U
        bUV = beta * U * V
        gV = gamma * V
        return jnp.array([aU - bUV, bUV - gV])

    def _diff(self, x, theta):
        """Drift on the original scale."""
        alpha, beta, gamma, tau = self._get_params(theta)
        U, V = self._get_data(x)
        aU = alpha * U
        bUV = beta * U * V
        gV = gamma * V
        return jnp.array([[aU + bUV, -bUV], [-bUV, gV + bUV]])

    def _ito_inv(self, x):
        """
        Inverse Ito transformation.

        In this case the transformation is `log(x)`, so the inverse transformation is `exp(x)`.
        """
        return jnp.exp(x)

    def _ito_dx(self, x):
        """
        Derivative of Ito transformation.

        If the transformation is `log(x)`, the derivative is `1/x`.
        """
        return 1.0/x

    def _ito_dx2(self, x):
        """
        Second derivative of Ito transformation.

        If the transformation is `log(x)`, the second derivative is `-1/x^2`.
        """
        return -1.0 / (x * x)

    def drift(self, x, theta):
        """
        Drift upon applying Ito transformation.
        """
        z = self._ito_inv(x)  # transform to original scale
        dx = self._ito_dx(z)
        dx2 = self._ito_dx2(z)
        return dx * self._drift(z, theta) + \
            0.5 * dx2 * jnp.diag(self._diff(z, theta))

    def diff(self, x, theta):
        """
        Diffusion upon applying Ito transformation.
        """
        z = self._ito_inv(x)  # transform to original scale
        dx = self._ito_dx(z)
        return jnp.outer(dx, dx) * self._diff(z, theta)

    def meas_lpdf(self, y_curr, x_curr, theta):
        r"""
        Log-density of `p(y_curr | x_curr, theta)`.

        FIXME: Explain the choice of distribution here.

        Args:
            y_curr: Measurement variable at current time `t`.
            x_curr: State variable at current time `t`.
            theta: Parameter value.

        Returns
            The log-density of `p(y_curr | x_curr, theta)`.
        """
        (alpha, beta, gamma, tau) = self._get_params(theta)
        return jnp.sum(
            jsp.stats.norm.logpdf(y_curr,
                                  loc=jnp.exp(x_curr[-1]), scale=tau)
        )

    def meas_sample(self, key, x_curr, theta):
        r"""
        Sample from `p(y_curr | x_curr, theta)`.

        Args:
            key: PRNG key.
            x_curr: State variable at current time `t`.
            theta: Parameter value.

        Returns:
            Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
        """
        (alpha, beta, gamma, tau) = self._get_params(theta)
        return jnp.exp(x_curr[-1]) + \
            tau * random.normal(key, (self._n_state[1],))

    def pf_init(self, key, y_init, theta):
        r"""
        Importance sampler for `x_init`.  

        See file comments for exact sampling distribution of `p(x_init | y_init, theta)`, i.e., we have a "perfect" importance sampler with `logw = CONST(theta)`.

        Args:
            key: PRNG key.
            y_init: Measurement variable at initial time `t = 0`.
            theta: Parameter value.

        Returns:
            - x_init: A sample from the proposal distribution for `x_init`.
            - logw: The log-weight of `x_init`.
        """
        (alpha, beta, gamma, tau) = self._get_params(theta)
        key, subkey = random.split(key)
        x_init = jnp.log(y_init + tau * random.truncated_normal(
            subkey,
            lower=-y_init/tau,
            upper=jnp.inf,
            shape=(self._n_state[1],)
        ))
        logw = jnp.sum(jsp.stats.norm.logcdf(y_init/tau))
        return \
            jnp.append(jnp.zeros((self._n_res-1,) + x_init.shape),
                       jnp.expand_dims(x_init, axis=0), axis=0), \
            logw

    def pf_step(self, key, x_prev, y_curr, theta):
        """
        Choose between bootstrap filter and bridge proposal.

        Args:
            x_prev: State variable at previous time `t-1`.
            y_curr: Measurement variable at current time `t`.
            theta: Parameter value.
            key: PRNG key.

        Returns:
            - x_curr: Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
            - logw: The log-weight of `x_curr`.
        """
        if self._bootstrap:
            x_curr, logw = super().pf_step(key, x_prev, y_curr, theta)
        else:
            (alpha, beta, gamma, tau) = self.get_params(theta)
            omega = (tau / y_curr)**2
            x_curr, logw = self.bridge_prop(
                key=key,
                x_prev=x_prev,
                y_curr=y_curr,
                theta=theta,
                Y=jnp.log(y_curr),
                A=jnp.eye(self._n_state[1]),
                Omega=jnp.diag(omega)
            )
        return x_curr, logw

In [31]:
# Simulate data from setting 1

theta = jnp.array([.5, .0025, .3, 1., 1.])
x0 = jnp.log(jnp.array([71., 79.]))

dt = 1.
n_res = 10
n_obs = 50

key, subkey = random.split(key)

x_init = jnp.block([[jnp.zeros((n_res-1, 2))], [x0]])
lv_model = LVMultModel(dt=dt, n_res=n_res)
(y_meas, x_state) = pf.simulate(
    model=lv_model,
    key=subkey,
    n_obs=n_obs,
    x_init=x_init,
    theta=theta
)
y_meas

DeviceArray([[ 70.21262487,  79.92433539],
             [ 80.42041459,  80.25574707],
             [111.96799467,  67.56168716],
             [150.11845576,  66.22223218],
             [198.75781923,  74.74067897],
             [276.58517712,  93.53309362],
             [366.87895763, 141.75974266],
             [371.87596578, 293.92770769],
             [222.32580286, 480.23536364],
             [108.23886615, 534.23730638],
             [ 33.82819555, 472.51697638],
             [ 21.50549159, 381.08688888],
             [ 20.72492078, 284.9783271 ],
             [ 25.77922956, 208.36774137],
             [ 28.82206524, 163.42226061],
             [ 36.90370283, 117.95917004],
             [ 55.96794108,  89.51758997],
             [ 60.5158919 ,  76.82175295],
             [ 90.2713489 ,  66.24808305],
             [131.16006656,  59.91892593],
             [192.18155443,  45.27773199],
             [268.66083948,  48.91028414],
             [379.43974795,  75.73259249],
           