In [3]:
using OrdinaryDiffEq
using LinearAlgebra
using DiffEqSensitivity
using Optim
using Zygote
using DelimitedFiles
using TimerOutputs
using BenchmarkTools

In [6]:
" Reshape a vector of parameters into the variables we know. "
@views function reshapeParams(p)
    w = reshape(p[1:6889], (83, 83))
    alpha = p[6890:6972]
    eps = p[6973:7055]
    
    @assert length(eps) == 83
    @assert length(alpha) == 83
    
    return w, alpha, eps
end

" Melt the variables back into a parameter vector. "
function unshapeParams(w, alpha, eps)
    return vcat(vec(w), alpha, eps)
end

" The ODE equations we're using. "
function ODEeq(du, u, p, t)
    w, alpha, eps = reshapeParams(p)
    
    du .= eps .* (1 .+ tanh.(w * u)) .- alpha .* u
end

" Solve the ODE system. "
function solveODE(ps)
    u0 = zeros(83)
    tspan = (0.0, 10000.0)
    prob = ODEProblem(ODEeq, u0, tspan, ps)
    sol = last(solve(prob, AutoTsit5(TRBDF2()); saveat = tspan[2], reltol=1e-8, abstol=1e-8))
    return sol
end

ps = ones(9408)
solveODE(ps);

In [4]:
" Load the experimental data matrix. "
function get_data(path_RNAseq)
    # Import RNAseq data as 83 x 84 matrix preprocessed using python
    exp = DelimitedFiles.readdlm(path_RNAseq, ',', Float64)
    exp = Matrix(exp)
    return exp
end

get_data

In [41]:
" Remove the effect of one gene across all others to simulate the KO experiments. "

function simKO(pIn, geneNum)
    pIn = copy(pIn) # Need to copy as we're using views
    w_temp, alpha, eps = reshapeParams(pIn)
    
    # Think we remove a column since this is the effect of one gene across all genes
    if geneNum == 1
        w = hcat(zeros(Float64, 83, 1), w_temp[:, 2:83])
    elseif geneNum == 83
        w = hcat(w_temp[:, 1:82], zeros(Float64, 83, 1))
    else
        w= hcat(w_temp[:, 1:geneNum - 1], zeros(Float64, 83, 1), w_temp[:, geneNum + 1:83])
    end
    
    pIn = unshapeParams(w, alpha, eps)
    
    return solveODE(pIn)
end

simKO(ps, 1)

83-element Array{Float64,1}:
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 ⋮
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0
 2.0

In [None]:
" Cost function. Returns SSE between model and experimental RNAseq data. "


function cost(pIn)
    exp_data = get_data("./de/data/exp_data.csv")
    sse = 0
    for i = 1:83
        sol_temp = simKO(pIn, i)
        sse += sum((sol_temp .- exp_data[:, i]) .^ 2)
    end
    neg = solveODE(pIn)
    sse += sum((neg .- exp_data[:, 84]) .^ 2)
    return sse
end

# We can also take the gradient of the cost
@btime grads = Zygote.gradient(cost, ps)

In [None]:
function g!(G, x)
    grads = Zygote.gradient(cost, x)
    G[:] .= grads[1]
end

#optimize(cost, g!, ps, LBFGS(), Optim.Options(iterations = 10, show_trace = true))

In [14]:

to = TimerOutput()


function simKO(pIn, geneNum)
    pIn = copy(pIn) # Need to copy as we're using views
    @timeit to "reshape" w_temp, alpha, eps = reshapeParams(pIn)
    
    # Think we remove a column since this is the effect of one gene across all genes
    if geneNum == 1
        @timeit to "hcat1" w = hcat(zeros(Float64, 83, 1), w_temp[:, 2:83])
    elseif geneNum == 83
        @timeit to "hcat2" w = hcat(w_temp[:, 1:82], zeros(Float64, 83, 1))
    else
        @timeit to "hcat3" w= hcat(w_temp[:, 1:geneNum - 1], zeros(Float64, 83, 1), w_temp[:, geneNum + 1:83])
    end
    
    @timeit to "unshape" pIn = unshapeParams(w, alpha, eps)
    
    @timeit to "solve" return solveODE(pIn)
end

function cost(pIn)
    exp_data = get_data("./de/data/exp_data.csv")
    sse = 0
    for i = 1:83
        @timeit to "loop1" sol_temp = simKO(pIn, i)
        @timeit to "loop2" sse += sum((sol_temp .- exp_data[:, i]) .^ 2)
        end
    end
    neg = solveODE(pIn)
    sse += sum((neg .- exp_data[:, 84]) .^ 2)
    return sse
end

cost(ps)
show(to)


[0m[1m ────────────────────────────────────────────────────────────────────[22m
[0m[1m                     [22m        Time                   Allocations      
                     ──────────────────────   ───────────────────────
  Tot / % measured:       806ms / 74.6%            161MiB / 93.7%    

 Section     ncalls     time   %tot     avg     alloc   %tot      avg
 ────────────────────────────────────────────────────────────────────
 loop1           83    585ms  97.4%  7.05ms    149MiB  98.8%  1.79MiB
   solve         83    569ms  94.7%  6.86ms    130MiB  86.0%  1.56MiB
   hcat3         81   5.97ms  0.99%  73.7μs   8.56MiB  5.68%   108KiB
   unshape       83   3.17ms  0.53%  38.2μs   4.49MiB  2.98%  55.3KiB
   hcat2          1    106μs  0.02%   106μs    108KiB  0.07%   108KiB
   reshape       83   88.0μs  0.01%  1.06μs   16.9KiB  0.01%     208B
   hcat1          1   66.4μs  0.01%  66.4μs    108KiB  0.07%   108KiB
 loop2        6.89k   15.9ms  2.65%  2.31μs   1.79MiB  1.19%   