In [1]:
using Pkg
using SSM
using Random
using Distributions
using LinearAlgebra
using Plots
using ForwardDiff
using Optim

In [2]:
function toy_PoissonLDS()
    T = 100
    # create a PLDS model
    x0 = [1.0, -1.0]
    p0 = Matrix(Diagonal([0.1, 0.1]))
    A = [cos(0.1) -sin(0.1); sin(0.1) cos(0.1)]
    Q = Matrix(Diagonal([0.1, 0.1]))
    C = [0.5 0.5; 0.5 0.1; 0.9 0.1]
    log_d = log.([0.1, 0.05, 1])
    D = Matrix(Diagonal([0., 0., 0.]))
    b = ones(T, 2) * 0.00

    plds = PoissonLDS(A=A, C=C, Q=Q, D=D, b=b, log_d=log_d, x0=x0, p0=p0, refractory_period=1, obs_dim=3, latent_dim=2)
    # sample data
    x, y = SSM.sample(plds, T, 3)
    return plds, x, y
end

plds, x, y = toy_PoissonLDS()
plds.fit_bool = [true, true, true, true, true, true]

b = ones(100, 2) * 0.000;
plds.A = [cos(0.1) -sin(0.1); sin(0.1) cos(0.1)]
plds.Q = Matrix(Diagonal([1., 1.]))
plds.x0 = [0.0, -0.0]
plds.p0 = Matrix{Float64}([1. 0; 0 1.])

In [3]:
fit!(plds, y, 10)



CompositeException: TaskFailedException

    nested task error: MethodError: no method matching iterate(::Nothing)
    
    Closest candidates are:
      iterate(!Matched::Union{LinRange, StepRangeLen})
       @ Base range.jl:880
      iterate(!Matched::Union{LinRange, StepRangeLen}, !Matched::Integer)
       @ Base range.jl:880
      iterate(!Matched::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
       @ Base dict.jl:698
      ...
    
    Stacktrace:
     [1] indexed_iterate(I::Nothing, i::Int64)
       @ Base .\tuple.jl:91
     [2] macro expansion
       @ c:\Users\ryansenne\Documents\GitHub\ssm_julia\src\LDS.jl:1084 [inlined]
     [3] (::SSM.var"#200#threadsfor_fun#47"{SSM.var"#200#threadsfor_fun#46#48"{PoissonLDS, Array{Float64, 3}, Array{Float64, 4}, Array{Float64, 4}, Array{Float64, 3}, UnitRange{Int64}}})(tid::Int64; onethread::Bool)
       @ SSM .\threadingconstructs.jl:200
     [4] #200#threadsfor_fun
       @ .\threadingconstructs.jl:167 [inlined]
     [5] (::Base.Threads.var"#1#2"{SSM.var"#200#threadsfor_fun#47"{SSM.var"#200#threadsfor_fun#46#48"{PoissonLDS, Array{Float64, 3}, Array{Float64, 4}, Array{Float64, 4}, Array{Float64, 3}, UnitRange{Int64}}}, Int64})()
       @ Base.Threads .\threadingconstructs.jl:139

In [3]:
x_pred_1 = zeros(100, 2)
x_pred_1[1, :] = plds.x0

for t in 2:100
    x_pred_1[t, :] = plds.A * x_pred_1[t-1, :]
end

H, _, _, _ = SSM.Hessian(x_pred_1, plds, y[3, :, :])

([-4.7409187846647445 -0.5735016488990272 … 0.0 0.0; -0.5735016488990272 -2.3139882587672624 … 0.0 0.0; … ; 0.0 0.0 … -3.740918784664745 -0.5735016488990272; 0.0 0.0 … -0.5735016488990272 -1.3139882587672627], [[-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.5735016488990272; -0.5735016488990272 -2.3139882587672624], [-4.7409187846647445 -0.573501648

In [10]:
all(isfinite.(H))

true

In [13]:
plds

PoissonLDS([0.9950041652780258 -0.09983341664682815; 0.09983341664682815 0.9950041652780258], [0.5 0.5; 0.5 0.1; 0.9 0.1], [1.0 0.0; 0.0 1.0], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], [-2.3025850929940455, -2.995732273553991, 0.0], [0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0], [0.0, -0.0], [1.0 0.0; 0.0 1.0], 1, 3, 2, Bool[1, 1, 1, 1, 1, 0])

In [15]:
E_z, E_zz, E_zz_prev, x_sm, p_sm = SSM.E_Step(plds, y)

function Q_initial_obs(x0::Vector{<:Real}, sqrt_p0::Matrix{<:Real}, E_z::Array{<:Real}, E_zz::Array{<:Real})
    # reparametrize p0
    p0 = sqrt_p0 * sqrt_p0'
    # Compute Q
    Q_val = 0.0
    trials = size(E_z, 1)
    for k in 1:trials
        Q_val += -0.5 * (logdet(p0) + tr(inv(p0) * (E_zz[k, 1, :, :] - 2*(E_z[k, 1, :] * x0') + (x0 * x0'))))
    end
    return Q_val
end

function Q_state_model(A::Matrix{<:Real}, sqrt_Q::Matrix{<:Real}, E_zz::Array{<:Real}, E_zz_prev::Array{<:Real})
    # reparametrize Q
    Q = sqrt_Q * sqrt_Q'
    Q_inv = pinv(Q)
    # Compute Q
    Q_val = 0.0
    trials = size(E_zz, 1)
    time_steps = size(E_zz, 2)
    for k in 1:trials
        for t in 2:time_steps
            term1 = E_zz[k, t, :, :]
            term2 = A * E_zz_prev[k, t, :, :]'
            term3 = E_zz_prev[k, t, :, :] * A'
            term4 = A * E_zz[k, t-1, :, :] * A'
            Q_val += -0.5 * (logdet(Q) + tr(Q_inv * (term1 - term2 - term3 + term4)))
        end
    end
    return Q_val
end

function Q_observation_model(C::Matrix{<:Real}, D::Matrix{<:Real}, log_d::Vector{<:Real}, E_z::Array{<:Real}, E_zz::Array{<:Real}, y::Array{<:Real})
    # Re-parametrize log_d
    d = exp.(log_d)
    # Compute Q
    Q_val = 0.0
    trials = size(E_z, 1)
    time_steps = size(E_z, 2)
    # sum over trials
    for k in 1:trials
        spikes = SSM.countspikes(y[k, :, :])
        # sum over time-points
        for t in 1:time_steps
            # linear term
            term_1 = y[k, t, :]' * (C * E_z[k, t, :] + D*spikes[t, :] + d)
            # first part of quadratic term (sum over neurons)
            term_2 = sum(exp.(C * E_z[k, t, :] + D*spikes[t, :] + d))
            # second part of quadratic term (sum over neurons)
            term_3 = 0.0
            for i in axes(C, 1)
                term_3 += 0.5 * C[i, :]' * E_zz[k, t, :, :] * C[i, :]
            end
            Q_val += term_1 - term_2 + term_3
        end
    end
    return Q_val
end

Q_observation_model(plds.C, plds.D, plds.log_d, E_z, E_zz, y)

-571.0573778719382

In [25]:
# optimize x0 and p0
opt_x0 = x0 -> -Q_initial_obs(x0, plds.p0, E_z, E_zz)
opt_p0 = p0 -> -Q_initial_obs(plds.x0, p0, E_z, E_zz)

result_x0 = optimize(opt_x0, plds.x0, LBFGS())
result_p0 = optimize(opt_p0, plds.p0, LBFGS())

result_x0.minimizer
result_p0.minimizer * result_p0.minimizer'

# @assert isapprox(result_x0.minimizer, SSM.update_initial_state_mean!(plds, E_z))
# @assert isapprox(result_p0.minimizer * result_p0.minimizer', SSM.update_initial_state_covariance!(plds, E_zz, E_z), atol=1e-4)


2×2 Matrix{Float64}:
  0.0997496  -4.9617e-5
 -4.9617e-5   0.0999744

In [26]:
dropdims(sum(E_zz[:, 1, :, :], dims=1), dims=1) / 3

2×2 Matrix{Float64}:
  0.0997496   -4.96169e-5
 -4.96169e-5   0.0999744

In [17]:
grad_p0 = ForwardDiff.gradient(p0 -> Q_initial_obs(plds.x0, p0, E_z, E_zz), result_p0.minimizer)

2×2 Matrix{Float64}:
 5.84571e-9   2.59299e-11
 2.58909e-11  5.88439e-9

In [20]:
SSM.update_initial_state_covariance!(plds, E_zz, E_z)

function Q_initial_obs_NO(x0::Vector{<:Real}, p0::Matrix{<:Real}, E_z::Array{<:Real}, E_zz::Array{<:Real})
    # Compute Q
    Q_val = 0.0
    trials = size(E_z, 1)
    for k in 1:trials
        Q_val += -0.5 * (logdet(p0) + tr(inv(p0) * (E_zz[k, 1, :, :] - 2*(E_z[k, 1, :] * x0') + (x0 * x0'))))
    end
    return Q_val
end

grad = ForwardDiff.gradient(p0 -> Q_initial_obs_NO(plds.x0, p0, E_z, E_zz), plds.p0)

2×2 Matrix{Float64}:
 0.0115105   0.00143689
 0.00143689  0.00070821

In [29]:
# optimize A and Q
opt_A = A -> -Q_state_model(A, plds.Q, E_zz, E_zz_prev)
result_A = optimize(opt_A, plds.A, LBFGS())

@assert isapprox(result_A.minimizer, SSM.update_A_plds!(plds, E_zz, E_zz_prev), atol=1e-6)

# # if pass assertion, update the model
# plds.A = result_A.minimizer

# # optimize Q now
# opt_Q = Q -> -Q_state(plds.A, Q, E_zz, E_zz_prev)
# result_Q = optimize(opt_Q, plds.Q, LBFGS())

# @assert isapprox(result_Q.minimizer * result_Q.minimizer', SSM.update_Q_plds!(plds, E_zz, E_zz_prev))

In [28]:
function objective(x_vec, plds, y)
    T, d = size(y, 1), plds.latent_dim
    x = interleave_reshape(x_vec, T, d)
    return -logposterior(x, plds, y)
end

function grad!(G, x_vec, plds, y)
    T, d = size(y, 1), plds.latent_dim
    x = interleave_reshape(x_vec, T, d)
    grad_matrix = SSM.Gradient(x, plds, y)
    G[:] = vec(grad_matrix')
end

function hess!(H, x_vec, plds, y)
    T, d = size(y, 1), plds.latent_dim
    x = interleave_reshape(x_vec, T, d)
    hessian, _, _, _ = SSM.Hessian(x, plds, y)
    H[:] = reshape(hessian, 1, T*d*T*d)
end

function optimize_latent_states(initial_x, plds, y)
    T, d = size(initial_x)
    x_vec = vec(permutedims(initial_x))
    
    obj = x -> objective(x, plds, y)
    g! = (G, x) -> grad!(G, x, plds, y)
    h! = (H, x) -> hess!(H, x, plds, y)
    
    od = OnceDifferentiable(obj, g!, h!, x_vec)
    
    result = optimize(od, x_vec, Newton(), Optim.Options(show_trace = true))
    
    optimized_x = interleave_reshape(Optim.minimizer(result), T, d)
    return optimized_x, result
end

optimize_latent_states (generic function with 1 method)