In [None]:
%pip install git+https://github.com/gbdrt/mu-ppl

# Introduction to Sequential Monte Carlo methods (SMC)

SMC are an alternative to MCMC methods and are particularly well-suited for inference on state-space models, i.e., models used to reason about time series.

In the following we focus on a very simple SSM, a Hidden Markov Model (HMM) which tries to track the position of a moving agent from noisy observations (e.g., detecting a boat using a radar).

The model is a following: at each (discrete) time step, we assume that:
1. the true position $X_t$ is normally distributed around the previous position $X_ {t-1}$, and 
2. the current observation $Y_t$ is normally distributed around $X_t$.


$$
X_t \sim \mathcal{N}(X_{t-1}, 1)\\
Y_t \sim \mathcal{N}(X_t, 1)
$$

## Model in mu-ppl

**Question 1.** Implement this model in mu-ppl.

In [None]:
from mu_ppl import *
from typing import List, Any, Iterator

def hmm(data: List[float]) -> List[float]:
    #TODO
    return []

Consider the following (very unrealistic) data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

t = np.arange(0, 6.5, 0.2)
data = np.sin(t)
plt.scatter(t, data)


**Question 2.** Try existing inference on this synthesized dataset.

In [None]:
#TODO
dist = None

We can use `distribution.split` to turn a distribution over lists into a list of distribution and plot the estimations at each step.

In [None]:
x_dists = split(dist)
x_means, x_stds = np.array([d.stats() for d in x_dists]).T

plt.plot(t, x_means)
plt.fill_between(t, x_means - x_stds, x_means + x_stds, color='blue', alpha=0.2)
plt.scatter(t, data, color='red')

Depending on you choice of inference algorithm (and your compute time) your results may range from ok to terrible.

**Question 3.** Rerun the previous cells using other inferences.

## Particle filtering

Sequential monte carlo methods are also called \emph{particle filters}.
The basis is Importance Sampling.
The inference launches a series of independent simulations, called _particles_.
Each particles computes a value and a score which measures the quality of the particle w.r.t. the model.

Problem: Importance Sampling behaves like a random walk. 
Each new sample is drawn completely at random regardless of the score of the particle.
On the HMM model, at the end of the trajectory, any given particle would have to be extremely lucky to have sampled a trajectory that remains relatively close to all the observed data.

To mitigate this issue, a particle filter periodically _resample_ the set of particle during the execution, duplicating the most relevant particles and discarding the worst ones.

But, to implement this algorithm we need a _checkpointing_ mechanism to interrupt the execution of the particles.
At each checkpoint, we can: 

1. interrupt all the particles, 
2. resample a new set of particles from their scores, 
3. reset the scores, 
4. resume execution until the next checkpoint.

## Checkpoints and Python generators

One way to implement this checkpointing mechanism is to _hardcode_ the notion of timestep into the model.
In Python, a SSM will be a class with two methods:
- the constructor `_init_` describes the first time step
- the `step` function describes the transition to the next step

In [None]:
class SSM:
    def __init__(self):
        pass # What happens at the first step

    def step(self, *args, **kwargs) -> Any:
        pass # Transition to the next step

**Question 4.** Re-implement the HMM models as an SSM.

In [None]:
class HMM(SSM):
    def __init__(self):
        #TODO
        pass

    def step(self, y: float) -> float:
        #TODO
        return 0

From this description, we can always unfold the model $n$ times to obtain something similar to the model used in the first part.

In [None]:
def unfold_hmm(data: List[float]) -> List[float]:
    hmm = HMM()
    res = []
    for y in data:
        res.append(hmm.step(y))
    return res

**Question 5.** Retry to run any inference on the unfolded model.

In [None]:
#TODO 
dist = None

In [None]:
x_dists = split(dist)
x_means, x_stds = np.array([d.stats() for d in x_dists]).T

plt.plot(t, x_means)
plt.fill_between(t, x_means - x_stds, x_means + x_stds, color='blue', alpha=0.2)
plt.scatter(t, data, color='red')

## Resampling

We now need a method to resample a set particles in the middle of the execution.
Each particle corresponds to an instance of the model (written as a `SSM`).
To clone a particle we can simply copy it (in Python we can use `deepcopy` to copy all the data structure).
To resample a set of particles:
1. turn a list of pair (particle, score) into a `Categorical` distribution
2. sample $n$ new particles from this distribution.

**Question 6.** Implement the `resample` method.

In [None]:
from copy import deepcopy

def resample(particles: List[SSM], scores: List[float]) -> List[SSM]:
        #TODO
        return []


We now have everything to implement the `SMC` inference.

At each time step, the `infer_stream` method:
1. read one value from the input data.
2. compute the values and scores of each particles
3. returns the current distribution (using the Python `yield` construct)
4. resample the set of particles before the next step.

**Question 7.** Complete the following implementation.

In [None]:
class SMC(ImportanceSampling):
    """
    Sequential Monte-Carlo.

    Model must be expressed as a state machine (SSM).
    Similar to Importance sampling, but particles are resampled after each step.
    """

    def infer_stream(self, ssm: type[SSM], data: List[Any]) -> Iterator[Categorical]:
        particles: List[SSM] = []  #TODO initialise the particles
        for y in data:  # at each step
            values: List[Any] = []
            scores: List[float] = []
            for i in range(self.num_particles):
                #TODO
                pass
            yield Categorical(list(zip(values, scores)))  # return current distribution
            #TODO resample the particles

We can finally test our new inference on our synthesized data.

In [None]:
with SMC(num_particles=1000) as smc:
    x_dists = smc.infer_stream(HMM, data)  # type: ignore 
    x_means, x_stds = np.array([d.stats() for d in x_dists]).T   

In [None]:
plt.plot(t, x_means)
plt.fill_between(t, x_means - x_stds, x_means + x_stds, color='blue', alpha=0.2)
plt.scatter(t, data, color='red')

**Bonus**
- Try to adapt the model to a 2D trajectory
- Try a more complex motion model, e.g., $X_{t+1} \sim N(f(X_{t}), 1)$