# Example: Sequential Monte Carlo Filtering

This example is from the Pyro documentation:
- https://pyro.ai/examples/smcfilter.html

This file demonstrates how to use the SMCFilter algorithm with
a simple model of a noisy harmonic oscillator of the form:

    z[t] ~ N(A*z[t-1], B*sigma_z)
    y[t] ~ N(z[t][0], sigma_y)
    
Copyright (c) 2017-2019 Uber Technologies, Inc.
SPDX-License-Identifier: Apache-2.0

In [1]:
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import SMCFilter

In [2]:
# Arguments - defaults
class args:
    pass

args.num_timesteps = 500  # number of timesteps
args.num_particles = 100  # number of particles
args.process_noise = 1
args.measurement_noise = 1
args.seed = 0

In [3]:
class SimpleHarmonicModel:
    def __init__(self, process_noise, measurement_noise):
        self.A = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
        self.B = torch.tensor([3.0, 3.0])
        self.sigma_z = torch.tensor(process_noise)
        self.sigma_y = torch.tensor(measurement_noise)

    def init(self, state, initial):
        self.t = 0
        state["z"] = pyro.sample("z_init", dist.Delta(initial, event_dim=1))

    def step(self, state, y=None):
        self.t += 1
        state["z"] = pyro.sample(
            "z_{}".format(self.t),
            dist.Normal(state["z"].matmul(self.A), self.B * self.sigma_z).to_event(1),
        )
        y = pyro.sample(
            "y_{}".format(self.t), dist.Normal(state["z"][..., 0], self.sigma_y), obs=y
        )
        return state["z"], y


class SimpleHarmonicModel_Guide:
    def __init__(self, model):
        self.model = model

    def init(self, state, initial):
        self.t = 0
        pyro.sample("z_init", dist.Delta(initial, event_dim=1))

    def step(self, state, y=None):
        self.t += 1

        # Proposal distribution
        pyro.sample(
            "z_{}".format(self.t),
            dist.Normal(
                state["z"].matmul(self.model.A), torch.tensor([1.0, 1.0])
            ).to_event(1),
        )


In [4]:
def generate_data(args):
    model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)

    state = {}
    initial = torch.tensor([1.0, 0.0])
    model.init(state, initial=initial)
    zs = [initial]
    ys = [None]
    for t in range(args.num_timesteps):
        z, y = model.step(state)
        zs.append(z)
        ys.append(y)

    return zs, ys

In [7]:
pyro.set_rng_seed(args.seed)

model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
guide = SimpleHarmonicModel_Guide(model)

smc = SMCFilter(model, guide, num_particles=args.num_particles, max_plate_nesting=0)

print("Generating data")
zs, ys = generate_data(args)

print("Filtering")

smc.init(initial=torch.tensor([1.0, 0.0]))
for y in ys[1:]:
    smc.step(y)

z = smc.get_empirical()["z"]

print("At final time step:")
print(f"truth: {zs[-1]}")
print(f"mean: {z.mean}")
print(f"std: {z.variance**0.5}")

Generating data
Filtering
At final time step:
truth: tensor([ 0.0809, 67.0811])
mean: tensor([-1.1345, 65.9661])
std: tensor([0.8514, 1.6066])
