In [1]:
include("inference/mcmc/sepsis_types.jl")
using .SepsisTypes
include("inference/mcmc/sepsis.jl")
using .Sepsis
include("inference/mcmc/inference.jl")
using .Inference
using Revise
using PyCall;
using Gen;
using CairoMakie
sepsis_gym = pyimport("custom_sepsis");
np = pyimport("numpy");
using BenchmarkTools


In [42]:
function get_transition_matrix(parameters::Parameters)
    states = enumerate(sepsis_gym.STATES)
    n_states = length(states)
    actions = enumerate(sepsis_gym.ACTIONS)
    n_actions = length(states)
    transition_matrix = PyObject(np.ndarray((n_states, n_actions, n_states)))
    for (i, state) in states
        state = to_state(state)
        for (j, action) in actions
            action = to_action(action)
            hr_p = hr_probs(parameters, state, action)
            hr_p = Dict(
                LOW => hr_p[1],
                NORMAL => hr_p[2],
                HIGH => hr_p[3]
                )
            bp_p = bp_probs(parameters, state, action)
            bp_p = Dict(
                LOW => bp_p[1],
                NORMAL => bp_p[2],
                HIGH => bp_p[3]
                )
            o2_p = o2_probs(parameters, state, action)
            o2_p = Dict(
                LOW => o2_p[1],
                NORMAL => o2_p[2],
                )
            glu_p = glu_probs(parameters, state, action)
            glu_p = Dict(
                SUPER_LOW => glu_p[1],
                LOW => glu_p[2],
                NORMAL => glu_p[3],
                HIGH => glu_p[4],
                SUPER_HIGH => glu_p[5]
                )

            diab_p = Dict(true=> Int(state.diabetic), false=> Int(!state.diabetic))
            abx_p = Dict(true=> Int(action.abx), false=> Int(action.abx))
            vaso_p = Dict(true=> Int(action.vaso), false=> Int(action.vaso))
            vent_p = Dict(true=> Int(action.vent), false=> Int(action.vent))
            
            for (k, next_state) in states
                next_state = to_state(next_state)
                transition_matrix[i-1, j-1, k-1] = hr_p[next_state.hr] * bp_p[next_state.bp] * o2_p[next_state.o2] * glu_p[next_state.glu] * diab_p[next_state.diabetic] * abx_p[next_state.abx] * vaso_p[next_state.vaso] * vent_p[next_state.vent]
            end
        end
    end
    return transition_matrix
end;

In [None]:
pms = @btime get_parameters()

In [20]:
hr_p = @btime hr_probs(pms, State(LOW, NORMAL, NORMAL, NORMAL, false, false, false, false), Action(false, false, false))


  915.541 ns (9 allocations: 464 bytes)


3-element Vector{Float64}:
 0.7191078150673792
 0.03927453775722887
 0.24161764717539197

In [23]:
tr_mat = @btime get_transition_matrix(pms)
# 3min 40s

  210.828 s (675565663 allocations: 18.41 GiB)


1440×8×1440 Array{Float64, 3}:
[:, :, 1] =
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.00124009   0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                                ⋮         
 0.000452451  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0          0.0  0.0  0.0  0.0  0.0  0.0  0.0


In [None]:
tr_mat = get_transition_matrix(pms)
# 3min 40s

In [25]:
V = np.zeros(length(sepsis_gym.STATES))
@btime sepsis_gym.matrix_value_iteration(V, tr_mat)

  45.487 ms (80716 allocations: 2.26 MiB)


(Dict{Any, Any}((0, -1, 0, 2, false, true, true, true) => (true, true, false), (-1, 0, -1, 1, false, false, false, true) => (true, true, false), (0, 0, 0, 0, true, false, false, true) => (true, true, false), (-1, -1, 0, 0, false, false, false, false) => (true, true, false), (-1, 1, -1, 0, false, false, true, false) => (true, true, false), (-1, 1, 0, -2, false, false, false, false) => (true, true, false), (1, -1, -1, -2, true, false, false, true) => (true, true, false), (1, 1, -1, -1, true, true, false, false) => (true, true, false), (-1, -1, 0, 2, true, true, false, true) => (true, true, false), (-1, 0, -1, -2, true, false, true, false) => (true, true, false)…), [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0  …  -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0])

In [40]:
tr_mat[1,1,1]

0.0012400914260942805

In [34]:
pol, V = sepsis_gym.matrix_value_iteration(V, PyObject(tr_mat))

PyCall.PyError: PyError ($(Expr(:escape, :(ccall(#= /Users/luisastue/.julia/packages/PyCall/1gn3u/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) <class 'ValueError'>
ValueError('einstein sum subscripts string contains too many subscripts for operand 1')
  File "/Users/luisastue/julia/hc-rl-thompson-sampling/sepsis/custom_sepsis/src/custom_sepsis/inference/dirichlet/tr_value_iter.py", line 36, in matrix_value_iteration
    Q = REWARDS + gamma * np.einsum('ijk,k->ij', transition_model, V)
  File "/Users/luisastue/miniconda3/lib/python3.10/site-packages/numpy/core/einsumfunc.py", line 1371, in einsum
    return c_einsum(*operands, **kwargs)


In [27]:
policy = to_policy(pol)

MethodError: MethodError: no method matching to_policy(::Vector{Float64})

Closest candidates are:
  to_policy(!Matched::Dict{Any, Any})
   @ Main.SepsisTypes ~/thesis/hc-rl-thompson-sampling/sepsis/custom_sepsis/src/custom_sepsis/inference/mcmc/sepsis_types.jl:55


In [24]:
@btime to_action((1,1,1))

  1.417 ns (0 allocations: 0 bytes)


Action(true, true, true)

In [22]:
@btime enumerate(sepsis_gym.STATES)

  9.817 ms (43228 allocations: 1.14 MiB)


enumerate(Tuple{Int64, Int64, Int64, Int64, Vararg{Bool, 4}}[(-1, -1, -1, -2, 1, 1, 1, 1), (-1, -1, -1, -2, 1, 1, 1, 0), (-1, -1, -1, -2, 1, 1, 0, 1), (-1, -1, -1, -2, 1, 1, 0, 0), (-1, -1, -1, -2, 1, 0, 1, 1), (-1, -1, -1, -2, 1, 0, 1, 0), (-1, -1, -1, -2, 1, 0, 0, 1), (-1, -1, -1, -2, 1, 0, 0, 0), (-1, -1, -1, -2, 0, 1, 1, 1), (-1, -1, -1, -2, 0, 1, 1, 0)  …  (1, 1, 0, 2, 1, 0, 0, 1), (1, 1, 0, 2, 1, 0, 0, 0), (1, 1, 0, 2, 0, 1, 1, 1), (1, 1, 0, 2, 0, 1, 1, 0), (1, 1, 0, 2, 0, 1, 0, 1), (1, 1, 0, 2, 0, 1, 0, 0), (1, 1, 0, 2, 0, 0, 1, 1), (1, 1, 0, 2, 0, 0, 1, 0), (1, 1, 0, 2, 0, 0, 0, 1), (1, 1, 0, 2, 0, 0, 0, 0)])

In [6]:
state = State(NORMAL, LOW, NORMAL, NORMAL, false, false, false, false)
policy = to_policy(sepsis_gym.random_policy())
next_state = get_next_state(deterministic_params, state, policy[state])
trace, sc = generate(get_next_state, (deterministic_params, state, Action(0, 0, 0)))

states, rewards = simulate_episode(deterministic_params, policy, state)
states