In [3]:
import jax
from jax import numpy as jnp
from jax import random as jr
from jax import vmap, lax

In [4]:
from ssm_jax.bp.gauss_bp import (potential_from_conditional_linear_gaussian,
                                 info_condition,
                                 info_marginalise,
                                 info_multiply,
                                 info_divide)

from ssm_jax.lgssm.models import LinearGaussianSSM
from ssm_jax.lgssm.info_inference import LGSSMInfoParams, lgssm_info_smoother

#### Construct LGSSM model

In [5]:
delta = 1.0
F = jnp.array([[1.0, 0, delta, 0], [0, 1.0, 0, delta], [0, 0, 1.0, 0], [0, 0, 0, 1.0]])

H = jnp.array([[1.0, 0, 0, 0], [0, 1.0, 0, 0]])

state_size, _ = F.shape
observation_size, _ = H.shape

Q = jnp.eye(state_size) * 0.001
Q_prec = jnp.linalg.inv(Q)
R = jnp.eye(observation_size) * 1.0
R_prec = jnp.linalg.inv(R)

input_size = 1
B = jnp.array([1.0, 0.5, -0.05, -0.01]).reshape((state_size, input_size))
b = jnp.ones((state_size,)) * 0.01
D = jnp.ones((observation_size, input_size))
d = jnp.ones((observation_size,)) * 0.02

# Prior parameter distribution
mu0 = jnp.array([8.0, 10.0, 1.0, 0.0])
Sigma0 = jnp.eye(state_size) * 0.1
Lambda0 = jnp.linalg.inv(Sigma0)
eta0 = Lambda0 @ mu0

# Construct LGSSM
lgssm = LinearGaussianSSM(
    initial_mean=mu0,
    initial_covariance=Sigma0,
    dynamics_matrix=F,
    dynamics_covariance=Q,
    dynamics_input_weights=B,
    dynamics_bias=b,
    emission_matrix=H,
    emission_covariance=R,
    emission_input_weights=D,
    emission_bias=d,
)

# Collect information form parameters
lgssm_info = LGSSMInfoParams(
    initial_mean=mu0,
    initial_precision=Lambda0,
    dynamics_matrix=F,
    dynamics_precision=Q_prec,
    dynamics_input_weights=B,
    dynamics_bias=b,
    emission_matrix=H,
    emission_precision=R_prec,
    emission_input_weights=D,
    emission_bias=d,
)



#### Sample from the model

In [177]:
key = jr.PRNGKey(111)
num_timesteps = 15
inputs = jnp.zeros((num_timesteps, input_size))
z, y = lgssm.sample(key, num_timesteps, inputs)

lgssm_info_posterior = lgssm_info_smoother(lgssm_info, y, inputs)

### Message Passing Inference

In [7]:
def info_blocks_condition(K_blocks, h_blocks,y):
    K11, K12, _ = K_blocks
    h1, _ = h_blocks
    return info_condition(K11, K12, h1, y)

# Can I do this without including the inputs yet?
#  maybe yes...
def cliques_from_lgssm(lgssm_info, inputs):
    """Construct pairwise latent and emission cliques from model."""
    
    # Calculate net inputs
    B, b = lgssm_info.dynamics_input_weights, lgssm_info.dynamics_bias
    D, d = lgssm_info.emission_input_weights, lgssm_info.emission_bias
    latent_net_inputs = vmap(jnp.dot,(None,0))(B, inputs) + b
    emission_net_inputs = vmap(jnp.dot,(None,0))(D, inputs) + d
    
    F, Q_prec = lgssm_info.dynamics_matrix, lgssm_info.dynamics_precision
    H, R_prec = lgssm_info.emission_matrix, lgssm_info.emission_precision 
    # Each of these is a tuple ((K11, K12, K22), (h1, h2)) where each element contains 
    #  the clique parameters for different timepoints as stacked rows.
    latent_pots = vmap(potential_from_conditional_linear_gaussian,(None,0,None))(F,latent_net_inputs[:-1],Q_prec)
    emission_pots = vmap(potential_from_conditional_linear_gaussian,(None,0,None))(H,emission_net_inputs[1:],R_prec)
    
    Lambda0, mu0 = lgssm_info.initial_precision, lgssm_info.initial_mean
    prior_pot = (Lambda0, Lambda0 @ mu0)
    
    return prior_pot, emission_pots, latent_pots

def absorb_emission_message(message, latent_pot):
    (K11, K12, K22), (h1,h2) = latent_pot
    K22, h2 = info_multiply(message,(K22,h2))
    return (K11, K12, K22), (h1, h2)

def absorb_latent_message(message, latent_pot):
    (K11, K12, K22), (h1,h2) = latent_pot
    K11, h1 = info_multiply(message,(K11,h1))
    return (K11, K12, K22), (h1, h2)

def step(carry, x):
    message = carry
    #K_message, h_message = carry
    latent_pot = x
    
    latent_pot = absorb_latent_message(message,latent_pot)
    #padded_message = ((K_message,0,0),(h_message,0))
    #latent_pot = info_multiply(padded_message,latent_pot)
    message_out = info_marginalise(*latent_pot)
    return message_out, message_out

In [8]:
# First emission
emission_pot1 = potential_from_conditional_linear_gaussian(H,D @ inputs[0] + d, R_prec)
emission_message1 = info_blocks_condition(*emission_pot1,y[0])

# x1 marginal
x1 = info_multiply((Lambda0,eta0),emission_message1)

In [9]:
_, emission_pots, latent_pots = cliques_from_lgssm(lgssm_info, inputs)

# Generate the emission messages by conditioning on emissions
emission_messages = vmap(info_blocks_condition)(*emission_pots,y[1:])

# Absorb the emission messages into the latent potentials
latent_pots = vmap(absorb_emission_message)(emission_messages,latent_pots)

# Initial carry is the marginal params from x1
init_carry = x1
_, out = lax.scan(step,init_carry,latent_pots)

In [10]:
# Add the initial marginal by prepending to each array.
marg_up = jax.tree_map(lambda h, t: jnp.row_stack((h[None,...],t)),x1,out)
K_up, h_up = marg_up
print(jnp.allclose(lgssm_info_posterior.filtered_precisions,K_up,
                   rtol=1e-3))
print(jnp.allclose(lgssm_info_posterior.filtered_etas,h_up,
                   rtol=1e-3))

True
True


### Heading Up

In [176]:
def info_blocks_condition(K_blocks, h_blocks,y):
    K11, K12, _ = K_blocks
    h1, _ = h_blocks
    return info_condition(K11, K12, h1, y)

def cliques_from_lgssm(lgssm_params, inputs):
    """Construct pairwise latent and emission cliques from model.
    
    Args:
        lgssm_params: an LGSSMInfoParams instance.
        inputs: (T,D_in): array of inputs.
        
    Returns:
        prior_pot: A tuple of parameters representing the prior potential,
                    (Lambda0, eta0)
        lambda_pots: A tuple containing the parameters for each pairwise latent
                     clique potential - ((K11, K12, K22),(h1, h2)).
        emission_pots: A tuple containing the parameters for each pairwise
                     emission clique potential - ((K11, K12, K22),(h1, h2)).
    """

    B, b = lgssm_params.dynamics_input_weights, lgssm_params.dynamics_bias
    D, d = lgssm_params.emission_input_weights, lgssm_params.emission_bias
    latent_net_inputs = vmap(jnp.dot,(None,0))(B, inputs) + b
    emission_net_inputs = vmap(jnp.dot,(None,0))(D, inputs) + d
    
    F, Q_prec = lgssm_params.dynamics_matrix, lgssm_params.dynamics_precision
    H, R_prec = lgssm_params.emission_matrix, lgssm_params.emission_precision 
    # Each of these is a tuple ((K11, K12, K22), (h1, h2)) where each element contains 
    #  the clique parameters for different timepoints as stacked rows.
    latent_pots = vmap(potential_from_conditional_linear_gaussian,(None,0,None))(F,latent_net_inputs[:-1],Q_prec)
    emission_pots = vmap(potential_from_conditional_linear_gaussian,(None,0,None))(H,emission_net_inputs[1:],R_prec)
    
    Lambda0, mu0 = lgssm_params.initial_precision, lgssm_params.initial_mean
    prior_pot = (Lambda0, Lambda0 @ mu0)
    
    return prior_pot, emission_pots, latent_pots

def absorb_emission_message(message, latent_pot):
    (K11, K12, K22), (h1,h2) = latent_pot
    K22, h2 = info_multiply(message,(K22,h2))
    return (K11, K12, K22), (h1, h2)

def absorb_latent_message(message, latent_pot):
    (K11, K12, K22), (h1,h2) = latent_pot
    K11, h1 = info_multiply(message,(K11,h1))
    return (K11, K12, K22), (h1, h2)

def forward_step(carry, x):
    prev_bel = carry
    latent_pot, emission_pot, y = x
    
    latent_pot = absorb_latent_message(prev_bel,latent_pot)
    latent_message = info_marginalise(*latent_pot)
    
    emission_message = info_blocks_condition(*emission_pot,y)
    bel = info_multiply(latent_message, emission_message)
    
    return bel, (bel, latent_message)

In [21]:
# First emission
emission_pot1 = potential_from_conditional_linear_gaussian(H,D @ inputs[0] + d, R_prec)
emission_message1 = info_blocks_condition(*emission_pot1,y[0])

# x1 marginal
x1 = info_multiply((Lambda0,eta0),emission_message1)

In [26]:
_, emission_pots, latent_pots = cliques_from_lgssm(lgssm_info, inputs)


# Initial carry is the marginal params from x1
init_carry = x1
_, (bels, messages) = lax.scan(forward_step,init_carry,(latent_pots, emission_pots, y[1:]))

In [61]:
# Add the initial marginal by prepending to each array.
bels_up = jax.tree_map(lambda h, t: jnp.row_stack((h[None,...],t)),x1,bels)
K_up, h_up = bels_up
print(jnp.allclose(lgssm_info_posterior.filtered_precisions,K_up,
                   rtol=1e-3))
print(jnp.allclose(lgssm_info_posterior.filtered_etas,h_up,
                   rtol=1e-3))

True
True


#### Heading Down

In [156]:
def info_marginalise_down(K_blocks, hs):
    """Calculate the parameters of marginalised MVN.
    
    For x1, x2 joint distributed as
        p(x1, x2) = Nc(x1,x2| h, K),
    the marginal distribution of x1 is given by:
        p(x2) = \int p(x1, x2) dx1 = Nc(x2 | h2_marg, K2_marg)
    where,
        h2_marg = h2 - K21 K11^{-1} h1
        K2_marg = K22 - K21 K11^{-1} K12

    Args:
        K_blocks: blocks of the joint precision matrix, (K11, K12, K22),
                    K11 (dim1,dim1),
                    K12 (dim1, dim2),
                    K22 (dim2, dim2).
        hs (D,1): joint precision weighted mean, (h1, h2):
                    h1 (dim1, 1),
                    h2 (dim2, 1).
    Returns:
        K2_marg (dim2, dim2): marginal precision matrix.
        h2_marg (dim2,1): marginal precision weighted mean.
    """
    K11, K12, K22 = K_blocks
    h1, h2 = hs 
    G = jnp.linalg.solve(K22,K12.T)
    K1_marg = K11 - K12 @ G
    h1_marg = h1 - G.T @ h2
    return K1_marg, h1_marg

def backward_step(carry, x):
    prev_bel = carry
    bel, message_up, latent_pot = x
    
    bel_minus_message_up = info_divide(prev_bel, message_up)
    latent_pot = info_multiply(latent_pot,((0,0,bel_minus_message_up[0]),(0,bel_minus_message_up[1])))
    message_down = info_marginalise_down(*latent_pot)
    
    bel = info_multiply(bel, message_down)
    return bel, bel

In [174]:
# Feels a bit much...
#split_bels = jax.tree_map(lambda a: (a[:-1],a[-1]),bels_up)
#bels_rest, init_carry = tuple(zip(*split_bels))
init_carry = jax.tree_map(lambda a: a[-1], bels_up)
bels_rest = jax.tree_map(lambda a: a[:-1], bels_up)
_, bels_down = lax.scan(backward_step,init_carry, (bels_rest, messages, latent_pots),reverse=True)
bels_down = jax.tree_map(lambda h, t: jnp.row_stack((h, t[None,...])), bels_down, init_carry)

In [175]:
K_down, h_down = bels_down
print(jnp.allclose(lgssm_info_posterior.smoothed_precisions,K_down,
                   rtol=1e-3))
print(jnp.allclose(lgssm_info_posterior.smoothed_etas,h_down,
                   rtol=1e-3))

True
True


## All together now

In [198]:
from ssm_jax.lgssm.info_inference_test import TestInfoFilteringAndSmoothing

In [199]:
lgssm = TestInfoFilteringAndSmoothing.lgssm
lgssm_info = TestInfoFilteringAndSmoothing.lgssm_info

In [200]:
key = jr.PRNGKey(111)
num_timesteps = 15
inputs = jnp.zeros((num_timesteps, input_size))
z, y = lgssm.sample(key, num_timesteps, inputs)

lgssm_info_posterior = lgssm_info_smoother(lgssm_info, y, inputs)

In [201]:
def cliques_from_lgssm(lgssm_params, inputs):
    """Construct pairwise latent and emission cliques from model.
    
    Args:
        lgssm_params: an LGSSMInfoParams instance.
        inputs: (T,D_in): array of inputs.
        
    Returns:
        prior_pot: A tuple of parameters representing the prior potential,
                    (Lambda0, eta0)
        lambda_pots: A tuple containing the parameters for each pairwise latent
                     clique potential - ((K11, K12, K22),(h1, h2)).
        emission_pots: A tuple containing the parameters for each pairwise
                     emission clique potential - ((K11, K12, K22),(h1, h2)).
    """

    B, b = lgssm_params.dynamics_input_weights, lgssm_params.dynamics_bias
    D, d = lgssm_params.emission_input_weights, lgssm_params.emission_bias
    latent_net_inputs = vmap(jnp.dot,(None,0))(B, inputs) + b
    emission_net_inputs = vmap(jnp.dot,(None,0))(D, inputs) + d
    
    F, Q_prec = lgssm_params.dynamics_matrix, lgssm_params.dynamics_precision
    H, R_prec = lgssm_params.emission_matrix, lgssm_params.emission_precision 
    # Each of these is a tuple ((K11, K12, K22), (h1, h2)) where each element contains 
    #  the clique parameters for different timepoints as stacked rows.
    latent_pots = vmap(potential_from_conditional_linear_gaussian,(None,0,None))(F,latent_net_inputs[:-1],Q_prec)
    emission_pots = vmap(potential_from_conditional_linear_gaussian,(None,0,None))(H,emission_net_inputs,R_prec)
    
    Lambda0, mu0 = lgssm_params.initial_precision, lgssm_params.initial_mean
    prior_pot = (Lambda0, Lambda0 @ mu0)
    
    return prior_pot, emission_pots, latent_pots

In [202]:
prior_pot, emission_pots, latent_pots = cliques_from_lgssm(lgssm_info, inputs)
init_emission_pot = jax.tree_map(lambda a: a[0], emission_pots)
emission_pots_rest = jax.tree_map(lambda a: a[1:], emission_pots)

# Absorb first emission message
init_emission_message = info_blocks_condition(*init_emission_pot,y[0])
init_carry = info_multiply((Lambda0,eta0),init_emission_message)

# Message pass along chain
_, (bels, messages) = lax.scan(forward_step,init_carry,(latent_pots, emission_pots_rest, y[1:]))
# Append first belief
bels_up = jax.tree_map(lambda h, t: jnp.row_stack((h[None,...],t)),init_carry,bels)

# Extract final belief
init_carry = jax.tree_map(lambda a: a[-1], bels_up)
bels_rest = jax.tree_map(lambda a: a[:-1], bels_up)

# Message pass back along chain
_, bels_down = lax.scan(backward_step,init_carry, (bels_rest, messages, latent_pots),reverse=True)
# Append final belief 
bels_down = jax.tree_map(lambda h, t: jnp.row_stack((h, t[None,...])), bels_down, init_carry)

In [203]:
K_down, h_down = bels_down
print(jnp.allclose(lgssm_info_posterior.smoothed_precisions,K_down,
                   rtol=1e-3))
print(jnp.allclose(lgssm_info_posterior.smoothed_etas,h_down,
                   rtol=1e-3))

True
True
