# Contents:
- [1 Estimation](#estimation)
- [2 Plotting](#plotting)
- [3 GOF](#GOF)

# 1 Estimation <a class="anchor" id="estimation"></a>

## 1.1 Import requirements

In [1]:
import sys
import time
import torch
import pickle

sys.path.append("../../src")
import stats.kernels
import stats.svGPFA.svGPFAModelFactory
import stats.svGPFA.svEM
import stats.pointProcess.tests                                                  
import utils.svGPFA.miscUtils
import utils.svGPFA.initUtils

## 1.2 Set parameters

In [2]:
nLatents = 2                                      # number of latents
nNeurons = 100                                    # number of neurons
nTrials = 15                                      # number of trials
nQuad = 200                                       # number of quadrature points
nIndPoints = 9                                    # number of inducing points
indPointsLocsKMSRegEpsilon = 1e-5                 # prior covariance regularization parameter
trial_start_time = 0.0                            # trial start time
trial_end_time = 1.0                              # trial end time

## 1.3 Load spikes times

In [3]:
# spikesTimes should be a list of lists
# spikesTimes[r][n] is the list of spikes times of neuron n in trial r
simResFilename = "../../scripts/results/32451751_simRes.pickle" # simulation results filename
with open(simResFilename, "rb") as f:
    simRes = pickle.load(f)                                                 
spikesTimes = simRes["spikes"]

## 1.4 Set model parameters initial values

In [4]:
# embedding parameters initial values: uniform[0,1]
# Duncker and Sahani, 2018, Eq. 1 (middle)
C0 = torch.normal(mean=0.0, std=1.0, size=(nNeurons, nLatents), dtype=torch.double).contiguous()
d0 = torch.normal(mean=0.0, std=1.0, size=(nNeurons, 1), dtype=torch.double).contiguous()

# kernels of latents: all ExponentialQuadratic Kernels
# Duncker and Sahani, 2018, Eq. 1 (top)
kernels = [[] for r in range(nLatents)]
for k in range(nLatents):
    kernels[k] = stats.kernels.ExponentialQuadraticKernel()

# kernels parameters initial values: all kernels have the same initial
# lengthscale0
# Duncker and Sahani, 2018, Eq. 1 (top)
lengthscale0 = 1.0                                # initial value of the lengthscale parameter
kernelsScaledParams0 = [torch.tensor([lengthscale0], dtype=torch.double)
                        for r in range(nLatents)]

# inducing points locations initial values: equally spaced nIndPoints
# between trial_start_time and trial_end_time
# Duncker and Sahani, 2018, paragraph above Eq. 2
Z0 = [[] for k in range(nLatents)]
for k in range(nLatents):
    Z0[k] = torch.empty((nTrials, nIndPoints, 1), dtype=torch.double)
    for r in range(nTrials):
        Z0[k][r, :, 0] = torch.linspace(trial_start_time, trial_end_time,
                                        nIndPoints, dtype=torch.double)

# variational mean initial value: Uniform[0, 1]
# Duncker and Sahani, 2018, m_k in paragraph above Eq. 4
qMu0 = [[] for r in range(nLatents)]
for k in range(nLatents):
    qMu0[k] = torch.empty((nTrials, nIndPoints, 1), dtype=torch.double)
    for r in range(nTrials):
        qMu0[k][r, :, 0] = torch.normal(mean=0.0, std=1.0, size=(nIndPoints,))                                                                                                                                                                                                  

# variational covariance initial value: Identity*1e-2
# Duncker and Sahani, 2018, V_k in paragraph above Eq. 4
diag_value = 1e-2
qSigma0 = [[] for r in range(nLatents)]
for k in range(nLatents):
    qSigma0[k] = torch.empty((nTrials, nIndPoints, nIndPoints),
                             dtype=torch.double)
    for r in range(nTrials):
        qSigma0[k][r, :, :] = torch.eye(nIndPoints)*diag_value

# we use the Cholesky lower-triangular matrix to represent the variational
# covariance. The following utility function extracts the lower-triangular
# elements from its input of list matrices.
srQSigma0Vecs = utils.svGPFA.initUtils.getSRQSigmaVecsFromSRMatrices(
    srMatrices=qSigma0)

# legendre quadrature points and weights used to calculate the integral in
# the first term of Eq. 7 in Duncker and Sahani, 2018.
trials_start_times = [trial_start_time for r in range(nTrials)]
trials_end_times = [trial_end_time for r in range(nTrials)]
legQuadPoints, legQuadWeights = \
    utils.svGPFA.miscUtils.getLegQuadPointsAndWeights(
        nQuad=nQuad, trials_start_times=trials_start_times,
        trials_end_times=trials_end_times)

# Finally, we build the dictionaries of initial parameters and quadrature
# parameters that we will use below to initialize the svGPFA model
qUParams0 = {"qMu0": qMu0, "srQSigma0Vecs": srQSigma0Vecs}
kmsParams0 = {"kernelsParams0": kernelsScaledParams0,
              "inducingPointsLocs0": Z0}
qKParams0 = {"svPosteriorOnIndPoints": qUParams0,
             "kernelsMatricesStore": kmsParams0}
qHParams0 = {"C0": C0, "d0": d0}
initialParams = {"svPosteriorOnLatents": qKParams0,
                 "svEmbedding": qHParams0}
quadParams = {"legQuadPoints": legQuadPoints,
              "legQuadWeights": legQuadWeights}


In [5]:
qMu0[k][r, :, 0] = torch.normal(mean=0.0, std=1.0, size=(nIndPoints,))

## 1.5 Create a model and set its initial parameters

In [6]:
kernelMatrixInvMethod = stats.svGPFA.svGPFAModelFactory.kernelMatrixInvChol
indPointsCovRep = stats.svGPFA.svGPFAModelFactory.indPointsCovChol
model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModelPyTorch(
    conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
    linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
    embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
    kernels=kernels, kernelMatrixInvMethod=kernelMatrixInvMethod,
    indPointsCovRep=indPointsCovRep)
model.setInitialParamsAndData(
    measurements=spikesTimes,
    initialParams=initialParams,
    eLLCalculationParams=quadParams,
    indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon)

## 1.6 Set the optimization parameters

In [7]:
optimMethod = "EM"
optimParams = dict(
    em_max_iter=30,
    #
    estep_estimate=True,
    estep_optim_params=dict(
        max_iter=20,
        lr=1.0,
        tolerance_grad=1e-7,
        tolerance_change=1e-9,
        line_search_fn="strong_wolfe"
    ),
    #
    mstep_embedding_estimate=True,
    mstep_embedding_optim_params=dict(
        max_iter=20,
        lr=1.0,
        tolerance_grad=1e-7,
        tolerance_change=1e-9,
        line_search_fn="strong_wolfe"
    ),
    #
    mstep_kernels_estimate=True,
    mstep_kernels_optim_params=dict(
        max_iter=20,
        lr=1.0,
        tolerance_grad=1e-7,
        tolerance_change=1e-9,
        line_search_fn="strong_wolfe"
    ),
    #
    mstep_indpointslocs_estimate=True,
    mstep_indpointslocs_optim_params=dict(
        max_iter=20,
        lr=1.0,
        tolerance_grad=1e-7,
        tolerance_change=1e-9,
        line_search_fn="strong_wolfe"
    ),
    verbose=True
)

## 1.7 Maximize the Lower Bound
<span style="color:red">(Warning: with the parameters above, this step takes around 15 minutes)</span>

In [None]:
svEM = stats.svGPFA.svEM.SVEM_PyTorch()
tic = time.perf_counter()
lowerBoundHist, elapsedTimeHist, terminationInfo, iterationsModelParams = \
    svEM.maximize(model=model, optimParams=optimParams, method=optimMethod)
toc = time.perf_counter()
print(f"Elapsed time {toc - tic:0.4f} seconds")

Iteration 01, estep start: -inf
Iteration 01, estep end: -2014.637965, niter: 15, nfeval: 26
Iteration 01, mstep_embedding start: -2014.637965
Iteration 01, mstep_embedding end: 720762.646155, niter: 11, nfeval: 25
Iteration 01, mstep_kernels start: 720762.646155
Iteration 01, mstep_kernels end: 719725.429823, niter: 11, nfeval: 19
Iteration 01, mstep_indpointslocs start: 719725.429823


# 2 Plotting <a class="anchor" id="plotting"></a>

## 2.1 Imports for plotting

In [None]:
import numpy as np
import pandas as pd
import sklearn.metrics
import plot.svGPFA.plotUtilsPlotly

## 2.2 Lower bound history

In [None]:
fig = plot.svGPFA.plotUtilsPlotly.getPlotLowerBoundHist(lowerBoundHist=lowerBoundHist)
fig.show()

## 2.3 Set neuron, latent and times to plot

In [None]:
neuronToPlot = 0
latentToPlot = 0
sampling_rate = 1000.0 # Hz
trial_times = torch.arange(trial_start_time, trial_end_time, 1.0/sampling_rate)

## 2.4 Latents

In [None]:
# plot estimated latent across trials
testMuK, testVarK = model.predictLatents(times=trial_times)
indPointsLocs = model.getIndPointsLocs()
fig = plot.svGPFA.plotUtilsPlotly.getPlotLatentAcrossTrials(times=trial_times.numpy(), latentsMeans=testMuK, latentsSTDs=torch.sqrt(testVarK), indPointsLocs=indPointsLocs, latentToPlot=latentToPlot, xlabel="Time (msec)")
fig.show()

## 2.5 Embedding

In [None]:
embeddingMeans, embeddingVars = model.predictEmbedding(times=trial_times)
embeddingMeans = embeddingMeans.detach().numpy()
embeddingVars = embeddingVars.detach().numpy()
title = "Neuron {:d}".format(neuronToPlot)
fig = plot.svGPFA.plotUtilsPlotly.getPlotEmbeddingAcrossTrials(times=trial_times.numpy(), embeddingsMeans=embeddingMeans[:,:,neuronToPlot], embeddingsSTDs=np.sqrt(embeddingVars[:,:,neuronToPlot]), title=title)
fig.show()

## 2.6 CIFs

In [None]:
with torch.no_grad():
    ePosCIFValues = model.computeExpectedPosteriorCIFs(times=trial_times)
fig = plot.svGPFA.plotUtilsPlotly.getPlotCIFsOneNeuronAllTrials(times=trial_times, cif_values=ePosCIFValues, neuron_index=neuronToPlot)                                                                                                                                      
fig.show()

## 2.7 Embedding parameters

In [None]:
estimatedC, estimatedD = model.getSVEmbeddingParams()
fig = plot.svGPFA.plotUtilsPlotly.getPlotEmbeddingParams(C=estimatedC.numpy(), d=estimatedD.numpy())
fig.show()

## 2.8 Kernels parameters

In [None]:
kernelsParams = model.getKernelsParams()
kernelsTypes = [type(kernel).__name__ for kernel in model.getKernels()]
fig = plot.svGPFA.plotUtilsPlotly.getPlotKernelsParams(
    kernelsTypes=kernelsTypes, kernelsParams=kernelsParams)
fig.show()

# 3 Goodness of fit (GOF) <a class="anchor" id="GOF"></a>

## 3.1 Set trial and neuron for GOF assesment

In [None]:
trialForGOF = 0
neuronForGOF = 0

## 3.2 KS time-rescaling GOF test

In [None]:
ksTest_gamma = 20                                 # number of simulations for the KS test numerical correction
with torch.no_grad():                                                                                                                                                                                                                                                             
    epmcifValues = model.computeExpectedPosteriorCIFs(times=trial_times)
cifValuesKS = epmcifValues[trialForGOF][neuronForGOF]
spikesTimesKS = spikesTimes[trialForGOF][neuronForGOF]
diffECDFsX, diffECDFsY, estECDFx, estECDFy, simECDFx, simECDFy, cb = stats.pointProcess.tests.KSTestTimeRescalingNumericalCorrection(spikesTimes=spikesTimesKS, cifTimes=trial_times, cifValues=cifValuesKS, gamma=ksTest_gamma)
title = "Trial {:d}, Neuron {:d}".format(trialForGOF, neuronForGOF)
fig = plot.svGPFA.plotUtilsPlotly.getPlotResKSTestTimeRescalingNumericalCorrection(diffECDFsX=diffECDFsX, diffECDFsY=diffECDFsY, estECDFx=estECDFx, estECDFy=estECDFy, simECDFx=simECDFx, simECDFy=simECDFy, cb=cb, title=title)
fig.show()

## 3.3 ROC predictive analysis

In [None]:
dt = (trial_times[1] - trial_times[0]).item()
pk = cifValuesKS.detach().numpy() * dt
bins = pd.interval_range(start=trial_start_time, end=trial_end_time, periods=len(pk))
cutRes, _ = pd.cut(spikesTimesKS.tolist(), bins=bins, retbins=True)
Y = cutRes.value_counts().values
fpr, tpr, thresholds = sklearn.metrics.roc_curve(Y, pk, pos_label=1)
roc_auc = sklearn.metrics.auc(fpr, tpr)
title = "Trial {:d}, Neuron {:d}".format(trialForGOF, neuronForGOF)
fig = plot.svGPFA.plotUtilsPlotly.getPlotResROCAnalysis(fpr=fpr, tpr=tpr, auc=roc_auc, title=title)
fig.show()