In [1]:
#using DifferentialEquations
using Plots, Colors, LaTeXStrings, StatsPlots, Dates
using Plots.Measures
font_title = Plots.font("Arial", 24)
font_axis = Plots.font("Arial", 12)
font_legend = Plots.font("Arial", 8)
pyplot(titlefont=font_title, guidefont = font_axis, legendfont = font_legend)

#Define different symbols used in the plotting
delta = latexstring("\$\\delta\$")
micro = latexstring("\$\\mu\$")
vt = latexstring("\$V_t\$")
nt = latexstring("\$N_t\$")
bt = latexstring("\$B_t\$")
wt = latexstring("\$W_t\$")
i_ach = latexstring("\$I_{ACh}\$")

v_color = :deepskyblue
n_color = :magenta
c_color = :green
a_color = :purple
b_color = :red
e_color = :blue
w_color = :gray

figure_path = "C:/users/mtarc/JuliaScripts/RetinalChaos/Notebooks/Figures"

"C:/users/mtarc/JuliaScripts/RetinalChaos/Notebooks/Figures"

In [2]:
#If you have a GPU available, you can run the code using 
#using CuArrays
#gpu = true; CuArrays.allowscalar(false)

In [3]:
#otherwise set gpu to false
gpu = false;

In [4]:
using Logging, TerminalLoggers
global_logger(TerminalLogger());

In [5]:
import RetinalChaos: Network, extract_dict, read_JSON
import RetinalChaos: tar_conds, tar_pars
import RetinalChaos: T_sde
import RetinalChaos: SDEProblem, noise_2D, SOSRI, solve

In [6]:
using JLD

In [7]:
dt = 10.0; tspan = (0.0, 300e3)
nx = 96; ny = 96;

In [8]:
SACnet = Network(nx, ny; μ = 0.65, gpu = gpu, version = :gACh)
p_dict = read_JSON("params.json");
u_dict = read_JSON("conds.json");
u0 = extract_dict(u_dict, tar_conds, (nx, ny));
p0 = extract_dict(p_dict, tar_pars);

In [9]:
println("Warming up solution")
prob = SDEProblem(SACnet, noise_2D, u0, (0.0, 60e3), p0);
@time sol = solve(
    prob,
    SOSRI(),
    abstol = 0.2,
    reltol = 2e-2,
    maxiters = 1e7,
    progress = true, 
    save_everystep = false
);
jldopen("test.jld", "w") do file
    file["0.0"] = sol[end]
end;
println("file saved")

Warming up solution


[32mSDE   0%|█                                              |  ETA: N/A[39m
[32mSDE   0%|█                                              |  ETA: 0:51:29[39m
[32mSDE   1%|█                                              |  ETA: 1:00:38[39m
[32mSDE   1%|█                                              |  ETA: 1:07:31[39m
[32mSDE   1%|█                                              |  ETA: 1:12:05[39m
[32mSDE   1%|█                                              |  ETA: 1:16:29[39m
[32mSDE   1%|█                                              |  ETA: 1:19:18[39m
[32mSDE   1%|█                                              |  ETA: 1:22:09[39m
[32mSDE   1%|█                                              |  ETA: 1:23:54[39m
[32mSDE   1%|█                                              |  ETA: 1:25:25[39m
[32mSDE   1%|█                                              |  ETA: 1:26:39[39m
[32mSDE   2%|█                                              |  ETA: 1:27:08[39m
[32mSDE   2%|█     

927.519534 seconds (113.01 M allocations: 11.831 GiB, 0.27% gc time)



[90mSDE 100%|███████████████████████████████████████████████| Time: 0:15:06[39m


file saved


In [10]:
#The last solution from the warmup is saved as the initial condition. So this starts at timepoint 1.0
prob = SDEProblem(prob.f, prob.g, sol[end], (1.0, 30e3), prob.p);
sol = solve(
        prob,
        SOSRI(),
        abstol = 0.2,
        reltol = 2e-2,
        maxiters = 1e7,
        progress = true, 
        saveat = 1.0,
    );

[32mSDE   0%|█                                              |  ETA: N/A[39m
[32mSDE   4%|██                                             |  ETA: 0:05:40[39m
[32mSDE   7%|████                                           |  ETA: 0:04:49[39m
[32mSDE   9%|█████                                          |  ETA: 0:05:09[39m
[32mSDE  10%|█████                                          |  ETA: 0:05:34[39m
[32mSDE  12%|██████                                         |  ETA: 0:06:02[39m
[32mSDE  13%|██████                                         |  ETA: 0:06:27[39m
[32mSDE  14%|███████                                        |  ETA: 0:06:53[39m
[32mSDE  14%|███████                                        |  ETA: 0:07:15[39m
[32mSDE  15%|████████                                       |  ETA: 0:07:36[39m
[32mSDE  16%|████████                                       |  ETA: 0:07:54[39m
[32mSDE  17%|████████                                       |  ETA: 0:08:11[39m
[32mSDE  17%|██████

In [11]:
#Save the solution, and the datapoints as 
jldopen("test.jld", "r+") do file
    for t in sol.t
        #println("Saving timepoint $t")
        file["$t"] = sol(t)[:,:,1]
    end
    file["time"] = [0.0; sol.t...]
end;

### Instead of using the file immediately we can save it using HDF5. This preserves the types and structures within the array
- IN order to write a file to a JLD file, we can save the individual timesteps as seperate categories
- The stepsize should be 1.0 ms, and the length of time can be 300s

In [12]:
data = jldopen("test.jld", "r") do file
    read(file, "time")
end;

In [None]:
data["time"]

### [1.5.b] Graphing and visualization

In [None]:
import RetinalChaos: calculate_threshold, get_timestamps, max_interval_algorithim, extract_burstmap, timescale_analysis

In [None]:
maximum(et_sol)

In [None]:
thresh = calculate_threshold(vt_sol)
spike_array = vt_sol .> thresh;

In [None]:
x_locs = round.(Int, LinRange(1, nx, 10));
y_locs = round.(Int, LinRange(1, ny, 10));

In [None]:
burst_map = extract_burstmap(spike_array);

@time anim = @animate for i = 1:10:size(sol, 4)
    println("Animating frame $i")
    
    p1 = plot(layout = grid(2, 2), size = (1000,1000))
    heatmap!(p1[1], vt_sol[:, :, i], ratio = :equal, grid = false,
        xticks = ([],[]), yticks = ([],[]), xlims = (0, nx), ylims = (0, ny),
        c = :curl, clims = (-70.0, 0.0),
    )
    scatter!(p1[1], x_locs, y_locs, marker = :star, c = :yellow, label = "")
    heatmap!(p1[2], ct_sol[:, :, i], ratio = :equal, grid = false,
        xticks = ([],[]), yticks = ([],[]), xlims = (0, nx), ylims = (0, ny),
        c = :kgy, clims = (0.0, 1.0),
    )
    scatter!(p1[2], x_locs, y_locs, marker = :star, c = :yellow, label = "")
    heatmap!(p1[3], et_sol[:, :, i], ratio = :equal, grid = false,
        xticks = ([],[]), yticks = ([],[]), xlims = (0, nx), ylims = (0, ny),
        c = :bgy, clims = (0.0, 3.0),
    )
    scatter!(p1[3], x_locs, y_locs, marker = :star, c = :yellow, label = "")
    heatmap!(p1[4], bt_sol[:, :, i], ratio = :equal, grid = false,
        xticks = ([],[]), yticks = ([],[]), xlims = (0, nx), ylims = (0, ny),
        c = :reds, clims = (0.0, 1.0),
    )
    scatter!(p1[4], x_locs, y_locs, marker = :star, c = :yellow, label = "")
    
    p2 = plot(vt_sol[x_locs[1], y_locs[1], :], layout = grid(3, 1), label = "", size = (500,150))
    plot!(p2[2], ct_sol[x_locs[1], y_locs[1], :], c = :kgy, line_z = 1, label = "")
    plot!(p2[3], bt_sol[x_locs[1], y_locs[1], :], c = :reds, line_z = 1, label = "")
    for i = 2:length(x_locs)
        plot!(p2[1], vt_sol[x_locs[i], y_locs[i], :], c = :curl, line_z = i, label = "")
        plot!(p2[2], ct_sol[x_locs[i], y_locs[i], :], c = :kgy, line_z = i, label = "")
        plot!(p2[3], bt_sol[x_locs[i], y_locs[i], :], c = :reds, line_z = i, label = "")
    end
    p2
    vline!(p2[1], [i], label = "", lw = 3.0, c = :black)
    vline!(p2[2], [i], label = "", lw = 3.0, c = :black)
    vline!(p2[3], [i], label = "", lw = 3.0, c = :black)
    p = plot(p2, p1, layout = grid(2,1), size = (1000,1000)) 
end

In [None]:
mp4(anim, "$(figure_path)/wave_propagation.mp4", fps = 10)
gif(anim, "$(figure_path)/wave_propagation.gif", fps = 10)

In [None]:
wave_locs = round.(Int, LinRange(2700,3200, 9));

In [None]:
grid_p = plot(layout = grid(3,3), xaxis = "", yaxis = "",  margin = 1mm, size = (500, 500))
for (idx, tstep) in enumerate(wave_locs)
    if idx == 6
        is_cbar = true
    else 
        is_cbar = false
    end
    heatmap!(grid_p[idx], et_sol[:, :, tstep], ratio = :equal, grid = false,
        xticks = ([],[]), yticks = ([],[]), xlims = (0, nx), ylims = (0, ny),
        c = :delta, clims = (0.0, 3.0), cbar = is_cbar
    )
    annotate!(grid_p[idx], [48], [5], "t = $(sol.t[tstep]/1000)s", :white)
end
grid_p
title!(grid_p[1], "C", title_pos = :left)

### [1.5.c] Analysis of the wave Data
- The output of the wave we get is very similar to the output of the trace simulation, except it is in 3D (x, y, d(Var)) with respect to time (t)
- Dispatches of count_interval are available to get all the intervals within the grid. 
    - This function however discards information about the x and y location and adds all the intervals to a single list
- Dispatches of get_timestamps are available max_interval_algorithim
    - These return a tuple with (x, y, data) 
    - for get_timestamps, data is the timestamps
    - for max_interval_algorithgim data is the 
        - 1) Burst Timestamps
        - 2) Duration list
        - 3) Spike per burst list
- A dispatch of the timescale analysis is also available, much like the count intervals all spatial information is lost in the process of conducting the timescale analysis. 
    - In order to conducte more precise spike duration analysis a higher resolution needs to be used which can end up very memory consuming. 

In [None]:
#get_timestamps() returns (x, y, timestamps)
timestamp_data = @time get_timestamps(spike_array);

In [None]:
#max_interval_algorithim() returns (x, y, [burst_timestamps, durations, spikes_per_burst and intervurst interval])
burst_data = max_interval_algorithim(spike_array);

In [None]:
#This function breaks downt the spike array into bursts according to the max interval algorithim
burst_map = extract_burstmap(spike_array);

In [None]:
using Statistics, StatsBase

In [None]:
ts_lattice = timescale_analysis(vt_sol; verbose = 1, mode = 2);
ts_lattice_vals = timescale_analysis(vt_sol; mode = 1);

#### Comparing to 1D traces

In [None]:
p = read_JSON("params.json") |> extract_dict;
u0 = read_JSON("conds.json") |> extract_dict;
dt = 1.0
tspan = (0.0, 500e3);

In [None]:
SDEprob = SDEProblem(T_sde, u0, tspan, p)
println("Time it took to simulate 60s:")
@time SDEsol = solve(SDEprob, SOSRI(), abstol = 2e-2, reltol = 2e-2, maxiters = 1e7, saveat = dt); 
trace = Array(SDEsol)';

In [None]:
vt_iso = trace[:, 1];

In [None]:
#Try to get these solutions to line up by adjusting this value
iso_begin = 265000
xlims = (0.0, 60e3) 
xticks = (collect(xlims[1]:5000:xlims[2]), collect(0:5.0:xlims[2]/1000))

fig5_A = plot(vt_iso[iso_begin:Int(iso_begin+60e3)],  
    label = "", ylabel = "$vt (mV)", xlabel = "Time (s)",
    lw = 2.0, c = v_color
)
plot!(fig5_A, sol.t, vt_sol[21,10,:], 
    label = "", 
    lw = 2.0, c = e_color, linestyle = :dash, 
    xticks = xticks 
);
title!(fig5_A, "A", title_loc = :left);

In [None]:
#Conducting the analysis in a single function you can use the imported function timescale_analysis
ts_iso = timescale_analysis(trace[:,1]; dt = dt, verbose = 1, mode = 2);
ts_iso_vals = timescale_analysis(trace[:,1]; dt = dt, mode = 1);
iso_burst_list

In [None]:
lat_spike_dur, lat_spike_std, lat_burst_dur, lat_burst_std, lat_IBI, lat_IBI_std = ts_lattice_vals;

iso_spike_dur, iso_spike_std, iso_burst_dur, iso_burst_std, iso_IBI, iso_IBI_std = ts_iso_vals;

lat_spike_list, lat_burst_list, lat_IBI_list = ts_lattice;

iso_spike_list, iso_burst_list, iso_IBI_list = ts_iso;

n_lat_spike = length(lat_spike_list)
n_lat_burst = length(lat_burst_list)
n_lat_IBI = length(lat_IBI_list)

n_iso_spike = length(iso_spike_list)
n_iso_burst = length(iso_burst_list)
n_iso_IBI = length(iso_IBI_list)

lat_spike_SEM = lat_spike_std/n_lat_spike
lat_burst_SEM = lat_burst_std/n_lat_burst
lat_IBI_SEM = lat_IBI_std/n_lat_IBI

iso_spike_SEM = iso_spike_std/n_iso_spike
iso_burst_SEM = iso_burst_std/n_iso_burst
iso_IBI_SEM = iso_IBI_std/n_iso_IBI

In [None]:
bins = LinRange(minimum(ts_lattice[2]), maximum(ts_lattice[2]), 100);
fig5_Ba = histogram(ts_iso[2], normalize = :pdf, xlims = (500,1500), c = v_color, ylabel = "Probability of Burst", label = "Isolated SAC",);
histogram!(fig5_Ba, ts_lattice[2], normalize = :pdf, xlims = (500,1500), c = e_color, xlabel = "Burst Length (ms)", label = "SAC in Network", );
fig5_B1 = boxplot([iso_burst_list./1000, lat_burst_list./1000], c = [v_color e_color], 
    xticks = ((1.0, 2.0), ("Isolated SAC", "SAC in Network")), ylabel = "Burst Length (s)", labels = ["" ""]
)

fig5_Bb = histogram(ts_iso[3]./1000, normalize = :pdf, xlims = (0,60), c = v_color, ylabel = "Probability of Burst", label = "Isolated SAC",);
histogram!(fig5_Bb, ts_lattice[3]./1000, normalize = :pdf, xlims = (0,60), c = e_color, xlabel = "Interburst Interval (s)", label = "SAC in Network", )
fig5_B2 = boxplot([iso_IBI_list./1000, lat_IBI_list./1000], c = [v_color e_color], 
    xticks = ((1.0, 2.0), ("Isolated SAC", "SAC in Network")), ylabel = "IBI Length (s)", labels = ["" ""]
)


fig5_B = plot(fig5_Ba, fig5_Bb,  layout = grid(2,1))
title!(fig5_B[1], "B", title_loc = :left);
fig5_BC = plot(fig5_B, grid_p, layout = grid(1,2), size = (1000, 750));

In [None]:
fig5 = plot(fig5_A, fig5_BC, layout = grid(2,1,  heights = [0.25, 0.75]), size = (1000,1000))

In [None]:
savefig(fig5, "$(figure_path)/Figure5_Isolated vs Lattice Sim.png")

### [1.5.d] Comparing Lattice Wave model to Physiological data
- As in th

### [1.5.d] We can do multiple runs of this simulation loop. This allows us to 
- Perform repeated trials
- Alter a variable

I will only be doing one of these however, because these trials are costly. The below code should be performed only on a computer with a capable GPU, otherwise this will take a very long time. I will be altering the amount of acetylcholine released by a single SAC in $\mu M / ms$($\rho$)

In [None]:
par = :ρ
n_trials = 10
par_range = LinRange(1.0, 10.0, n_trials)

In [None]:
for val in par_range
    with_logger(TerminalLogger()) do
        #Warmup
        SDE_mat_prob = SDEProblem(SACnet, noise_2D, u0, (0.0, 60e3), p0);
        SDE_mat_sol = solve(
            SDE_mat_prob,
            SOSRI(),
            abstol = 0.2,
            reltol = 2e-2,
            maxiters = 1e7,
            progress = true, 
            saveat = dt,
        )
        #Run the trial
        u0_new = SDE_mat_sol[end]
        SDE_mat_prob = SDEProblem(SACnet, noise_2D, u0_new, tspan, p0);
        SDE_mat_sol = solve(
            SDE_mat_prob,
            SOSRI(),
            abstol = 0.2,
            reltol = 2e-2,
            maxiters = 1e7,
            progress = true, 
            saveat = dt,
        )
        #Perform the Data Analysis TODO
    end
end