In [1]:
%matplotlib inline
import autograd.numpy as np
from autograd.numpy.linalg import inv, det
from autograd import grad
from functools import reduce
import autograd.numpy.random as npr
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal, gamma, invgamma 
from scipy.special import digamma, polygamma
from scipy.special import gamma as gafun
from util import *
from kalman import *

In [2]:
num_sequences = 50
time_length = 100
image_length = 20
generate_dataset_filename = './bouncingball'
dataset_filename = './bouncingball.npy'

## Load Dataset

In [3]:
## generate dataset
# video_arrays = generate_dataset(generate_dataset_filename, num_sequences, image_length, time_length)
train_dataset = load_data(dataset_filename, num_sequences, image_length, time_length)
## just take one video
video = train_dataset[0, :, :]

dataset size : 50, image length : 20, sequence length : 100


## Hyper-parameters

In [17]:
N = train_dataset.shape[1]
T = train_dataset.shape[2]
K = 8

# initialize hyperpriors
alpha_a_0, alpha_b_0, r_a_0, r_b_0, rho_a, rho_b, mu_x0_0_sigma_diag, sigma_x0_a, sigma_x0_b = init_hyper(K)

## sample hyper parameters, hidden state priors
alpha = gamma.rvs(a=alpha_a_0, scale=(1.0 / alpha_b_0))
r = gamma.rvs(a=r_a_0, scale=(1.0 / r_b_0))
mu_x0 = multivariate_normal.rvs(np.zeros(K), np.diag(mu_x0_0_sigma_diag))
mu_x0.shape = (K, 1)
sigma_x0 = 1.0 / gamma.rvs(a=sigma_x0_a, scale=(1.0 / sigma_x0_b))
sigma_x0 = np.diag(sigma_x0)

## Initialize Hidden State Sufficient Statistics
W_A, S_A, W_C, S_C = init_hsss(N, K)
Y_hat = compute_yhat(video, N, T)

## VBM Step

In [18]:
## infer variational distributions
q_sigma_A, q_mu_A, sigma_C, q_rho_a, q_rho_b, q_mu_C = infer_qs(W_A, S_A, W_C, S_C, Y_hat, alpha, r, rho_a, rho_b, N, T)
## compute expected natural parameters
E_A, E_ATA, E_rho_s, E_log_rho_s, E_R_inv, E_C, E_R_inv_C, E_CT_R_inv_C = natstats(q_rho_a, q_rho_b, S_A, S_C, q_sigma_A, sigma_C, N, K)
natstats_list = [E_A, E_ATA, E_rho_s, E_log_rho_s, E_R_inv, E_C, E_R_inv_C, E_CT_R_inv_C]

## VBE Step

In [20]:
# forward recursion
Mu_xt, Sigma_xt, Sigma_star, log_Z = forward(mu_x0, sigma_x0, N, T, K, video, *natstats_list)
# backward recursion
psi_T_inv = np.zeros((K, K))
eta_T = np.ones((K, 1))
Psi, Eta = backward(psi_T_inv, eta_T, video, N, T, K, *natstats_list)
# update Gamma and Omega
Gamma_ts, Omega_ts, Gamma_ttp1s = update_marginals(Mu_xt, Sigma_xt, Sigma_star, Psi, Eta, N, T, K, E_A, E_CT_R_inv_C)
# update hsss
W_A, S_A, W_C, S_C = update_hsss(Gamma_ts, Omega_ts, Gamma_ttp1s, video, N, T)

  return f_raw(*args, **kwargs)


## Update Hyper-parameters

In [7]:
alpha_new, r_new, mu_x0_new, sigma_x0_new, rho_a_new, rho_b_new = update_hyper(Gamma_ts, Omega_ts, S_A, S_C, q_sigma_A, sigma_C, E_R_inv, E_rho_s, E_log_rho_s, N, rho_a, rho_b)

## ELBO

In [17]:
p_sigma_A = np.diag(1.0 / alpha_new)
p_sigma_C = np.diag(1.0 / r_new)

kl_A = KL_A(q_mu_A, q_sigma_A, p_sigma_A, K)
kl_C = KL_C(q_mu_C, q_rhos, sigma_C, p_sigma_C, N)
kl_rho = KL_gamma(q_rho_a, q_rho_b, rho_a, rho_b)
ELBO = - kl_A - kl_C - kl_rho + log_Z

  return f_raw(*args, **kwargs)
