In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/python/BayesianFiltering

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/python/BayesianFiltering


In [2]:
%%capture
!pip install jaxtyping
!pip install dynamax

In [3]:
from jax import numpy as jnp
from jax import jacfwd, jacrev, jit, vmap, lax, make_jaxpr
from jax import random as jr
from jax import tree_util as jtu
import jax
from functools import partial
import tensorflow

import gaussfiltax.utils as utils
import gaussfiltax.containers as containers
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
import time
import gaussfiltax.inference as gf
import gaussfiltax.particlefilt as pf
from gaussfiltax.models import ParamsNLSSM, NonlinearGaussianSSM, NonlinearSSM, ParamsBPF

import matplotlib.pyplot as plt
import matplotlib_inline
from IPython.display import set_matplotlib_formats
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

## Models and simulations

In [4]:
# Parameters
state_dim = 3
state_noise_dim = 3
emission_dim = 1
emission_noise_dim = 1
seq_length = 100
mu0 = jnp.zeros(state_dim)
Sigma0 = 1.0 * jnp.eye(state_dim)
Q = 1.0 * jnp.eye(state_noise_dim)
R = 1.0 * jnp.eye(emission_noise_dim)

# ICASSP
f1 = lambda x, q, u: (1-u) * x / 2.  + u * jnp.sin(10 * x) + q
g1 = lambda x, r, u:  0.01 * jnp.dot(x, x) + r
def g1lp(x,y,u):
    return MVN(loc = g1(x, 0.0, u), covariance_matrix = R).log_prob(y)

g2 = lambda x, r, u:  1.0 * jnp.sqrt(jnp.dot(x, x)) + r
def g2lp(x,y,u):
    return MVN(loc = g2(x, 0.0, u), covariance_matrix = R).log_prob(y)


# Lorenz 63
def lorentz_63(x, sigma=10, rho=28, beta=2.667, dt=0.01):
    dx = dt * sigma * (x[1] - x[0])
    dy = dt * (x[0] * rho - x[1] - x[0] *x[2]) 
    dz = dt * (x[0] * x[1] - beta * x[2])
    return jnp.array([dx+x[0], dy+x[1], dz+x[2]])
f63 = lambda x, q, u: lorentz_63(x) + q

# Lorentz 96
alpha = 1.0
beta = 1.0
gamma = 8.0
dt = 0.01
H = jnp.zeros((emission_dim,state_dim))
for row in range(emission_dim):
    col = 2*row
    H = H.at[row,col].set(1.0)
CP = lambda n: jnp.block([[jnp.zeros((1,n-1)), 1.0 ],[jnp.eye(n-1), jnp.zeros((n-1,1))]])
A = CP(state_dim)
B = jnp.power(A, state_dim-1) - jnp.power(A, 2)
f96 = lambda x, q, u: x + dt * (alpha * jnp.multiply(A @ x, B @ x) - beta * x + gamma * jnp.ones(state_dim)) + q
g96 = lambda x, r, u: H @ x + r
def g96lp(x,y,u):
  return MVN(loc = g96(x, 0.0, u), covariance_matrix = R).log_prob(y)


# stochastic growth model
f3 = lambda x, q, u: x / 2. + 100. * x / (1 + jnp.power(x, 2)) * u + q
g3 = lambda x, r, u: 0.8 * x + r
def g3lp(x,y,u):
    return MVN(loc = g3(x, 0.0, u), covariance_matrix = R).log_prob(y)


# Stochastic Volatility
alpha = 0.91
sigma = 1.0
beta = 0.5
f_sv = lambda x, q, u: alpha * x + sigma * q
g_sv = lambda x, r, u: beta * jnp.exp(x/2) * r
def svlp(x,y,u):
    return MVN(loc = g_sv(x, 0.0, u), covariance_matrix = g_sv(x, 1.0, u)**2 * R).log_prob(y)

# Multivariate SV
Phi = 0.8 * jnp.eye(state_dim)
f_msv = lambda x, q, u: Phi @ x +  q
g_msv = lambda x, r, u: 0.5 * jnp.multiply(jnp.exp(x/2), r)
def msvlp(x,y,u):
    return MVN(loc = g_msv(x, 0.0, u), covariance_matrix = jnp.diag(jnp.exp(x/2.0)) @ R @ jnp.diag(jnp.exp(x/2.0))).log_prob(y)


# Inputs
# inputs = 1. * jnp.cos(0.1 * jnp.arange(seq_length))
sm = lambda x : jnp.exp(x) / (1+jnp.exp(x))
inputs = sm(jnp.arange(seq_length)-50) # off - on
# inputs = 1.0 * jnp.ones(seq_length) # on - on

In [5]:
f = f63
g = g1
glp = g1lp

In [6]:
# initialization
model = NonlinearSSM(state_dim, state_noise_dim, emission_dim, emission_noise_dim)
params = ParamsNLSSM(
    initial_mean=mu0,
    initial_covariance=Sigma0,
    dynamics_function=f,
    dynamics_noise_bias=jnp.zeros(state_noise_dim),
    dynamics_noise_covariance=Q,
    emission_function=g,
    emission_noise_bias=jnp.zeros(emission_noise_dim),
    emission_noise_covariance=R,
)

## Experiments

In [7]:
Nsim = 100
gsf_rmse = jnp.zeros(Nsim)
gsf_time = jnp.zeros(Nsim)
agsf_rmse = jnp.zeros(Nsim)
agsf_time = jnp.zeros(Nsim)
bpf_rmse = jnp.zeros(Nsim)
bpf_time = jnp.zeros(Nsim)
next_key = jr.PRNGKey(10123412)
for i in range(Nsim):
    print('sim {}/{}'.format(i+1, Nsim))
    # Generate Data
    key, next_key = jr.split(next_key)
    states, emissions = model.sample(params, key, seq_length, inputs = inputs)
    

    # GSF
    M = 3
    tin = time.time()
    posterior_filtered_gsf = gf.gaussian_sum_filter(params, emissions, M, 1, inputs)
    point_estimate_gsf = jnp.sum(jnp.einsum('ijk,ij->ijk', posterior_filtered_gsf.means, posterior_filtered_gsf.weights), axis=0)
    time_gsf = time.time() - tin
    print('       Time taken for GSF: ', time_gsf)

    # AGSF
    tin = time.time()
    num_components = [M, 2, 2] # has to be set correctly OW "TypeError: Cannot interpret '<function <lambda> at 0x12eae3ee0>' as a data type". Check internal containers._branch_from_node
    posterior_filtered_agsf, aux_outputs = gf.augmented_gaussian_sum_filter(params, emissions, num_components, rng_key = key, opt_args = (1.0, 1.0), inputs=inputs)    
    point_estimate_agsf = jnp.sum(jnp.einsum('ijk,ij->ijk', posterior_filtered_agsf.means, posterior_filtered_agsf.weights), axis=0)
    time_agsf = time.time() - tin
    print('       Time taken for AGSF: ', time_agsf)

    # BPF
    tin = time.time()
    num_particles = 100

    params_bpf = ParamsBPF(
        initial_mean=mu0,
        initial_covariance=Sigma0,
        dynamics_function=f,
        dynamics_noise_bias=jnp.zeros(state_noise_dim),
        dynamics_noise_covariance=Q,
        emission_function=g,
        emission_noise_bias=jnp.zeros(emission_noise_dim),
        emission_noise_covariance=R,
        emission_distribution_log_prob = glp
    )

    posterior_bpf = gf.bootstrap_particle_filter(params_bpf, emissions, num_particles, key, inputs)
    point_estimate_bpf = jnp.sum(jnp.einsum('ijk,ij->ijk', posterior_bpf["particles"], posterior_bpf["weights"]), axis=0)
    time_bpf = time.time()- tin
    print('       Time taken for BPF: ', time_bpf)

    # Computation of errors
    gsf_rmse = gsf_rmse.at[i].set(utils.rmse(point_estimate_gsf, states))
    agsf_rmse = agsf_rmse.at[i].set(utils.rmse(point_estimate_agsf, states))
    bpf_rmse = bpf_rmse.at[i].set(utils.rmse(point_estimate_bpf, states))

    gsf_time = gsf_time.at[i].set(time_gsf)
    agsf_time = agsf_time.at[i].set(time_agsf)
    bpf_time = bpf_time.at[i].set(time_bpf)

    print('              GSF RMSE:', gsf_rmse[i])
    print('              AGSF RMSE:', agsf_rmse[i])
    print('              BPF RMSE:', bpf_rmse[i])


sim 1/100
       Time taken for GSF:  4.420901298522949
       Time taken for AGSF:  8.683189630508423
       Time taken for BPF:  2.309436082839966
              GSF RMSE: 7.892008
              AGSF RMSE: 6.739996
              BPF RMSE: 26.818485
sim 2/100
       Time taken for GSF:  0.7338500022888184
       Time taken for AGSF:  2.321514368057251
       Time taken for BPF:  0.9086179733276367
              GSF RMSE: 4.1749196
              AGSF RMSE: 4.225998
              BPF RMSE: 10.4304085
sim 3/100
       Time taken for GSF:  1.097968578338623
       Time taken for AGSF:  3.2195329666137695
       Time taken for BPF:  0.9216327667236328
              GSF RMSE: 25.43598
              AGSF RMSE: 4.047869
              BPF RMSE: 21.87963
sim 4/100
       Time taken for GSF:  0.9531919956207275
       Time taken for AGSF:  2.3769543170928955
       Time taken for BPF:  0.9057424068450928
              GSF RMSE: 25.680283
              AGSF RMSE: 29.950012
              BPF RMSE: 

In [8]:
import pandas as pd
gsf_armse = jnp.mean(gsf_rmse)
agsf_armse = jnp.mean(agsf_rmse)
bpf_armse = jnp.mean(bpf_rmse)
gsf_atime = jnp.mean(gsf_time)
agsf_atime = jnp.mean(agsf_time)
bpf_atime = jnp.mean(bpf_time)

gsf_tab_out = '{:10.2f}±{:10.2f}'.format(gsf_armse, jnp.std(gsf_rmse))
agsf_tab_out = '{:10.2f}±{:10.2f}'.format(agsf_armse, jnp.std(agsf_rmse))
bpf_tab_out = '{:10.2f}±{:10.2f}'.format(bpf_armse, jnp.std(bpf_rmse))

gsf_tab_out1 = '{:10.2f}±{:10.2f}'.format(gsf_atime, jnp.std(gsf_time))
agsf_tab_out1 = '{:10.2f}±{:10.2f}'.format(agsf_atime, jnp.std(agsf_time))
bpf_tab_out1 = '{:10.2f}±{:10.2f}'.format(bpf_atime, jnp.std(bpf_time))

df = pd.DataFrame(columns = [' ','RMSE','time(s)'])
df[' '] = ['GSF', 'AGSF', 'BPF']
df['RMSE'] = [gsf_tab_out, agsf_tab_out, bpf_tab_out]
df['time(s)'] = [gsf_tab_out1, agsf_tab_out1, bpf_tab_out1]
print(df.to_latex(index=False))  
df

\begin{tabular}{lll}
\toprule
     &                  RMSE &               time(s) \\
\midrule
 GSF &      19.83±     12.46 &       0.89±      0.41 \\
AGSF &      18.56±     13.84 &       2.79±      0.75 \\
 BPF &      17.36±      9.49 &       1.05±      0.26 \\
\bottomrule
\end{tabular}



  print(df.to_latex(index=False))


Unnamed: 0,Unnamed: 1,RMSE,time(s)
0,GSF,19.83± 12.46,0.89± 0.41
1,AGSF,18.56± 13.84,2.79± 0.75
2,BPF,17.36± 9.49,1.05± 0.26
