In [1]:
using Pkg
using SSM
using Random
using Distributions
using LinearAlgebra
using Plots
using ForwardDiff
using Optim

In [2]:
Random.seed!(123)

function toy_PoissonLDS()
    T = 100
    # create a PLDS model
    x0 = [1.0, -1.0]
    p0 = Matrix(Diagonal([0.001, 0.001]))
    A = [cos(0.1) -sin(0.1); sin(0.1) cos(0.1)]
    Q = Matrix(Diagonal([0.001, 0.001]))
    C = [0.05 0.05; 0.05 0.01; 0.1 0.01]
    log_d = log.([0.05, 0.05, 1.])
    D = Matrix(Diagonal([0., 0., 0.]))
    b = ones(T, 2) * 0.01

    plds = PoissonLDS(A=A, C=C, Q=Q, D=D, b=b, log_d=log_d, x0=x0, p0=p0, refractory_period=1, obs_dim=3, latent_dim=2)
    # sample data
    x, y = SSM.sample(plds, T, 3)
    return plds, x, y
end

plds, x, y = toy_PoissonLDS()

# model = SSM.PoissonPCA(;latent_dim=2, obs_dim=3)
# lls = fit!(model, y[1, :, :], 20)

b = ones(100, 2) * 0.005;
plds.A = Matrix{Float64}([0.1 0.; 0. 0.1])
plds.Q = Matrix{Float64}([0.1 0; 0 0.1])
plds.x0 = [0.0, -0.0]
plds.p0 = Matrix{Float64}([0.1 0; 0 0.1])
plds.b = b;

In [24]:
E_z, E_zz, E_zz_prev, x_sm, p_sm = SSM.E_Step(plds, y)

function Q_initial_obs(x0::Vector{<:Real}, sqrt_p0::Matrix{<:Real}, E_z::Array{<:Real}, E_zz::Array{<:Real})
    # reparametrize p0
    p0 = sqrt_p0 * sqrt_p0'
    # Compute Q
    Q_val = 0.0
    trials = size(E_z, 1)
    for k in 1:trials
        Q_val += -0.5 * (logdet(p0) + tr(inv(p0) * (E_zz[k, 1, :, :] - 2*(E_z[k, 1, :] * x0') + (x0 * x0'))))
    end
    return Q_val
end

function Q_state_model(A::Matrix{<:Real}, sqrt_Q::Matrix{<:Real}, E_zz::Array{<:Real}, E_zz_prev::Array{<:Real})
    # reparametrize Q
    Q = sqrt_Q * sqrt_Q'
    Q_inv = pinv(Q)
    # Compute Q
    Q_val = 0.0
    trials = size(E_zz, 1)
    time_steps = size(E_zz, 2)
    for k in 1:trials
        for t in 2:time_steps
            term1 = E_zz[k, t, :, :]
            term2 = A * E_zz_prev[k, t, :, :]'
            term3 = E_zz_prev[k, t, :, :] * A'
            term4 = A * E_zz[k, t-1, :, :] * A'
            Q_val += -0.5 * (logdet(Q) + tr(Q_inv * (term1 - term2 - term3 + term4)))
        end
    end
    return Q_val
end

function Q_observation_model(C::Matrix{<:Real}, D::Matrix{<:Real}, log_d::Vector{<:Real}, E_z::Array{<:Real}, E_zz::Array{<:Real}, y::Array{<:Real})
    # Re-parametrize log_d
    d = exp.(log_d)
    # Compute Q
    Q_val = 0.0
    trials = size(E_z, 1)
    time_steps = size(E_z, 2)
    # sum over trials
    for k in 1:trials
        spikes = SSM.countspikes(y[k, :, :])
        # sum over time-points
        for t in 1:time_steps
            # linear term
            term_1 = y[k, t, :]' * (C * E_z[k, t, :] + D*spikes[t, :] + d)
            # first part of quadratic term (sum over neurons)
            term_2 = sum(exp.(C * E_z[k, t, :] + D*spikes[t, :] + d))
            # second part of quadratic term (sum over neurons)
            term_3 = 0.0
            for i in axes(C, 1)
                term_3 += 0.5 * C[i, :]' * E_zz[k, t, :, :] * C[i, :]
            end
            Q_val += term_1 - term_2 + term_3
        end
    end
    return Q_val
end

Q_observation_model(plds.C, plds.D, plds.log_d, E_z, E_zz, y)

-571.0573778719382

In [27]:
# param_update = Matrix(cholesky((dropdims(sum(p_sm[:, 1, :, :], dims=1)./3, dims=1) + dropdims(sum(p_sm[:, 1, :, :], dims=1)./3, dims=1)') / 2))
# p0_l = Matrix(cholesky(plds.p0).L)
grad_p0 = ForwardDiff.gradient(p0 -> Q_initial_obs(plds.x0, p0, E_z, E_zz), result_p0.minimizer)

2×2 Matrix{Float64}:
  5.82607e-9   -4.28744e-11
 -4.27871e-11   5.6684e-9

In [23]:
typeof(result_p0.minimizer)

Matrix{Float64}[90m (alias for [39m[90mArray{Float64, 2}[39m[90m)[39m

In [16]:
(dropdims(sum(p_sm[:, 1, :, :], dims=1)./3, dims=1) + dropdims(sum(p_sm[:, 1, :, :], dims=1)./3, dims=1)') / 2

2×2 Matrix{Float64}:
  0.0996733  -5.9114e-5
 -5.9114e-5   0.0999697

In [19]:
# optimize x0 and p0
opt_x0 = x0 -> -Q_initial_obs(x0, plds.p0, E_z, E_zz)
opt_p0 = p0 -> -Q_initial_obs(plds.x0, p0, E_z, E_zz)

result_x0 = optimize(opt_x0, plds.x0, LBFGS())
result_p0 = optimize(opt_p0, plds.p0, LBFGS())

result_x0.minimizer
result_p0.minimizer * result_p0.minimizer'

# @assert isapprox(result_x0.minimizer, SSM.update_initial_state_mean!(plds, E_z))
# @assert isapprox(result_p0.minimizer * result_p0.minimizer', SSM.update_initial_state_covariance!(plds, E_zz, E_z), atol=1e-4)


2×2 Matrix{Float64}:
  0.0997496   -4.96169e-5
 -4.96169e-5   0.0999744

In [22]:
result_p0.minimizer

2×2 Matrix{Float64}:
 -0.315832     7.85007e-5
  7.85102e-5  -0.316187

In [55]:
# optimize A and Q
opt_A = A -> -Q_state(A, plds.Q, E_zz, E_zz_prev)
result_A = optimize(opt_A, plds.A, LBFGS())

@assert isapprox(result_A.minimizer, SSM.update_A_plds!(plds, E_zz, E_zz_prev), atol=1e-6)

# if pass assertion, update the model
plds.A = result_A.minimizer

# optimize Q now
opt_Q = Q -> -Q_state(plds.A, Q, E_zz, E_zz_prev)
result_Q = optimize(opt_Q, plds.Q, LBFGS())

@assert isapprox(result_Q.minimizer * result_Q.minimizer', SSM.update_Q_plds!(plds, E_zz, E_zz_prev))

In [11]:
function countspikes(y::Matrix{<:Real}, window::Int=1)
    # Get size of the observation matrix
    T, D = size(y)
    # Initialize the spike-count matrix
    s = zeros(T, D)
    # Compute the cumulative sum of the observation matrix along the first dimension (time)
    cumsum_y = cumsum(y, dims=1)
    
    # Loop over time points from 2 to T
    for t in 2:T
        if t - window <= 1
            # If the time window is less than or equal to 1, use the cumulative sum directly
            s[t, :] = cumsum_y[t-1, :]
        else
            # Otherwise, calculate the sum of the window by subtracting cumulative sums
            s[t, :] = cumsum_y[t-1, :] .- cumsum_y[t-window-1, :]
        end
    end
    
    return s
end

function logposterior_nonthreaded(x::AbstractMatrix{<:Real}, plds::PoissonLDS, y::Matrix{<:Real}) 
    # Re-parameterize log_d
    d = exp.(plds.log_d)
    # Calculate the log-posterior
    T = size(y, 1)
    # Get an array of prior spikes
    s = countspikes(y, plds.refractory_period)
    # calculate the first term
    pygivenx = 0.0
    for t in 1:T
        pygivenx += (y[t, :]' * ((plds.C * x[t, :]) + (plds.D * s[t, :]) + d)) - sum(exp.((plds.C * x[t, :]) + (plds.D * s[t, :]) + d))
    end
    # calculate the second term
    px1 = -0.5 * (x[1, :] - plds.x0)' * pinv(plds.p0) * (x[1, :] - plds.x0)
    # calculate the last term
    pxtgivenxt1 = 0.0
    for t in 2:T
        pxtgivenxt1 += -0.5 * (x[t, :] - ((plds.A * x[t-1, :]) + plds.b[t, :]))' * pinv(plds.Q) * (x[t, :] - ((plds.A * x[t-1, :]) + plds.b[t, :])) 
    end
    # sum the terms
    return pygivenx + px1 + pxtgivenxt1
end
@btime logposterior(zeros(100, 2), plds, y[1, :, :])
# obj = x -> -logposterior(x, plds, y[1, :, :])

# using ForwardDiff 
# ForwardDiff.gradient(obj, zeros(100, 2))

  332.400 μs (4893 allocations: 277.09 KiB)


-203.40715212110916

In [14]:
using Base.Threads
using BenchmarkTools

function logposterior(x::AbstractMatrix{<:Real}, plds::PoissonLDS, y::Matrix{<:Real})
    # Convert the log firing rate to firing rate
    d = exp.(plds.log_d)
    T = size(y, 1)
    s = countspikes(y, plds.refractory_period)
    # Get the number of time steps
    pygivenx = zeros(T)
    # Calculate p(yₜ|xₜ)
    @threads for t in 1:T
        temp = (plds.C * x[t, :] .+ plds.D * s[t, :] .+ d)
        pygivenx[t] = (y[t, :]' * temp) - sum(exp.(temp))
    end
    pygivenx_sum = sum(pygivenx)
    # Calculate p(x₁)
    px1 = -0.5 * (x[1, :] .- plds.x0)' * pinv(plds.p0) * (x[1, :] .- plds.x0)
    # Calculate p(xₜ|xₜ₋₁)
    pxtgivenxt1 = zeros(T-1)
    @threads for t in 2:T
        temp = (x[t, :] .- (plds.A * x[t-1, :] .+ plds.b[t, :]))
        pxtgivenxt1[t-1] = -0.5 * temp' * pinv(plds.Q) * temp
    end
    pxtgivenxt1_sum = sum(pxtgivenxt1)
    # Return the log-posterior
    return pygivenx_sum + px1 + pxtgivenxt1_sum
end

Zygote.CompileError: Compiling Tuple{typeof(Base._wait), Task}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations


In [17]:
H, main, super, sub = SSM.Hessian(x[1, :, :], plds, y[1, :, :])
p_sm, ptt1 = SSM.block_tridiagonal_inverse(-sub, -main, -super)

([0.09964753619401641 -6.171103550769067e-5; 0.10063365806081821 -6.414153286612634e-5; … ; 0.10069165946854451 -5.933540801424813e-5; 0.10069680914924041 -5.8432329248563064e-5;;; -6.171103550768466e-5 0.09996948579925777; -6.414153286613473e-5 0.10096827160413231; … ; -5.933540801424887e-5 0.10097872615981948; -5.8432329248559595e-5 0.10097917916072414], [0.009929308717424798 -1.2375836282808027e-5; 0.01002731642289832 -1.2729857659660688e-5; … ; 0.010037861736130216 -1.1761267247855036e-5; 0.010038559471751356 -1.1640917676080294e-5;;; -1.2375837256344493e-5 0.009993872511605752; -1.2729879043070484e-5 0.010093691994177598; … ; -1.1761281080845925e-5 0.010094786682380183; -1.1640936728077284e-5 0.010094846290284554])

In [30]:
# Wrapper function for the objective (negative log-posterior)
function objective(x_vec::Vector{T}, plds::PoissonLDS, y::Matrix{T}) where T<:Real
    T_steps, latent_dim = size(y, 1), plds.latent_dim
    x = interleave_reshape(x_vec, T_steps, latent_dim)
    return -logposterior(x, plds, y)
end

# Wrapper function for the gradient
function gradient!(G::Vector{T}, x_vec::Vector{T}, plds::PoissonLDS, y::Matrix{T}) where T<:Real
    T_steps, latent_dim = size(y, 1), plds.latent_dim
    x = interleave_reshape(x_vec, T_steps, latent_dim)
    grad = SSM.Gradient(x, plds, y)
    G .= vec(grad)
end

# Wrapper function for the Hessian
function hessian!(H::Matrix{T}, x_vec::Vector{T}, plds::PoissonLDS, y::Matrix{T}) where T<:Real
    T_steps, latent_dim = size(y, 1), plds.latent_dim
    x = interleave_reshape(x_vec, T_steps, latent_dim)
    hess, _, _, _ = SSM.Hessian(x, plds, y)
    H .= hess
end

# Function to run the optimization
function optimize_latent_states(initial_x::Matrix{T}, plds::PoissonLDS, y::Matrix{T}) where T<:Real
    T_steps, latent_dim = size(initial_x)
    
    # Create wrapper functions with fixed plds and y
    f(x) = objective(x, plds, y)
    g!(G, x) = gradient!(G, x, plds, y)
    h!(H, x) = hessian!(H, x, plds, y)
    
    # Flatten the initial guess
    initial_x_vec = vec(permutedims(initial_x))
    
    # Set up the optimization problem
    result = optimize(f, g!, h!, initial_x_vec)
    
    # Reshape the result back to a matrix
    optimal_x = interleave_reshape(result.minimizer, T_steps, latent_dim)
    
    return optimal_x, result
end

optimize_latent_states(zeros(100, 2), plds, y[1, :, :]) 

([-0.01727703627004362 -0.02090831686249808; -0.04152467064665573 0.01935306490561603; … ; -0.0007752634189086974 0.001402238832428847; -0.0032608823985584316 -0.006620031372885339],  * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     2.036107e+02

 * Found with
    Algorithm:     Newton's Method

 * Convergence measures
    |x - x'|               = 2.29e-05 ≰ 0.0e+00
    |x - x'|/|x'|          = 5.52e-04 ≰ 0.0e+00
    |f(x) - f(x')|         = 2.04e-04 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.00e-06 ≰ 0.0e+00
    |g(x)|                 = 6.05e-01 ≰ 1.0e-08

 * Work counters
    Seconds run:   97  (vs limit Inf)
    Iterations:    1000
    f(x) calls:    64046
    ∇f(x) calls:   64046
    ∇²f(x) calls:  1001
)

In [27]:
x_sm, p_sm, ptt1 = SSM.directsmooth(plds, y[1, :, :])

([0.013369048145516576 0.0007830711933090716; 0.01742122896861685 0.003249000830713299; … ; 0.015695964843495183 0.005740105402084179; 0.013825347030947756 0.006538013121058254], [0.0996730428873636 -5.9159419522603635e-5; 0.10066312523569919 -6.095100068439873e-5; … ; 0.10067315667024221 -6.09683215243662e-5; 0.10067647846717766 -6.036991281801671e-5;;; -5.915941952260452e-5 0.09996966659552509; -6.095100068438583e-5 0.10096874712805827; … ; -6.096832152437012e-5 0.10097883664292145; -6.036991281802365e-5 0.10097914348996768], [0.0 0.0; 0.009934705099655475 -1.1813006990574534e-5; … ; 0.010034513924107511 -1.2034454505465084e-5; 0.010034726686880257 -1.199274956016177e-5;;; 0.0 0.0; -1.181300610629099e-5 0.009993936682818378; … ; -1.203444189415753e-5 0.010094830792888657; -1.1992750437381327e-5 0.010094853896451077])

In [33]:
maximum(SSM.Gradient(x_sm, plds, y[1, :, :]))

6.245004513516506e-17