# Estimating firing rates from neural rasters

### Data: 
N trials (length 1), M neurons, {$t_{m,i}^n$} spike times for neuron M on trial N

### Model:
* D latent functions over [0, 1], each function is in an RKHS with kernel $k_d$

$$ x^n_d(t) = \sum_{j=1}^J \alpha_{d,j}^n k_d(t, u^d_j) $$

* $\mathbf{C} \in \mathsf{R}^{M\times D}$ maps the vector of latent functions drawn for trial n ($\mathbf{x}^n$) into neural space

* $\mathbf{b} \in \mathsf{R}^M_+$ is the log mean firing rate for neurons.

* There is a non-linearity between the linear combination of latent functions, and firing rates $\mathbf{\lambda}_m^n(t)$

$$ \lambda_m^n(t) = e^{\mathbf{C}\mathbf{x}^n(t) + \mathbf{b}} $$


### Log-likelihood
$\log p(y | \lambda) = \sum_{m,n,i}\log\lambda_m^n(t_{m,i}^n) - \int_t \lambda_m^n(t) dt$

### Score function with penalties

$$ \mathcal{J}(\mathbf{C}, \alpha, u) = \\ \\
 \text{Score objective:} \\ \\
 \sum_{m,n,i} [ \frac{1}{2}(\sum_d C_{md} \sum_j \alpha_{d,j}^n \nabla_t k_d(t_{m,i}^n, u^d_j))^2 \\
 + \sum_d C_{md} \sum_j \alpha_{d,j}^n (\nabla_t)^2 k_d(t_{m,i}^n, u^d_j) ] \\
 \text{RKHS smoothness:} \\
+ \eta_\text{RKHS} \sum_{n,d} \| x_d^n(t) \|_\text{RKHS} \\
 \text{Loading matrix penalty} \\
 + \eta_\text{loading} \sum_m \| C_{m,\cdot} \|_p 
 $$
 
 
 Different choices of the $p$-norm entail different penalty types for C:
 * $p=0$ - sparsity
 * $p=1$ - semi-sparse
 * $p=2$ - limited power
 * $p=\infty$ - limited maximal contribution

### Kernel choices

In order to evaluate the score matching objective, we need to be able to cheaply evaluate the time-derivatives of kernels around the data points.

Different kernels mean different interpretations. We want to design the D individual kernels such that the row space of $\mathbf{C}$ mapping from the individual RKHSs is interpretable biologically.

Certain kernels are useful to represent certain behavior:
* Gaussian (RBF) kernel - Different $\sigma$s represent frequency bands
* Linear / polinomial kernel - Represents trends in the data ?
* Sobolev-like kernel - Possibly frequency band limited -> Frequency content of the data

In [None]:
#using MLKernels # Kernel functions (cannot differentiate)
#include("/Users/gergobohner/Dropbox/Gatsby/scripts/julia/MLKernels.jl/src/MLKernels.jl") # my version for ForwardDiffable kernels
#using ForwardDiff
using MLKernels
using Optim # Optimisation for parameters
using Calculus

In [None]:
# Define parameters
type PP_params
    C::AbstractArray # M x D
    α::AbstractArray # N x D x {J_i}
    u::AbstractArray # D x {J_i}
end

function create_rand_params{T<:Integer}(N::T, M::T, D::T, J::AbstractArray{T})
    #C = rand(M,D)
    C = ones(M,D)
    α = Array(Any, N, D)
    u = Array(Any, D)
    for d = 1:D
        #u[d] = rand(J[d])
        u[d] = linspace(0.1, 0.9, J[d])
        for n = 1:N
            α[n,d] = randn(J[d])
        end
    end
    
    return PP_params(C, α, u)
end


function name_params(params::AbstractArray)
    C=params[1]
    α=params[2]
    u=params[3]
    
    N = size(α,1)
    M = size(C,1)
    D = size(C,2)
    
    J = zeros(D)
    for d = 1:D
        J[d] = length(u[d])
    end 
    
    return C, α, u, N, M, D, J
end

function name_params(params::Union{PP_params, PoissonProcessEstimation.RKHS_params})
    C=params.C
    α=params.α
    u=params.u
    
    N = size(α,1)
    M = size(C,1)
    D = size(C,2)
    
    J = Array(Integer, D)
    for d = 1:D
        J[d] = round(length(u[d]),1)
    end 
    
    return C, α, u, N, M, D, J
end

function param_to_vec(θ::PP_params)
    C, α, u, N, M, D, J = name_params(θ)
    out = Array(Any,3)
    out[1] = C
    out[2] = α
    out[3] = u
    return out
end

function vec_to_param(params::AbstractArray)
    return PP_params(out...)
end

In [None]:
# Visualisation functions
function firing_rates(t_grid, θ::Any, KernList::AbstractArray; n_range=1, d_range=1, m_range=1)
    C, α, u, N, M, D, J = name_params(θ)
    latent_funcs = Array(Any,(maximum(n_range), maximum(d_range)))
    for n = n_range
        for d = d_range
            tmp = function(t::Float64)
                out = 0
                for j = 1:length(u[d])
                     out += (α[n,d][j] * kernel(KernList[d], t, u[d][j]));
                end
                return out
            end
            latent_funcs[n, d] = tmp
        end
    end
    
    firing_rate = Array(Any, (maximum(n_range), maximum(m_range)))
    for n = n_range
        for m = m_range
            firing_rate[n,m] = zeros(length(t_grid))
            for t = 1:length(t_grid)
                for d = d_range
                    firing_rate[n,m][t] += C[m,d]*latent_funcs[n,d](t_grid[t]);
                end
            end
            firing_rate[n,m] = exp(firing_rate[n,m]) # exponential link function
        end
    end
                
    return firing_rate
end

In [None]:
"""
The function takes data, parameters and a list of latent kernels in, and computes the cost defined above

gives an option of computing the cost wrt only a single trial or neuron or latent dimension
"""
function Cost(data::AbstractArray, θ::Union{PP_params, AbstractArray}, KernList::AbstractArray; 
    slicing=[])
    
    # Name the parameters
    C, α, u, N, M, D, J = name_params(θ)
   
    # Compute kernel derivatives wrt second argument
    
    # Define the computations we need to make
    n_range = 1:N
    m_range = 1:M
    d_range = 1:D
    
    if !isempty(slicing)
        if slicing[1] == "d"
            d_range = slicing[2]
        elseif slicing[1] == "n"
            n_range = slicing[2]
        elseif slicing[1] == "m"
            m_range = slicing[2]
        end
    end
    
    
    num_spikes = 0
    score = 0
    rkhs_norm = 0
    C_penalty = 0
    
    # Computing the score
    for n = n_range
        for m = m_range
            # Compute the mean score for the events of neuron m on trial n
            score_nm = 0
            for t = 1:length(data[n][m])
                num_spikes += 1
                score_nmt = 0
                # First derivative squared
                for d = d_range
                    for j = 1:J[d]
                        dK = Calculus.derivative(x->kernel(KernList[d], u[d][j], x))
                        score_nmt += C[m,d]*α[n,d][j]*dK(data[n][m][t])
                    end
                end
                score_nmt = score_nmt.^2
                
                #Second derivative
                for d = d_range
                    for j = 1:J[d]
                        ddK = Calculus.second_derivative(x->kernel(KernList[d], u[d][j], x))
                        score_nmt += C[m,d]*α[n,d][j]*ddK(data[n][m][t])
                    end
                end
                score_nm += score_nmt
            end
            score += score_nm
        end
    end
    

    # Computing the RKHS norm
    rkhs_norm = 0
    for n = n_range
        for d = d_range
            if length(u[d]) > 1
                rkhs_norm += α[n,d]'*kernelmatrix(KernList[d], reshape(u[d],J[d],1))*α[n,d]
            else
                rkhs_norm += α[n,d][1].^2
            end
        end
    end
    
    # Computing the C-norm
    C_2norm = 0
    C_infnorm = 0
    for m = m_range
        C_2norm += norm(C[m,:])
        C_infnorm += norm(C[m, :], Inf)
    end
                
    
    return (score/num_spikes + 
            1e1*exp(rkhs_norm/(length(n_range)*length(d_range))) + 
            1e0*exp((C_2norm+C_infnorm)/length(m_range)))
end

In [None]:
# Load the data
using MAT # to load mat file
# Load some data
file = matopen("MotorHam1.mat")
rawdata = read(file, "S")
size(rawdata) #neuron by timebin by trial
# read the data into our input format, trial x (neuron) x ((spiketimes)))
hasfired = zeros(size(rawdata,1),1) # track if a neuron has fired ever
data = Array(Any, (size(rawdata,3),))
for n = 1:length(data)
    data[n] = Array(Any, (size(rawdata,1),))
    for m = 1:size(rawdata,1)        
        data[n][m] = find(x->(x>=1), rawdata[m,201:1200,n])*1e-3 # set of spiketimes between 0 and 1 sec
        if length(data[n][m]) > 10
            hasfired[m] = 1
        end
    end
end

# trim the data of neurons that are silent on every trial (no spikes whatsoever within the whole dataset)
for n = 1:length(data)
    data[n] = data[n][find(x->x==1, hasfired)]
end
        

In [None]:
# KernList = [PeriodicKernel(1.0, convert(Float64, f1)) for f1 in collect(2:2:10)]
# KernList = [LinearKernel(1.0); KernList]
KernList = [GaussianKernel(200.0), GaussianKernel(50.0), GaussianKernel(10.0)]

In [None]:
# θ = create_rand_params(5,3,length(KernList),repmat([1],length(KernList)));
θ = create_rand_params(30,20,length(KernList),[8,8,8]);

In [None]:
@time Cost(data, θ, KernList)

In [None]:
# Coordinate descent functions
function opt_αn!(x, n, θ::PP_params, KernList)
    C, α, u, N, M, D, J = name_params(θ)
    count = 1
    for d=1:D
        θ.α[n,d] = x[count:(count+J[d]-1)]
        count+=J[d]
    end
    return Cost(data, θ, KernList, slicing=["n", n])[1]
end

function opt_Cm!(x, m, θ::PP_params, KernList)
    θ.C[m,:] = reshape(x, size(θ.C[m,:]))    
    return Cost(data, θ, KernList, slicing=["m", m])[1]
end

function opt_ud!(x, d, θ::PP_params, KernList)
    if length(x)>1
        θ.u[d] = reshape(x, size(θ.u[d]))    
    else
        θ.u[d] = x
    end
    return Cost(data, θ, KernList)[1]
end


function opt_all!(x, θ::PP_params, KernList)
    C, α, u, N, M, D, J = name_params(θ)
    count = 1;
    
    # Set C
    θ.C = reshape(x[count:(count+M*D-1)], M, D);
    count = count + M*D
    
    # Set α
    for n = 1:N
        for d=1:D
            θ.α[n,d] = x[count:(count+J[d]-1)]
            count+=J[d]
        end
    end
    
    # Set u
    for d=1:D
        θ.u[d] = x[count:(count+J[d]-1)]
        count+=J[d]
    end
    
    Cost(data, θ, KernList)
end
    

In [None]:
θ_orig = θ
θ_opt = deepcopy(θ)
C, α, u, N, M, D, J = name_params(θ_opt);

In [None]:
# Run the coordinate descent
#opt_method = LBFGS()
opt_method = NelderMead()

for iters = 1:3
    @show iters, Cost(data, θ_opt, KernList)
    
    
    @time for n=1:N
        res = optimize(x->(opt_αn!(x, n, θ_opt, KernList)), vcat(θ_opt.α[n,:]...), method=opt_method)
        #@show n, Cost(data, θ_opt, KernList), Optim.converged(res)
    end
    @show N, Cost(data, θ_opt, KernList)
    #=
    @time for d=1:D
        if length(u[d])>1
            res = optimize(x->(opt_ud!(x, d, θ_opt, KernList)), collect(θ.u[d][:]), method=opt_method)
        else
            res = optimize(x->(opt_ud!(x, d, θ_opt, KernList)), 0.0, 1.0)
        end
        #@show d, Cost(data, θ_opt, KernList), Optim.converged(res)
    end
    =#
    @show D, Cost(data, θ_opt, KernList)
    @time for m=1:M
        res = optimize(x->(opt_Cm!(x, m, θ_opt, KernList)), θ.C[m,:][:], method=opt_method)
        #@show m, Cost(data, θ_opt, KernList), Optim.converged(res)
    end
    @show M, Cost(data, θ_opt, KernList)

    #=
    α_vec = vcat(θ_opt.α[1,:]...)
    for n = 2:N α_vec = vcat(α_vec, θ_opt.α[n,:]...) end
    res = optimize(x->opt_all!(x, θ_opt, KernList), vcat(θ_opt.C[:], α_vec, θ.u...), method=opt_method)
    =#
end

In [None]:
# Plot data, true rate function and estimated rate function for a neuron on given trials
using PlotlyJS

num_neur = 1
trials = 1:30
dim_latent = 1:D

#Plot estimated
to_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(data,1))
rate_rand_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(data,1))
rate_est_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(data,1))
latent_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(data,1))
latent_est_plot = Array(PlotlyJS.GenericTrace{Dict{Symbol,Any}}, size(data,1))
colors = ["rgb(0,0,255)", "rgb(0,255,0)", "rgb(255,0,0)", "rgb(128,128,0)", "rgb(0,128,128)"]
for i1 = trials
    #@show collect(data[i1][num_neur])
    # Plot data
    to_plot[i1] = scatter(;x=collect(data[i1][num_neur]), y=i1*collect(ones(size(data[i1][num_neur]))), mode="markers", marker_color=colors[mod(i1,5)+1], yaxis="y2")
    
    # Compute underlying rate and latent functions
    rate_rand = firing_rates(0:0.001:1, θ_orig, KernList, n_range=i1, d_range=dim_latent, m_range=num_neur)
    rate_est = firing_rates(0:0.001:1, θ_opt, KernList, n_range=i1, d_range=dim_latent, m_range=num_neur)
    

    rate_est_plot[i1] = scatter(; 
    x=collect(0:0.001:1), 
    y=collect(rate_est[i1, num_neur]),
    line_color = colors[mod(i1,5)+1],
    yaxis = "y0"
)
    
    rate_rand_plot[i1] = scatter(; 
    x=collect(0:0.001:1), 
    y=collect(rate_rand[i1, num_neur]),
    line_color = colors[mod(i1,5)+1],
    yaxis = "y0"
)

end

lo = Layout(;xaxis_range=[0,1])

#plt = plot([Plot(to_plot[trials]); Plot(rate_plot[trials]);  Plot(rate_est_plot[trials]);  Plot(latent_plot[trials]);  Plot(latent_est_plot[trials])])
#plt = plot([Plot(to_plot[trials], lo); Plot(rate_est_plot[trials], lo); Plot(rate_rand_plot[trials],lo)])
plt = plot([Plot(to_plot[trials], lo); Plot(rate_est_plot[trials],lo); Plot(rate_rand_plot[trials],lo)])

relayout!(plt, height=700)

plt

In [None]:
[1:M θ_opt.C]

In [None]:
inds = sortperm(θ_opt.u[1])
[θ_opt.u[1][inds] θ_opt.α[1,1][inds]]

In [None]:
# Freq power for neuron num_neur
num_neur = 1
C[num_neur,:]*θ_opt.α[5,:][:]

In [None]:
θ_opt.α[3,1]

In [None]:
collect(θ_opt.u[1])

In [None]:
θ.u[1]

In [None]:
using JLD
save("MotorHam1_results_kernel_big_closedform_onlyalpha.jld", "theta_opt", θ_opt, "data", data, "theta", θ, "KernList", KernList)

In [None]:
using NBInclude
nbinclude("PoissonProcessestimation.jl/PoissonProcessEstimation.ipynb")
import PoissonProcessEstimation

In [None]:
θ_opt2 = deepcopy(PoissonProcessEstimation.RKHS_params(θ_orig.C, θ_orig.α[1:5,:], θ_orig.u))
PoissonProcessEstimation.optimise!(data, θ_opt2, KernList)

In [None]:
θ_opt_slow =θ_opt
θ_opt = θ_opt2