# Fit an LDS to the Allen Data

In [7]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from pykalman import KalmanFilter
import pdb

In [8]:
datadir = '/Users/antoniomoretti/Desktop/dhern-ts_wcommona-b4b1ad88b3aa/data/allendata/'

with open(datadir + "allendatadict_030", 'rb') as handle:
    data = pickle.load(handle, encoding='latin1')

In [9]:
X = data['Ytrain']

In [10]:
class LDS:
    """
    Train an LDS with the EM algorithm, find latent paths using Kalman Filter and Smoother
    """
    def __init__(self, X, Dz):
        """
        Initialize observations (X) and smoothed latent state means (Z)
        X: 3-Tensor of n_trials, NTbins, Dx
        Z: 3-Tensor of n_trials, NTbins, Dz
        """
        self.X = X
        self.NTrials, self.NTbins, self.Dx = X.shape
        self.Dz = Dz
        self.kf = KalmanFilter(n_dim_state=Dz, n_dim_obs=self.Dx)
        self.Z = np.zeros([self.NTrials, self.NTbins, self.Dz])
        self.filtered_paths = np.zeros([self.NTrials, self.NTbins, self.Dz])
        self.filtered_covar = np.zeros([self.NTrials, self.NTbins, self.Dz, self.Dz])
        self.smoothed_covar = np.zeros([self.NTrials, self.NTbins, self.Dz, self.Dz])
        

    def train(self, epochs):
        print("EM Algorithm Training...")
        for i in range(epochs):
            print("- Epoch %i" %i)
            for n in range(self.NTrials):
                print("-- Trial %i" %n)
                self.kf.em(self.X[n],n_iter=1)

        self.A = self.kf.transition_matrices
        self.Q = self.kf.transition_covariance
        self.C = self.kf.observation_matrices
        self.Sigma = self.kf.observation_covariance

    def inference(self):
        for n in range(self.NTrials):
            print("-- Trial %i" % n)
            self.filtered_paths[n], self.filtered_covar[n] = self.kf.filter(self.X[n])
            self.Z[n], self.smoothed_covar[n] = self.kf.smooth(self.X[n])

## Define an LDS and train on Allen data

In [None]:
# Specify latent dimension Dz
AllenLDS = LDS(X, Dz=5)
# Specify number of epochs to run EM
AllenLDS.train(epochs=100)
# Print transition matrices
print(AllenLDS.A)
# Perform filtering and smoothing
AllenLDS.inference()

EM Algorithm Training...
- Epoch 0
-- Trial 0
-- Trial 1
-- Trial 2
-- Trial 3
-- Trial 4
-- Trial 5
-- Trial 6
-- Trial 7
-- Trial 8
-- Trial 9
-- Trial 10
-- Trial 11
-- Trial 12
-- Trial 13
-- Trial 14
-- Trial 15
-- Trial 16
-- Trial 17
