In [1]:
# first import the POMDPs.jl interface
using POMDPs
using POMDPDistributions
using Distributions
using SARSOP
using POMDPModels
using Iterators

In [2]:
################ Functions and Structures definitions #######################

# EmbryoState definition
type EmbryoState <: State
    v::Array{Int64,1} # viable or not? 
    n::Array{Int64,1} # developmental day achieved {1,2,3,4,5}
    done::Int64; # done? (only if TR or D at last step or if day >= 6)
    day::Int64; # culture day
end

# initialize day1 embryos with known viability
EmbryoState(v::Array{Int64,1}) = EmbryoState(v,ones(Int64,v),0,1);
EmbryoState(v::Array{Int64,1}, n::Array{Int64,1}) = EmbryoState(v,n,0,1);

function ==(s1::EmbryoState, s2::EmbryoState)
    return (s1.v == s2.v) && (s1.n == s2.n) && (s1.done == s2.done) && (s1.day == s2.day);
end

# convert EmbryoState to string for hashing purposes
function embryoStateToString(es::EmbryoState)
    sOut = "";
    for i in 1:length(es.v)
        sOut = string(sOut, es.v[i]);
    end
    for i in 1:length(es.n)
        sOut = string(sOut, es.n[i]);
    end
    sOut = string(sOut, es.done);
    sOut = string(sOut, es.day);
    return sOut;
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoState(s::ASCIIString)
   
    m = convert(Int64, length(s)/2) - 1;
    v = zeros(Int64, m);
    n = zeros(Int64, m);
    
    for i in 1:m
        v[i] = parse(Int64,s[i]);
    end
    for i in 1:m
        n[i] = parse(Int64,s[m+i]);
    end
    
    done = parse(Int64, s[end-1]);
    day = parse(Int64, s[end]);
        
    return EmbryoState(v,n,done,day);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoState(sa::Array{ASCIIString,1})
   
    nStates = length(sa);
    ea = cell(nStates);
    
    for j in 1:nStates
        s = sa[j];
    
        m = convert(Int64, length(s)/2) - 1;
        v = zeros(Int64, m);
        n = zeros(Int64, m);

        for i in 1:m
            v[i] = parse(Int64,s[i]);
        end
        for i in 1:m
            n[i] = parse(Int64,s[m+i]);
        end
        d = parse(Int64, s[end]);
        
        done = parse(Int64, s[end-1]);
        day = parse(Int64, s[end]);
        
        ea[j] = EmbryoState(v,n,done,day);
    end
    
    return ea;
end

function Base.copy!(s1::EmbryoState, s2::EmbryoState)
    s1.v = s2.v;
    s1.n = s2.n;
    s1.done = s2.done;
    s1.day = s2.day;
    s1
end

# EmbryoAction definition
type EmbryoAction <: Action
    a::String # {CC = continue culture, TL = measure time lapse params, B = biopsy (genetic testing)}
    # if we're transferring, TR = transfer to patient}
    tr::Array{Int64,1}; # vector of embryo nums to transfer (if a == TR)
end

# Define embryo culture POMDP scenario
type EmbryoCulture <: POMDP
    #rf::Function; # calculates reward based on input S,A
    #tf::Function; #
    m::Int64; # number of embryos starting in culture
    vStart::Array{Int64,1}; # starting viability info
    nStart::Array{Int64,1}; # starting cell #
    dayStart::Int64;
    discount::Float64; # = 1; # discount_factor
    pv::Float64; # = 0.9; # probability that a viable embryo will continue to be viable
    pn::Array{Float64,1}; # = [.7; .95]; # probability that a [nonviable, viable] embryo will have normal cell division
end

# culture will start at day 1 by default
EmbryoCulture(m::Int64, vStart::Array{Int64,1}, 
    nStart::Array{Int64,1}) = EmbryoCulture(m,vStart,nStart,1,1,0.9,[0.7;0.95]);
# could also start at a later day (with nStart <= dayStart)
EmbryoCulture(m::Int64, vStart::Array{Int64,1}, nStart::Array{Int64,1}, 
    dayStart::Int64) = EmbryoCulture(m,vStart,nStart,dayStart,1,0.9,[0.7;0.95]);

# given number of actual viable embryos and culture day #,
# return expected # of viable embryos at the time of transfer
# only about half of viable day 1 embryos will remain viable at day 4
# TODO: adjust this function based on actual pv param
function expectedViable(nViable::Int64, nDay::Int64)
    return nViable * (2 ^ (-(5 - nDay)/4))
end


expectedViable (generic function with 1 method)

In [3]:
# state space functions and structures
# define iterator over entire EmbryoState space
type EmbryoStateIterator
    m::Int64;
    startState::EmbryoState;
    minVals::Vector{Int64};
    maxVals::Vector{Int64};
end

# can iterate over all states in embryo culture with m embryos
EmbryoStateIterator(m::Int64) = EmbryoStateIterator(m,EmbryoState(ones(Int64,m),ones(Int64,m),0,1),
                                            [zeros(Int64,m);ones(Int64,m)],[ones(Int64,m);5*ones(Int64,m)]);
# can also iterate over possible next states
EmbryoStateIterator(s::EmbryoState) = EmbryoStateIterator(length(s.v),s,[zeros(Int64,length(s.v));s.n],
    [s.v;5*ones(Int64,length(s.v))]);

function Base.start(e::EmbryoStateIterator)
    return e.startState;
end

function Base.done(e::EmbryoStateIterator, state)
    return state.day == 6;
    #return (state.v[1] == -1);
    #return (state.v == e.minVals[1:m]) && (state.n == e.maxVals[(m+1):(2*m)]);
end

function Base.next(e::EmbryoStateIterator, state)
    iterVector = [state.day; state.v ; state.n];
    m = e.m;
    i = 2*m + 1 # 1 = day; 2:m+1 = v; m+2:2m+1 = n;
    
    # increment development days, decrement viability
    while (i > (m+1)) && (iterVector[i] == e.maxVals[i-1])
        iterVector[i] = e.minVals[i-1];
        i -= 1;
    end
    
    while (i < (m+2)) && (i > 1) && (iterVector[i] == e.minVals[i-1])
        iterVector[i] = e.maxVals[i-1];
        i -= 1;
    end
    
    if i == 1
        iterVector[1] += 1; # increment day
    elseif i < (m + 2)
        iterVector[i] -= 1;  
    else
        iterVector[i] += 1;
    end
    
    return (state,EmbryoState(iterVector[2:(m+1)], iterVector[(m+2):(2*m+1)],state.done,iterVector[1]))
end

type EmbryoStateSpace <: AbstractSpace
    # only variable is state iterator 
    states::EmbryoStateIterator; 
end

# calculate number of possible remaining states
function numRemainingStates(space::EmbryoStateSpace)
    minVals = space.states.minVals;
    maxVals = space.states.maxVals;
    nums = maxVals - minVals + ones(maxVals);
    
    stateProd = 1;
    for i in 1:length(nums)
        stateProd *= nums[i];
    end
    
    stateProd *= (6 - space.states.startState.day); # for culture days
    return stateProd;
end

numRemainingStates (generic function with 1 method)

In [4]:
# returns EmbryoStateSpace
function POMDPs.states(mdp::EmbryoCulture)
    return EmbryoStateSpace(EmbryoStateIterator(EmbryoState(mdp.vStart,mdp.nStart,0,mdp.dayStart)));
end;

# returns EmbryoStateIterator
function POMDPs.domain(space::EmbryoStateSpace)
    return space.states;
end

# define function to uniformly sample state space
# max num states is 10^m
function POMDPs.rand!(es::EmbryoState, space::EmbryoStateSpace)
    maxNumStates = numRemainingStates(space); # 5*(10^(space.states.m));
    randStateNum = convert(Int64,ceil(rand(1)*maxNumStates)[1]);
    sp = start(takenth(space.states,randStateNum));
    copy!(es, sp)
    es
end

# for now, assume nStart <= dayStart
function POMDPs.n_states(mdp::EmbryoCulture)
    #n = 2^(sum(mdp.vStart));
    
    #for i in 1:length(mdp.nStart)
    #    n *= (5 - mdp.dayStart + 1); # min(5-mdp.dayStart+1,5-mdp.nStart[i]+1); 
    #end
    
    #n*= (5 - mdp.dayStart + 1); # extra dim for day
    n = 50^(mdp.m);
    return n;
end

function POMDPs.n_actions(mdp::EmbryoCulture)
    return 4 + 2^(mdp.m);
end

function POMDPs.discount(mdp::EmbryoCulture)
    return mdp.discount; 
end

# TODO: double check that this works ... 
function POMDPs.index(mdp::EmbryoCulture, s::EmbryoState)
    dimSizes = [[mdp.vStart + 1];[5 - mdp.dayStart*ones(Int64, mdp.m) + 1];[5 - mdp.dayStart + 1]];
    currInds = [[mdp.vStart - s.v + 1];[s.n - mdp.nStart + 1];[s.day - mdp.dayStart + 1]];
    return sub2ind(dimSizes, currInds);
end

index (generic function with 6 methods)

In [5]:
# test cell only
tic()
mdp = EmbryoCulture(2, [1;0], [4;4]);
stateSpace = states(mdp)

startingState = EmbryoState(mdp.vStart,mdp.nStart);
randState = rand!(startingState, stateSpace)

display(randState)

i = 1;
for es in domain(stateSpace)
    i += 1;
    display(es)
    if i > 60
        break;
    end
end
toc()

EmbryoState([0,0],[5,5],0,3)

EmbryoState([1,0],[4,4],0,1)

EmbryoState([1,0],[4,5],0,1)

EmbryoState([1,0],[5,4],0,1)

EmbryoState([1,0],[5,5],0,1)

EmbryoState([0,0],[4,4],0,1)

EmbryoState([0,0],[4,5],0,1)

EmbryoState([0,0],[5,4],0,1)

EmbryoState([0,0],[5,5],0,1)

EmbryoState([1,0],[4,4],0,2)

EmbryoState([1,0],[4,5],0,2)

EmbryoState([1,0],[5,4],0,2)

EmbryoState([1,0],[5,5],0,2)

EmbryoState([0,0],[4,4],0,2)

EmbryoState([0,0],[4,5],0,2)

EmbryoState([0,0],[5,4],0,2)

EmbryoState([0,0],[5,5],0,2)

EmbryoState([1,0],[4,4],0,3)

EmbryoState([1,0],[4,5],0,3)

EmbryoState([1,0],[5,4],0,3)

EmbryoState([1,0],[5,5],0,3)

EmbryoState([0,0],[4,4],0,3)

EmbryoState([0,0],[4,5],0,3)

EmbryoState([0,0],[5,4],0,3)

EmbryoState([0,0],[5,5],0,3)

EmbryoState([1,0],[4,4],0,4)

EmbryoState([1,0],[4,5],0,4)

EmbryoState([1,0],[5,4],0,4)

EmbryoState([1,0],[5,5],0,4)

EmbryoState([0,0],[4,4],0,4)

EmbryoState([0,0],[4,5],0,4)

EmbryoState([0,0],[5,4],0,4)

EmbryoState([0,0],[5,5],0,4)

EmbryoState([1,0],[4,4],0,5)

EmbryoState([1,0],[4,5],0,5)

EmbryoState([1,0],[5,4],0,5)

EmbryoState([1,0],[5,5],0,5)

EmbryoState([0,0],[4,4],0,5)

EmbryoState([0,0],[4,5],0,5)

EmbryoState([0,0],[5,4],0,5)

EmbryoState([0,0],[5,5],0,5)

elapsed time: 0.945549589 seconds


0.945549589

In [6]:
# EmbryoActionSpace definitions and functions
type EmbryoActionSpace <: AbstractSpace
    actions::Array{EmbryoAction,1}
end

function POMDPs.actions(mdp::EmbryoCulture)
    
    acts = Array(EmbryoAction,(2^(mdp.m))+4);
    acts[1] = EmbryoAction("CC", []); # continue culture only
    acts[2] = EmbryoAction("TL", []); # collect cell cycle params
    acts[3] = EmbryoAction("B", []); # biopsy all
    acts[4] = EmbryoAction("D", []); # discard all
    
    for i in 1:(2^(mdp.m))
        transferBits = bits(i-1)[end-mdp.m+1:end];
        transferVector = Array(Bool,mdp.m);
        
        for j in 1:mdp.m
            transferVector[j] = parse(Bool,transferBits[j]);
        end
        
        acts[i+4] = EmbryoAction("TR", transferVector);
    end
    
    return EmbryoActionSpace(acts)
end

POMDPs.actions(mdp::EmbryoCulture, s::EmbryoState, as::EmbryoActionSpace=actions(mdp)) = as;

function POMDPs.domain(space::EmbryoActionSpace)
    return space.actions; 
end

# equally likely to choose CC,TL,B,D,or TR
function POMDPs.rand!(ea::EmbryoAction, space::EmbryoActionSpace)
    #actionInd = rand(1:5);
    #if actionInd < 5
    #    ap = space.actions[actionInd];
    #    ea.a = ap.a;
    #    ea.tr = ap.tr;
    #    return ea;
    #else
        ap = space.actions[rand(1:((2^(mdp.m))+4))];
        ea.a = ap.a;
        ea.tr = ap.tr;
        ea 
    #end
end


rand! (generic function with 27 methods)

In [7]:
# transition distribution stuff
# define transition distribution
type EmbryoTransitionDistribution <: AbstractDistribution
    tD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize transition distribution (for preallocation only)
function POMDPs.create_transition_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    # init with only min v, max n to begin with
    # possibly define iterator if space becomes too large
    
    tD = Dict{ASCIIString, Float64}(); # transition distribution
    #tD = [embryoStateToString(EmbryoState(zeros(Int64, mdp.m), 
    #    5*ones(Int64, mdp.m))) => 1.0];
    
    return EmbryoTransitionDistribution(tD);
end

# get all possible next states
function POMDPs.domain(d::EmbryoTransitionDistribution)
    return stringToEmbryoState(collect(keys(d.tD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoTransitionDistribution, s::EmbryoState)
    sk = embryoStateToString(s);
    if haskey(d.tD, sk)
        return d.tD[sk];
    else
        return 0.0;
    end
end

# randomly sample next state according to transition probs
function POMDPs.rand!(d::EmbryoTransitionDistribution, s::EmbryoState)
    keyArray = collect(keys(d.tD));
    pArray = collect(values(d.tD));
    
    cat = Categorical(pArray);
    randState = stringToEmbryoState(keyArray[rand(cat)]);
    copy!(s, randState);
    return s;
end



rand! (generic function with 28 methods)

In [8]:
# transition model
function POMDPs.transition(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction,
    d::EmbryoTransitionDistribution=create_transition_distribution(mdp))

    # if we're already "done" or if action leads to terminal state
    # (action = TR or D or B, day = 5) then can't change state
    if (s.day > 4) || (s.done > 0) || (a.a == "TR") || (a.a == "D") || (a.a == "B")
        # can not transition out of done state
        d.tD = [embryoStateToString(EmbryoState(s.v, s.n, 1, s.day)) => 1.0];
        return d;
    end
    
    pv = mdp.pv; # p(v(t+1) = 1 | v(t) = 1)
    pn = mdp.pn; # p(n(t+1) = n(t) + 1) for [nonviable, viable] embryo 

    # calc number of possible next states (<= 4^m)
    nChange = sum(s.v .> 0) + sum(s.n .< 5);
    numNeighbors = 2^nChange;
    indChange = [s.v .> 0; s.n .< 5];

    # re-init dictionary w/ s' => p(s')
    d.tD = Dict{ASCIIString, Float64}();
    
    for i in 1:numNeighbors

        vNext = deepcopy(s.v);
        nNext = deepcopy(s.n);
        pNext = 1;
        
        # which indices to change in indChange (length is nChange)
        currNum = bits(i-1)[end-nChange+1:end];
        cInd = 1;
        
        for j in 1:(2*mdp.m)
            if (j < (mdp.m + 1))
                
                if indChange[j] # if we can change the current value

                    if parse(Bool,currNum[cInd]) # v(t) = 1 -> v(t+1) = 0
                        vNext[j] -= 1; # decrement viability of embryo j
                        pNext *= (1-pv); # p(viable -> nonviable)
                    else # v(t) = 1 -> v(t+1) = 1
                        pNext *= pv; # p(viable -> viable)
                    end
                    
                    cInd += 1;
                end
            elseif (j > mdp.m)

                if indChange[j] # if we can change current value

                    if parse(Bool,currNum[cInd]) # n(t+1) = n(t) + 1
                        nNext[j - mdp.m] += 1;
                        pNext *= (pn[s.v[j - mdp.m] + 1]);
                    else # n(t+1) = n(t)
                        pNext *= (1 - pn[s.v[j - mdp.m] + 1]); 
                    end
                    
                    cInd += 1;
                end
            end
        end
        
        d.tD[embryoStateToString(EmbryoState(vNext,nNext,s.done,s.day+1))] = pNext;
        
    end
    
    return d;
    
end


transition (generic function with 9 methods)

In [9]:
# reward model
# take day into account
# output reward given state,action and culture day
# (must transfer,discard,or biopsy at that point)
function POMDPs.reward(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction)
    
    nDay = s.day;
    
    # if we are in a terminal state, no reward
    if (s.done == 1)
        return 0;
    # if we are transferring some of the embryos
    elseif a.a == "TR"

        nViable = sum(s.v .* a.tr);
        ev = expectedViable(nViable, nDay);
        
        # big negative reward if no viable embryos transferred
        # big positive reward if one viable embryo transferred
        # diminishing positive rewards for >1 viable embryo transferred
        if ev < 1
            return -20; # transfer 0 viable embryos
        else 
            return 56 - 16*floor(ev); # transfer >= 1 viable embryo
        end
    elseif a.a == "D"
        
        nViable = sum(s.v);
        ev = expectedViable(nViable, nDay);
        
        # slight negative reward for discarding all (meaning no viable embryos available to transfer)
        # big negative reward for discarding viable embryos
        if ev < 1
            return -2; # none viable, discard all
        else 
            return -20; # discard >= 1 viable embryo
        end
    elseif a.a == "B"
        
        return -40;
        if nDay < 5
            return -40;
        end
        
        nViable = sum(s.v);
        ev = expectedViable(nViable, nDay);
        
        if ev < 1
            return -5;
        else
            return 40;
        end
    elseif a.a == "TL"
        if nDay > 3
            return -40;
        else
            return -2; # slight negative reward for doing test
        end
    else
        # do nothing / continue culture
        if nDay > 4
            return -40;
        else
            return -1; # slight negative reward for continuing culture
        end
    end
end

reward (generic function with 6 methods)

In [10]:
# observation definitions
type EmbryoObservation <: Observation
    oN::Array{Int64,1}; # developmental stage reached
    oD::Int64; # day in culture
end

function Base.copy!(o1::EmbryoObservation, o2::EmbryoObservation)
    o1.oN = o2.oN;
    o1.oD = o2.oD;
    return o1;
end

# convert EmbryoState to string for hashing purposes
function embryoObsToString(obs::EmbryoObservation)
    sOut = "";
    for i in 1:length(obs.oN)
        sOut = string(sOut, obs.oN[i]);
    end
    sOut = string(sOut, obs.oD);
    return sOut;
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoObs(s::ASCIIString)
   
    m = convert(Int64, length(s)) - 1;
    n = zeros(Int64, m);
    
    for i in 1:m
        n[i] = parse(Int64,s[i]);
    end
    day = parse(Int64, s[end]);
        
    return EmbryoObservation(n,day);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoObs(oa::Array{ASCIIString,1})
   
    nStates = length(oa);
    oOut = cell(nStates);
    
    for j in 1:nStates
        s = oa[j];
    
        m = convert(Int64, length(s)) - 1;
        n = zeros(Int64, m);

        for i in 1:m
            n[i] = parse(Int64,s[i]);
        end

        day = parse(Int64, s[end]);
        oOut[j] = EmbryoObservation(n,day);
    end
    
    return oOut;
end


type EmbryoObservationIterator
    m::Int64;
    startDay::Int64;
    minVals::Vector{Int64};
    maxVals::Vector{Int64};
end

# iterate over all possible observations for state
EmbryoObservationIterator(mdp::EmbryoCulture) = EmbryoObservationIterator(mdp.m,
    mdp.dayStart, ones(Int64,mdp.m), 5*ones(Int64,mdp.m));
EmbryoObservationIterator(s::EmbryoState) = EmbryoObservationIterator(length(s.n), s.day,
    ones(Int64,length(s.n)), 5*ones(Int64,length(s.n)));

function Base.start(e::EmbryoObservationIterator)
    return EmbryoObservation(ones(Int64,e.m), e.startDay);
end

function Base.done(e::EmbryoObservationIterator, obs)
    return obs.oD == 6;
end

function Base.next(e::EmbryoObservationIterator, obs)
    iterVector = [obs.oD; obs.oN];
    m = e.m;
    i = m + 1; # 1 = day; 2:m+1 = n;
    
    # increment development obs
    while (i > 1) && (iterVector[i] == e.maxVals[i-1])
        iterVector[i] = e.minVals[i-1];
        i -= 1;
    end
    
    iterVector[i] += 1;
    
    return (obs,EmbryoObservation(iterVector[2:(m+1)], iterVector[1]))
end

type EmbryoObservationSpace <: AbstractSpace
    obs::EmbryoObservationIterator;
end

# function returning observation space
function POMDPs.observations(mdp::EmbryoCulture)
    return EmbryoObservationSpace(EmbryoObservationIterator(mdp));
end
    
POMDPs.observations(mdp::EmbryoCulture, s::EmbryoState, obs::EmbryoObservationSpace) = obs;
POMDPs.domain(space::EmbryoObservationSpace) = space.obs;

In [11]:
# define observation distribution
type EmbryoObservationDistribution <: AbstractDistribution
    obsD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize observation distribution
function POMDPs.create_observation_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    return EmbryoObservationDistribution(Dict{ASCIIString, Float64}());
end

# get all possible next obs
function POMDPs.domain(d::EmbryoObservationDistribution)
    return stringToEmbryoObs(collect(keys(d.obsD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoObservationDistribution, o::EmbryoObservation)
    ok = embryoObsToString(o);
    if haskey(d.obsD, ok)
        return d.obsD[ok];
    else
        return 0.0;
    end
end

# randomly sample next state according to transition probs
function POMDPs.rand!(d::EmbryoObservationDistribution, o::EmbryoObservation)
    keyArray = collect(keys(d.obsD));
    pArray = collect(values(d.obsD));
    
    cat = Categorical(pArray);
    randObs = stringToEmbryoObs(keyArray[rand(cat)]);
    copy!(o, randObs);
    return o;
end

# observation model
function POMDPs.observation(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction, 
    d::EmbryoObservationDistribution=create_observation_distribution(mdp))
    
    # for now, observation is simply EXACTLY s.n and s.day
    d.obsD = [embryoObsToString(EmbryoObservation(s.n, s.day)) => 1.0];
    return d;
end

function POMDPs.n_observations(mdp::EmbryoCulture)
    n = 1;
    
    for i in 1:length(mdp.nStart)
        n *= 5; #(5 - mdp.dayStart + 1);  
    end
    
    n*= (5 - mdp.dayStart + 1); # extra dim for day
    return n;
end

n_observations (generic function with 4 methods)

In [12]:
# finally ... belief stuff
using POMDPToolbox

POMDPs.create_belief(mdp::EmbryoCulture) = DiscreteBelief(n_states(EmbryoCulture(mdp.m,ones(mdp.vStart), mdp.nStart, mdp.dayStart)));
POMDPs.initial_belief(mdp::EmbryoCulture) = DiscreteBelief(n_states(EmbryoCulture(mdp.m,ones(mdp.vStart),mdp.nStart,mdp.dayStart)));


In [13]:
using QMDP

# initialize the solver
# key-word args are the maximum number of iterations the solver will run for, and the Bellman tolerance
solver = QMDPSolver(max_iterations=10, tolerance=1) 
mdp = EmbryoCulture(2, [1;1], [1;1], 1);

# initialize the QMDP policy
qmdp_policy = create_policy(solver, mdp)

# run the solver
solve(solver, mdp, qmdp_policy, verbose=true)

Iteration : 1, residual: 40.0, iteration run-time: 35.053637607, total run-time: 35.053637607
Iteration : 2, residual: 35.80000000000001, iteration run-time: 35.875828792, total run-time: 70.929466399
Iteration : 3, residual: 32.120000000000005, iteration run-time: 36.508056991, total run-time: 107.43752339
Iteration : 4, residual: 27.90800000000001, iteration run-time: 37.524894955, total run-time: 144.962418345
Iteration : 5, residual: 24.117200000000008, iteration run-time: 38.308252316, total run-time: 183.270670661
Iteration : 6, residual: 0.0, iteration run-time: 35.737039733, total run-time: 219.00771039400001


QMDPPolicy(2500x8 Array{Float64,2}:
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
 36.0434  35.0434  -40.0  -20.0  -20.0  -20.0  -20.0  40.0
  ⋮                                       ⋮               
  0.0      0.0       0.0    0.0    0.0    0.0    0.0   0.0
  0.0      0.0       0.0    0.0    0.0    0.0    0.0   0.0
  0.0      0.0      

In [None]:
# simulate actions based on policy

vInit = [1;0];
nInit = [1;1];
#pomdp = EmbryoCulture(length(vInit), vInit, nInit, 1);

# start with two viable embryos at day 1
s = EmbryoState(vInit, nInit, 0, 1);
o = EmbryoObservation(nInit, 1);
b = initial_belief(mdp)
updater = DiscreteUpdater(mdp) # this comes from POMDPToolbox

rtot = 0.0
# run the simulation for 5 days max
for i = 1:5
    # get the action from our SARSOP policy
    a = action(qmdp_policy, b) # the QMDP action function returns the POMDP action not its index like the SARSOP action function
    # compute the reward
    r = reward(mdp, s, a)
    rtot += r
    
    println("Time step $i")
    println("Taking action: $(a), got reward: $(r)")
    
    # transition the system state
    trans_dist = transition(mdp, s, a)
    rand!(trans_dist, s)

    # sample a new observation
    obs_dist = observation(mdp, s, a)
    rand!(obs_dist, o)
    
    # update the belief
    b = update(updater, b, a, o)
    
    println("Saw observation: $(o)\n")

end
println("Total reward: $rtot")


In [None]:
# try SARSOP ...
using SARSOP

mdp = EmbryoCulture(1, [1], [1], 1);
policy = POMDPPolicy("embryo.policy")

# create the .pomdpx file, this is the format which the SARSOP backend reads in
pomdpfile = POMDPFile(mdp, "embryo.pomdpx") # must end in .pomdpx


In [None]:
# initialize the solver
solver = SARSOPSolver()
# run the solve function
solve(solver, pomdpfile, policy)