# Fitting generalized models to continous spiking data via score matching

Model is
$$ y_m \sim PoissonProcess(\lambda_m) \\
\lambda_m(t) = exp\{\sum_d C_{md}x^{(d)}(t)+b_m\} \\
$$

where the $x^{(d)}(t)$ are latent functions shared between neurons ($D < M$)

Let us parametrize the latent functions in a finite bandwith Fourier basis, with weights $\Theta_{1:D,1:F}$ weighing the frequencies $\mathsf{F} = \{ \omega, 2\omega, \dots, F\omega \}$ and phases $\Phi_{1:D,1:F}$ 

Now we wish to optimize the score function
$$ J(C,b,\Theta, \Phi, \omega) = \frac{1}{2} \langle\sum_{i=1}^N \partial_{t_i} \log p^*(\mathsf{T}) - \partial_{t_i} \log p^{C,b,\Theta, \Phi, \omega}(\mathsf{T}) \rangle_{\mathsf{T} \sim P^*} $$

with respect to the parameters $C, b, \Theta, \Phi, \omega$. We fix the frequency content of the signal $\Theta$, but allow phases $\Phi$ to very between trials. This can be done by following the usual score matching tricks, then doing coordinate descent in the parameters.

In [None]:
using Optim
using Cubature

In [None]:
# Get fourier basis
function fourier_basis(t, omega, F, phi=zeros(F,1))
    out = zeros(F,size(t,1))
    dout=copy(out)
    ddout=copy(out)
    for f = 1:F
        out[f,:] = cos(2pi*f*omega*t+phi[f])
        dout[f,:] = -2pi*f*omega*sin(2pi*f*omega*t+phi[f])
        ddout[f,:] = -(2pi*f*omega)^2*cos(2pi*f*omega*t+phi[f])
    end
    
    return out, dout, ddout
end

In [None]:
function latent_functions(t, theta; omega=1, phi=zeros(size(theta,2),1))
    return theta*fourier_basis(t, omega, size(theta,2), phi)[1]
end
    

In [None]:
function lambda_functions(t, C, theta; b=zeros(size(C,1),1), phi=zeros(size(theta,2),1), omega=1)
    return exp(broadcast(+, C*theta*fourier_basis(t, omega, size(theta,2), phi)[1], b))
end

In [None]:
"""
**Using thinning of a higher rate homogeneous Poisson Process, samples spikes from an inhomogeneous one**

## Parameters
* T: Length of trial
* N: Number of trials
* C: MxD matrix mapping from D dim latent space to M neurons
* θ: DxF matrix, mapping from F dimensional Fourier function-space to D dim latent space
* ϕ: FxN matrix, setting the phase of the Fourier functions on each trial  
* b: Mx1 vector, setting the base firing rate for each neuron
"""
function sample_spikes(C, θ, ϕ, b, ω, T)
    N = size(ϕ, 2) #Trials
    M = size(C, 1) #Neurons
    D = size(C, 2) # Latent dimensions
    @show N, M, D
    
    out = Array(Any, N) # Output array, N array (for trials) of C arrays (for neuron spike times)

    # For each trial
    for n = 1:N
        out[n] = Array(Any, M)
        for m1 = 1:M out[n][m1]=zeros(1,1) end
        # First sample the underlying rate functions to create a homogeneous rate majoring each
        λ_samp = lambda_functions(0:0.001:T, C, θ, phi=ϕ[:,n], b=b, omega=ω);
        λ_hom = maximum(λ_samp,1)*1.05;
        reached_T = zeros(M,1)
        while sum(reached_T)<M
            #Add a new time point to the ones that haven't yet reached T
            for m1 in find(x->x==0, reached_T)
                re = randexp(1)
                out[n][m1] = vcat(out[n][m1], last(out[n][m1])+(re.*(1./λ_hom[m1]))[1])
                if last(out[n][m1])>T reached_T[m1]=1 end
            end
        end
        #Remove 0 from the start and the last one that is larger than T
        for m1 = 1:M 
            out[n][m1]=out[n][m1][2:end-1] 
        end
        
        # Do the thinning (i.e. Metropolis rejection)
        for m1=1:M
            u = rand(1,length(out[n][m1]));
            keep_probs = lambda_functions(out[n][m1], C[m1,:], θ, phi=ϕ[:,n], b=b[m1], omega=ω)/λ_hom[m1]
            out[n][m1] = out[n][m1][collect(u.<keep_probs)'[:]]
        end
        
    end
    
    return out
end        
    
    

In [None]:
# Define GLMparams type, and a function that returns the parameters for "sample_spikes" 

type GLMparams
    C::AbstractArray # MxD
    θ::AbstractArray # DxF
    ϕ::AbstractArray # FxN
    b::AbstractArray # Mx1
    ω::Float64 #base frequency
    T::Float64 # Trial length
end

function getall(x::GLMparams)
    return deepcopy(Any[x.C, x.θ, x.ϕ, x.b, x.ω, x.T])
end

function create_rand_params(M, D, F, N)
    return GLMparams(randn(M,D), randn(D,F)*0.3, 2*π.*rand(F, N), log(30)*ones(M,1), 1, 1)
end    

In [None]:
#=
using PlotlyJS


out = sample_spikes(getall(ex_params)...);
to_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(out,1))
rate_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(out,1))
for i1 = 1:5
    to_plot[i1] = scatter(;x=collect(out[i1][1]), y=i1*collect(ones(size(out[i1][1]))), mode="markers", yaxis="y2")
    bla = lambda_functions(0:0.01:ex_params.T, ex_params.C, ex_params.θ; b=ex_params.b, phi=ex_params.ϕ[:,i1], omega=ex_params.ω)
    rate_plot[i1] = scatter(; 
        x=collect(0:0.01:ex_params.T), 
        y=collect(bla[1,:]),
    yaxis = "y1"
)

end

lo = Layout(;xaxis_range=[0,1])

plot([Plot(to_plot[1:5]); Plot(rate_plot[1:5])])
=#

In [23]:
function name_params(params::AbstractArray)
    C=params[1]
    θ=params[2]
    ϕ=params[3]
    b=params[4]
    ω=params[5]
    
    N = size(ϕ,2)
    M = size(C,1)
    D = size(C,2)
    F = size(θ,2)
    
    return C, θ, ϕ, b, ω, N, M, D, F
end

function mean_cov(a::Array)
    out = zeros(size(a,1), size(a,1))
    for t = 1:size(a,2)
        out += a[:,t]*(a[:,t]')
    end
    out = out./size(a,2)
    return out
end

function mean_cov(a::Array, b::Array)
    out = zeros(size(a,1), size(b,1))
    assert(size(b,2)==size(a,2)) # Same number of samples to average over
    for t = 1:size(a,2)
        out += a[:,t]*(b[:,t]')
    end
    out = out./size(a,2)
    return out
end

function get_suff_stats(u)
# Store the relevant sufficient statitics (always avaraging over spikes of a neuron within a trial)
    N = size(u,1)
    M = size(u,2)
    S = Dict()
    S["cos"] = Array(Any, (N,M))
    S["sin"] = Array(Any, (N,M))
    S["Cov_sin"] = Array(Any, (N,M))
    S["XCov_sincos"] = Array(Any, (N,M))
    S["sum_n_cos"] = Array(Any, (M,))
    S["sum_n_Cov_sin"] = Array(Any, (M,))
    
    active_neur = zeros(N,M)
    
    for m = 1:M
        for n = 1:N
            if length(u[n,m]) > 0 # Skip neurons that are silent on certain trials
                S["cos"][n,m] = mean(cos(u[n,m]), 2)[:] # Fx1
                S["sin"][n,m] = mean(sin(u[n,m]), 2)[:] # Fx1
                S["Cov_sin"][n,m] = mean_cov(sin(u[n,m])) # FxF
                S["XCov_sincos"] = mean_cov(sin(u[n,m]), cos(u[n,m])) #FxF
                active_neur[n,m] = 1;
            end
        end
        S["sum_n_cos"][m] = sum(S["cos"][active_neur[:,m].==1, m]) #Fx1
        S["sum_n_Cov_sin"][m] = sum(S["Cov_sin"][active_neur[:,m].==1, m]) #Fx1
    end
    return S
end


function J(data::AbstractArray, params::AbstractArray; penalise_integrals=0)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    
    s = 2*π*collect(1:F)*ω; # Scaling factor
    
    J = 0;
    u = Array(Any, (N,M))
    for n = 1:N
        for m = 1:M
            u[n,m] = broadcast(+, s * data[n][m]', ϕ[:,n]); # Rescaled plus phase added
            if length(u[n,m]) > 0 # Skip silent neurons
                J += -mean(C[m,:]*θ*((s.^2).*cos(u[n,m])))
                J += 1./2.*trace( θ'*C[m,:]'*C[m,:]*θ*(s.*sin(u[n,m]))*(s.*sin(u[n,m]))')./size(u[n,m], 2);
            end
        end
    end
    
    # Optionally compute the integral of firing rate functions and add the total integral as penalty
    if !(penalise_integrals==0)
        T = params[6]
        # If non-zero, penalise_integral should be an Array(Any, 3) with:
        # penalise_integrals[1] - Array[n,m] where each field is the total number of spikes (should constrain the integrals relative to that)
        # penalise_integrals[2] - λ weight of integral surpassing number of spikes
        # penalise_integrals[3] - Array[...], arguments passed for h_quadrature of integration
        
        int_rates = zeros(N,M)
        for n = 1:N
            for m = 1:M
                rate_func(x) = lambda_functions(x, C, θ; b=b, phi=ϕ[:,n], omega=ω)[m]
                int_rate, err = hquadrature(rate_func, 0.0, T; penalise_integrals[3]...)
                int_rates[n,m] = int_rate;
            end
        end
        
        J += penalise_integrals[2] * 
            sum(((int_rates .- penalise_integrals[1]).^2)[:]) *
        (1/(N*M*T)) # Increase cost by error in rate function per neuron per trial per sec x weight 
    end
    
    return J, u, s
    
end

function J(S, params; penalise_integrals=0)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    s = 2*π*collect(1:F)*ω; # Scaling factor
    
    J = 0;
    for m = 1:M
        J += - C[m,:] * θ * (s.^2.* S["sum_n_cos"][m]);
        J += 1./2. * trace(θ'*C[m,:]'*C[m,:]*θ*broadcast(.*, broadcast(.*, s, S["sum_n_Cov_sin"][m]), s'));
    end     
    
    # Optionally compute the integral of firing rate functions and add the total integral as penalty
    if !(penalise_integrals==0)
        T = params[6]
        # If non-zero, penalise_integral should be an Array(Any, 3) with:
        # penalise_integrals[1] - Array[n,m] where each field is the total number of spikes (should constrain the integrals relative to that)
        # penalise_integrals[2] - λ weight of integral surpassing number of spikes
        # penalise_integrals[3] - Array[...], arguments passed for h_quadrature of integration
        
        int_rates = zeros(N,M)
        for n = 1:N
            for m = 1:M
                rate_func(x) = lambda_functions(x, C, θ; b=b, phi=ϕ[:,n], omega=ω)[m]
                int_rate, err = hquadrature(rate_func, 0.0, T; penalise_integrals[3]...)
                int_rates[n,m] = int_rate;
            end
        end
        
        J += penalise_integrals[2] * 
            sum(((int_rates .- penalise_integrals[1]).^2)[:]) *
        (1/(N*M*T)) # Increase cost by error in rate function per neuron per trial per sec x weight 
    end
    
    return J[1]
end

# Specific function for cost related to C[m,:] (m-th neuron only)
function JCm(S, params, m; penalise_integrals=0)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    s = 2*π*collect(1:F)*ω; # Scaling factor
    
    J = 0;
    J += - C[m,:] * θ * (s.^2.* S["sum_n_cos"][m]);
    J += 1./2. * trace(θ'*C[m,:]'*C[m,:]*θ*broadcast(.*, broadcast(.*, s, S["sum_n_Cov_sin"][m]), s'));
    
    
    # Optionally compute the integral of firing rate functions and add the total integral as penalty
    if !(penalise_integrals==0)
        T = params[6]
        # If non-zero, penalise_integral should be an Array(Any, 3) with:
        # penalise_integrals[1] - Array[n,m] where each field is the total number of spikes (should constrain the integrals relative to that)
        # penalise_integrals[2] - λ weight of integral surpassing number of spikes
        # penalise_integrals[3] - Array[...], arguments passed for h_quadrature of integration
        
        int_rates = zeros(N,M)
        for n = 1:N
            rate_func(x) = lambda_functions(x, C, θ; b=b, phi=ϕ[:,n], omega=ω)[m]
            int_rate, err = hquadrature(rate_func, 0.0, T; penalise_integrals[3]...)
            int_rates[n,m] = int_rate;
        end
        
        J += penalise_integrals[2] * 
            sum(((int_rates[:,m] .- penalise_integrals[1][:,m]).^2)[:]) *
        (1/(N*T)) # Increase cost by error in rate function per neuron per trial per sec x weight 
    end
    
    return J[1]
end

# Specific function for cost related to ϕ[:,n] (n-th trial only)
function Jϕn(data, params, n; penalise_integrals=0)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    
    s = 2*π*collect(1:F)*ω; # Scaling factor
    
    J = 0;
    u = Array(Any, (N,M))
    for m = 1:M
        u[n,m] = broadcast(+, s * data[n][m]', ϕ[:,n]); # Rescaled plus phase added
        if length(u[n,m]) > 0 # Skip silent neurons
            J += -mean(C[m,:]*θ*((s.^2).*cos(u[n,m])))
            J += 1./2.*trace( θ'*C[m,:]'*C[m,:]*θ*(s.*sin(u[n,m]))*(s.*sin(u[n,m]))')./size(u[n,m], 2);
        end
    end
    
    # Optionally compute the integral of firing rate functions and add the total integral as penalty
    if !(penalise_integrals==0)
        T = params[6]
        # If non-zero, penalise_integral should be an Array(Any, 3) with:
        # penalise_integrals[1] - Array[n,m] where each field is the total number of spikes (should constrain the integrals relative to that)
        # penalise_integrals[2] - λ weight of integral surpassing number of spikes
        # penalise_integrals[3] - Array[...], arguments passed for h_quadrature of integration
        
        int_rates = zeros(N,M)
        for m = 1:M
            rate_func(x) = lambda_functions(x, C, θ; b=b, phi=ϕ[:,n], omega=ω)[m]
            int_rate, err = hquadrature(rate_func, 0.0, T; penalise_integrals[3]...)
            int_rates[n,m] = int_rate;
        end
        
        J += penalise_integrals[2] * 
        sum(((int_rates[n,:] .- penalise_integrals[1][n,:]).^2)[:]) *
        (1/(N*M*T)) # Increase cost by error in rate function per neuron per trial per sec x weight 
    end
    
    return J, u, s
end


function gradC(u::AbstractArray, s, params::AbstractArray, m)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    
    ∇Cm = zeros(size(C[m, :]))
    Cm0 = zeros(size(C[m, :]))
    
    b = zeros(length(Cm0),1)
    A = zeros(length(b),length(b))
    
    for n = 1:size(u,1)
        b += mean(θ*((s.^2).*cos(u[n,m])), 2)
        A += θ*mean_cov(s.*sin(u[n,m]))*θ';
        
        ∇Cm += (-1*(mean(θ*((s.^2).*cos(u[n,m])), 2))
        + θ*mean_cov(s.*sin(u[n,m]))*θ'*C[m,:]')';
    end
    
    Cm0 = b\A;
    
    
    return ∇Cm, Cm0
end

function gradθ(S, s, params::AbstractArray)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    
    ∇θ = zeros(size(θ))
    θ0 = zeros(size(θ))
    
    for m = 1:M
        ∇θ -= ((s.^2).*S["sum_n_cos"][m]*C[m, :])'
        ∇θ += (C[m,:]'*C[m,:])*θ*((s*s').*S["Cov_sin"][m])
    end
    
    
    return ∇θ
end
    

function gradϕ(u::AbstractArray, s, params::AbstractArray, f, n)
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    ∇ϕfn = 0;
    
    for m=1:M
        ∇ϕfn += C[m,:]*θ[:,f]*s[f].^2*mean(sin(u[n,m][f,:]))[1] # Linear part
        ∇ϕfn += 0.5 * (θ[:,f]'*C[m,:]')*(C[m,:]*θ[:,f])*s[f].^2 * 2 * mean(sin(u[n,m][f,:]).*cos(u[n,m][f,:]))[1] # Quadratic part
        for f1 in [1:(f-1); (f+1):F]
            ∇ϕfn += 0.5 * ((θ[:,f]'*C[m,:]')*(C[m,:]*θ[:,f1])+(θ[:,f1]'*C[m,:]')*(C[m,:]*θ[:,f])) .* 
            s[f].^2 * mean(sin(u[n,m][f1,:]).*cos(u[n,m][f,:]))[1]
        end
    end
           
    
    return ∇ϕfn
end

gradϕ (generic function with 1 method)

In [24]:
"""
x shall be a parameter vector, but shaped in an Any[] array, 
such that each column of x is a set of parameter we want to do coordinate descent in.

Each of the function f_i is assumed to return the gradient of C(x) wrt the i-th column in params, 
then take a step in that direction
"""
function coordinate_descent(data::AbstractArray, params::AbstractArray; max_iter=10, verbose=1, penalise_integrals=0, optim_params=[])
    C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
    T = params[6]
    
    
    if !(penalise_integrals==0)
        # count the number of spikes for each neuron on each trials
        penalise_integrals[1] = zeros(N,M)
        for n = 1:N
            for m = 1:M
                penalise_integrals[1][n,m] = length(data[n][m])
            end
        end
    end
    
    @show N,M,D,F
    cost, u, s = J(data, params, penalise_integrals=penalise_integrals)
    if verbose>=1 println("Initial cost is $(cost)") end
    
    
    # Estimate mean firing rate of individual neurons across all trials (b vector)
    btmp = zeros(size(b))
    for n = 1:N
        for m = 1:M
            btmp[m] += size(data[n][m],1) # Counting total number of spikes for each neuron
        end
    end
    btmp = log(btmp./(N*T)) #diving total spikes by number of trials * length of each trial then taking log
    params[4] = btmp;
    b = deepcopy(params[4])
    
    
    # Get sufficient statistics (DEPENDS ON CURRENT SETTING OF ϕ !!!)
    cost, u, s = J(data, params, penalise_integrals=penalise_integrals)
    S = get_suff_stats(u)
    
    if verbose>=1 println("After setting b, cost is $(cost)") end
        
    
    # Use these sufficient statistics to perform the optimization steps
    
    for iters = 1:max_iter
        if verbose>=1 println(); println(); println("Starting iteration $(iters)...............") end
    
        # θ optimization
        function θ_opt(x::Vector)
            tmp_params = deepcopy(params)
            tmp_params[2] = reshape(x, size(params[2]));
            return J(S, tmp_params, penalise_integrals=penalise_integrals)
        end
        res = optimize(θ_opt, θ[:]; optim_params...)
        params[2] = reshape(Optim.minimizer(res), size(params[2]))
        cost, = J(S, params, penalise_integrals=penalise_integrals)
            C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
            if verbose>=1 println("Updated θ ----------------- Current cost is $(cost)") end


        # C[m,:] optimization - rowwise for neurons (should be equivalent to global minimization)
        for m = 1:M
            function C_m_opt(x::Vector)
                tmp_params = deepcopy(params)
                tmp_params[1][m,:] = reshape(x, size(params[1][m,:]));
                return JCm(S, tmp_params, m, penalise_integrals=penalise_integrals)
            end
            res = optimize(C_m_opt, C[m,:][:]; optim_params...)
            params[1][m,:] = reshape(Optim.minimizer(res), size(params[1][m,:]))
        end
        cost, = J(S, params, penalise_integrals=penalise_integrals)
        C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
        if verbose>=1 println("Updated C ----------------- Current cost is $(cost)") end


        # ϕ_n optimization - this step can be extremely easily parallelized
        for n = 1:N
            function ϕ_n_opt(x::Vector)
                tmp_params = deepcopy(params)
                tmp_params[3][:,n] = reshape(x, F);
                    # I think this optimizers doesn't need to worry about the integral
                #return Jϕn(data, tmp_params, n; penalise_integrals=penalise_integrals)
                return J(data, tmp_params; penalise_integrals=penalise_integrals)[1]
            end
            res = optimize(ϕ_n_opt, ϕ[:,n]; optim_params...)
            params[3][:,n] = reshape(Optim.minimizer(res), size(params[3][:,n]))
            cost, = J(data, params, penalise_integrals=penalise_integrals)
                C, θ, ϕ, b, ω, N, M, D, F = name_params(params);
            if verbose>=1 println("Updated ϕ_$(n) ------------ Current cost is $(cost)") end
        end
        # Update sufficient stats after chaning ϕ!!!
        cost, u, s = J(data, params, penalise_integrals=penalise_integrals)
        S = get_suff_stats(u)
        cost = J(S, params, penalise_integrals=penalise_integrals)
            
        if verbose>=1 println("Updated suff stats --------- Current cost is $(cost)") end
    
    end
    
    
    return params, S
end

coordinate_descent (generic function with 2 methods)