### Get started

- [x] Install pytorch
- [x] Install pyro:
> pip install pyro-ppl
- Intro to Models: http://pyro.ai/examples/intro_part_i.html
- Kalman Filter with Pyro: http://pyro.ai/examples/ekf.html

In [7]:
import pyro

In [8]:
## Sample distribution
loc = 0.   # mean zero
scale = 1. # unit variance
normal = torch.distributions.Normal(loc, scale) # create a normal distribution object
x = normal.rsample() # draw a sample from N(0,1)
print("sample", x)
print("log prob", normal.log_prob(x)) # score the sample from N(0,1)

sample tensor(0.7935)
log prob tensor(-1.2337)


In [9]:
## Simple weather model
def weather():
    cloudy = torch.distributions.Bernoulli(0.3).sample() # sample 0 or 1 with 30% chance for 1
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
    return cloudy, temp.item()

In [13]:
c = torch.distributions.Bernoulli(0.3).sample()

In [17]:
weather()

('sunny', 92.30638885498047)

In [18]:
cloudy = 'cloudy' if 1 == 1.0 else 'sunny'

In [21]:
mean_temp = {'cloudy': 55.0, 'sunny': 75.0}['sunny']

In [22]:
mean_temp

75.0

## Example: Kalman Filter
- Ref: http://pyro.ai/examples/ekf.html

In [2]:
import os
import math

import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, config_enumerate
from pyro.contrib.tracking.extended_kalman_filter import EKFState
from pyro.contrib.tracking.distributions import EKFDistribution
from pyro.contrib.tracking.dynamic_models import NcvContinuous
from pyro.contrib.tracking.measurements import PositionMeasurement

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('0.3.0')
pyro.enable_validation(True)

In [3]:
dt = 1e-2
num_frames = 10
dim = 4

# Continuous model
ncv = NcvContinuous(dim, 2.0)

# Truth trajectory
xs_truth = torch.zeros(num_frames, dim)
# initial direction
theta0_truth = 0.0
# initial state
with torch.no_grad():
    xs_truth[0, :] = torch.tensor([0.0, 0.0,  math.cos(theta0_truth), math.sin(theta0_truth)])
    for frame_num in range(1, num_frames):
        # sample independent process noise
        dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))
        xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx

In [6]:
# Measurements
measurements = []
mean = torch.zeros(2)
# no correlations
cov = 1e-5 * torch.eye(2)
with torch.no_grad():
    # sample independent measurement noise
    dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))
    # compute measurement means
    zs = xs_truth[:, :2] + dzs

In [7]:
def model(data):
    # a HalfNormal can be used here as well
    R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)
    Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)
    # observe the measurements
    pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,
                                                      Q, time_steps=num_frames),
                obs=data)

guide = AutoDelta(model)  # MAP estimation

In [10]:
optim = pyro.optim.Adam({'lr': 2e-2})
svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))

pyro.set_rng_seed(0)
pyro.clear_param_store()

for i in range(250 if not smoke_test else 2):
    loss = svi.step(zs)
    if not i % 10:
        print('loss: ', loss)

loss:  -10.6841459274292
loss:  -11.320069313049316
loss:  -11.894853591918945
loss:  -12.406867980957031
loss:  -12.855977058410645
loss:  -13.241031646728516
loss:  -13.559944152832031
loss:  -13.811338424682617
loss:  -13.99674129486084
loss:  -14.122206687927246
loss:  -14.198473930358887
loss:  -14.239151000976562
loss:  -14.257674217224121
loss:  -14.264633178710938
loss:  -14.2666654586792
loss:  -14.267088890075684
loss:  -14.267152786254883
loss:  -14.26719856262207
loss:  -14.267251014709473
loss:  -14.267306327819824
loss:  -14.267353057861328
loss:  -14.267391204833984
loss:  -14.267426490783691
loss:  -14.267455101013184
loss:  -14.267480850219727


In [11]:
# retrieve states for visualization
R = guide()['pv_cov'] * torch.eye(4)
Q = guide()['measurement_cov'] * torch.eye(2)
ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)
states= ekf_dist.filter_states(zs)