In [1]:
include("main.jl")
using PyCall
using CairoMakie
sepsis_gym = pyimport("dbn_ppl_rl")
using Statistics


  from pandas.core import (


In [2]:
struct MeanRewardsType
    mean_rewards::Vector{Float64}
    individual_runs::Vector{Vector{Float64}}
    smoothed_mean::Vector{Float64}
    smoothed_std_low::Vector{Float64}
    smoothed_std_high::Vector{Float64}
    keys_of_smoothed::Vector{Float64}
    name::String
    info::Dict
end

In [3]:

function load_rewards_from_json(file_path)
    json_data = JSON3.read(file_path)
    mean_rewards = [Float64(rew) for rew in json_data["mean_rewards"]]
    individual_runs = [[Float64(rew) for rew in r] for r in json_data["individual_runs"]]
    smoothed_mean = [Float64(rew) for rew in json_data["smoothed_mean"]]
    smoothed_std_low = [Float64(rew) for rew in json_data["smoothed_std_low"]]
    smoothed_std_high = [Float64(rew) for rew in json_data["smoothed_std_high"]]
    keys_of_smoothed = [Int(k) for k in json_data["keys_of_smoothed"]]
    name = json_data["name"]
    info = Dict(
        string(k) => string(v) for (k, v) in json_data["info"]
    )
    return MeanRewardsType(mean_rewards, individual_runs, smoothed_mean, smoothed_std_low, smoothed_std_high, keys_of_smoothed, name, info)
end

load_rewards_from_json (generic function with 1 method)

In [4]:
rewards = Dict(
    :Simple100 => load_rewards_from_json("data/processed/ts/Simple100.json"),
    :Medium100 => load_rewards_from_json("data/processed/ts/Medium100.json"),
    :None100 => load_rewards_from_json("data/processed/ts/None100.json"),
    :None100P => load_rewards_from_json("data/processed/ts/None100P.json"),
    :None1P => load_rewards_from_json("data/processed/ts/None1P.json"),
    :Softmax100 => load_rewards_from_json("data/processed/ts/Softmax100.json"),
    :SimplePPL100 => load_rewards_from_json("data/processed/ts/SimplePPL100.json"),
    :Simple1 => load_rewards_from_json("data/processed/ts/Simple1.json"),
    :Medium1 => load_rewards_from_json("data/processed/ts/Medium1.json"),
    :Softmax1 => load_rewards_from_json("data/processed/ts/Softmax1.json"),
    :SimplePPL1 => load_rewards_from_json("data/processed/ts/SimplePPL1.json"),
    :DQN_S => load_rewards_from_json("data/processed/dqn/small_buff.json"),
    :DQN_SS => load_rewards_from_json("data/processed/dqn/small_buff_small_expl.json"),
    :DQN_L => load_rewards_from_json("data/processed/dqn/large_buff.json"),
    :DQN_M => load_rewards_from_json("data/processed/dqn/medium.json"),
    :QLearning => load_rewards_from_json("data/processed/qlearning/q_learning_results.json"),
)

Dict{Symbol, MeanRewardsType} with 16 entries:
  :Softmax1     => MeanRewardsType([-0.67245, -0.750493, -0.7667, -0.81518, -0.…
  :None100P     => MeanRewardsType([-0.66737, -0.66169, -0.655537, -0.641487, -…
  :None1P       => MeanRewardsType([-0.667683, -0.658437, -0.667097, -0.672253,…
  :DQN_L        => MeanRewardsType([-0.740741, -0.666667, -0.740741, -0.814815,…
  :SimplePPL100 => MeanRewardsType([-0.624323, -0.453003, -0.38704, -0.38704, -…
  :DQN_M        => MeanRewardsType([-0.666667, -0.619048, -0.761905, -0.714286,…
  :QLearning    => MeanRewardsType([-0.73, -0.7, -0.73, -0.6, -0.66, -0.67, -0.…
  :Medium100    => MeanRewardsType([-0.671667, -0.64189, -0.64451, -0.64451, -0…
  :Simple100    => MeanRewardsType([-0.671597, -0.62753, -0.618317, -0.618317, …
  :None100      => MeanRewardsType([-0.658183, -0.67427, -0.668743, -0.668743, …
  :Simple1      => MeanRewardsType([-0.65016, -0.60532, -0.634067, -0.55527, -0…
  :Medium1      => MeanRewardsType([-0.652907, -0.668417, -0.6

In [6]:
latex_labels = Dict(
    key => "\$\\mathsf{$(replace(value, "_" => "\\_"))}\$" for (key, value) in label_dict
)

Dict{Symbol, String} with 20 entries:
  :Simple       => "\$\\mathsf{SimpleDBN}\$"
  :Softmax1     => "\$\\mathsf{SoftmaxPPL\\_TS1}\$"
  :None100P     => "\$\\mathsf{FullDBN\\_SmallPrior\\_TS100}\$"
  :None1P       => "\$\\mathsf{FullDBN\\_SmallPrior\\_TS1}\$"
  :DQN_L        => "\$\\mathsf{DQN\\_LargeBuff}\$"
  :SimplePPL100 => "\$\\mathsf{SimplePPL\\_TS100}\$"
  :DQN_M        => "\$\\mathsf{DQN\\_M}\$"
  :QLearning    => "\$\\mathsf{QLearning}\$"
  :Medium100    => "\$\\mathsf{MediumDBN\\_TS100}\$"
  :Simple100    => "\$\\mathsf{SimpleDBN\\_TS100}\$"
  :Simple1      => "\$\\mathsf{SimpleDBN\\_TS1}\$"
  :Medium       => "\$\\mathsf{MediumDBN}\$"
  :Medium1      => "\$\\mathsf{MediumDBN\\_TS1}\$"
  :None100      => "\$\\mathsf{FullDBN\\_TS100}\$"
  :SimplePPL    => "\$\\mathsf{SimplePPL}\$"
  :Softmax      => "\$\\mathsf{SoftmaxPPL}\$"
  :SimplePPL1   => "\$\\mathsf{SimplePPL\\_TS1}\$"
  :DQN_SS       => "\$\\mathsf{DQN\\_SmallBuff\\_SmallEps}\$"
  :DQN_S        => "\$\\mathsf{DQN\\_Sm

In [7]:
function get_cumsums(means)
    cumsums = Dict()
    for (type, mean) in means
        cumsums[type] = cumsum(mean.mean_rewards, dims=1)
    end
    return cumsums
end

get_cumsums (generic function with 1 method)

In [8]:
function get_sample_efficiency(means, checkpoints=[-1:0.05:1])
    sample_eff = Dict()
    for (type, mean) in means
        sample_eff[type] = Dict()
        for checkpoint in checkpoints
            index = findfirst(x -> x >= checkpoint, mean.smoothed_mean)
            sample_eff[type][checkpoint] = isnothing(index) ? NaN : mean.keys_of_smoothed[index]
            # index = findfirst(x -> x >= checkpoint, mean.mean_rewards)
            # sample_eff[type][checkpoint] = isnothing(index) ? NaN : index
        end
    end
    return sample_eff  
end

get_sample_efficiency (generic function with 2 methods)

In [9]:
function get_asymptotic_reward(means, checkpoints=[0.5,0.7,0.9,0.99])
    asymptotic_rewards = Dict{Symbol, Float64}()
    convergence_speeds = Dict{Symbol, Dict{Float64, Int}}()
    lengths = Dict{Symbol, Int}()

    for (type, m) in means
        rewards = m.mean_rewards
        # Store the length of the rewards array
        lengths[type] = Int(m.keys_of_smoothed[end])
        last_n = Int(round(lengths[type] / 100))

        # Asymptotic mean reward
        if !isempty(rewards)
            asymptotic_rewards[type] = mean(rewards[end-last_n:end])
        else
            asymptotic_rewards[type] = NaN
        end

        # Convergence speed for each checkpoint
        convergence_speeds[type] = Dict()
        for checkpoint in checkpoints
            min_reward = minimum(rewards)
            max_reward = maximum(rewards)
            reward_range = max_reward - min_reward
            threshold_value = min_reward + checkpoint * reward_range
            index = findfirst(x -> x >= threshold_value, rewards)
            if isnothing(index)
                convergence_speeds[type][checkpoint] = NaN  # Indicate no convergence
            else
                convergence_speeds[type][checkpoint] = index
            end
        end
    end
    return asymptotic_rewards, convergence_speeds, lengths
end

get_asymptotic_reward (generic function with 2 methods)

In [10]:
type_order = [[:SimplePPL1, :Softmax1, :Simple1, :Medium1, :None1P], [:SimplePPL100, :Softmax100, :Simple100, :Medium100, :None100, :None100P],  [:QLearning, :DQN_L, :DQN_S, :DQN_SS]]

3-element Vector{Vector{Symbol}}:
 [:SimplePPL1, :Softmax1, :Simple1, :Medium1, :None1P]
 [:SimplePPL100, :Softmax100, :Simple100, :Medium100, :None100, :None100P]
 [:QLearning, :DQN_L, :DQN_S, :DQN_SS]

In [11]:
using Printf

# Helper function to format best/worst values across types with your adjustments
function format_best_worst_across_types(values_matrix, digs; is_higher_better=true, skip=false)
    rows, cols = size(values_matrix)
    formatted_matrix = Array{String}(undef, rows, cols)  # Initialize a matrix to store formatted values

    for col in 1:cols
        column_values = values_matrix[:, col]
        max_val = maximum(column_values[.!isnan.(column_values)]; init=-typemax(Int64))
        min_val = minimum(column_values[.!isnan.(column_values)]; init=typemax(Int64))
        for row in 1:rows
            v = column_values[row]
            if isnan(v)
                formatted_matrix[row, col] = "-"
            else
                if digs == 0
                    formatted_value = @sprintf("%d", round(v, digits=digs))
                elseif digs == 1
                    formatted_value = @sprintf("%0.1f", round(v, digits=digs))
                elseif digs == 2
                    formatted_value = @sprintf("%0.2f", round(v, digits=digs))
                elseif digs == 3
                    formatted_value = @sprintf("%0.3f", round(v, digits=digs))
                else 
                    formatted_value = v
                end
                if v == max_val && is_higher_better && !skip
                    formatted_matrix[row, col] = "\\color{blue}{\$$(formatted_value)\$}"
                elseif v == min_val && !is_higher_better&& !skip
                    formatted_matrix[row, col] = "\\color{blue}{\$$(formatted_value)\$}"
                elseif v == min_val && is_higher_better&& !skip
                    formatted_matrix[row, col] = "\\color{red}{\$$(formatted_value)\$}"
                elseif v == max_val && !is_higher_better&& !skip
                    formatted_matrix[row, col] = "\\color{red}{\$$(formatted_value)\$}"
                else
                    formatted_matrix[row, col] = "\$$(formatted_value)\$"
                end
            end
        end
    end
    return formatted_matrix
end

# Generate LaTeX table for cumulative rewards
function generate_cumulative_rewards_table(cumsums, checkpoints)
    types = collect(keys(cumsums))
    cols = length(checkpoints)
    
    # Collect rewards for each type at each checkpoint into a matrix
    values_matrix = [length(cumsums[type]) >= chk ? cumsums[type][chk-1] : NaN for type in types, chk in checkpoints]

    # Format the values with best/worst highlighting
    formatted_matrix = format_best_worst_across_types(values_matrix, 2, is_higher_better=true)

    # Create LaTeX table
    header = "\\hline\n & \$" * join(checkpoints, "\$ & \$") * "\$ \\\\\n\\hline\n"
    rows = Dict()
    for (i, type) in enumerate(types)
        rows[type] = "$(latex_labels[type]) & " * join(formatted_matrix[i, :], " & ") * " \\\\\n"
    end
    rows_text = ""
    for ts in type_order
        for type in ts
            rows_text *= rows[type]
        end
        rows_text *= "\\hline\n"
    end
    return "\\begin{tabular}{|l|" * "r"^cols * "|}\n\\hline\n" * header * rows_text * "\\hline\n\\end{tabular}"
end

# Generate LaTeX table for sample efficiency
function generate_sample_efficiency_table(sample_eff, checkpoints, lengths)
    types = collect(keys(sample_eff))
    cols = length(checkpoints)
    
    # Collect sample efficiencies into a matrix
    values_matrix = [sample_eff[type][chk] for type in types, chk in checkpoints]

    # Format the values with best/worst highlighting
    formatted_matrix = format_best_worst_across_types(values_matrix, 0, is_higher_better=false)

    # Create LaTeX table
    header = "\\hline\n & \$" * join(checkpoints, "\$ & \$") * "\$ & Length \\\\\n\\hline\n"
    rows = Dict()
    for (i, type) in enumerate(types)
        rows[type] = "$(latex_labels[type]) & " * join(formatted_matrix[i, :], " & ") * " & \$$(lengths[type])\$ \\\\\n"
    end

    rows_text = ""
    for ts in type_order
        for type in ts
            rows_text *= rows[type]
        end
        rows_text *= "\\hline\n"
    end
    return "\\begin{tabular}{|l|" * "r"^cols * "|r|}\n\\hline\n" * header * rows_text * "\\hline\n\\end{tabular}"
end

# Generate LaTeX table for asymptotic rewards
function generate_asymptotic_rewards_table(asymptotic_rewards, convergence_speeds, checkpoints, lengths)
    types = collect(keys(asymptotic_rewards))
    cols = length(checkpoints)
    
    # Collect convergence speeds into a matrix
    conv_matrix = [convergence_speeds[type][chk] for type in types, chk in checkpoints]
    asym_matrix = [asymptotic_rewards[type] for type in types, chk in [1]]

    # Format the values with best/worst highlighting
    formatted_conv = format_best_worst_across_types(conv_matrix, 0, is_higher_better=false, skip=true)
    formatted_asym = format_best_worst_across_types(asym_matrix, 3, is_higher_better=true, )
    percentages = ["\$$(Int(chk*100))\\%\$" for chk in checkpoints]
    # Create LaTeX table
    header = "\\hline\n & Asympt.~Rew. & " * join(percentages, " & ") * " & Length \\\\\n\\hline\n"
    rows = Dict()
    for (i, type) in enumerate(types)
        rows[type] = "$(latex_labels[type]) & {$(formatted_asym[i,1])} & " * 
                     join(formatted_conv[i, :], " & ") * " & \$$(lengths[type])\$ \\\\\n"
    end
    rows_text = ""
    for ts in type_order
        for type in ts
            rows_text *= rows[type]
        end
        rows_text *= "\\hline\n"
    end
    return "\\begin{tabular}{|l|r|" * "r"^cols * "|r|}\n\\hline\n" * header * rows_text * "\\hline\n\\end{tabular}"
end




generate_asymptotic_rewards_table (generic function with 1 method)

In [12]:
cumsums = get_cumsums(rewards)
sample_eff = get_sample_efficiency(rewards, [-0.6,-0.5,-0.4,-0.3,-0.2,-0.15])
asymptotic_rewards, convergence_speeds, lengths = get_asymptotic_reward(rewards, [0.6,0.7,0.8,0.9,0.99])


# Generate tables
cumulative_table = generate_cumulative_rewards_table(cumsums, [10,100,150,400,1000,3000])
sample_eff_table = generate_sample_efficiency_table(sample_eff, [-0.6,-0.5,-0.4,-0.3,-0.2,-0.15], lengths)
asymptotic_table = generate_asymptotic_rewards_table(asymptotic_rewards, convergence_speeds, [0.6,0.7,0.8,0.9,0.99], lengths)

# Print LaTeX tables
println(cumulative_table)

\begin{tabular}{|l|rrrrrr|}
\hline
\hline
 & $10$ & $100$ & $150$ & $400$ & $1000$ & $3000$ \\
\hline
$\mathsf{SimplePPL\_TS1}$ & \color{blue}{$-3.77$} & $-25.91$ & $-36.82$ & $-91.07$ & - & - \\
$\mathsf{SoftmaxPPL\_TS1}$ & \color{red}{$-6.80$} & $-53.18$ & $-70.00$ & $-137.41$ & - & - \\
$\mathsf{SimpleDBN\_TS1}$ & $-5.24$ & $-36.91$ & $-51.74$ & - & - & - \\
$\mathsf{MediumDBN\_TS1}$ & $-5.77$ & $-52.97$ & $-75.41$ & $-164.50$ & - & - \\
$\mathsf{FullDBN\_SmallPrior\_TS1}$ & $-5.97$ & $-65.26$ & $-97.67$ & - & - & - \\
\hline
$\mathsf{SimplePPL\_TS100}$ & $-3.78$ & \color{blue}{$-21.82$} & \color{blue}{$-28.58$} & \color{blue}{$-64.26$} & \color{blue}{$-143.31$} & \color{blue}{$-422.26$} \\
$\mathsf{SoftmaxPPL\_TS100}$ & $-6.26$ & $-58.21$ & $-83.68$ & $-176.00$ & $-286.61$ & $-596.47$ \\
$\mathsf{SimpleDBN\_TS100}$ & $-5.36$ & $-34.38$ & $-46.83$ & $-109.72$ & $-251.75$ & $-697.69$ \\
$\mathsf{MediumDBN\_TS100}$ & $-5.86$ & $-55.63$ & $-79.55$ & $-188.78$ & $-432.30$ & $-1194.55$ \

In [13]:
println(sample_eff_table)

\begin{tabular}{|l|rrrrrr|r|}
\hline
\hline
 & $-0.6$ & $-0.5$ & $-0.4$ & $-0.3$ & $-0.2$ & $-0.15$ & Length \\
\hline
$\mathsf{SimplePPL\_TS1}$ & \color{blue}{$2$} & $3$ & \color{blue}{$6$} & \color{blue}{$13$} & $383$ & - & $400$ \\
$\mathsf{SoftmaxPPL\_TS1}$ & $37$ & $58$ & $92$ & $192$ & - & - & $400$ \\
$\mathsf{SimpleDBN\_TS1}$ & $3$ & $11$ & $28$ & $121$ & - & - & $299$ \\
$\mathsf{MediumDBN\_TS1}$ & $11$ & $76$ & $185$ & $679$ & - & - & $924$ \\
$\mathsf{FullDBN\_SmallPrior\_TS1}$ & - & - & - & - & - & - & $360$ \\
\hline
$\mathsf{SimplePPL\_TS100}$ & \color{blue}{$2$} & \color{blue}{$2$} & $8$ & $32$ & \color{blue}{$100$} & \color{blue}{$500$} & $3500$ \\
$\mathsf{SoftmaxPPL\_TS100}$ & $64$ & $200$ & $300$ & $500$ & $900$ & $2600$ & $3400$ \\
$\mathsf{SimpleDBN\_TS100}$ & $8$ & $16$ & $64$ & $200$ & - & - & $7600$ \\
$\mathsf{MediumDBN\_TS100}$ & $32$ & $200$ & $1000$ & - & - & - & $7700$ \\
$\mathsf{FullDBN\_TS100}$ & - & - & - & - & - & - & $7700$ \\
$\mathsf{FullDBN\_SmallP

In [14]:

println(asymptotic_table)

\begin{tabular}{|l|r|rrrrr|r|}
\hline
\hline
 & Asympt.~Rew. & $60\%$ & $70\%$ & $80\%$ & $90\%$ & $99\%$ & Length \\
\hline
$\mathsf{SimplePPL\_TS1}$ & {$-0.224$} & $4$ & $12$ & $14$ & $34$ & $273$ & $400$ \\
$\mathsf{SoftmaxPPL\_TS1}$ & {$-0.259$} & $45$ & $78$ & $90$ & $217$ & $248$ & $400$ \\
$\mathsf{SimpleDBN\_TS1}$ & {$-0.225$} & $27$ & $44$ & $79$ & $138$ & $271$ & $299$ \\
$\mathsf{MediumDBN\_TS1}$ & {$-0.283$} & $114$ & $181$ & $205$ & $339$ & $750$ & $924$ \\
$\mathsf{FullDBN\_SmallPrior\_TS1}$ & {$-0.609$} & $6$ & $145$ & $248$ & $248$ & $248$ & $360$ \\
\hline
$\mathsf{SimplePPL\_TS100}$ & {\color{blue}{$-0.132$}} & $17$ & $33$ & $33$ & $65$ & $401$ & $3500$ \\
$\mathsf{SoftmaxPPL\_TS100}$ & {$-0.156$} & $301$ & $301$ & $401$ & $501$ & $1401$ & $3400$ \\
$\mathsf{SimpleDBN\_TS100}$ & {$-0.213$} & $17$ & $33$ & $65$ & $501$ & $1201$ & $7600$ \\
$\mathsf{MediumDBN\_TS100}$ & {$-0.356$} & $201$ & $201$ & $701$ & $1401$ & $5301$ & $7700$ \\
$\mathsf{FullDBN\_TS100}$ & {\color{