In [1]:
using DrWatson
@quickactivate "FenrirForNeuro"
using CairoMakie
using CSV, DataFrames
using Printf
using PrettyTables
using Statistics, LinearAlgebra
using FenrirForNeuro, ModelingToolkit, OrdinaryDiffEq
using FileIO
using StatsBase

Axis = CairoMakie.Axis

Makie.Axis

In [2]:
# helpers
function import_csvs(fpaths, header=nothing)
    dfs = []
    for fpath in fpaths
        df = CSV.read(fpath, DataFrame, header=false)
        if header != nothing
            rename!(df, header)
        end
        df[!, :iter] = 1:size(df, 1)
        push!(dfs, df)
    end
    return dfs
end

get_last_rows(trajectories) = vcat([DataFrame(t[end, :]) for t in trajectories]...)


function where_converged(trajectories, threshold=5e-2)
    last_rows = get_last_rows(trajectories)
    pRMSEs = map(x -> rel_pRMSE(x, pvalues(θ)), eachrow(Array(last_rows[!, pkeys(θ)])))
    return pRMSEs .< 5e-2
end


l2 = get_l2_loss()

pkeys(θ) = map(θᵢ -> Symbol(θᵢ[1]), θ)
pvalues(θ) = map(θᵢ -> θᵢ[2], θ)

pvalues (generic function with 1 method)

In [43]:
function get_merged_cell_idxs(latex_table_str)
    idxs_μ = findall([occursin.(["mu"], split(split(latex_table_str, "\\\\")[1], "&"))][1])
    idxs_σ = findall([occursin.(["sigma"], split(split(latex_table_str, "\\\\")[1], "&"))][1])
    merge_idxs = idxs_μ[[any(idxs_μ .+ 1 .== idxs_σ', dims=2)...]]
end

function merge_means_and_std_cols(latex_table_str)
    reformatted_rows = []
    merge_at = get_merged_cell_idxs(latex_table_str)
    num_cols_merged = length(merge_at)
    for row in split(latex_table_str, "\\\\")
        if occursin("&", row)
            split_row = split(row, "&")
            reformatted_row = []
            for (i, element) in enumerate(split_row)
                if i in merge_at
                    push!(reformatted_row, element * "pm")
                elseif i == length(split_row)
                    push!(reformatted_row, element)
                else
                    push!(reformatted_row, element * "&")
                end
            end
            reformatted_row = join(reformatted_row)
            push!(reformatted_rows, reformatted_row)
        else
            push!(reformatted_rows, row)
        end
    end

    s = join(reformatted_rows, "\\\\")
    s = replace(s, r"\\textbf{mu\\_(.*?)} pm \\textbf{sigma\\_\1}" => s"\\textbf{\1}")
    s = replace(s, r"\{(r+)\}" => m -> "{" * repeat("l", length(m) - num_cols_merged - 2) * "}")
    reformatted_table = replace(s, r"& ([^&]*?) pm ([^&]*?) " => s"& $\1 \\pm \2$ ")

    # if "\\\hline" in reformatted_table add \n
    if occursin("\\\\hline", reformatted_table)
        reformatted_table = replace(reformatted_table, r"\\\\hline" => s"\\\n  \\hline")
    end

    s = split(reformatted_table, "\\\\")
    # insert hlines
    lines_starting_w_0 = [occursin("\n  0.00", line) ? i : 0 for (i,line) in enumerate(s)]
    lines_starting_w_0 = filter(x -> x != 0, lines_starting_w_0)[2:end]
    for (i, l) in enumerate(lines_starting_w_0)
        insert!(s, l+i-1, "\n  \\hline")
    end
    reformatted_table = join(s, "\\\\")
    # replace \hline\\ with \hline
    reformatted_table = replace(reformatted_table, r"\\hline\\\\\\\\" => s"\\hline")
    # replace \_ with _
    reformatted_table = replace(reformatted_table, r"\\_" => s"_")

    return reformatted_table
end

merge_means_and_std_cols (generic function with 1 method)

In [3]:
HALF_WIDTH = ((6.75 - 0.25) / 2) # inches
FULL_WIDTH = 6.75 # inches
HEIGHT = 0.5 * HALF_WIDTH # inches

PT_PER_INCH = 72
HALF_WIDTH *= PT_PER_INCH
FULL_WIDTH *= PT_PER_INCH
HEIGHT *= PT_PER_INCH

RESULTS_PATH = "../../results/"
FIGURES_PATH = "../../figures/"

PLOT_DEFAULTS = (titlesize=7, xlabelsize=7, ylabelsize=7, xticklabelsize=7, yticklabelsize=7, xticksize=2, yticksize=2)

(titlesize = 7, xlabelsize = 7, ylabelsize = 7, xticklabelsize = 7, yticklabelsize = 7, xticksize = 2, yticksize = 2)

In [14]:
IN_PATH = RESULTS_PATH * "pd/1p/single_pendulum/rk4/l2_noisy_grad"
FNAMES = readdir(IN_PATH)
FPATHS = [joinpath(IN_PATH, fname) for fname in FNAMES]

grad_noise_scales = [0, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]

metrics = []
for (j, decay_rate) in enumerate([0, 5, 10])
    for (k, scale) in enumerate(grad_noise_scales)
        fpaths = filter(x -> occursin("$j" * "_$k-", x), FPATHS)
        trajectories = import_csvs(fpaths, [:l, :loss, :T])
        last_rows = vcat([DataFrame(t[end, :]) for t in trajectories]...)
        θ_est = Array(last_rows[!, [:l]])
        pRMSEs = map(x -> rel_pRMSE(x, [3.0]), eachrow(θ_est))

        last_rows[!, :pRMSE] = [pRMSEs...]
        last_rows[!, :conv] = [(pRMSEs .< 5e-2)...]

        mean_conv = mean(last_rows[!, :conv])
        mean_pRMSE = mean(last_rows[!, :pRMSE])
        mean_iter = mean(last_rows[!, :iter])

        std_pRMSE = std(last_rows[!, :pRMSE])
        std_iter = std(last_rows[!, :iter])

        push!(metrics, [scale, decay_rate, mean_conv, mean_pRMSE, std_pRMSE, mean_iter, std_iter])
    end
end

metrics = hcat(metrics...)'
metrics = DataFrame(metrics, [:scale, :decay_rate, :conv, :mu_pRMSE, :sigma_pRMSE, :mu_iter, :sigma_iter])
println(metrics)

[1m30×7 DataFrame[0m
[1m Row [0m│[1m scale   [0m[1m decay_rate [0m[1m conv    [0m[1m mu_pRMSE [0m[1m sigma_pRMSE [0m[1m mu_iter [0m[1m sigma_iter [0m
     │[90m Float64 [0m[90m Float64    [0m[90m Float64 [0m[90m Float64  [0m[90m Float64     [0m[90m Float64 [0m[90m Float64    [0m
─────┼──────────────────────────────────────────────────────────────────────────
   1 │    0.0          0.0     0.32  1.42454      1.09027     33.47    18.1077
   2 │    0.01         0.0     0.35  1.40087      1.11206     33.79    12.8396
   3 │    0.05         0.0     0.34  1.41001      1.10422     29.89    10.7674
   4 │    0.1          0.0     0.37  1.40297      1.12603     27.3     11.8369
   5 │    0.5          0.0     0.49  1.14034      1.12978     15.19    10.723
   6 │    1.0          0.0     0.63  0.770379     1.04765     15.52    14.1196
   7 │    5.0          0.0     0.64  0.509967     0.881543    13.01     8.86771
   8 │   10.0          0.0     0.39  0.642715     0.93

In [44]:
metrics_latex_table = pretty_table(String, metrics; backend=Val(:latex), show_subheader=false, formatters=ft_printf("%.2f"))
metrics_latex_table = merge_means_and_std_cols(metrics_latex_table)

print(metrics_latex_table)

\begin{tabular}{lllll}
  \hline
  \textbf{scale} & \textbf{decay_rate} & \textbf{conv} & \textbf{pRMSE} & \textbf{iter} \\
  \hline
  0.00 & 0.00 & 0.32 & $1.42 \pm 1.09$ & $33.47 \pm 18.11$ \\
  0.01 & 0.00 & 0.35 & $1.40 \pm 1.11$ & $33.79 \pm 12.84$ \\
  0.05 & 0.00 & 0.34 & $1.41 \pm 1.10$ & $29.89 \pm 10.77$ \\
  0.10 & 0.00 & 0.37 & $1.40 \pm 1.13$ & $27.30 \pm 11.84$ \\
  0.50 & 0.00 & 0.49 & $1.14 \pm 1.13$ & $15.19 \pm 10.72$ \\
  1.00 & 0.00 & 0.63 & $0.77 \pm 1.05$ & $15.52 \pm 14.12$ \\
  5.00 & 0.00 & 0.64 & $0.51 \pm 0.88$ & $13.01 \pm 8.87$ \\
  10.00 & 0.00 & 0.39 & $0.64 \pm 0.93$ & $10.43 \pm 7.19$ \\
  50.00 & 0.00 & 0.45 & $0.49 \pm 0.83$ & $11.66 \pm 8.26$ \\
  100.00 & 0.00 & 0.28 & $0.44 \pm 0.74$ & $10.26 \pm 7.42$ \\
  \hline
  0.00 & 5.00 & 0.32 & $1.42 \pm 1.09$ & $33.47 \pm 18.11$ \\
  0.01 & 5.00 & 0.31 & $1.45 \pm 1.08$ & $35.94 \pm 18.58$ \\
  0.05 & 5.00 & 0.33 & $1.45 \pm 1.10$ & $32.58 \pm 13.39$ \\
  0.10 & 5.00 & 0.35 & $1.40 \pm 1.11$ & $33.65 \pm 1