In [1]:
# drastically reduced size of state,action,observation spaces to make problem more tractable
# s : 50^n -> 5n^3
# a : 2^n -> n
# o : 25^n -> 5n^2

using POMDPs
using POMDPDistributions
using Distributions
using SARSOP
using POMDPModels
using Iterators
using POMDPToolbox
using QMDP
using SARSOP

In [2]:
##################################################
######### EmbryoCulture POMDP ####################
##################################################

type EmbryoCulture <: POMDP
    n::Int64; # number of embryos starting in culture
    dayStart::Int64;
    discount::Float64; # = 1; # discount_factor
    pv::Float64; # = 0.9; # probability that a viable embryo will continue to be viable
    pn::Array{Float64,1}; # = [.7; .95]; # probability that a [nonviable, viable] embryo will keep its morphology
end

EmbryoCulture(n::Int64) = EmbryoCulture(n,1,1,0.9,[0.7;0.95]);

# calculate expected viable at day 5
# about half of viable day 1 embryos make it to day 5
function expectedViable(nv::Int64, day::Int64)
    return nv * (2 ^ (-(5 - day)/4))
end

function POMDPs.discount(mdp::EmbryoCulture)
    return mdp.discount; 
end

discount (generic function with 5 methods)

In [3]:
##################################################
######### EmbryoState definitions ################
##################################################

# EmbryoState definition
type EmbryoState <: State
    n::Int64 # total number of embryos in culture
    v::Int64 # number of currently viable embryos
    
    # number of embryos with {poor, fair, good} morphology
    mV::Array{Int64,1} # morphology distribution for viable mebryos
    mNV::Array{Int64,1} # morphology distribution for nonviable embryos
    done::Int64; # done? (only if B, TR or D)
    day::Int64; # culture day
end

# helper to generate morphology distribution {poor, fair, good}
function generateMorphology(nEmbryos::Int64, viable::Bool, theta::Array{Float64,1}=[.8;.9])
    
    mOut = zeros(Int64,3);
    b = Bernoulli(theta[viable + 1]);
    
    for i = 1:nEmbryos
        currInd = sum(rand(b,2)) + 1;
        mOut[currInd] += 1;
    end
    
    return mOut;
end

# initialize day1 embryos with known viability
EmbryoState(n::Int64, v::Int64) = EmbryoState(n,v,generateMorphology(v,true), generateMorphology(n-v,false),0,1);
EmbryoState(n::Int64, v::Int64, mV::Array{Int64,1}, mNV::Array{Int64,1}) = EmbryoState(n,v,mV,mNV,0,1);

function ==(s1::EmbryoState, s2::EmbryoState)
    return (s1.n == s2.n) && (s1.v == s2.v) && (s1.mV == s2.mV) && 
        (s1.mNV == s2.mNV) && (s1.done == s2.done) && (s1.day == s2.day);
end

# convert EmbryoState to string for hashing purposes
function embryoStateToString(s::EmbryoState)
    return string(s.n,s.v,s.mV[1], s.mV[2], s.mV[3],s.mNV[1], s.mNV[2], s.mNV[3],s.done,s.day);
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoState(s::ASCIIString)
   
    n = parse(Int64, s[1]);
    v = parse(Int64, s[2]);
    
    mV = [parse(Int64, s[3]);parse(Int64, s[4]);parse(Int64, s[5])];
    mNV = [parse(Int64, s[6]);parse(Int64, s[7]);parse(Int64, s[8])];

    done = parse(Int64, s[9]);
    day = parse(Int64, s[10]);
        
    return EmbryoState(n,v,mV,mNV,done,day);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoState(sa::Array{ASCIIString,1})
   
    nStates = length(sa);
    ea = cell(nStates);
    
    for j in 1:nStates
        s = sa[j];
    
        n = parse(Int64, s[1]);
        v = parse(Int64, s[2]);

        mV = [parse(Int64, s[3]);parse(Int64, s[4]);parse(Int64, s[5])];
        mNV = [parse(Int64, s[6]);parse(Int64, s[7]);parse(Int64, s[8])];

        done = parse(Int64, s[9]);
        day = parse(Int64, s[10]);

        ea[j] = EmbryoState(n,v,mV,mNV,done,day);
    end
    
    return ea;
end

function Base.copy!(s1::EmbryoState, s2::EmbryoState)
    s1.n = s2.n;
    s1.v = s2.v;
    s1.mV = s2.mV;
    s1.mNV = s2.mNV;
    s1.done = s2.done;
    s1.day = s2.day;
    s1
end

copy! (generic function with 26 methods)

In [4]:
##################################################
######### EmbryoStateSpace etc. ##################
##################################################
type EmbryoStateIterator
    n::Int64;
    minVals::Array{Int64,1};
    maxVals::Array{Int64,1};
end

EmbryoStateIterator(n::Int64) = EmbryoStateIterator(n,[1;0;0;0;0;0;0;0],[5;n;n;n;n;n;n;n]);

function Base.start(e::EmbryoStateIterator)
    return EmbryoState(e.n,0,[0;0;0],[0,0,e.n],0,1);
end

function Base.done(e::EmbryoStateIterator, state)
    return state.day == 6;
end

function Base.next(e::EmbryoStateIterator, state)

    currV = state.v;
    currNV = state.n - state.v;
        
    iterVector = [state.day;state.v;state.mV;state.mNV];
    
    # advance iterVector until we get a valid v,mV,mNV combo
    isValid = false;
    
    while (isValid == false)
    
        i = 8;

        # increment entire vector by 1, check if it's valid
        while (i > 1) && (iterVector[i] == e.maxVals[i])
            iterVector[i] = e.minVals[i];
            i -= 1;
        end
        
        iterVector[i] += 1;
        vNew = iterVector[2];
        vSum = sum(iterVector[3:5]);
        nvSum = sum(iterVector[6:8]);
        isValid = (vSum == vNew) && (nvSum == (state.n - vNew))
        
    end
    
    return (state, EmbryoState(state.n, iterVector[2], iterVector[3:5], iterVector[6:8], 0, iterVector[1]));
    
end

type EmbryoStateSpace <: AbstractSpace
    # only variable is state iterator 
    states::EmbryoStateIterator; 
end

# returns EmbryoStateSpace
function POMDPs.states(mdp::EmbryoCulture)
    return EmbryoStateSpace(EmbryoStateIterator(mdp.n));
end;

# returns EmbryoStateIterator
function POMDPs.domain(space::EmbryoStateSpace)
    return space.states;
end

# count number of states in state space
function POMDPs.n_states(mdp::EmbryoCulture)
    s = 0;  
    for i in 0:mdp.n
        s += (i+1)*(i+2)*(mdp.n-i+1)*(mdp.n-i+2)/4;
    end  
    return convert(Int64,s*5);
end

# define function to uniformly sample state space
function POMDPs.rand!(es::EmbryoState, space::EmbryoStateSpace)
    
    n = space.states.n;
    s = 0;
    for i in 0:n
        s += (i+1)*(i+2)*(n-i+1)*(n-i+2)/4;
    end 
    s *= 5;
    
    sp = start(takenth(space.states,rand(1:s)));
    copy!(es, sp)
    es
end

rand! (generic function with 26 methods)

In [5]:
##################################################
######### EmbryoAction definitions ###############
##################################################

type EmbryoAction <: Action
    a::String # {CC = continue culture, TL = measure time lapse params, B = biopsy (genetic testing), D = discard all}
    # if we're transferring, TR = # to transfer to patient}
    tr::Int64; # transfer tr embryos to patient in order of morphology. If morphologically equal, choose at random
end

type EmbryoActionSpace <: AbstractSpace
    actions::Array{EmbryoAction,1}
end

function POMDPs.actions(mdp::EmbryoCulture)
    
    acts = Array(EmbryoAction,mdp.n+4);
    acts[1] = EmbryoAction("CC", 0); # continue culture only
    acts[2] = EmbryoAction("TL", 0); # collect cell cycle params
    acts[3] = EmbryoAction("B", 0); # biopsy all
    acts[4] = EmbryoAction("D", 0); # discard all
    
    for i in 1:mdp.n
        acts[i+4] = EmbryoAction("TR", i);
    end
    
    return EmbryoActionSpace(acts)
end

POMDPs.actions(mdp::EmbryoCulture, s::EmbryoState, as::EmbryoActionSpace=actions(mdp)) = as;

function POMDPs.domain(space::EmbryoActionSpace)
    return space.actions; 
end

function POMDPs.rand!(ea::EmbryoAction, space::EmbryoActionSpace)
    ap = space.actions[rand(1:(length(space.actions)))];
    ea.a = ap.a;
    ea.tr = ap.tr;
    ea 
end

function POMDPs.n_actions(mdp::EmbryoCulture)
    return 4 + (mdp.n);
end


n_actions (generic function with 5 methods)

In [6]:
##################################################
############## Transition Function ###############
##################################################

# define transition distribution
type EmbryoTransitionDistribution <: AbstractDistribution
    tD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize transition distribution (for preallocation only)
function POMDPs.create_transition_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    tD = Dict{ASCIIString, Float64}(); # transition distribution
    return EmbryoTransitionDistribution(tD);
end

# get all possible next states
function POMDPs.domain(d::EmbryoTransitionDistribution)
    return stringToEmbryoState(collect(keys(d.tD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoTransitionDistribution, s::EmbryoState)
    sk = embryoStateToString(s);
    if haskey(d.tD, sk)
        return d.tD[sk];
    else
        return 0.0;
    end
end

# randomly sample next state according to transition probs
function POMDPs.rand!(d::EmbryoTransitionDistribution, s::EmbryoState)
    keyArray = collect(keys(d.tD));
    pArray = collect(values(d.tD));
    
    # normalize to make sure it's a probability distribution
    pArray /= sum(pArray);
    cat = Categorical(pArray);
    
    randState = stringToEmbryoState(keyArray[rand(cat)]);
    copy!(s, randState);
    return s;
end

function calcP(mNew::Array{Int64,1},mOld::Array{Int64,1},p)
   
    pOut = 1;
    mOldCopy = deepcopy(mOld);
    
    for i in 1:(length(mNew)-1)
        nStay = mOldCopy[i];
        nDrop = mNew[i] - nStay;
        mOldCopy[i+1] -= nDrop;
        pOut *= (p^nStay)*(1-p)^(nDrop);
    end
    
    pOut *= p^(mOldCopy[3]);
    return pOut;
end

# transition model
function POMDPs.transition(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction,
    d::EmbryoTransitionDistribution=create_transition_distribution(mdp))

    # if we're already "done" or if action leads to terminal state
    # (action = TR or D or B, day = 5) then can't change state
    if (s.day > 4) || (s.done > 0) || (a.a == "TR") || (a.a == "D") || (a.a == "B")
        # can not transition out of done state
        d.tD = [embryoStateToString(EmbryoState(s.n,s.v,s.mV,s.mNV,1,s.day)) => 1.0];
        return d;
    end
    
    pv = mdp.pv; # p(v(t+1) = 1 | v(t) = 1)
    pn = mdp.pn; # p(n(t+1) = n(t)) for [nonviable, viable] embryo 

    # v(t+1) could go from 0 to v(t)
    iterVector = zeros(Int64,7);
    minVals = zeros(Int64,7);
    maxVals = s.n * ones(Int64,7);
    
    # re-init dictionary w/ s' => p(s')
    d.tD = Dict{ASCIIString, Float64}();
    
    # advance iterVector until we go through ALL valid v,mV,mNV combos
    while (iterVector[1] < (s.v + 1))
        
        isValid = false;
        while (isValid == false) && (iterVector[1] < (s.v + 1))
            i = 7;

            # increment entire vector by 1, check if it's valid
            while (i > 1) && (iterVector[i] == maxVals[i])
                iterVector[i] = minVals[i];
                i -= 1;
            end

            iterVector[i] += 1;
            vNew = deepcopy(iterVector[1]);
            vSum = sum(iterVector[2:4]);
            nvSum = sum(iterVector[5:7]);
            isValid = (vSum == vNew) && (nvSum == (s.n - vNew));
            
        end
            
        if vNew < (s.v + 1)
           
            vNew = deepcopy(iterVector[1]);
            mvNew = deepcopy(iterVector[2:4]);
            mnvNew = deepcopy(iterVector[5:7]);
            sNew = EmbryoState(s.n,vNew,mvNew,mnvNew,0,s.day+1);
                
            # 1. subtract viable embryos from s.mV
            mvSub = deepcopy(s.mV);
            eRM = [0;0;0]; # embryos removed from mvSub
            
            for i in 1:3
                numChange = min(s.v - vNew - sum(eRM), mvSub[i]);
                eRM[i] += numChange;
                mvSub[i] -= numChange;
            end
            
            # 2. add embryos of same morphology to mNV
            mnvAdd = deepcopy(s.mNV) + eRM; 
            maxV = [mvSub[1] + mvSub[2]; mvSub[2] + mvSub[3]; mvSub[3]];
            maxNV = [mnvAdd[1] + mnvAdd[2]; mnvAdd[2] + mnvAdd[3]; mnvAdd[3]];
            
            # 3. if state is unreachable, write p = 0
            if (mvNew[1] > maxV[1]) || (mvNew[2] > maxV[2]) || (mvNew[3] > maxV[3]) 
                d.tD[embryoStateToString(sNew)] = 0.0;
            elseif (mnvNew[1] > maxNV[1]) || (mnvNew[2] > maxNV[2]) || (mnvNew[3] > maxNV[3]) 
                d.tD[embryoStateToString(sNew)] = 0.0;
            else   
                # 4. if it is reachable, compute probability and write to dict
                pNew = (pv^(vNew))*((1-pv)^(s.v-vNew)); # p(vNew(t+1)|s.v(t))
                pNew *= calcP(mvNew,mvSub,pn[2]); # pdf for mvNew
                pNew *= calcP(mnvNew,mnvAdd,pn[1]); # pdf for mnvNew
                d.tD[embryoStateToString(sNew)] = pNew; 
            end
        end 
    end
    
    return d;
end

transition (generic function with 9 methods)

In [7]:
# test cell only
tic()
mdp = EmbryoCulture(1);
stateSpace = states(mdp)

#startingState = EmbryoState(mdp.vStart,mdp.nStart);
#randState = rand!(startingState, stateSpace)

#display(randState)

i = 1;
for es in domain(transition(mdp, EmbryoState(1,1,[1;0;0],[0;0;0]), EmbryoAction("CC",0)))
    i += 1;
    display(es)
    if i > 200
        break;
    end
end
toc()

EmbryoState(1,0,[0,0,0],[0,0,1],0,2)

EmbryoState(1,0,[0,0,0],[1,0,0],0,2)

EmbryoState(1,1,[0,0,1],[0,0,0],0,2)

EmbryoState(1,1,[0,1,0],[0,0,0],0,2)

EmbryoState(1,1,[1,0,0],[0,0,0],0,2)

EmbryoState(1,0,[0,0,0],[0,1,0],0,2)

elapsed time: 0.776275026 seconds


0.776275026

In [8]:
##################################################
################ Reward Function #################
##################################################
# output reward given state,action
function POMDPs.reward(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction)
    
    nDay = s.day;
    
    # if we are in a terminal state, no reward
    if (s.done == 1)
        return 0;
    # if we are transferring some of the embryos
    elseif a.a == "TR"

        nTransfer = a.tr;
        nSampled = [];
        
        # randomly choose nTransfer best morphology embryos
        i = 3;
        while (i > 0) && (length(nSampled) < nTransfer)
            currV = s.mV[i];
            currNV = s.mNV[i];
            numToSample = min(nTransfer - length(nSampled), currV + currNV);
            arrayToSample = [zeros(Int64,currNV); ones(Int64,currV)];
            nSampled = [nSampled; sample(arrayToSample,numToSample,replace=false)];
            i -= 1;
        end
        
        nViable = sum(nSampled);
        ev = expectedViable(nViable, nDay);
        
        # big negative reward if no viable embryos transferred
        # big positive reward if one viable embryo transferred
        # diminishing positive rewards for >1 viable embryo transferred
        if ev < 1
            return -20; # transfer 0 viable embryos
        else 
            return 56 - 16*floor(ev); # transfer >= 1 viable embryo
        end
    elseif a.a == "D"
        
        nViable = s.v;
        ev = expectedViable(nViable, nDay);
        
        # slight negative reward for discarding all (meaning no viable embryos available to transfer)
        # big negative reward for discarding viable embryos
        if ev < 1
            return -2; # none viable, discard all
        else 
            return -20; # discard >= 1 viable embryo
        end
    elseif a.a == "B"
        
        # ignore biopsy option for now
        nViable = s.v;
        ev = expectedViable(nViable, nDay);
        
        return -80;
        if nDay < 5
            return -40;
        elseif ev < 1 
            return -5;
        else
            return 40;
        end
    elseif a.a == "TL"
        if nDay > 3
            return -40;
        else
            return -2; # slight negative reward for doing test
        end
    else
        # do nothing / continue culture
        if nDay > 4
            return -40;
        else
            return -1; # slight negative reward for continuing culture
        end
    end
end

reward (generic function with 6 methods)

In [9]:
##################################################
########## Observation Distribution ##############
##################################################
# observation definitions
type EmbryoObservation <: Observation
    n::Int64; # number of embryos in culture
    oN::Array{Int64,1}; # morphology of all embryos (3x1 vector)
    oD::Int64; # day in culture
end

function Base.copy!(o1::EmbryoObservation, o2::EmbryoObservation)
    o1.n = o2.n;
    o1.oN = o2.oN;
    o1.oD = o2.oD;
    return o1;
end

# convert EmbryoState to string for hashing purposes
function embryoObsToString(obs::EmbryoObservation)
    sOut = string(obs.oN[1],obs.oN[2],obs.oN[3], obs.oD, obs.n);
    return sOut;
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoObs(s::ASCIIString)
    oN = [parse(Int64, s[1]);parse(Int64, s[2]);parse(Int64, s[3])];
    oD = parse(Int64, s[4]);   
    n = parse(Int64,s[5]);
    return EmbryoObservation(n,oN,oD);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoObs(oa::Array{ASCIIString,1})
   
    nStates = length(oa);
    oOut = cell(nStates);
    
    for j in 1:nStates
        s = oa[j];
        oN = [parse(Int64, s[1]);parse(Int64, s[2]);parse(Int64, s[3])];
        oD = parse(Int64, s[4]);  
        n = parse(Int64,s[5]);
        oOut[j] = EmbryoObservation(n,oN,oD);
    end
    
    return oOut;
end

type EmbryoObservationIterator
    n::Int64;
    minVals::Array{Int64,1};
    maxVals::Array{Int64,1};
end

EmbryoObservationIterator(n::Int64) = EmbryoObservationIterator(n,[1;0;0;0],[5;n;n;n]);

function Base.start(e::EmbryoObservationIterator)
    return EmbryoObservation(e.n,[0;0;e.n],1);
end

function Base.done(e::EmbryoObservationIterator, obs)
    return obs.oD == 6;
end

function Base.next(e::EmbryoObservationIterator, obs)
        
    iterVector = [obs.oD;obs.oN];
    
    # advance iterVector until we get a valid oN
    isValid = false;
    
    while (isValid == false)
    
        i = 4;

        # increment entire vector by 1, check if it's valid
        while (i > 1) && (iterVector[i] == e.maxVals[i])
            iterVector[i] = e.minVals[i];
            i -= 1;
        end
        
        iterVector[i] += 1;
        oNewSum = sum(iterVector[2:4]);
        isValid = (obs.n == oNewSum);
    end
    
    return (obs, EmbryoObservation(obs.n, iterVector[2:4], iterVector[1]));
end

type EmbryoObservationSpace <: AbstractSpace
    obs::EmbryoObservationIterator;
end

# function returning observation space
function POMDPs.observations(mdp::EmbryoCulture)
    return EmbryoObservationSpace(EmbryoObservationIterator(mdp.n));
end
    
POMDPs.observations(mdp::EmbryoCulture, s::EmbryoState, obs::EmbryoObservationSpace) = obs;
POMDPs.domain(space::EmbryoObservationSpace) = space.obs;

# define observation distribution
type EmbryoObservationDistribution <: AbstractDistribution
    obsD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize observation distribution
function POMDPs.create_observation_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    return EmbryoObservationDistribution(Dict{ASCIIString, Float64}());
end

# get all possible next obs
function POMDPs.domain(d::EmbryoObservationDistribution)
    return stringToEmbryoObs(collect(keys(d.obsD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoObservationDistribution, o::EmbryoObservation)
    ok = embryoObsToString(o);
    if haskey(d.obsD, ok)
        return d.obsD[ok];
    else
        return 0.0;
    end
end

# randomly sample next state according to transition probs
function POMDPs.rand!(d::EmbryoObservationDistribution, o::EmbryoObservation)
    keyArray = collect(keys(d.obsD));
    pArray = collect(values(d.obsD));
    cat = Categorical(pArray);
    randObs = stringToEmbryoObs(keyArray[rand(cat)]);
    copy!(o, randObs);
    return o;
end

# observation model
function POMDPs.observation(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction, 
    d::EmbryoObservationDistribution=create_observation_distribution(mdp))
    
    # for now, observation is simply EXACTLY s.n, s.mV+s.mNV and s.day
    d.obsD = [embryoObsToString(EmbryoObservation(s.n, s.mV+s.mNV, s.day)) => 1.0];
    return d;
end

function POMDPs.n_observations(mdp::EmbryoCulture)
    return 6*(mdp.n+1)*(mdp.n+2)/2;  
end

n_observations (generic function with 4 methods)

In [10]:
# Belief ... modify this eventually to reflect starting day always = 1, know starting n

POMDPs.create_belief(mdp::EmbryoCulture) = DiscreteBelief(n_states(EmbryoCulture(mdp.n)));
POMDPs.initial_belief(mdp::EmbryoCulture) = DiscreteBelief(n_states(EmbryoCulture(mdp.n)));


In [11]:
using QMDP

# initialize the solver
# key-word args are the maximum number of iterations the solver will run for, and the Bellman tolerance
solver = QMDPSolver(max_iterations=6, tolerance=1e-3) 
mdp = EmbryoCulture(5);

# initialize the QMDP policy
qmdp_policy = create_policy(solver, mdp)
 
# run the solver
solve(solver, mdp, qmdp_policy, verbose=true)

Iteration : 1, residual: 40.0, iteration run-time: 8464.483812594, total run-time: 8464.483812594
Iteration : 2, residual: 8223.208980277543, iteration run-time: 8757.951377781, total run-time: 17222.435190375
Iteration : 3, residual: 121770.33804317136, iteration run-time: 11856.853575946, total run-time: 29079.288766321
Iteration : 4, residual: 171323.66920148907, iteration run-time: 11502.90915023, total run-time: 40582.197916551
Iteration : 5, residual: 261275.6143158019, iteration run-time: 11514.306230765, total run-time: 52096.504147315994
Iteration : 6, residual: 1947.0747859033945, iteration run-time: 11147.275590122, total run-time: 63243.779737437995


QMDPPolicy(1260x9 Array{Float64,2}:
  -1.47106   -2.47106  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -1.76718   -2.76718  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -1.86206   -2.86206  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -1.83998   -2.83998  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -1.70803   -2.70803  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -1.41305   -2.41305  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -2.77125   -3.77125  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -2.96752   -3.96752  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -2.8574    -3.8574   -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -2.41278   -3.41278  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -1.40544   -2.40544  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -5.25851   -6.25851  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
  -4.55837   -5.55837  -80.0   -2.0  -20.0  -20.0  -20.0  -20.0  -20.0
   ⋮                                     

In [None]:
# simulate actions based on policy
nInit = 5;
vInit = 5;
#pomdp = EmbryoCulture(length(vInit), vInit, nInit, 1);

# start with two viable embryos at day 1
s = EmbryoState(nInit, vInit);
o = EmbryoObservation(2, [0;0;2], 1);
b = initial_belief(mdp)
updater = DiscreteUpdater(mdp) # this comes from POMDPToolbox

rtot = 0.0
# run the simulation for 5 days max
for i = 1:5

    a = action(qmdp_policy, b) # the QMDP action function returns the POMDP action not its index like the SARSOP action function
    # compute the reward
    r = reward(mdp, s, a)
    rtot += r
    
    println("Time step $i")
    println("Taking action: $(a), got reward: $(r)")
    
    # transition the system state
    trans_dist = transition(mdp, s, a)
    rand!(trans_dist, s)
    # sample a new observation
    obs_dist = observation(mdp, s, a)
    rand!(obs_dist, o)
    
    # update the belief
    b = update(updater, b, a, o)
    
    println("Saw observation: $(o)\n")

end
println("Total reward: $rtot")

Time step 1

In [13]:
isprobvec([0,.1,.9])

true