In [None]:
include("main.jl")
using PyCall
using CairoMakie
sepsis_gym = pyimport("custom_sepsis")
using Statistics


In [None]:
ts = Dict(
    :Simple => [sepsis_gym.DirThompsonSampling.load_json("json/dirichlet/ts/Simple-$i.json") for i in 0:2],
    :Medium => [sepsis_gym.DirThompsonSampling.load_json("json/dirichlet/ts/Medium-$i.json") for i in 0:2],
    :None => [sepsis_gym.DirThompsonSampling.load_json("json/dirichlet/ts/None-$i.json") for i in 0:2],
    :Softmax => [load_jld("data/mcmc/runs/SoftmaxPPL-$i.jld")  for i in 1:3],
    :SimplePPL => [load_jld("data/mcmc/runs/SimplePPL-$i.jld")  for i in 1:3]
)

In [None]:
dqn_1M = load_dqn_from_json("json/dqn/DQN-mean-rewards-1M.json")

In [None]:
dqn_35000 = load_dqn_from_json("json/dqn/DQN-mean-rewards-35000.json")


In [None]:
dqn_5000 = load_dqn_from_json("json/dqn/DQN-mean-rewards-5000.json")


In [None]:
function plot_mean_rewards(ts, batch_size, types=[:SimplePPL, :Softmax, :Simple, :Medium, :None], dqn=nothing, window_size=5, dqn_window=1000, x_lim=nothing, )
    fig = Figure(resolution=(900, 500))  
    ax = Axis(fig[1, 1], xlabel = "Number of Episodes", ylabel = "Mean Reward Across 100'000 Episodes", title = "Mean Rewards for Thompson Sampling with Batch Size $batch_size")

    # Data storage for accessing smoothed values later
    data = Dict()

    ks = []
    for (i, type) in enumerate(types)
        if isempty(ts[type])
            continue
        end

        all_keys = [sort(collect(keys(model.mean_rewards))) for model in ts[type]]
        min_keys = minimum(length.(all_keys))
        ks = all_keys[1][1:min_keys]
        filled_rewards = []
        all_rewards = []
        for model in ts[type]
            rewards = [model.mean_rewards[1]]
            for i in 2:min_keys
                key = all_keys[1][i]
                mult_factor = key - all_keys[1][i-1]
                push!(rewards, fill(model.mean_rewards[key], mult_factor)...)
            end
            push!(all_rewards, [model.mean_rewards[k] for k in ks])
            push!(filled_rewards, rewards)
            lines!(ax, Float64.(rewards), color=(colors_dict[type], 0.2))
        end
        mean_filled_rewards = mean(filled_rewards)
        mean_rewards = mean(all_rewards)
        smoothed_mean_rewards = moving_avg(mean_rewards, window_size)
        lines!(ax, ks, Float64.(smoothed_mean_rewards), color=colors_dict[type], linewidth=1.5, label=label_dict[type])

        std_rewards = std(all_rewards)
        smoothed_std_rewards = Float64.(moving_avg(std_rewards, window_size))
        low = smoothed_mean_rewards .- smoothed_std_rewards
        high = smoothed_mean_rewards .+ smoothed_std_rewards
        band!(ax, Float64.(ks), low, high, color=(colors_dict[type], 0.2))

        # Save smoothed data for future use
        data[type] = mean_filled_rewards
    end
    len = maximum(length(data[type]) for type in types)

    lines!(ax, 1:len, fill(random_mean, len), color=:black, linestyle=:dash, label="Random Policy")
    if dqn != nothing
        smoothed = add_dqn!(ax, dqn, dqn_window)
        data[:DQN] = dqn.mean_rewards
        data[:SmoothedDQN] = smoothed
    end
    # axislegend(ax, position=(:right, :bottom))
    Legend(fig[1, 2], ax, position = :right)

    if x_lim != nothing
        xlims!(ax, 0, x_lim)
    end
    
    ylims!(ax, -1, 0)
    ax.yticks = -1:0.05:0
    save("ts.png", fig)
    display(fig)

    return data
end


In [None]:
plot_mean_rewards(ts, 100, [:SimplePPL, :Softmax, :Simple, :Medium, :None], dqn, 5, 300, 3400)


In [None]:
# mean_rewards = plot_mean_rewards(ts, 100)
mean_rewards = plot_mean_rewards(ts, 100, [:SimplePPL, :Softmax, :Simple, :Medium, :None], dqn_1M, 5, 300)


In [None]:
plot_mean_rewards(ts, 100, [:SimplePPL, :Softmax, :Simple, :Medium, :None], dqn_5000, 5, 400, 3500)

In [None]:
mean_rewards_35000 = plot_mean_rewards(ts, 100, [:SimplePPL, :Softmax, :Simple, :Medium, :None], dqn_35000, 5, 400, 3500)

In [None]:
ts_every = Dict(
    :Simple => [sepsis_gym.DirThompsonSampling.load_json("json/dirichlet/ts/Simple-every-$i.json") for i in 1:3],
    :Medium => [sepsis_gym.DirThompsonSampling.load_json("json/dirichlet/ts/Medium-every-$i.json") for i in 1:3],
    :Softmax => [load_jld("data/mcmc/runs/SoftmaxPPL-every-$i.jld")  for i in 1:3],
    :SimplePPL => [load_jld("data/mcmc/runs/SimplePPL-every-$i.jld")  for i in 1:3]
)

In [None]:
plot_mean_rewards(ts_every, 1, [:SimplePPL, :Softmax, :Simple, :Medium], dqn_5000, 20, 20, 500)


In [None]:
mean_rewards_every = plot_mean_rewards(ts_every, 1, [:SimplePPL, :Softmax, :Simple, :Medium], dqn_5000, 20, 20)


In [None]:
plot_mean_rewards(ts_every, 4, [ :Softmax :Medium], dqn, 10, 300)


In [None]:
type_labels = Dict(
    :Softmax => "\\mathsf{SoftmaxPPL}",
    :None => "\\mathsf{FullDBN}",
    :Medium => "\\mathsf{MediumDBN}",
    :DQN => "\\mathsf{DQN}",
    :SmoothedDQN => "\\mathsf{DQN}",
    :Simple => "\\mathsf{SimpleDBN}",
    :SimplePPL => "\\mathsf{SimplePPL}"
)

In [None]:
function generate_latex_table(cumsums, checkpoints, type_labels, types)
    table = "\\begin{tabular}{|l|" * "r|"^length(checkpoints) * "}\n\\hline\n"
    table *= " & " * join(["\$" * string(ch) * "\$" for ch in checkpoints], " & ") * " \\\\ \\hline\n"

    # Collect all column values to compute min and max later
    column_values = [Float64[] for _ in checkpoints]

    for (_, values) in cumsums
        for (j, checkpoint) in enumerate(checkpoints)
            if checkpoint <= length(values)
                push!(column_values[j], round(values[checkpoint], digits=2))
            else
                push!(column_values[j], NaN)
            end
        end
    end

    # Identify min and max for each column
    column_mins = [isempty(col[.!isnan.(col)]) ? NaN : minimum(col[.!isnan.(col)]) for col in column_values]
    column_maxs = [isempty(col[.!isnan.(col)]) ? NaN : maximum(col[.!isnan.(col)]) for col in column_values]


    for key in types
        label = type_labels[key]
        # Generate row data, only include values where checkpoints are valid
        row_data = []
        for (j, checkpoint) in enumerate(checkpoints)
            if checkpoint <= length(cumsums[key])
                value = round(cumsums[key][checkpoint], digits=2)
                if value == column_mins[j]
                    push!(row_data, "\\color{red}{\$" * string(value) * "\$}")
                elseif value == column_maxs[j]
                    push!(row_data, "\\color{blue}{\$" * string(value) * "\$}")
                else
                    push!(row_data, "\$" * string(value) * "\$")
                end
            else
                push!(row_data, "-")
            end
        end
        table *= "\$" * label * "\$ & " * join(row_data, " & ") * " \\\\ \n"
    end

    table *= "\\hline \n\\end{tabular}\n"
    return table
end


In [None]:
cumsums = Dict(type => accumulate(.+, mean_rewards_35000[type]) for type in keys(mean_rewards_35000))

In [None]:
mean_rewards_every

In [None]:
cumsums_every = Dict(type => accumulate(.+, mean_rewards_every[type]) for type in keys(mean_rewards_every))


In [None]:
print(generate_latex_table(cumsums, [10,100,150,400,1000,3000], type_labels, [:SimplePPL, :Softmax, :Simple, :Medium, :None, :DQN]))

In [None]:
print(generate_latex_table(cumsums_every, [10,100, 150, 400], type_labels, [:SimplePPL, :Softmax, :Simple, :Medium, :DQN]))


In [None]:
mean_rewards_every

In [None]:
function calculate_and_generate_sample_efficiency_table(mean_rewards, thresholds, types)
    # Calculate sample efficiency
    sample_efficiency = Dict{Symbol, Dict{Float64, Int}}()
    
    for (type, rewards) in mean_rewards
        type_efficiency = Dict{Float64, Int}()

        for threshold in thresholds
            # Find the first index where the cumulative reward reaches or exceeds the threshold
            index = findfirst(x -> x >= threshold, rewards)

            if isnothing(index)
                continue
            else
                type_efficiency[threshold] = index
            end
        end

        sample_efficiency[type] = type_efficiency
    end

    # Generate LaTeX table
    table = "\\begin{tabular}{|l|" * "r|"^length(thresholds) * "}\n\\hline\n"
    table *= " & " * join(["\$" * string(th) * "\$" for th in thresholds], " & ") * " \\\\ \\hline\n"

    # Collect all column values to compute min and max later
    column_values = [Float64[] for _ in thresholds]

    for (_, values) in sample_efficiency
        for (j, threshold) in enumerate(thresholds)
            if haskey(values, threshold)
                push!(column_values[j], values[threshold])
            else
                push!(column_values[j], NaN)
            end
        end
    end

    # Identify min and max for each column
    column_mins = [isempty(col[.!isnan.(col)]) ? NaN : minimum(col[.!isnan.(col)]) for col in column_values]
    column_maxs = [isempty(col[.!isnan.(col)]) ? NaN : maximum(col[.!isnan.(col)]) for col in column_values]

    # Generate table rows
    for key in types
        label = type_labels[key]
        row_data = []
        for (j, threshold) in enumerate(thresholds)
            if haskey(sample_efficiency[key], threshold)
                value = sample_efficiency[key][threshold]
                if value == column_mins[j]
                    push!(row_data, "\\color{blue}{\$" * string(value) * "\$}")
                elseif value == column_maxs[j]
                    push!(row_data, "\\color{red}{\$" * string(value) * "\$}")
                else
                    push!(row_data, "\$" * string(value) * "\$")
                end
            else
                push!(row_data, "-")
            end
        end
        table *= "\$" * label * "\$ & " * join(row_data, " & ") * " \\\\ \n"
    end

    table *= "\\hline \\end{tabular}\n"
    return table
end


In [None]:

thresholds = [-0.6, -0.5, -0.4, -0.3, -0.2, -0.1]

latex_table = calculate_and_generate_sample_efficiency_table(mean_rewards_35000, thresholds, [:SimplePPL, :Softmax, :Simple, :Medium, :None, :DQN])
println(latex_table)


In [None]:

latex_table = calculate_and_generate_sample_efficiency_table(mean_rewards_every, thresholds, [:SimplePPL, :Softmax, :Simple, :Medium, :DQN])
println(latex_table)


In [None]:
mean_rewards[:SmoothedDQN]


In [None]:
mean_rewards_every[:SmoothedDQN]

In [None]:
function convergence(mean_rewards, checkpoints, types, last_n)
    # Calculate asymptotic mean rewards and convergence speeds for each checkpoint
    asymptotic_rewards = Dict{Symbol, Float64}()
    convergence_speeds = Dict{Symbol, Dict{Float64, Int}}()
    lengths = Dict{Symbol, Int}()

    for (type, rewards) in mean_rewards
        # Store the length of the rewards array
        lengths[type] = length(rewards)

        # Asymptotic mean reward
        if !isempty(rewards)
            asymptotic_rewards[type] = mean(rewards[end-last_n:end])
        else
            asymptotic_rewards[type] = NaN
        end

        # Convergence speed for each checkpoint
        convergence_speeds[type] = Dict()
        for checkpoint in checkpoints
            min_reward = minimum(rewards)
            max_reward = maximum(rewards)
            reward_range = max_reward - min_reward

            if reward_range == 0
                convergence_speeds[type][checkpoint] = -1  # No valid range for convergence
                continue
            end

            threshold_value = min_reward + checkpoint * reward_range
            index = findfirst(x -> x >= threshold_value, rewards)
            if isnothing(index)
                convergence_speeds[type][checkpoint] = -1  # Indicate no convergence
            else
                convergence_speeds[type][checkpoint] = index
            end
        end
    end

    # Generate LaTeX table
    table = "\\begin{tabular}{|l|" * repeat("r|", length(checkpoints) + 2) * "}\\hline\n"
    table *= " & Asympt. Rew. " * join(["& $(Int(round(cp*100, digits=0)))\\%" for cp in checkpoints]) * " & Nr. Ep. \\\\ \\hline\n"

    # Identify min and max for highlighting across checkpoints
    all_asymptotic_values = collect(values(asymptotic_rewards))
    valid_asymptotic_values = all_asymptotic_values[.!isnan.(all_asymptotic_values)]
    min_asymptotic = isempty(valid_asymptotic_values) ? NaN : minimum(valid_asymptotic_values)
    max_asymptotic = isempty(valid_asymptotic_values) ? NaN : maximum(valid_asymptotic_values)

    for type in types
        label = type_labels[type]
        asymptotic = asymptotic_rewards[type]
        length = lengths[type]

        # Asymptotic mean reward formatting
        if isnan(asymptotic)
            asymptotic_cell = "-"
        elseif asymptotic == min_asymptotic
            asymptotic_cell = "\\color{red}{\$" * string(round(asymptotic, digits=2)) * "\$}"
        elseif asymptotic == max_asymptotic
            asymptotic_cell = "\\color{blue}{\$" * string(round(asymptotic, digits=2)) * "\$}"
        else
            asymptotic_cell = "\$" * string(round(asymptotic, digits=2)) * "\$"
        end

        # Convergence speed formatting for each checkpoint
        convergence_cells = []
        for checkpoint in checkpoints
            convergence = convergence_speeds[type][checkpoint]
            all_convergence_values = [convergence_speeds[t][checkpoint] for t in types if convergence_speeds[t][checkpoint] >= 0]
            min_convergence = isempty(all_convergence_values) ? NaN : minimum(all_convergence_values)
            max_convergence = isempty(all_convergence_values) ? NaN : maximum(all_convergence_values)

            if convergence == -1
                push!(convergence_cells, "-")
            elseif convergence == min_convergence
                push!(convergence_cells, "\\color{blue}{\$" * string(convergence) * "\$}")
            elseif convergence == max_convergence
                push!(convergence_cells, "\\color{red}{\$" * string(convergence) * "\$}")
            else
                push!(convergence_cells, "\$" * string(convergence) * "\$")
            end
        end

        table *= "\$" * label * "\$ & " * asymptotic_cell * " & " * join(convergence_cells, " & ") * " & \$" * string(length) * "\$ \\\\ \n"
    end

    table *= "\\hline\\end{tabular}\n"
    return table
end


In [None]:

latex_table = convergence(mean_rewards, [0.6,0.7,0.8,0.9,0.99], [:SimplePPL, :Softmax, :Simple, :Medium, :None, :SmoothedDQN], 10)
println(latex_table)


In [None]:

latex_table = convergence(mean_rewards_35000, [0.6,0.7,0.8,0.9,0.99], [:SimplePPL, :Softmax, :Simple, :Medium, :None, :SmoothedDQN], 10)
println(latex_table)


In [None]:

latex_table = convergence(mean_rewards_every, [0.6,0.7,0.8,0.9,0.99], [:SimplePPL, :Softmax, :Simple, :Medium, :SmoothedDQN], 10)
println(latex_table)

In [None]:
medium = ts[:Medium][1]

In [None]:
exploration_data = Dict(vs=> [] for vs in ["HR", "BP", "O2", "GLU"])
iterations = collect(keys(medium.models))
inp_vals = [sepsis_gym.HR_STATES, sepsis_gym.BP_STATES, sepsis_gym.O2_STATES, sepsis_gym.GLU_STATES]
values = [[sepsis_gym.Level.LOW.value, sepsis_gym.Level.NORMAL.value, sepsis_gym.Level.HIGH.value],
          [sepsis_gym.Level.LOW.value, sepsis_gym.Level.NORMAL.value, sepsis_gym.Level.HIGH.value],
            [sepsis_gym.Level.LOW.value, sepsis_gym.Level.NORMAL.value,],
            [sepsis_gym.Level.SUPER_LOW.value, sepsis_gym.Level.LOW.value, sepsis_gym.Level.NORMAL.value, sepsis_gym.Level.HIGH.value, sepsis_gym.Level.SUPER_HIGH.value]]

for iteration in iterations
    for (i, vs) in enumerate(["HR", "BP", "O2", "GLU"])
        state_vals = values[i]
        exploration_per_input = Dict()
        for state in inp_vals[i]
            for action in sepsis_gym.ACTIONS
                exploration_per_input[(state, action)] = sum([medium.models[iteration][i][(state, action, next_s)] for next_s in state_vals]) / length(state_vals)
            end
        end

        flat = collect(values(exploration_per_input))
        # incorporate sums
        explored_percentage = length(filter(x -> x > 1, flat)) / length(flat) * 100
        push!(exploration_data[vs], explored_percentage)
    end
end



# plt.figure(figsize=(7, 4))

# for vs in ["HR", "BP", "O2", "GLU"]:
#     plt.plot(iterations, exploration_data[vs], label=f"{vs} Count Exploration %")

# plt.xlabel("Nr Episodes in History")
# plt.ylabel("Exploration Percentage (%)")
# plt.title("Exploration Percentage Growth Over Nr Episodes in History")
# plt.legend()
# plt.xscale("log", base=10)
# plt.grid(True)

# plt.show()
