## 2. Vanilla GPFA

It is vanilla in the sense that it only handles trials with the same length for now.

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import quantities as pq
from sklearn.decomposition import FactorAnalysis
import matplotlib.pyplot as plt

In [4]:
from e_step import e_step
from m_step import m_step

In [19]:
# =====================
# load simulated data
# =====================

seqs = np.load('simulated_data1.npy',allow_pickle=True)

In [55]:
# ==================================
# Initialize state model parameters
# ==================================
x_dim = 2
bin_width=20.0    # in ms, this should match how we simulated the synthetic data.
tau_init=100.0
eps_init=1.0E-3
em_tol = 1.0E-3

max_iteration_num = 100
    
params_init = dict()
params_init['covType'] = 'rbf' # so far only rbf is implemented for this vanilla version
# GP timescale
# Assume binWidth is the time step size.
params_init['gamma'] = (bin_width / tau_init) ** 2 * np.ones(x_dim)
# GP noise variance
params_init['eps'] = eps_init * np.ones(x_dim)

# ========================================
# Initialize observation model parameters
# ========================================
print('Initializing parameters using factor analysis...')

y_all = np.hstack(seqs['y'])
fa = FactorAnalysis(n_components=x_dim, copy=True,
                    noise_variance_init=np.diag(np.cov(y_all, bias=True)))
fa.fit(y_all.T)
params_init['d'] = y_all.mean(axis=1)
params_init['C'] = fa.components_.T
params_init['R'] = np.diag(fa.noise_variance_)
params_init['x_dim'] = 2
params_init['tau'] = tau_init


Initializing parameters using factor analysis...


In [57]:
# =====================
# Fit model parameters
# =====================

params = params_init
for i in range(max_iteration_num):
    seqs_out, LL_i = e_step(seqs, params)
    params = m_step(seqs_out, params)

    
    # Check convergence
    if i <= 1:
        LL_base = LL_i
        LL_old = LL_i
    elif LL_i < LL_old:
        print(f"\nError: Log likelihood decreased from {LL_old:.1f} to {LL_i:.1f}")
        break
    elif (LL_i - LL_base) < (1 + em_tol) * (LL_old - LL_base):
        print(f"\nConverged after {i+1} EM iterations")
        break
    LL_old = LL_i
            


Converged after 42 EM iterations
