In [1]:
using DrWatson
@quickactivate "FenrirForNeuro"
using CSV, DataFrames
using Printf
using Statistics, LinearAlgebra
using PrettyTables
using FenrirForNeuro
using ModelingToolkit
using OrdinaryDiffEq

In [80]:
# helpers
function import_csvs(path, fnames, header=nothing)
    dfs = []
    for fname in fnames
        fpath = joinpath(path, fname)
        df = CSV.read(fpath, DataFrame, header=false)
        if header != nothing
            rename!(df, header)
        end
        df[!, :iter] = 1:size(df, 1)
        push!(dfs, df)
    end
    return dfs
end



function get_problem_from_filename(filename, return_prob=false)
    θ, prob, proj = nothing, nothing, nothing
    if occursin("hh", filename) & occursin("simple", filename)
        if occursin("2p", filename)
            @parameters gNa gK
            θ = [gNa => 25, gK => 7]
        elseif occursin("3p", filename)
            @parameters gNa gK gleak
            θ = [gNa => 25, gK => 7, gleak => 0.1]
        elseif occursin("gNa", filename)
            @parameters gNa
            θ = [gNa => 25]
        elseif occursin("gK", filename)
            @parameters gK
            θ = [gK => 7]
        end
        proj = [1 0 0 0]
    
        if return_prob
            prior, prob = get_SinglecompartmentHH(θ)
        end

    elseif occursin("hh", filename) & occursin("2comp", filename)
        A = get_Asoma(15.3, 184) / 2
        if occursin("4p", filename)
            @parameters gK₁ gNa₁ gK₂ gNa₂
            θ = [gK₁ => 7, gNa₁ => 25, gK₂ => 10, gNa₂ => 20]
            @named c₁ = SimpleHHSystem(; gNa=gNa₁, gK=gK₁, A=A)
            @named c₂ = SimpleHHSystem(; gNa=gNa₂, gK=gK₂, Iₑ=I₀, A=A)
        elseif occursin("6p", filename)
            @parameters gNa₁ gK₁ gleak₁ gNa₂ gK₂ gleak₂
            θ = [gK₁ => 7, gleak₁ => 0.09, gNa₁ => 25, gleak₂ => 0.11, gK₂ => 10, gNa₂ => 20]
            @named c₁ = SimpleHHSystem(; gNa=gNa₁, gK=gK₁, gleak=gleak₁, A=A)
            @named c₂ = SimpleHHSystem(; gNa=gNa₂, gK=gK₂, gleak=gleak₂, Iₑ=I₀, A=A)
        end
        proj = [
            1 0 0 0 0 0 0 0
            0 0 0 0 1 0 0 0
            ]
        
        if return_prob
            compartments = [c₁, c₂]
            G = get_cable_matrix([1])
            
            prior, prob = get_MulticompartmentHH(θ, [c₁, c₂], G)
            
        end
            
    elseif occursin("hh", filename) & occursin("red_posp", filename)
        if occursin("8p", filename)            
            @parameters gNa gK gleak Eleak VT gM gL τ_max
            θ = [
                gNa => 25,
                gK => 7,
                gleak => 0.05,
                Eleak => -70,
                VT => -60,
                gM => 0.1,
                τ_max => 1e3,
                gL => 0.01,
            ]
        elseif occursin("6p", filename)
            @parameters gNa gK gleak VT gM gL
            θ = [gNa => 25, gK => 7, gleak => 0.05, VT => -60, gM => 0.1, gL => 0.01]
        end
        proj = [1 0 0 0 0 0 0]
        
        if return_prob
            prior, prob = get_SinglecompartmentHH(θ; Sys=ReducedPospischilHHSystem, name=:Pospischil)
        end

    elseif occursin("lv", filename) & occursin("2p", filename)
        @parameters α β
        θ = [α => 1.5, β => 1.0]
        proj = [1 0]
        
        if return_prob
            prior, prob = get_LV(θ)
        end
        
    elseif occursin("lv", filename) & occursin("4p", filename)
        @parameters α β γ δ
        θ = [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0]
        proj = [1 0]
        
        if return_prob
            prior, prob = get_LV(θ)
        end
        
    elseif occursin("pd", filename)
        @parameters l
        θ = [l => 3]
        proj = [0 1]
        
        if return_prob
            prior, prob = get_Pendulum(θ)
        end
    end
    return θ, prob, proj
end

function highlight_cells(latex_table_str, cells2highlight)
    highlighted_table = []
    cells = [occursin(" & ", row) ? split(row, " & ") : [row] for row in split(latex_table_str, "\\\\")]
    for (i, row) in enumerate(cells)
        highlight_rows = []
        for (j, cell) in enumerate(row)
            if (i - 1, j) in cells2highlight
                push!(highlight_rows, "\\textbf{$cell}")
            else
                push!(highlight_rows, cell)
            end
        end
        push!(highlighted_table, highlight_rows)
    end
    highlighted_table = join([join(row, " & ") for row in highlighted_table], "\\\\")
end

function get_merged_cell_idxs(latex_table_str)
    idxs_μ = findall([occursin.(["mu"], split(split(latex_table_str, "\\\\")[1], "&"))][1])
    idxs_σ = findall([occursin.(["sigma"], split(split(latex_table_str, "\\\\")[1], "&"))][1])
    merge_idxs = idxs_μ[[any(idxs_μ .+ 1 .== idxs_σ', dims=2)...]]
end

function merge_means_and_std_cols(latex_table_str)
    reformatted_rows = []
    merge_at = get_merged_cell_idxs(latex_table_str)
    num_cols_merged = length(merge_at)
    for row in split(latex_table_str, "\\\\")
        if occursin("&", row)
            split_row = split(row, "&")
            reformatted_row = []
            for (i, element) in enumerate(split_row)
                if i in merge_at
                    push!(reformatted_row, element * "pm")
                elseif i == length(split_row)
                    push!(reformatted_row, element)
                else
                    push!(reformatted_row, element * "&")
                end
            end
            reformatted_row = join(reformatted_row)
            push!(reformatted_rows, reformatted_row)
        else
            push!(reformatted_rows, row)
        end
    end

    s = join(reformatted_rows, "\\\\")
    s = replace(s, r"\\textbf{mu\\_(.*?)} pm \\textbf{sigma\\_\1}" => s"\\textbf{\1}")
    s = replace(s, r"\{(r+)\}" => m -> "{" * repeat("l", length(m) - num_cols_merged - 2) * "}")
    reformatted_table = replace(s, r"& ([^&]*?) pm ([^&]*?) " => s"& $\1 \\pm \2$ ")
    
    # if "\\\hline" in reformatted_table add \n
    if occursin("\\\\hline", reformatted_table)
        reformatted_table = replace(reformatted_table, r"\\\\hline" => s"\\\n  \\hline")
    end

    s = split(reformatted_table, "\\\\")
    # insert hlines
    for i in [6,11,15,19,23, 27, 32, 37]
        insert!(s, i, "\n  \\hline")
    end
    reformatted_table = join(s, "\\\\")
    # replace \hline\\ with \hline
    reformatted_table = replace(reformatted_table, r"\\hline\\\\\\\\" => s"\\hline")
    # replace \_ with _
    reformatted_table = replace(reformatted_table, r"\\_" => s"_")
    
    return reformatted_table
end

adjust4merged_cells(coords, merge_at) = [(i, j - d) for (d, (i, j)) in zip(sum([c[2] for c in coords] .> merge_at', dims=2), coords)]

pkeys(θ) = map(θᵢ -> Symbol(θᵢ[1]), θ)
pvalues(θ) = map(θᵢ -> θᵢ[2], θ)

pvalues (generic function with 1 method)

In [81]:
RESULTS_DIR = "../../results"
TABLE = DataFrame()

PROB_DIRS = readdir.(RESULTS_DIR)

PATHS = joinpath.(RESULTS_DIR, PROB_DIRS)
NUM_PARAM_DIRS = readdir.(PATHS)
PATHS = vcat([joinpath.(PATH, NUM_PARAM_DIR) for (PATH, NUM_PARAM_DIR) in zip(PATHS, NUM_PARAM_DIRS)]...)
MODEL_DIRS = readdir.(PATHS)
PATHS = vcat([joinpath.(PATH, MODEL_DIR) for (PATH, MODEL_DIR) in zip(PATHS, MODEL_DIRS)]...)
METHOD_DIRS = readdir.(PATHS)
PATHS = vcat([joinpath.(PATH, METHOD_DIR) for (PATH, METHOD_DIR) in zip(PATHS, METHOD_DIRS)]...)
EXP_DIRS = readdir.(PATHS)
PATHS = vcat([joinpath.(PATH, EXP_DIR) for (PATH, EXP_DIR) in zip(PATHS, EXP_DIRS)]...)


excludes(x) = all(map(y -> !occursin(y, x), ["diff_loss", "loss_surface", "tradeoff", "compare", "loss_", "prior", "wo_bt", "8", "noisy"]))
includes(x) = any(map(y -> occursin(y, x), ["rk4/l2", "fenrir/tempered_diff", "fenrir/learned_diff", "fenrir/low_tol_tempered_diff"]))
PATHS = filter(x -> excludes(x), PATHS)
PATHS = filter(x -> includes(x), PATHS)

INCLUDE_MSE = true # SLOW !!! 

for PATH in PATHS
    fnames = readdir(PATH)
    fpaths = joinpath.(PATH, fnames)

    DETAILS = []
    EXP = split(PATH, "/")[end-4:end]
    num_params = parse(Int, EXP[2][1:end-1])
    model_name = uppercase(EXP[1])
    model_name = occursin("hh", EXP[1]) ? (occursin("2comp", EXP[3]) ? model_name*"_2" : model_name*"_1") : model_name


    alg_name = occursin("rk4", EXP[4]) ? "RK" : EXP[4]
    alg_name = occursin("low_tol", EXP[5]) ? "ours+" : alg_name
    alg_name = occursin("tempered_diff", EXP[5]) & !occursin("low_tol", EXP[5]) ? "ours" : alg_name
    alg_name = occursin("learned", EXP[5]) ? "Fenrir" : alg_name
    push!(DETAILS, model_name)
    push!(DETAILS, "$num_params")
    push!(DETAILS, alg_name)

    θ, prob, proj = get_problem_from_filename(PATH, INCLUDE_MSE)

    t_obs, u_obs = INCLUDE_MSE ? simulate(remake(prob, p=pvalues(θ)), 1e-2) : (NaN, NaN)
    get_trmse(p) = INCLUDE_MSE ? tRMSE(u_obs, simulate(remake(prob, p=p), 1e-2)[2]) : NaN

    extra_cols = occursin("fenrir", PATH) ? Symbol.(["κ²", "loss", "T"]) : Symbol.(["loss", "T",])
    trajectories = import_csvs(PATH, fnames, [pkeys(θ)..., extra_cols...,])
    last_rows = vcat([DataFrame(t[end, :]) for t in trajectories]...)
    
    metrics = ["iter", "pRMSE", "conv", "tRMSE", "#correct_params"]

    θ_est = Array(last_rows[!, pkeys(θ)])
    pRMSEs = map(x -> rel_pRMSE(x, pvalues(θ)), eachrow(θ_est))

    where_loss_nan = isnan.(last_rows.loss)

    last_rows[!, :pRMSE] = [pRMSEs...]
    last_rows[!, :conv] = [(pRMSEs .< 5e-2)...]
    last_rows[!, :tRMSE] = [get_trmse(p) for p in eachrow(θ_est)]
    last_rows[!, "#correct_params"] .= sum((abs.(θ_est .- pvalues(θ)') ./ pvalues(θ)') .< 5e-2, dims=2)

    last_rows = Array(last_rows[!, metrics])
    μ = mean(last_rows, dims=1)
    σ = std(last_rows, dims=1)

    μ[end-1] = mean(last_rows[.!where_loss_nan, end-1])
    σ[end-1] = std(last_rows[.!where_loss_nan, end-1])
    if any(where_loss_nan)
        println(DETAILS)
        println("Num NaNs: ", sum(where_loss_nan))
        println(metrics)
        println(mean(last_rows[where_loss_nan, :], dims=1))
        println(std(last_rows[where_loss_nan, :], dims=1))
        println()
    end

    median = Statistics.median(last_rows, dims=1)
    # ql = Statistics.quantile(last_rows, 0.25, dims=1)
    # qu = Statistics.quantile(last_rows, 0.75, dims=1)

    cols = vcat("model", "#params", "alg", [["mu_$m"; "sigma_$m"] for m in metrics]...)

    stats = reshape(cat(μ, σ, dims=1), :)
    row = hcat([DETAILS...; stats]...)
    row = DataFrame(row, cols)

    row = row[:, Not([:sigma_conv])]

    TABLE = vcat(TABLE, row)
end
sort!(TABLE, ["model", "#params", "alg"])

latex_table = pretty_table(String, TABLE; backend=Val(:latex), show_subheader=false, formatters=ft_printf("%.2f"))
pretty_table(TABLE, show_subheader=false, formatters=ft_printf("%.2f"))

Any["HH_1", "6", "ours"]
Num NaNs: 9
["iter", "pRMSE", "conv", "tRMSE", "#correct_params"]
[2232.5555555555557 13.330222424303063 0.0 18.45198170264261 1.1111111111111112]
[581.5984678261951 7.6705260248235 0.0 3.993977517330876 0.33333333333333337]

Any["LV", "2", "ours+"]
Num NaNs: 8
["iter", "pRMSE", "conv", "tRMSE", "#correct_params"]
[46.125 1.1048868554871034 0.0 1447.3927483889806 0.0]
[15.21688441924205 0.4414141558064893 0.0 1175.8305451701935 0.0]

Any["LV", "2", "ours"]
Num NaNs: 8
["iter", "pRMSE", "conv", "tRMSE", "#correct_params"]
[43.625 1.104572704416712 0.0 1435.0784880246172 0.0]
[24.5818603271367 0.49542900308677756 0.0 963.5018785535058 0.0]

Any["LV", "4", "ours+"]
Num NaNs: 15
["iter", "pRMSE", "conv", "tRMSE", "#correct_params"]
[114.46666666666667 1.6284630263263595 0.06666666666666667 16174.37919739556 0.3333333333333333]
[108.76046200360723 0.8669455298153883 0.25819888974716115 19109.112903086858 1.046536236944567]

Any["LV", "4", "ours"]
Num NaNs: 20
["iter

In [82]:
reformatted_table = merge_means_and_std_cols(latex_table)
print(reformatted_table)

\begin{tabular}{llllllll}
  \hline
  \textbf{model} & \textbf{\#params} & \textbf{alg} & \textbf{iter} & \textbf{pRMSE} & \textbf{mu_conv} & \textbf{tRMSE} & \textbf{\#correct_params} \\
  \hline
  HH_1 & 1 & Fenrir & $46.96 \pm 19.08$ & $0.38 \pm 0.67$ & 0.68 & $7.85 \pm 10.70$ & $0.68 \pm 0.47$ \\
  HH_1 & 1 & RK & $43.30 \pm 43.45$ & $0.42 \pm 0.48$ & 0.57 & $7.54 \pm 8.26$ & $0.57 \pm 0.50$ \\
  HH_1 & 1 & ours & $382.08 \pm 32.19$ & $0.00 \pm 0.00$ & 1.00 & $0.43 \pm 0.02$ & $1.00 \pm 0.00$ \\
  HH_1 & 1 & ours+ & $116.73 \pm 7.11$ & $0.00 \pm 0.00$ & 1.00 & $0.43 \pm 0.11$ & $1.00 \pm 0.00$ \\
  \hline
  HH_1 & 2 & Fenrir & $110.04 \pm 61.70$ & $0.20 \pm 0.37$ & 0.75 & $5.89 \pm 10.15$ & $1.53 \pm 0.83$ \\
  HH_1 & 2 & RK & $54.02 \pm 62.60$ & $0.28 \pm 0.45$ & 0.72 & $4.88 \pm 7.58$ & $1.44 \pm 0.90$ \\
  HH_1 & 2 & ours & $696.22 \pm 78.10$ & $0.00 \pm 0.00$ & 1.00 & $0.42 \pm 0.04$ & $2.00 \pm 0.00$ \\
  HH_1 & 2 & ours+ & $197.81 \pm 40.38$ & $0.04 \pm 0.19$ & 0.96 & $1.08 \p