In [None]:
using Statistics
using Gen
using PyCall
using Random
gym = pyimport("gymnasium")
DTRGym = pyimport("DTRGym")
spaces = pyimport("gym.spaces")

In [6]:
@dist function labeled_categorical(labels, probs)
    index = categorical(probs)
    labels[index]
end;

const Action = Int
struct State
    hr::Int
    bp::Int
    o2::Int
    glu::Float64
    diabetic::Bool
    abx::Bool
    vaso::Bool
    vent::Bool
end;
const Policy = Dict{State,Action}
struct Episode
    policy::Policy
    rewards::Vector{Float64}
    visited::Vector{State}
end;

const DirichletCounts = Dict{Tuple{State,Action,State},Int}
ACTIONS = [i for i in 1:8]
STATES = [State(hr, bp, o2, glu / 2, diabetic, abx, vaso, vent) for hr in -1:1 for bp in -1:1 for o2 in -1:1 for glu in -2:2 for diabetic in [true, false] for abx in [true, false] for vaso in [true, false] for vent in [true, false]]

state_to_index = Dict(state => i for (i, state) in enumerate(STATES))

function get_reward(state::State)::Float64
    reward = 0.0
    critical_counts = count(c -> c != 0, [state.hr, state.bp, state.o2, state.glu])
    if critical_counts >= 3
        reward = -1.0
    elseif critical_counts == 0 && !state.abx && !state.vaso && !state.vent
        reward = 1.0
    end
    return reward
end;

function to_state(dict::Dict{Any,Any})::State
    return State(dict["hr_state"], dict["sysbp_state"], dict["percoxyg_state"], dict["glucose_state"], dict["diabetic_idx"], dict["antibiotic_state"], dict["vaso_state"], dict["vent_state"])
end;

const TransitionModel = Dict{Tuple{State, Action}, Vector{Float64}}

function random_policy()::Policy
    policy = Dict{State,Action}()
    for state in STATES
        policy[state] = rand(ACTIONS)
    end
    return policy
end;



## Gen

In [7]:
@gen function transition_model(dirichlet_counts::DirichletCounts=DirichletCounts())::TransitionModel
    # beliefs are a mapping S,A -> S
    beliefs = TransitionModel()
    for state in STATES
        for action in ACTIONS
            beliefs[(state, action)] = {state => action} ~ dirichlet([haskey(dirichlet_counts, (state,action,new_state)) ? dirichlet_counts[(state, action, new_state)] : 1.0 for new_state in STATES])
        end
    end
    return beliefs
end;

In [8]:
@gen function simulate_episode(beliefs::TransitionModel, actions::Vector{Action}, start_state::State)
    states = [start_state]
    state = start_state
    rewards = []
    for (t, action) in enumerate(actions)
        {t => :action} ~ labeled_categorical([action], [1])
        new_state = {t => :new_state} ~ labeled_categorical(STATES, beliefs[(state, action)])
        push!(states, new_state)
        reward = {t => :reward} ~ labeled_categorical([get_reward(new_state)], [1])
        push!(rewards, reward)
        state = new_state
    end
    return states, rewards
end;


## Gym

In [9]:
function run_episode(env, policy::Policy, max_length::Int)
    obs, info = env.reset()
    state = to_state(info["state"])
    visited = [state]
    rewards = []
    for t in 1:max_length
        action = policy[state]
        obs, reward, terminated, truncated, info = env.step(action)
        new_state = to_state(info["state"])
        push!(visited, new_state)
        push!(rewards, reward)
        state = new_state
        if terminated
            break
        end
    end
    return Episode(policy, rewards, visited)
end;

sepsis_env = gym.make("OberstSepsisEnv-discrete")
run_episode(sepsis_env, random_policy(), 100)

Episode(Dict{State, Int64}(State(-1, 0, 1, 1.0, true, false, false, false) => 2, State(-1, 1, 1, 1.0, false, false, false, false) => 5, State(1, -1, -1, 0.5, true, false, true, false) => 7, State(1, 0, 1, 1.0, false, false, true, true) => 6, State(0, -1, -1, 0.0, true, true, false, false) => 7, State(1, 0, -1, -0.5, false, false, false, false) => 2, State(-1, 1, -1, 1.0, true, true, false, true) => 6, State(-1, -1, 0, -0.5, false, false, true, true) => 2, State(-1, -1, 0, 0.0, false, false, true, false) => 4, State(0, -1, 1, -0.5, false, false, false, false) => 1…), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0], State[State(1, 0, 1, 0.0, false, false, false, false), State(1, 0, 1, 0.0, false, false, true, true), State(1, 0, 1, 0.0, false, false, true, false), State(1, 0, 1, 0.0, false, true, false, true), State(1, 0, 1, 0.0, false, false, false, true), State(1, 1, 1, 0.0, false, false, true, false), State(0, 0, 1, 0.0, false, true, false, false), State(0, 1, 1, 0.0, fal

In [10]:
function run_episode_sepsis_model(env, model, max_length::Int)
    obs, info = env.reset()
    state = to_state(info["state"])
    visited = [state]
    rewards = []
    for t in 1:max_length
        action = model.policy.predict(obs)[1][1]
        obs, reward, terminated, truncated, info = env.step(action)
        new_state = to_state(info["state"])
        push!(visited, new_state)
        push!(rewards, reward)
        state = new_state
        if terminated
            break
        end
    end
    return Episode(model, rewards, visited)
end;

sepsis_env = gym.make("OberstSepsisEnv-discrete")
run_episode(sepsis_env, random_policy(), 100)

Episode(Dict{State, Int64}(State(-1, 0, 1, 1.0, true, false, false, false) => 6, State(-1, 1, 1, 1.0, false, false, false, false) => 1, State(1, -1, -1, 0.5, true, false, true, false) => 1, State(1, 0, 1, 1.0, false, false, true, true) => 2, State(0, -1, -1, 0.0, true, true, false, false) => 4, State(1, 0, -1, -0.5, false, false, false, false) => 6, State(-1, 1, -1, 1.0, true, true, false, true) => 7, State(-1, -1, 0, -0.5, false, false, true, true) => 5, State(-1, -1, 0, 0.0, false, false, true, false) => 4, State(0, -1, 1, -0.5, false, false, false, false) => 8…), [-1.0], State[State(-1, 0, 0, 0.0, false, false, false, false), State(-1, 1, 0, 0.0, false, false, true, false)])

In [11]:
# Import Stable-Baselines3
sb3 = pyimport("stable_baselines3")

PyObject <module 'stable_baselines3' from '/Users/luisastue/miniconda3/lib/python3.10/site-packages/stable_baselines3/__init__.py'>

In [None]:
# Create a PPO agent
ppo = sb3.PPO("MlpPolicy", sepsis_env, verbose=0)

# Train the agent using `learn`
ppo.learn(total_timesteps=100000, log_interval=1000) 


In [13]:
function get_policy_sepsis_model(model)
    policy = Dict{State, Int}()
    for state in STATES
        policy[state] = model.policy.predict([state.hr, state.bp, state.o2, state.abx, state.vaso, state.vent])[1][1]
    end
    return policy
end;


function get_policy(model)
    policy = Dict{State, Int}()
    for i in 1:length(STATES)
        policy[STATES[i]] = model.policy.predict([i-1])[1][1]
    end
    return policy
end;


In [None]:
eval_episodes = []
for _ in 1:100
    ppo_sepsis_optimal_policy = get_policy_sepsis_model(ppo)
    for _ in 1:1000
        push!(eval_episodes, run_episode(sepsis_env, ppo_sepsis_optimal_policy, 100))
    end
end


In [None]:
mean([sum(ep.rewards) for ep in eval_episodes])

In [21]:
@pydef mutable struct GymEnv <: gym.Env
    action_space = gym.spaces.Discrete(length(ACTIONS)) # from 0 to 7
    observation_space = gym.spaces.Discrete(length(STATES)) # from 0 to length(STATES) - 1
    state = rand(0:length(STATES) - 1)
    done = false

    function __init__(self, transition_model)
        self.transition_model = transition_model
        self.state = rand(0:length(STATES) - 1)
        self.done = false
        self.nr_actions = 0
    end

    function reset(self; seed...)
        self.state = rand(0:length(STATES) - 1)
        self.done = false
        self.nr_actions = 0
        return [self.state], Dict()
    end

    function step(self, action) # action from 0 to 7
        transition_probs = self.transition_model[(STATES[self.state + 1], action + 1)]
        next_state = categorical(transition_probs) -1
        self.nr_actions += 1
        reward = get_reward(STATES[self.state + 1])  # Compute reward
        self.done = (reward != 0|| self.nr_actions > 100)  # Termination logic
        self.state = next_state
        return ([self.state], reward, self.done, false, Dict())  # Return Gym-compatible tuple
    end

    function render(self, mode="human")
        println("State: $(self.state)")
    end

    function close(self)
        println("Closing environment.")
    end
end;

In [22]:
# Create an instance of the environment
tr = transition_model();

In [23]:
env = GymEnv(tr)

PyObject <GymEnv object at 0x336bb7ca0>

In [19]:
# Reset the environment
obs = env.reset()
println("Initial Observation: $obs")

# Step through the environment
done = false
total_reward = 0.0

while !done
    action = env.action_space.sample() # Take a random action
    obs, reward, done, _ = env.step(action)
    total_reward += reward
    println("Obs: $obs, Reward: $reward, Done: $done")
end

println("Total Reward: $total_reward")
env.close()

Initial Observation: ([2046], Dict{Any, Any}())
Obs: [393], Reward: 0.0, Done: false
Obs: [1395], Reward: 0.0, Done: false
Obs: [2113], Reward: 0.0, Done: false
Obs: [1276], Reward: -1.0, Done: true
Total Reward: -1.0
Closing environment.


## Running DQN on random transition model

In [20]:
dqn = sb3.DQN("MlpPolicy", env, verbose=0)

PyObject <stable_baselines3.dqn.dqn.DQN object at 0x3402c7c70>

In [None]:
dqn.learn(total_timesteps=10000)

In [17]:
# eval_episodes = []
# for _ in 1:100
#     opt_pol = get_policy(dqn)
#     for _ in 1:1000
#         push!(eval_episodes, run_episode(sepsis_env, opt_pol, 100))
#     end
# end

# mean([sum(ep.rewards) for ep in eval_episodes])

# Learning from a history

In [18]:
history = [run_episode(sepsis_env, random_policy(), 100) for _ in 1:1e4];

In [19]:
random_reward = mean([sum(episode.rewards) for episode in history])
random_reward

-0.8853

In [20]:
function update_state_counts(state_counts::DirichletCounts, episodes::Vector{Episode})
    for episode in episodes
        for (i, state) in enumerate(episode.visited[1:end-1])
            action = episode.policy[state]
            new_state = episode.visited[i+1]
            state_counts[(state, action, new_state)] = get(state_counts, (state, action, new_state), 1) + 1
        end
    end
    return state_counts
end;

In [21]:
state_counts = update_state_counts(DirichletCounts(), history)
tr = transition_model(state_counts)

Dict{Tuple{State, Int64}, Vector{Float64}} with 17280 entries:
  (State(1, 1, 1, -0.5, tr… => [0.00111337, 0.000788264, 0.000693644, 0.0010753…
  (State(-1, -1, 1, 1.0, f… => [0.000149381, 0.000617536, 0.000593287, 0.001233…
  (State(-1, -1, -1, 1.0, … => [0.000140698, 0.000716112, 0.000117856, 0.000309…
  (State(1, 0, 0, 0.0, fal… => [0.000831603, 4.99648e-6, 0.000776732, 0.0019834…
  (State(-1, 1, 0, 0.5, tr… => [0.000511288, 0.00112391, 5.30439e-5, 0.00050322…
  (State(-1, 1, -1, 1.0, t… => [0.000921606, 0.000401734, 0.00412987, 1.61151e-…
  (State(0, 1, -1, 0.0, tr… => [2.21788e-5, 0.000811698, 0.000562487, 0.0004956…
  (State(1, 1, 1, -0.5, tr… => [0.000190075, 0.000369742, 0.000852077, 0.001127…
  (State(0, -1, 0, 1.0, tr… => [0.00216023, 0.000416265, 0.000569966, 0.0004093…
  (State(0, 0, -1, 0.5, fa… => [3.38045e-5, 0.000694458, 4.25422e-5, 0.00181959…
  (State(0, 0, 1, 0.0, fal… => [0.000214809, 0.000550098, 0.000332522, 0.000164…
  (State(1, 1, 0, -1.0, tr… => [3.75901e-5, 0.

In [22]:
env = GymEnv(tr)

PyObject <GymEnv object at 0x3244c1510>

In [23]:
# # dqn = sb3.DQN("MlpPolicy", env, verbose=0)
# # dqn.learn(total_timesteps=10000, log_interval=1000)
# # dqn = sb3.DQN("MlpPolicy", env, replay_buffer_size=10_000, verbose=0)
# dqn = sb3.DQN("MlpPolicy", env, batch_size=32, verbose=0)
# dqn.learn(total_timesteps=5000, log_interval=500)
# policy = get_policy(dqn)

In [None]:
ppo = sb3.PPO("MlpPolicy", env, verbose=0)
ppo.learn(total_timesteps=5000)
ppo_policy = get_policy(ppo)

In [None]:
a2c = sb3.A2C("MlpPolicy", env, verbose=0)
a2c.learn(total_timesteps=5000, log_interval=500)
policy = get_policy(a2c)

In [None]:
test = [run_episode(sepsis_env, policy, 100) for _ in 1:1000]
test_reward = mean([sum(episode.rewards) for episode in test])