# TODO

## Current

### Pranav

- [ ] Create Jupyter Notebook with stochastic optimization results for all settings in Michelle's Lotka-Volterra example.

    - Please include details about the model, inference settings (e.g., true parameters, number of observations, resolution number, what is and isn't observed, etc.), timings, etc.  Also, please include images whenever possible.
    
    - Please focus on content, not so much on presentation, since I'll have to change this for my slides anyways.
    
    - Please check that things still work with a few minor changes I've made to the `LotVolModel` class.  In particular, the last of the `n_res` latent variables per time step line up with the observation, whereas it used to be the first.  Also, the prior is now mathematically correct (I'd previously omitted the trunctation), but I haven't checked the calculations. 

    - Note that in addition to mode finding, you can estimate the fisher information by taking the Hessian of the stochastic objective function at the mode.  Approximate Bayesian inference can then be conducted by taking the parameters to be multivariate normal with mean at the mode and variance at the inverse of the fisher information.  Please talk to Mohan for details, since he's done this type of Bayesian normal approximation many times in **rodeo**.  

### Jonathan

- [ ] Create R Markdown file with **sdetmb** inference resuts for the gene network model (PGNET).

    - Please include details about the model, inference settings (e.g., true parameters, number of observations, resolution number, what is and isn't observed, etc.), timings, etc.  Also, please include images whenever possible.
    
    - Please focus on content, not so much on presentation, since I'll have to change this for my slides anyways.

    - Please include code to run inference results.  However, since these take a long time to compute, please have the R Markdown file load results from an external file(s), and send this file to me as well.
    
- [ ] Create Jupyter Notebook with stochastic optimization results using the MVN filter for as many as possible settings in Michelle's Lotka-Volterra example.

    - Please see comments to Pranav above about content vs presentation, and about the Bayesian normal approximation.

    - Please try to make your `particle_filter_mvn()` function jittable and gradable.  I think this means removing as many `if` statements as possible.  Those that remain should be replaced by `lax.cond()`.
    
    
### Mohan

- [ ] In the `sde.py` module, please add functions `euler_sim_var()` and `euler_lpdf_var()` which mirror `euler_sim_diag()` and `euler_lpdf_diag()` but with dense diffusion matrix instead of diagonal.

- [ ] Use this to create a `PGNETModel` class similar to `LotVolModel`.

    - It is best to do this on the unconstrained scale.  Please talk to Michelle about how she did this.

### Ferris

- [ ] Create a Jupyter Notebook about inference for the QM-GLE model.  

    - It should contain the simplified SDE model (i.e., after diagonalizing the SPD variance matrix), and the modified Euler approximation.
    
    - Start working out the bridge proposal for this Euler approximation.  Please look carefully at what I did in `sde.ipynb` and `bridge_proposal.ipynb`.  The main similarity is that you assume the normal error model `y_t ~ N(A x_t, Omega)`, where `Omega = 0` means exact observations.  The main difference is that in the "formula", you now need `X | W ~ N(Gamma X + mu_{X|W}, Sigma_{X|W})`, whereas in the existing formula we have `Gamma = Identity`.

### Michelle

- [ ] Please send me all code and data files needed to recreate the figures in the section on the Lotka-Volterra example in the manuscript.

### Kanika

- [ ] Create a Python package out of the projection plot codes.

    - You can also focus on automatic the axis selection for now.  We'll get back to adding the vectorization option later.
    
### Yunfeng

- [ ] Please make sure that `cppmagi` and `stanmagi` give the same result whenever possible for the Lotka-Volterra model (I recall from last time that there was a difference when the prey was unobserved...)

### Other

- [ ] Example where `n_state` and `n_meas` are not scalar.  This can be SDE with `n_res > 1`.

    **Update:** Implemented with `LotVolModel` but not really tested...

- [ ] Add CPU parallelism support.

    **Update:** Have started on an `xmap` implementation for this.  Trick will be to divide the `n_particles` axis into a 2D axis, in which the first is between (CPU) devices and the second is within each device.
    
- [ ] Add support for non-homogeneous state-space models.  So for example instead of `state_sample(x_prev, theta)` we have `state_sample(x_prev, t, theta)`, where `t` is the step index.

## November - December 2021

- [x] Test against true BM likelihood.

    Seem to get reasonable projection plots and a reasonable stochastic optimization.

- [x] Data should be $y_{0:T}$ instead of $y_{1:T}$.  This has been changed in some of the documentation but not all of it...

- [x] Add arbitrary prior specification $p(x_0 \mid \theta)$.  Should also be able to specify $x_0$ directly to calculate the marginal likelihood $\mathcal{L}(x_0, \theta) = p(y_{0:T} \mid x_0, \theta)$.

    This is done with `pf_init()`.  However, for testing purposes in `particle_filter_for()` we break this into `init_logw()` and `init_sample()`.

- [x] Interface for supplying new models to generic PF code.  Currently we have the following:

    ```python
    import pfjax as pf # generic code
    from my_model import MyModel # user-defined model

    # pf estimate of marginal loglikelihood
    def marginal_loglik(model, y_meas, theta, n_particles, key):
        out = pf.particle_filter(model, y_meas, theta, n_particles, key)
        return pf.particle_loglik(out["logw_particles"])
    
    # construct model object
    model = MyModel(dt=dt) # set dt as a data member
    marginal_loglik(model, y_meas, theta, n_particles, key)
    
    # jit + grad version
    marginal_loglik_jgrad = jax.jit(jax.grad(marginal_loglik, argnums=2), static_argnums=(0, 3))
    marginal_loglik_jgrad(model, y_meas, theta, n_particles, key)
    ```
    
    **WARNING:** The following code does not behave as "expected":
    
    ```python
    model.dt = 2 * dt # update value of dt
    marginal_loglik_jgrad(model, y_meas, theta, n_particles, key) # WARNING: uses the original dt...
    ```
    
    In other words, JAX does not recompile `marginal_loglik_jgrad()` with the updated value of `model`, even though we've flagged it with `static_argnums`.
    
- [x] Add arbitrary proposal distribution $q(x_t \mid x_{t-1}, \theta)$.  This should perhaps be done using:

    ```python
    step_sample(x_prev, theta)
    step_logw(x_curr, x_prev, theta)
    ```
    
    instead of e.g., `state_prop_lpdf()`, `state_targ_lpdf()` and `meas_lpdf()` which gets assembled internally for the user.  This is what was done in [**SMCTC**](https://warwick.ac.uk/fac/sci/statistics/staff/academic-research/johansen/smctc/).  But what are the right args to these functions?  Sort of depends if they know about the "global" `y_meas`.  I think the right compromise is:
    
    ```python
    step_sample(x_prev, y_curr, theta)
    step_logw(x_curr, x_prev, y_curr, theta)
    ```
    
    In other words, can use `y_curr` for an optimal filter, and anything more complicated will require globals.
    
    **Update:** This is done with `pf_step()`, which combines both `step_sample()` and `step_logw()` into one, since there are a lot of shared calculations between the two.
    
- [x] Add proper unit tests.  But of what?  Seems like JAX and NumPy are so similar that checking code from one against the other will hardly be that helpful...

    So, decided to test for-loop vs `vmap`/`xmap`/`lax.scan` etc always using JAX.  Reasoning is that without these constructs, JAX is identical to NumPy except PRNGs.
    
    Also testing OOP vs globals interface and jit + grad.
