In [None]:
using Statistics

In [None]:
#Used for loading Phys data
using PyCall
pyABF = pyimport("pyabf")

In [None]:
#using DifferentialEquations
using Plots, Colors, LaTeXStrings, StatsPlots
using Plots.Measures
font_title = Plots.font("Arial", 24)
font_axis = Plots.font("Arial", 12)
font_legend = Plots.font("Arial", 8)
pyplot(titlefont=font_title, guidefont = font_axis, legendfont = font_legend)

#Define different symbols used in the plotting
delta = latexstring("\$\\delta\$")
micro = latexstring("\$\\mu\$")
vt = latexstring("\$V_t\$")
nt = latexstring("\$N_t\$")
bt = latexstring("\$B_t\$")
wt = latexstring("\$W_t\$")
i_ach = latexstring("\$I_{ACh}\$")

v_color = :deepskyblue
n_color = :magenta
c_color = :green
a_color = :purple
b_color = :red
e_color = :blue
w_color = :gray
phys_color = :black

figure_path = "C:/users/mtarc/JuliaScripts/RetinalChaos/Notebooks/Figures"

In [None]:
#Logging for 2D simulations
using Logging, TerminalLoggers
global_logger(TerminalLogger());

In [None]:
#Import all the functions for extracting parameters
import RetinalChaos: read_JSON, extract_dict 
import RetinalChaos: tar_conds, tar_pars
import RetinalChaos: calculate_threshold, count_intervals, timescale_analysis
import RetinalChaos: SOSRI, SDEProblem, solve
import RetinalChaos: T_sde, noise_2D
import RetinalChaos: parse_abf, extract_abf
import RetinalChaos: Network

## Table of contents:

### [0] Introduction to RetinalChaos.jl

### [1] Methods
#### [1.1] Voltage and Potassium gating components of the model
#### [1.2] Calcium and the Biochemical Reactions of the sAHP
#### [1.3] Acetylcholine Diffusion and Dynamics
#### [1.4] Gaussian White noise and dynamics

### _**[2] Analyzing Data Output of the model**_
#### [2.1] Time Scale Analysis
#### [2.2] **Fitting Experimental Data (Patch)**
#### [2.3] Fitting Experimental Data (Multi-electrode array)
#### [2.4] Fitting Experimental Data (Calcium Imageing) 

### [3] Experiments
#### [3.1] Dual Eye Correlations
#### [3.2] Recapitulations of older papers
#### [3.x] Dynamical Analysis

### Optimizing current Models

- We have so far established a acceptable model to characterize the neuronal firing of starburst amacrine cells, and established a method to quantify the timescales within the cells firing. 
- Our goal is to measure how closely our model and quantification is related to actual data including 

##### We can use data collected from whole cell patch clamp to quantify how closely our data matches 

### [2.2.a] Loading Physiological Data from .abf files
- All physiological traces were obtained in ClampEx
- The first thing we have to do is to load the .abf file.

In [None]:
target_folder = "D:\\2019_Renna Lab\\Data\\Patching\\"
target_file = "D:\\Data\\Patching\\2019_11_03_Patch\\Animal_2\\Cell_3\\19n03042.abf"
#Load experimental data
exp_data = pyABF.ABF(target_file)
println("File successfully loaded")
#Extract the time trace, and the 
t = exp_data.sweepX
dt_exp = t[2]
aq = 1/dt_exp
println("Data from time stamp $(t[1]) s to $(t[end]+dt_exp) s with dt = $dt_exp s")
println("Data was acquired at $aq Hz")
println("$(length(t)) data points")

##### Analyzing the data before loading into Julia. 
- It may be helpful to look over data before loading it into the Julia interface. 
- with 300s at 20kHz you will have 6 million data points. This can cause some overhead memory issues. 
- In this specific recording, we only care about 100s (from 130-230s). 

In [None]:
#In order to account for Junction potentials we can add an offset
v_offset = -25.0
#Extract the 1st sweep 1st channel from the .abf file
exp_data.setSweep(sweepNumber = 0, channel = 0);
#Vm is the raw data
vm = Float64.(exp_data.sweepY);
#We reduce the data to 100k data points between 120-_230s for memory saving and graphing. 
reduced = round.(Int, LinRange(Int(130/dt_exp), Int(230/dt_exp), Int(100e3)));
#For some graphing purposes, we reset the timestamps (keeping t as the original)
t_offset = (t[reduced] .- t[reduced][1]).*1000;
#The new interval should be 1ms or 1kHz
dt_sim = round(t_offset[2]-t_offset[1], digits = 2)
println("Reduced data acquisition $dt_sim")

In [None]:
#We have two traces:
#2) Normal for data analysis, but slow graphing with the orginal interval
Vm_trace = vm[Int(130/dt_exp):Int(230/dt_exp)].+v_offset;
#1) Reduced for quicker graphing with an interval of 1ms
Vm_trace_RED = vm[reduced].+v_offset;

In [None]:
plot(t_offset/1000, Vm_trace_RED, 
    xlabel = "Time (s)", ylabel = "Voltage (mV)", c = :black)

### [2.2.b] Timescale analysis
- As described in the previous notebook, we can extend the quantification methods used to analyze the simulation. 

In [None]:
threshold = calculate_threshold(Vm_trace);
println("The spiking threshold = $(round(threshold, digits = 2)) mV")
spike_phys = (Vm_trace .> threshold);
intervals = count_intervals(spike_phys)*dt_exp*1000;
println("$(length(intervals)+1) spiking events have been detected")

In [None]:
maximum(intervals)

In [None]:
xstops = [1, 5, 10, 50, 100, 500, 1000, 5000, 10000, 50000];
xticks = (log.(xstops), xstops);   
histogram(log.(intervals), yaxis = :log, xticks = xticks, ylabel = "Log Magnitude", xlabel = "Interspike Interval (ms)")

In [None]:
#Conducting the analysis in a single function you can use the imported function timescale_analysis
ts_analysis = timescale_analysis(Vm_trace, dt = dt_exp*1000, mode = 2);

In [None]:
ts_analysis

### [2.2.c] Visual comparison with simulations. 
- Using the process defined in the previous notebook, we can compare our physiologically obtained data to data that was simulated

In [None]:
p = read_JSON("params.json") |> extract_dict;
u0 = read_JSON("conds.json") |> extract_dict;
dt = 1.0
tspan = (0.0, 300e3);
SDEprob = SDEProblem(T_sde, u0, tspan, p)
println("Time it took to simulate 200ms:")
@time SDEsol = solve(SDEprob, SOSRI(), abstol = 2e-2, reltol = 2e-2, maxiters = 1e7, saveat = dt); 
trace = Array(SDEsol)';

In [None]:
##Here we can plot the traces against each other, they will not always align well.  
xlims = (0, t_offset[end]); ylims = (-90.0, 0.0)
ellapsed_time = (xlims[end]-xlims[1])/1000
dt_lims = 20e3
xticks = (collect(xlims[1]:dt_lims:xlims[2]), collect(0:(dt_lims/1000):ellapsed_time)); yticks = (collect(ylims[1]:5:ylims[2]))

p = plot(xlabel = "time (s)", ylabel = "membrane voltage (mV)")
plot!(p, SDEsol, vars = [:v], label = "Model data", 
    lw = 2.0, c = v_color)
plot!(p, t_offset, Vm_trace_RED, label = "Experimental Data", 
    lw = 2.0, c = phys_color,
    xlims = xlims, xticks = xticks, 
    ylims = ylims, yticks = yticks
)

In [None]:
#Conducting the analysis in a single function you can use the imported function timescale_analysis
println("Timescale analysis for simulation")
sim_spikes, sim_bursts, sim_IBIs = timescale_analysis(trace[:,1]; dt = dt, mode = 2);
sim_thresh = calculate_threshold(trace[:,1])
println("The spiking threshold = $(round(sim_thresh, digits = 2)) mV")
sim_spike = (trace[:,1] .> sim_thresh);
sim_intervals = count_intervals(sim_spike) .*dt;
println("$((sim_intervals|>length)+1) spiking events have been detected")

In [None]:
println("Timescale analysis for physiological trace")
exp_spikes, exp_bursts, exp_IBIs = timescale_analysis(Vm_trace, dt = dt_exp*1000, mode = 2);
exp_thresh = calculate_threshold(Vm_trace)
println("The spiking threshold = $(round(exp_thresh, digits = 2)) mV")
exp_spike = (Vm_trace .> exp_thresh);
exp_intervals = count_intervals(exp_spike) .*dt_exp*1000;
println("$((exp_intervals|>length)+1) spiking events have been detected")

In [None]:
xstops = [0.01, 0.1, 1, 10, 100, 1000, 10000, 50000];
xticks = (log.(xstops), xstops);  
histogram(log.(exp_intervals), yaxis = :log, label = "Experimental Data", c = :gray)
histogram!(log.(sim_intervals),yaxis = :log, label = "Simulation Data", c = v_color, xticks = xticks, ylabel = "Log Magnitude", xlabel = "Interspike Interval (ms)")

### [2.2.d] We can analyze many files in a automated fashion. 
- We need to be able to parse through a file heirarchy and grab not only the patch traces, but regions where the patch traces are decent. 
- Jordan has indicated which recordings are decent. This could help eliminate several traces which are not going to be good. 

In [None]:
target_folder = "D:\\Data\\Jordans_Patch_Data\\Starburst Recordings\\"
target_file = "D:\\Data\\Patching\\2019_11_03_Patch\\Animal_2\\Cell_3\\19n03042.abf"
@time paths = target_folder |> parse_abf;

In [None]:
function extract_abf(abf_path; verbose = false, v_offset = -25.0)
    if length(abf_path |> splitpath) > 1
        full_path = abf_path
    else
        full_path = joinpath(pwd(), abf_path)   
    end
    #extract the abf file by using pyABF
    exp_data = pyABF.ABF(full_path)
    #if the data is segmented into sweeps (which Jordans data is) concatenate all sweeps
    if length(exp_data.sweepList) > 1
        data = Float64[]
        time = Float64[]
        previous_time = 0.0
        for sweepNumber in exp_data.sweepList
            exp_data.setSweep(sweepNumber = sweepNumber, channel = 0);
            push!(data, exp_data.sweepY...);
            push!(time, (exp_data.sweepX.+previous_time)...);
            previous_time = time[end]
        end
        dt = time[2]*1000
    else
        exp_data.setSweep(sweepNumber = 0, channel = 0);
        data = Float64.(exp_data.sweepY);
        time = Float64.(exp_data.sweepX);
        dt = time[2]*1000
        
    end
    if verbose
        println("Data extracted from $full_path")
        println("Data from time stamp $(t[1]) s to $(t[end]+dt) s with dt = $dt ms")
        println("Data was acquired at $(1/dt/1000) Hz")
        println("$(length(t)) data points")
    end
    data, time, dt
end

In [None]:
spike_durs = Float64[]
burst_durs = Float64[]
IBIs = Float64[]
@time for (i, path) in enumerate(paths)
    #println(i)
    vm_trace, time, dt_exp = extract_abf(path; verbose = false);
    ts_analysis = timescale_analysis(vm_trace, dt = dt_exp);
    push!(spike_durs, ts_analysis[1])
    push!(burst_durs, ts_analysis[3])
    push!(IBIs, ts_analysis[5])
end

In [None]:
#Remove any NaN values
spike_durs = spike_durs[spike_durs .|> !isnan];
burst_durs = burst_durs[burst_durs .|> !isnan];
IBIs = IBIs[IBIs .|> !isnan];

This is the output from running the entire analysis on the patch files. 

In [None]:
mean_spike_dur = sum(spike_durs)/length(spike_durs)
std_spike_dur = std(spike_durs)
mean_burst_dur = sum(burst_durs)/length(burst_durs)
std_burst_dur = std(burst_durs)
mean_IBI = sum(IBIs./1000)/length(IBIs./1000)
std_IBI = std(IBIs./1000)

In [None]:
fig4_a1 = violin(spike_durs, label = "", c = :gray,
    xaxis = nothing, ylabel = "Spike Duration (ms)"
)
dotplot!(fig4_a1, repeat([1.0], length(spike_durs)), spike_durs, label = "", c = :black, markersize = 3.0)
hline!(fig4_a1, [mean_spike_dur], 
    c = :black, lw = 3.0, 
    label = "$(round(mean_spike_dur, digits = 2)) ms +-  $(round(std_spike_dur, digits = 2)) ms"
)

fig4_a2 = violin(burst_durs, label = "", c = :gray, 
    xaxis = nothing, ylabel = "Burst Duration (ms)"
)
dotplot!(fig4_a2, repeat([1.0], length(burst_durs)), burst_durs, label = "", c = :black, markersize = 3.0)
hline!(fig4_a2, [mean_burst_dur], 
        c = :black, lw = 3.0,
    label = "$(round(mean_burst_dur, digits = 2)) ms +-  $(round(std_burst_dur, digits = 2)) ms"
)

fig4_a3 = violin(IBIs./1000, label = "", c = :gray, 
    xaxis = nothing, ylabel = "Interburst Interval (ms)"
)
dotplot!(fig4_a3, repeat([1.0], length(IBIs)), IBIs./1000, label = "", c = :black, markersize = 3.0)
hline!(fig4_a3, [mean_IBI], 
    c = :black, lw = 3.0,
    label = "$(round(mean_IBI, digits = 2)) s +- $(round(std_IBI, digits = 2)) s"
)

fig4_A = plot(fig4_a1, fig4_a2, fig4_a3, layout = grid(1,3))

### [2.2.e] In order to make a good comparison with our data, we can run repeated trials.
- We can use the saved simulations from the previous notebook. 

In [None]:
#Set up the network
nx, ny = 96,96
SACnet = Network(nx, ny; μ = 0.65, version = :gACh)
p_dict = read_JSON("params.json");
u_dict = read_JSON("conds.json");
u0 = extract_dict(u_dict, tar_conds, (nx, ny));
p0 = extract_dict(p_dict, tar_pars);

In [None]:
println("Warming up solution")
prob = SDEProblem(SACnet, noise_2D, u0, (0.0, 60e3), p0);
@time sol = solve(
    prob,
    SOSRI(),
    abstol = 0.2,
    reltol = 2e-2,
    maxiters = 1e7,
    progress = true, 
    save_everystep = false
);

In [None]:
prob = SDEProblem(prob.f, prob.g, sol[end], (0.0, 60e3), prob.p);
sol = solve(
        prob,
        SOSRI(),
        abstol = 0.2,
        reltol = 2e-2,
        maxiters = 1e7,
        progress = true, 
        saveat = 1.0,
    );

In [None]:
using HDF5

In [None]:
vt_sol = copy(sol[:,:,1,:]);
#nt_sol = Array(sol[:,:,2,:]);
#ct_sol = Array(sol[:,:,3,:]);
#at_sol = Array(sol[:,:,4,:]);
#bt_sol = Array(sol[:,:,5,:]);
#et_sol = Array(sol[:,:,6,:]);

In [None]:
arr = zeros(nx, ny, length(sol.t));

In [None]:
h5open("test.h5", "w") do file
    
    for i = 1:length(sol.t)
        println(i)
        arr[:,:,i] = sol[:,:,1,i]
    end
    write(file, "Vt", arr)  # alternatively, say "@write file A"
end

In [None]:
sim_spike_durs = Float64[]
sim_burst_durs = Float64[]
sim_IBIs = Float64[]
for (idx, traj) in enumerate(sim)
    val = p_range[idx]
    #println("$idx -> $val")
    trace = Array(traj)'
    mean_spike_dur, std_spike_dur, mean_burst_dur, std_burst_dur, mean_ibi, std_ibi = timescale_analysis(trace[:,1]; dt = dt)
    push!(sim_spike_durs, mean_spike_dur)
    push!(sim_burst_durs, mean_burst_dur)
    push!(sim_IBIs, mean_ibi./1000)
end

In [None]:
sim_mean_spike_dur = sum(sim_spike_durs)/length(sim_spike_durs)
sim_std_spike_dur = std(sim_spike_durs)
sim_mean_burst_dur = sum(sim_burst_durs)/length(sim_burst_durs)
sim_std_burst_dur = std(sim_burst_durs)
sim_mean_IBI = sum(sim_IBIs)/length(sim_IBIs)
sim_std_IBI = std(sim_IBIs)

In [None]:
fig4_b1 = violin(spike_durs, label = "", c = :gray, ylabel = "Spike Duration (ms)", xticks = ((1.0, 2.0), ("Experiment", "Simulation")))
dotplot!(fig4_b1, repeat([1.0], length(spike_durs)), spike_durs, label = "", c = :black, markersize = 2.0)
plot!(fig4_b1, [0.5, 1.5], [mean_spike_dur,  mean_spike_dur], 
    c = :black, lw = 3.0, linestyle = :dash,
    label = ""#"$(round(mean_spike_dur, digits = 2)) ms +-  $(round(std_spike_dur, digits = 2)) ms"
)

violin!(fig4_b1, repeat([2.0], length(sim_spike_durs)), sim_spike_durs, c = v_color, label = "")
dotplot!(fig4_b1, repeat([2.0], length(sim_spike_durs)), sim_spike_durs, label = "", c = :black, markersize = 5.0)
plot!(fig4_b1, [1.5, 2.5], [sim_mean_spike_dur,  sim_mean_spike_dur], 
    c = :black, lw = 3.0,
    label = ""#"$(round(sim_mean_spike_dur, digits = 2)) ms +-  $(round(sim_std_spike_dur, digits = 2)) ms"
)

fig4_b2 = violin(burst_durs, label = "", c = :gray, ylabel = "Burst Duration (ms)", xticks = ((1.0, 2.0), ("Experiment", "Simulation")))
dotplot!(fig4_b2, repeat([1.0], length(burst_durs)), burst_durs, label = "", c = :black, markersize = 3.0)
plot!(fig4_b2, [0.5, 1.5], [mean_burst_dur,  mean_burst_dur], 
    c = :black, lw = 3.0, linestyle = :dash,
    label = ""#"$(round(mean_spike_dur, digits = 2)) ms +-  $(round(std_spike_dur, digits = 2)) ms"
)


violin!(fig4_b2, repeat([2.0], length(sim_burst_durs)), sim_burst_durs, c = v_color, label = "")
dotplot!(fig4_b2, repeat([2.0], length(sim_burst_durs)), sim_burst_durs, label = "", c = :black, markersize = 5.0)
plot!(fig4_b2, [1.5, 2.5], [sim_mean_burst_dur,  sim_mean_burst_dur], 
    c = :black, lw = 3.0,
    label = ""#"$(round(sim_mean_burst_dur, digits = 2)) ms +-  $(round(sim_std_burst_dur, digits = 2)) ms"
)

fig4_b3 = violin(IBIs./1000, label = "", c = :gray, ylabel = "Interburst Interval(s)", xticks = ((1.0, 2.0), ("Experiment", "Simulation")))
dotplot!(fig4_b3, repeat([1.0], length(IBIs./1000)), IBIs./1000, label = "", c = :black, markersize = 3.0)
plot!(fig4_b3, [0.5, 1.5], [mean_IBI,  mean_IBI], 
    c = :black, lw = 3.0, linestyle = :dash,
    label = ""#"$(round(mean_spike_dur, digits = 2)) ms +-  $(round(std_spike_dur, digits = 2)) ms"
)


violin!(fig4_b3, repeat([2.0], length(sim_IBIs)), sim_IBIs, c = v_color, label = "")
dotplot!(fig4_b3, repeat([2.0], length(sim_IBIs)), sim_IBIs, label = "", c = :black, markersize = 5.0)
plot!(fig4_b3, [1.5, 2.5], [sim_mean_IBI,  sim_mean_IBI], 
    c = :black, lw = 3.0,
    label = ""#"$(round(sim_mean_burst_dur, digits = 2)) ms +-  $(round(sim_std_burst_dur, digits = 2)) ms"
)

fig4_b = plot(fig4_b1, fig4_b2, fig4_b3, layout = grid(1,3))
title!(fig4_b[1], "A", title_pos = :left)

In [None]:
fig4_c1 = histogram(spike_durs, 
    c = :gray, legend = nothing, 
    xlabel = "Spike Duration (ms)", ylabel = "Count")
histogram!(fig4_c1, sim_spike_durs, c = v_color)
fig4_c2 = histogram(burst_durs, 
    c = :gray, label = "Experimental Data", 
    xlabel = "Burst Duration (ms)", ylabel = "Count")
histogram!(fig4_c2, sim_burst_durs, c = v_color, label = "Simuation Data")
fig4_c3 = histogram(IBIs./1000, 
    c = :gray, legend = nothing, 
    xlabel = "Interburst Interval (s)", ylabel = "Count")
histogram!(fig4_c3, sim_IBIs, c = v_color)

fig4_c = plot(fig4_c1,fig4_c2, fig4_c3, layout = grid(3,1))
title!(fig4_c[1], "A", title_pos = :left)

### [2.2.f] Wave vs Isolated
- We will go more in depth (or we have gone more in depth) with the differences between wave simulation and isolated simulations. The point is however that wave simulations result in usually longer burst durations. 

### [2.2.e] Actively fitting the Data
- Using a combination of a Ensemble Test and data analysis we can adjust parameters of the model to try to pick the optimal

In [None]:
n_sims = 50
par_sym = :V4
par = findall(isequal(par_sym), Symbol.(T_sde.ps))[1]
p_range = LinRange(1.0, 14.0, n_sims)
prob_func(prob, i, repeat) = ensemble_func(prob, i, repeat; pars = par, rng = p_range)
ensemble_prob = EnsembleProblem(SDEprob, prob_func = prob_func)

In [None]:
@time sim = solve(ensemble_prob, SOSRI(), abstol = 2e-2, reltol = 2e-2, maxiters = 1e7, saveat = dt, trajectories = n_sims, EnsembleThreads(), save_idxs = [1]);

In [None]:
vars = [:v]
vals = []; e_spike_durs = []; e_burst_durs = []; e_ibis = []
for (idx, traj) in enumerate(sim)
    val = p_range[idx]
    #println("$idx -> $val")
    trace = Array(traj)'
    try
        e_mean_spike_dur, e_std_spike_dur, e_mean_burst_dur, e_std_burst_dur, e_mean_ibi, e_std_ibi = timescale_analysis(trace[:,1]; dt = dt)
        push!(vals, val)
        push!(e_spike_durs, e_mean_spike_dur)
        push!(e_burst_durs, e_mean_burst_dur)
        push!(e_ibis, e_mean_ibi./1000)
    catch
        
    end
end
p = plot(layout = grid(3,1))
plot!(p[1], vals, e_ibis, marker = :circle, markersize = 5, lw = 3.0, ylabel = "IBI (s)")
hline!(p[1], [mean_IBI], label = "$(round(mean_IBI, digits = 2)) ms +-  $(round(std_spike_dur, digits = 2)) ms")

plot!(p[2], vals, e_burst_durs, marker = :circle, markersize = 5, lw = 3.0, ylabel = "Bursts (ms)")
hline!(p[2], [mean_burst_dur], label = "$(round(mean_burst_dur, digits = 2)) ms +-  $(round(std_burst_dur, digits = 2)) ms")

plot!(p[3], vals, e_spike_durs, marker = :circle, markersize = 5, lw = 3.0, ylabel = "Spikes (ms)", xlabel = "$par_sym")
hline!(p[3], [mean_spike_dur], label = "$(round(mean_spike_dur, digits = 2)) s +- $(round(std_IBI, digits = 2)) s")
