In [1]:
include("inference/mcmc/sepsis_types.jl")
using .SepsisTypes
include("inference/mcmc/sepsis.jl")
using .Sepsis
include("inference/mcmc/inference.jl")
using .Inference
include("inference/mcmc/softmax.jl")
using .Softmax
include("inference/mcmc/smart.jl")
using .Smart
include("inference/mcmc/value_iter.jl")
using .ValueIteration
using Revise
using PyCall;
using Gen;
using CairoMakie
sepsis_gym = pyimport("custom_sepsis");
np = pyimport("numpy");
using BenchmarkTools
using Serialization

In [None]:
mutable struct TS
    choices::ChoiceMap
    policies::Vector{Policy}
    start_states::Vector{State}
    index::Int
    params::Dict
    scores::Dict
    acceptance::Dict
    sampled_params::Dict
    mean_rewards::Dict

    function History()
        return new(choicemap(), [], [], 0, Dict(), Dict(), Dict(), Dict(), Dict())
    end
end;


In [None]:
function add_episodes!(history::History, nr::Int)
    start = history.index
    for i in start+1:start+nr
        policy = sepsis_gym.random_policy()
        episode = sepsis_gym.run_episode(policy)

        push!(history.policies, to_policy(policy))
        push!(history.start_states, to_state(episode.visited[1]))
        
        history.choices = update_choicemap(history.choices, i, episode)
        history.index = i
    end
end

function train_ts!(ts::TS, nr_iter::Int, functions)
    trace, _ = generate(sepsis_model, (ts.policies, ts.start_states, functions), ts.choices)
    
    params = [functions.extract_parameters(trace)]
    scores = [get_score(trace)]
    acceptance = 0.0

    for _ in 1:nr_iter
        trace, a = drift_update(trace, 0.01)

        push!(params, functions.extract_parameters(trace))
        push!(scores, get_score(trace))
        acceptance += a
    end
    acceptance /= nr_iter

    ts.params[ts.index] = params
    ts.scores[ts.index] = scores
    ts.acceptance[ts.index] = acceptance

    posterior = params[end-100:end]
    
    mean_rew = []
    sampled_params = []
    for i in 1:10
        param = rand(posterior)
        push!(sampled_params, param)
        policy, V = optimize(param, functions) 
        pol = to_gym_pol(policy)
        r = sepsis_gym.evaluate_policy(pol, 100000)
        push!(mean_rew, r)
    end

    ts.sampled_params[ts.index] = sampled_params
    ts.mean_rewards[ts.index] = mean_rew

    return params, scores, acceptance, mean_rew
end

function save_file(self::Any, name::String)
    # Implement saving to file, e.g., using Serialization
    open(name, "w") do io
        serialize(io, self)
    end
end

function load_history(name::String)::History
    # Implement loading from file, e.g., using Serialization
    open(name, "r") do io
        return deserialize(io)
    end
end