# A Quick Tour of PFJAX

**Martin Lysy, University of Waterloo**

**August 3, 2022**

## Summary

The goal of **PFJAX** is to provide tools for estimating the parameters $\tth$ of a [state-space model](http://www.scholarpedia.org/article/State_space_model)

$$
\begin{aligned}
\xx_0 & \sim \pi(\xx_0 \mid \tth) \\
\xx_t & \sim f(\xx_t \mid \xx_{t-1}, \tth) \\
\yy_t & \sim g(\yy_t \mid \xx_t, \tth).
\end{aligned}
$$

In such models, only the *measurement variables* $\yy_{0:T} = (\yy_0, \ldots, \yy_T)$ are observed, whereas the *state variables* $\xx_{0:T}$ are latent.  The marginal likelihood given the observed data is

$$
\begin{aligned}
\Ell(\tth) & = p(\yy_{0:T} \mid \tth) \\
& = \int \prod_{t=0}^T g(\yy_t \mid \xx_t, \tth) \times \prod_{t=1}^T f(\xx_t \mid \xx_{t-1}, \tth) \times \pi(\xx_0 \mid \tth)\, \ud \xx_{0:T},   
\end{aligned}
$$

but this integral is typically intractable.  The state-of-the-art for approximating it is via [particle filtering methods](https://warwick.ac.uk/fac/sci/statistics/staff/academic-research/johansen/publications/dj11.pdf).  **PFJAX** provides several particle filters to estimate the marginal loglikelihood $\ell(\tth) = \log \Ell(\tth)$, along with its gradient $\nabla \ell(\tth) = \frac{\partial}{\partial \tth} \ell(\tth)$ and hessian $\nabla^2 \ell(\tth) = \frac{\partial^2}{\partial \tth \partial \tth'} \ell(\tth)$.  To do this efficiently, **PFJAX** uses JIT-compilation and automatic differentiation as provided by the [**JAX**](https://github.com/google/jax) library.

In this tutorial, we'll show how to use **PFJAX** for:

- Estimating the marginal loglikelihood $\ell(\tth)$.

- Approximating the maximum likelihood estimator $\hat \tth = \argmax_{\tth} \ell(\tth)$ and its variance $\var(\hat \tth)$ via stochastic optimization.

- Calculating the posterior distribution $p(\tth \mid \yy_{0:T}) \propto \Ell(\tth) \times \pi(\tth)$ via Markov chain Monte Carlo (MCMC) sampling.

## Example Model: Brownian Motion with Drift

The model is

$$
\begin{aligned}
x_0 & \sim \N(0, \sigma^2 \dt) \\
x_t & \sim \N(x_{t-1} + \mu \dt, \sigma^2 \dt) \\
y_t & \sim \N(x_t, \tau^2).
\end{aligned}
$$

The parameters of the model are $\tth = (\mu, \sigma, \tau)$.  Since $(\xx_{0:T}, \yy_{0:T})$ are jointly normal given $\tth$, we can show (see [here](#bm_deriv)) that $\yy_{0:T}$ is multivariate normal with mean and covariance 

$$
E[y_t \mid \tth] = \tilde \mu t, \qquad \cov(y_s, y_t) = \tilde \sigma^2 \cdot (1 + \min(s, t)) + \tau^2 \delta_{st},
$$

where $\tilde \mu = \mu \dt$, $\tilde \sigma^2 = \sigma^2 \dt$, and $\delta_{st}$ is the indicator function.  These formulas provide an analytic expression for $\Ell(\tth = p(\yy_{0:T})$, which we can use to benchmark our calculations.  

<a id="bm_deriv"></a>
## Appendix: Exact Likelihood of the BM Model

The distribution of $p(\xx_{0:T}, \yy_{0:T} \mid \tth)$ is multivariate normal.  Thus, $p(\yy_{0:T} \mid \tth)$ is also multivariate normal, and we only need to find $E[y_t \mid \tth]$ and $\cov(y_s, y_t \mid \tth)$.

Conditioned on $x_0$ and $\tth$, the Brownian latent variables $\xx_{1:T}$ are multivariate normal with

$$
\newcommand{\cov}{\operatorname{cov}}
\begin{aligned}
E[x_t \mid x_0, \tth] & = x_0 + \tilde \mu t, \\
\cov(x_s, x_t \mid x_0, \tth) & = \tilde \sigma^2 \min(s, t),
\end{aligned}
$$

where $\tilde \mu = \mu \dt$ and $\tilde \sigma^2 = \sigma^2 \dt$.

Therefore, $p(\xx_{0:T} \mid \tth)$ is multivariate normal with

$$
\begin{aligned}
E[x_t \mid \tth] & = E[E[x_t \mid x_0, \tth]] \\
& = \tilde \mu t, \\
\cov(x_s, x_t \mid \tth) & = \cov(E[x_s \mid x_0, \tth], E[x_t \mid x_0, \tth]) + E[\cov(x_s, x_t \mid x_0, \tth)] \\
& = \tilde \sigma^2 (1 + \min(s, t)).
\end{aligned}
$$

Similarly, conditioned on $\xx_{0:T}$ and $\tth$, the measurement variables $\yy_{0:T}$ are multivariate normal with

$$
\begin{aligned}
E[y_t \mid \xx_{0:T}, \tth] & = x_t, \\
\cov(y_s, y_t \mid \xx_{0:T}, \tth) & = \tau^2 \delta_{st}.
\end{aligned}
$$

Therefore, $p(\yy_{0:T} \mid \tth)$ is multivariate normal with

$$
\begin{aligned}
E[y_t \mid \tth] & = \tilde \mu t, \\
\cov(y_s, y_t \mid \tth) & = 
% \cov(E[y_s \mid \xx_{0:T}, \tth], E[y_t \mid \xx_{0:T}, \tth]) + E[\cov(y_s, y_t \mid \xx_{0:T}, \tth)] \\
\tilde \sigma^2 (1 + \min(s, t)) + \tau^2 \delta_{st}.
\end{aligned}
$$