In [38]:
# drastically reduced size of state,action,observation spaces to make problem more tractable
# s : 50^n -> 5n^3
# a : 2^n -> n
# o : 25^n -> 5n^2

using POMDPs
using POMDPDistributions
using Distributions
using SARSOP
using POMDPModels
using Iterators
using POMDPToolbox
using QMDP
using POMDPXFile

nEmbryos = 3;

In [39]:
##################################################
######### EmbryoCulture POMDP ####################
##################################################

type EmbryoCulture <: POMDP
    n::Int64; # number of embryos starting in culture
    dayStart::Int64;
    discount::Float64; # = 1; # discount_factor
    pv::Float64; # = 0.9; # probability that a viable embryo will continue to be viable
    pn::Array{Float64,1}; # = [.7; .95]; # probability that a [nonviable, viable] embryo will keep its morphology
end

EmbryoCulture(n::Int64) = EmbryoCulture(n,1,1,0.9,[0.5;0.95]);

# simulate number viable at day 5
# about half of viable day 1 embryos make it to day 5
function expectedViable(mdp::EmbryoCulture, nv::Int64, day::Int64)
    while day < 5
        d = Binomial(nv,mdp.pv);
        nv = rand(d);
        day += 1;
    end
    return nv;
end

function POMDPs.discount(mdp::EmbryoCulture)
    return mdp.discount; 
end

discount (generic function with 5 methods)

In [40]:
##################################################
######### EmbryoState definitions ################
##################################################

# EmbryoState definition
type EmbryoState <: State
    n::Int64 # total number of embryos in culture
    v::Int64 # number of currently viable embryos
    
    # number of embryos with {poor, fair, good} morphology
    mV::Array{Int64,1} # morphology distribution for viable mebryos
    mNV::Array{Int64,1} # morphology distribution for nonviable embryos
    day::Int64; # culture day
end

# helper to generate morphology distribution {poor, fair, good}
function generateMorphology(nEmbryos::Int64, viable::Int64, theta::Array{Float64,1}=[.75;.95])
    
    mOut = zeros(Int64,3);
    b = Bernoulli(theta[viable + 1]);
    
    for i = 1:nEmbryos
        currInd = sum(rand(b,2)) + 1;
        mOut[currInd] += 1;
    end
    
    return mOut;
end

# initialize day1 embryos with known viability
EmbryoState(n::Int64, v::Int64) = EmbryoState(n,v,generateMorphology(v,1), generateMorphology(n-v,0),1);
EmbryoState(n::Int64, v::Int64, mV::Array{Int64,1}, mNV::Array{Int64,1}) = EmbryoState(n,v,mV,mNV,1);

function ==(s1::EmbryoState, s2::EmbryoState)
    return (s1.n == s2.n) && (s1.v == s2.v) && (s1.mV == s2.mV) && 
        (s1.mNV == s2.mNV) && (s1.day == s2.day);
end

# convert EmbryoState to string for hashing purposes
function embryoStateToString(s::EmbryoState)
    return string(s.n,s.v,s.mV[1], s.mV[2], s.mV[3],s.mNV[1], s.mNV[2], s.mNV[3],s.day);
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoState(s::ASCIIString)
   
    n = parse(Int64, s[1]);
    v = parse(Int64, s[2]);
    
    mV = [parse(Int64, s[3]);parse(Int64, s[4]);parse(Int64, s[5])];
    mNV = [parse(Int64, s[6]);parse(Int64, s[7]);parse(Int64, s[8])];

    day = parse(Int64, s[9]);
        
    return EmbryoState(n,v,mV,mNV,day);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoState(sa::Array{ASCIIString,1})
   
    nStates = length(sa);
    ea = cell(nStates);
    
    for j in 1:nStates
        s = sa[j];
    
        n = parse(Int64, s[1]);
        v = parse(Int64, s[2]);

        mV = [parse(Int64, s[3]);parse(Int64, s[4]);parse(Int64, s[5])];
        mNV = [parse(Int64, s[6]);parse(Int64, s[7]);parse(Int64, s[8])];

        day = parse(Int64, s[9]);

        ea[j] = EmbryoState(n,v,mV,mNV,day);
    end
    
    return ea;
end

function Base.copy!(s1::EmbryoState, s2::EmbryoState)
    s1.n = s2.n;
    s1.v = s2.v;
    s1.mV = s2.mV;
    s1.mNV = s2.mNV;
    s1.day = s2.day;
    s1
end

copy! (generic function with 27 methods)

In [41]:
##################################################
######### EmbryoStateSpace etc. ##################
##################################################
type EmbryoStateIterator
    n::Int64;
    minVals::Array{Int64,1};
    maxVals::Array{Int64,1};
end

EmbryoStateIterator(n::Int64) = EmbryoStateIterator(n,[1;0;0;0;0;0;0;0],[6;n;n;n;n;n;n;n]);

function Base.start(e::EmbryoStateIterator)
    return EmbryoState(e.n,0,[0;0;0],[0,0,e.n],1);
end

function Base.done(e::EmbryoStateIterator, state)
    return state.day == 7;
end

function Base.next(e::EmbryoStateIterator, state)

    currV = state.v;
    currNV = state.n - state.v;
        
    iterVector = [state.day;state.v;state.mV;state.mNV];
    
    # advance iterVector until we get a valid v,mV,mNV combo
    isValid = false;
    
    while (isValid == false)
    
        i = 8;

        # increment entire vector by 1, check if it's valid
        while (i > 1) && (iterVector[i] == e.maxVals[i])
            iterVector[i] = e.minVals[i];
            i -= 1;
        end
        
        iterVector[i] += 1;
        vNew = iterVector[2];
        vSum = sum(iterVector[3:5]);
        nvSum = sum(iterVector[6:8]);
        isValid = (vSum == vNew) && (nvSum == (state.n - vNew))
        
    end
    
    return (state, EmbryoState(state.n, iterVector[2], iterVector[3:5], iterVector[6:8], iterVector[1]));
    
end

type EmbryoStateSpace <: AbstractSpace
    # only variable is state iterator 
    states::EmbryoStateIterator; 
end

# returns EmbryoStateSpace
function POMDPs.states(mdp::EmbryoCulture)
    return EmbryoStateSpace(EmbryoStateIterator(mdp.n));
end;

# returns EmbryoStateIterator
function POMDPs.domain(space::EmbryoStateSpace)
    return space.states;
end

# count number of states in state space
function POMDPs.n_states(mdp::EmbryoCulture)
    s = 0;  
    for i in 0:mdp.n
        s += (i+1)*(i+2)*(mdp.n-i+1)*(mdp.n-i+2)/4;
    end  
    return convert(Int64,s*6);
end

# define function to uniformly sample state space
function POMDPs.rand!(es::EmbryoState, space::EmbryoStateSpace)
    
    n = space.states.n;
    s = 0;
    for i in 0:n
        s += (i+1)*(i+2)*(n-i+1)*(n-i+2)/4;
    end 
    s *= 6;
    
    sp = start(takenth(space.states,rand(1:s)));
    copy!(es, sp)
    es
end

rand! (generic function with 29 methods)

In [42]:
##################################################
######### EmbryoAction definitions ###############
##################################################

type EmbryoAction <: Action
    a::String # {CC = continue culture, B = biopsy (genetic testing), D = discard all}
    # if we're transferring, TR = # to transfer to patient}
    tr::Int64; # transfer tr embryos to patient in order of morphology. If morphologically equal, choose at random
end

type EmbryoActionSpace <: AbstractSpace
    actions::Array{EmbryoAction,1}
end

function POMDPs.actions(mdp::EmbryoCulture)
    
    acts = Array(EmbryoAction,mdp.n+3);
    acts[1] = EmbryoAction("CC", 0); # continue culture only
    acts[2] = EmbryoAction("B", 0); # biopsy all
    acts[3] = EmbryoAction("D", 0); # discard all
    
    for i in 1:mdp.n
        acts[i+3] = EmbryoAction("TR", i);
    end
    
    return EmbryoActionSpace(acts)
end

POMDPs.actions(mdp::EmbryoCulture, s::EmbryoState, as::EmbryoActionSpace=actions(mdp)) = as;

function POMDPs.domain(space::EmbryoActionSpace)
    return space.actions; 
end

function POMDPs.rand!(ea::EmbryoAction, space::EmbryoActionSpace)
    ap = space.actions[rand(1:(length(space.actions)))];
    ea.a = ap.a;
    ea.tr = ap.tr;
    ea 
end

function POMDPs.n_actions(mdp::EmbryoCulture)
    return 3 + (mdp.n);
end


n_actions (generic function with 5 methods)

In [43]:
##################################################
############## Transition Function ###############
##################################################

# define transition distribution
type EmbryoTransitionDistribution <: AbstractDistribution
    tD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize transition distribution (for preallocation only)
function POMDPs.create_transition_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    tD = Dict{ASCIIString, Float64}(); # transition distribution
    return EmbryoTransitionDistribution(tD);
end

# get all possible next states
function POMDPs.domain(d::EmbryoTransitionDistribution)
    return stringToEmbryoState(collect(keys(d.tD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoTransitionDistribution, s::EmbryoState)
    sk = embryoStateToString(s);
    if haskey(d.tD, sk)
        pArray = collect(values(d.tD));
        return d.tD[sk];
    else
        return 0.0;
    end
end

# randomly sample next state according to transition probs
function POMDPs.rand!(d::EmbryoTransitionDistribution, s::EmbryoState)
    keyArray = collect(keys(d.tD));
    pArray = collect(values(d.tD));
    
    # normalize to make sure it's a probability distribution
    cat = Categorical(pArray/sum(pArray));
    
    randState = stringToEmbryoState(keyArray[rand(cat)]);
    copy!(s, randState);
    return s;
end

function calcP(mNew::Array{Int64,1},mOld::Array{Int64,1},p)
   
    pOut = 1;
    mOldCopy = deepcopy(mOld);
    
    if (cumsum(mNew)-cumsum(mOld)) == abs(cumsum(mNew)-cumsum(mOld))
        for i in 1:(length(mNew)-1)
            nStay = mOldCopy[i];
            nDrop = mNew[i] - nStay;
            mOldCopy[i+1] -= nDrop;
            pOut *= (1-p)^(nDrop);
            pOut *= binomial(mOld[i+1], nDrop);
            if i > 1
                pOut *= (p^nStay);
            end
        end

        pOut *= p^(mOldCopy[3]);
        return pOut;
    else
        return 0.0;
    end
end

# transition model
function POMDPs.transition(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction,
    d::EmbryoTransitionDistribution=create_transition_distribution(mdp))

    # if we're already "done" or if action leads to terminal state
    # (action = TR or D or B, day = 5) then go to day 6 and stay there (no more rewards)
    if (s.day > 4) || (a.a == "TR") || (a.a == "D") || (a.a == "B")
        # can not transition out of done state
        d.tD = [embryoStateToString(EmbryoState(s.n,s.v,s.mV,s.mNV,6)) => 1.0];
        return d;
    end
    
    pv = mdp.pv; # p(v(t+1) = 1 | v(t) = 1)
    pn = mdp.pn; # p(n(t+1) = n(t)) for [nonviable, viable] embryo 

    # v(t+1) could go from 0 to v(t)
    iterVector = zeros(Int64,7);
    minVals = zeros(Int64,7);
    maxVals = s.n * ones(Int64,7);
    
    # re-init dictionary w/ s' => p(s')
    d.tD = Dict{ASCIIString, Float64}();
    #println("Old state: $s")
    
    # advance iterVector until we go through ALL valid v,mV,mNV combos
    while (iterVector[1] < (s.v + 1))
        
        isValid = false;
        while (isValid == false) && (iterVector[1] < (s.v + 1))
            i = 7;

            # increment entire vector by 1, check if it's valid
            while (i > 1) && (iterVector[i] == maxVals[i])
                iterVector[i] = minVals[i];
                i -= 1;
            end

            iterVector[i] += 1;
            vNew = deepcopy(iterVector[1]);
            vSum = sum(iterVector[2:4]);
            nvSum = sum(iterVector[5:7]);
            isValid = (vSum == vNew) && (nvSum == (s.n - vNew));
            
        end
            
        if vNew < (s.v + 1)
           
            vNew = deepcopy(iterVector[1]);
            mvNew = deepcopy(iterVector[2:4]);
            mnvNew = deepcopy(iterVector[5:7]);
            sNew = EmbryoState(s.n,vNew,mvNew,mnvNew,s.day+1);
                
            # 1. subtract viable embryos from s.mV
            mvSub = deepcopy(s.mV);
            eRM = [0;0;0]; # embryos removed from mvSub
            
            for i in 1:3
                numChange = min(s.v - vNew - sum(eRM), mvSub[i]);
                eRM[i] += numChange;
                mvSub[i] -= numChange;
            end
            
            #println("eRM: $eRM");
            
            # 2. add embryos of same morphology to mNV
            mnvAdd = deepcopy(s.mNV) + eRM; 
            maxV = [mvSub[1] + mvSub[2]; mvSub[2] + mvSub[3]; mvSub[3]];
            maxNV = [mnvAdd[1] + mnvAdd[2]; mnvAdd[2] + mnvAdd[3]; mnvAdd[3]];
            
            # 3. if state is unreachable, write p = 0
            if (mvNew[1] > maxV[1]) || (mvNew[2] > maxV[2]) || (mvNew[3] > maxV[3]) 
                d.tD[embryoStateToString(sNew)] = 0.0;
            elseif (mnvNew[1] > maxNV[1]) || (mnvNew[2] > maxNV[2]) || (mnvNew[3] > maxNV[3]) 
                d.tD[embryoStateToString(sNew)] = 0.0;
            else   
                # 4. if it is reachable, compute probability and write to dict
                pNew = binomial(s.v,vNew)*(pv^(vNew))*((1-pv)^(s.v-vNew)); # p(vNew(t+1)|s.v(t))
                pNew *= calcP(mvNew,mvSub,pn[2]); # pdf for mvNew
                pNew *= calcP(mnvNew,mnvAdd,pn[1]); # pdf for mnvNew
                
                if haskey(d.tD, embryoStateToString(sNew))
                    d.tD[embryoStateToString(sNew)] = d.tD[embryoStateToString(sNew)] + pNew;
                else
                    d.tD[embryoStateToString(sNew)] = pNew;
                end
                
                #println("New state: $sNew, p: $pNew");
                
            end
        end 
    end
        
    if isempty(d.tD)
        d.tD = [embryoStateToString(s) => 1.0] ;
    end
    
    return d;
end

transition (generic function with 9 methods)

In [44]:
# test cell only
tic()
mdp = EmbryoCulture(2);
stateSpace = states(mdp)

#startingState = EmbryoState(mdp.vStart,mdp.nStart);
#randState = rand!(startingState, stateSpace)

#display(randState)

d = create_transition_distribution(mdp);
transition(mdp, EmbryoState(2,1,[1;0;0],[1;0;0]), EmbryoAction("CC",0), d);
#println(sum(collect(values(d.tD))))
#println(d);

i = 1;
for es in domain(states(mdp))#domain(transition(mdp, EmbryoState(1,0,[0;0;0],[0;0;1]), EmbryoAction("CC",0)))
    
    transition(mdp, EmbryoState(2,1,[1;0;0],[1;0;0]), EmbryoAction("CC",0), d);
    
    if sum(collect(values(d.tD))) < 1.0
        println(sum(collect(values(d.tD))))
        println("i: $i");
    end
    
    i += 1;
    if i > 100
        break;
    end
end
toc()

elapsed time: 0.589329681 seconds


0.589329681

In [45]:
##################################################
################ Reward Function #################
##################################################
# output reward given state,action
function POMDPs.reward(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction)
    
    nDay = s.day;
    
    # if we are in a terminal state, no reward
    if (s.day > 5)
        return 0;
    # if we are transferring some of the embryos
    elseif a.a == "TR"

        nTransfer = a.tr;
        nSampled = [];
        
        # randomly choose nTransfer best morphology embryos
        i = 3;
        while (i > 0) && (length(nSampled) < nTransfer)
            currV = s.mV[i];
            currNV = s.mNV[i];
            numToSample = min(nTransfer - length(nSampled), currV + currNV);
            arrayToSample = [zeros(Int64,currNV); ones(Int64,currV)];
            nSampled = [nSampled; sample(arrayToSample,numToSample,replace=false)];
            i -= 1;
        end
        
        nViable = sum(nSampled);
        ev = expectedViable(mdp,nViable,nDay);
        
        # big negative reward if no viable embryos transferred
        # big positive reward if one viable embryo transferred
        # diminishing positive rewards for >1 viable embryo transferred
        if ev < 1
            return -80; # transfer 0 viable embryos
        elseif ev < 2
            return 80;
        else 
            return 0; # transfer >= 1 viable embryo
        end
    elseif a.a == "D"
        
        nViable = s.v;
        ev = expectedViable(mdp,nViable, nDay);
        
        # slight negative reward for discarding all (meaning no viable embryos available to transfer)
        # big negative reward for discarding viable embryos
        if ev < 1
            return 0; # none viable, discard all
        else 
            return -40; # discard >= 1 viable embryo
        end
    elseif a.a == "B"
        
        # ignore biopsy option for now
        nViable = s.v;
        ev = expectedViable(mdp,nViable, nDay);
        
        return -80;
        if nDay < 5
            return -40;
        elseif ev < 1 
            return -5;
        else
            return 40;
        end
    else
        # do nothing / continue culture
        if nDay > 4
            return -40;
        else
            return -1; # slight negative reward for continuing culture
        end
    end
end

reward (generic function with 6 methods)

In [46]:
##################################################
########## Observation Distribution ##############
##################################################
# observation definitions
type EmbryoObservation <: Observation
    n::Int64; # number of embryos in culture
    oN::Array{Int64,1}; # morphology of all embryos (3x1 vector)
    oD::Int64; # day in culture
end

function Base.copy!(o1::EmbryoObservation, o2::EmbryoObservation)
    o1.n = o2.n;
    o1.oN = o2.oN;
    o1.oD = o2.oD;
    return o1;
end

# convert EmbryoState to string for hashing purposes
function embryoObsToString(obs::EmbryoObservation)
    sOut = string(obs.oN[1],obs.oN[2],obs.oN[3], obs.oD, obs.n);
    return sOut;
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoObs(s::ASCIIString)
    oN = [parse(Int64, s[1]);parse(Int64, s[2]);parse(Int64, s[3])];
    oD = parse(Int64, s[4]);   
    n = parse(Int64,s[5]);
    return EmbryoObservation(n,oN,oD);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoObs(oa::Array{ASCIIString,1})
   
    nStates = length(oa);
    oOut = cell(nStates);
    
    for j in 1:nStates
        s = oa[j];
        oN = [parse(Int64, s[1]);parse(Int64, s[2]);parse(Int64, s[3])];
        oD = parse(Int64, s[4]);  
        n = parse(Int64,s[5]);
        oOut[j] = EmbryoObservation(n,oN,oD);
    end
    
    return oOut;
end

type EmbryoObservationIterator
    n::Int64;
    minVals::Array{Int64,1};
    maxVals::Array{Int64,1};
end

EmbryoObservationIterator(n::Int64) = EmbryoObservationIterator(n,[1;0;0;0],[6;n;n;n]);

function Base.start(e::EmbryoObservationIterator)
    return EmbryoObservation(e.n,[0;0;e.n],1);
end

function Base.done(e::EmbryoObservationIterator, obs)
    return obs.oD == 7;
end

function Base.next(e::EmbryoObservationIterator, obs)
        
    iterVector = [obs.oD;obs.oN];
    
    # advance iterVector until we get a valid oN
    isValid = false;
    
    while (isValid == false)
    
        i = 4;

        # increment entire vector by 1, check if it's valid
        while (i > 1) && (iterVector[i] == e.maxVals[i])
            iterVector[i] = e.minVals[i];
            i -= 1;
        end
        
        iterVector[i] += 1;
        oNewSum = sum(iterVector[2:4]);
        isValid = (obs.n == oNewSum);
    end
    
    return (obs, EmbryoObservation(obs.n, iterVector[2:4], iterVector[1]));
end

type EmbryoObservationSpace <: AbstractSpace
    obs::EmbryoObservationIterator;
end

# function returning observation space
function POMDPs.observations(mdp::EmbryoCulture)
    return EmbryoObservationSpace(EmbryoObservationIterator(mdp.n));
end
    
POMDPs.observations(mdp::EmbryoCulture, s::EmbryoState, obs::EmbryoObservationSpace) = obs;
POMDPs.domain(space::EmbryoObservationSpace) = space.obs;

# define observation distribution
type EmbryoObservationDistribution <: AbstractDistribution
    obsD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize observation distribution
function POMDPs.create_observation_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    return EmbryoObservationDistribution(Dict{ASCIIString, Float64}());
end

# get all possible next obs
function POMDPs.domain(d::EmbryoObservationDistribution)
    return stringToEmbryoObs(collect(keys(d.obsD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoObservationDistribution, o::EmbryoObservation)
    ok = embryoObsToString(o);
    if haskey(d.obsD, ok)
        return d.obsD[ok];
    else
        return 0.0;
    end
end

# randomly sample next state according to observation probs
function POMDPs.rand!(d::EmbryoObservationDistribution, o::EmbryoObservation)
    keyArray = collect(keys(d.obsD));
    pArray = collect(values(d.obsD));
    cat = Categorical(pArray);
    randObs = stringToEmbryoObs(keyArray[rand(cat)]);
    copy!(o, randObs);
    return o;
end

# observation model
function POMDPs.observation(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction, 
    d::EmbryoObservationDistribution=create_observation_distribution(mdp))
    
    # for now, observation is simply EXACTLY s.n, s.mV+s.mNV and s.day
    d.obsD = [embryoObsToString(EmbryoObservation(s.n, s.mV+s.mNV, s.day)) => 1.0];
    return d;
end

function POMDPs.n_observations(mdp::EmbryoCulture)
    return 6*(mdp.n+1)*(mdp.n+2)/2;  
end

n_observations (generic function with 4 methods)

In [47]:
# Belief ... modify this eventually to reflect starting day always = 1, know starting n

POMDPs.create_belief(mdp::EmbryoCulture) = 
DiscreteBelief([(1.0/convert(Int64,n_states(mdp)/6))*ones(Float64,convert(Int64,n_states(mdp)/6));
    zeros(Float64,5*convert(Int64,n_states(mdp)/6))]);
POMDPs.initial_belief(mdp::EmbryoCulture) = 
DiscreteBelief([(1.0/convert(Int64,n_states(mdp)/6))*ones(Float64,convert(Int64,n_states(mdp)/6));
    zeros(Float64,5*convert(Int64,n_states(mdp)/6))]);


In [48]:
using POMDPXFile
mdp = EmbryoCulture(nEmbryos);
policy = POMDPPolicy(string("embryo",nEmbryos,"_ccp_R1.policy"))

# create the .pomdpx file, this is the format which the SARSOP backend reads in
pomdpfile = POMDPFile(mdp, string("embryo",nEmbryos,"_ccp_R1.pomdpx")) # must end in .pomdpx
#pomdpfile = POMDPFile(string("embryo",nEmbryos,".pomdpx"))
#write(mdp, )

POMDPFile("embryo3_ccp.pomdpx")

In [49]:
# try SARSOP ...
#using SARSOP

# initialize the solver
solver = SARSOPSolver(randomization=true, timeout=3.0)

# run the solve function
solve(solver, pomdpfile, policy)



Loading the model ...
  input file   : embryo3_ccp.pomdpx
  loading time : 0.63s 

SARSOP initializing ...
  initialization time : 0.06s

-------------------------------------------------------------------------------
 Time   |#Trial |#Backup |LBound    |UBound    |Precision  |#Alphas |#Beliefs  
-------------------------------------------------------------------------------
 0.056   0       0        -14.4048   6.44236    20.8471     6        1        
 0.087   9       51       -14.4048   -7.51382   6.89094     29       27       
 0.117   19      103      -14.4048   -10.1392   4.26559     45       49       
 0.146   29      153      -14.4048   -11.3473   3.0575      69       73       
 0.168   38      201      -14.4048   -12.1929   2.21187     83       93       
 0.232   46      250      -14.4048   -12.8652   1.53958     98       109      
 0.292   54      303      -14.4048   -13.2035   1.20122     120      125      
 0.331   61      350      -14.4048   -13.4974   0.907404    135     

POMDPAlphas(336x194 Array{Float64,2}:
  -1.0     -77.0      -71.4425  -47.6964   …  -41.9614   -25.4009  -80.0
  -1.0     -77.0      -71.4425  -39.9968      -36.0515   -25.9099  -80.0
  -1.0     -77.0      -71.4425  -29.4973      -27.461    -27.469   -80.0
  -1.0     -77.0      -71.4425  -15.1798      -15.1798   -30.5552  -80.0
  -1.0     -77.0      -71.4425  -38.0719      -34.1266   -26.8724  -80.0
  -1.0     -77.0      -71.4425  -26.8724   …  -24.8361   -28.7814  -80.0
  -1.0     -77.0      -71.4425  -11.6004      -11.6004   -32.3449  -80.0
  -1.0     -77.0      -71.4425  -24.0725      -22.0363   -30.1813  -80.0
  -1.0     -77.0      -71.4425   -7.78241      -7.78241  -34.2539  -80.0
  -1.0     -77.0      -71.4425   -3.70987      -3.70987  -36.2901  -80.0
 -34.7725  -36.6013   -16.023   -16.4812   …  -35.6271   -13.8345  -80.0
 -26.65    -37.5631   -15.3091  -14.92        -34.2735   -12.7461   10.0
 -18.5275   -1.97375  -15.1901  -12.7911      -32.3918   -11.8095  -80.0
   ⋮         

In [50]:
# simulate the SARSOP policy
simulator = SARSOPSimulator(6, 50)
simulate(simulator, pomdpfile, policy)


Loading the model ...
  input file   : embryo3_ccp.pomdpx

Loading the policy ...
  input file   : embryo3_ccp.policy

Simulating ...
  action selection :  one-step look ahead

-----------------------------------
 #Simulations  | Exp Total Reward  
-----------------------------------
 5               -12
 10              -4
 15              -8
 20              -4.5
 25              -3.6
 30              -2.33333
 35              -6.28571
 40              -4.75
 45              -4
 50              -8.2
-----------------------------------

Finishing ...

-------------------------------------------------------------
 #Simulations  | Exp Total Reward | 95% Confidence Interval 
-------------------------------------------------------------
 50              -8.2               (-16.4169, 0.0168661)
-------------------------------------------------------------


In [51]:
# evaluate the SARSOP policy
evaluator = SARSOPEvaluator(6, 50)
evaluate(evaluator, pomdpfile, policy)


Loading the model ...
  input file   : embryo3_ccp.pomdpx

Loading the policy ...
  input file   : embryo3_ccp.policy

Simulating ...
  action selection :  one-step look ahead

-----------------------------------
 #Simulations  | Exp Total Reward  
-----------------------------------
 5               -14.4048
 10              -14.4048
 15              -14.4048
 20              -14.4048
 25              -14.4048
 30              -14.4048
 35              -14.4048
 40              -14.4048
 45              -14.4048
 50              -14.4048
-----------------------------------

Finishing ...

-------------------------------------------------------------
 #Simulations  | Exp Total Reward | 95% Confidence Interval 
-------------------------------------------------------------
 50              -14.4048           (-14.4048, -14.4048)
-------------------------------------------------------------


In [52]:
# generates a policy graph
#graphgen = PolicyGraphGenerator("Embryo3.dot")
#polgraph(graphgen,pomdpfile,policy)

In [53]:
# simulate actions based on policys

nInit = nEmbryos;
println("numEmbryos = $nInit")
f = open(string("SARSOP_sim", nInit, "_3_R1.txt"), "w")
write(f, "CC cost: 1 \n");
write(f, "nonviable and viable starting morphology DIFF (.5,.9)")

for k in 0:nEmbryos
    vInit = k;
    println("numViable = $vInit");
    write(f, string("NumViable: ", vInit, "\n"));

    for j in 1:10

        # start with two viable embryos at day 1
        s = EmbryoState(nInit, vInit, generateMorphology(vInit,1,[.5;.9]), generateMorphology(nInit-vInit,0,[.5;.9]));
        o = EmbryoObservation(mdp.n, [3;0;0], 1);
        b = initial_belief(mdp)
        action_map = domain(actions(mdp)) # create a mapping array
        updater = DiscreteUpdater(mdp) # this comes from POMDPToolbox

        println("Starting state: $(s)")
        write(f, string("Starting state: ", s, "\n"));

        rtot = 0.0
        # run the simulation for 5 days max
        while s.day < 6
            # get the action from our SARSOP policy
            #a = action(qmdp_policy, b) # the QMDP action function returns the POMDP action not its index like the SARSOP action function

            ai = action(policy,b);
            a = action_map[ai];
            # compute the reward
            r = reward(mdp, s, a)
            rtot += r

            println("Time step $(s.day)")
            println("Taking action: $(a), got reward: $(r)")
            write(f, string("Time step $(s.day)", "\n"));
            write(f, string("Taking action: $(a), got reward: $(r)", "\n"));

            # transition the system state
            trans_dist = transition(mdp, s, a)
            rand!(trans_dist, s)
            # sample a new observation
            obs_dist = observation(mdp, s, a)
            rand!(obs_dist, o)

            # update the belief
            b = update(updater, b, a, o)

            println("Saw observation: $(o)\n")
            write(f, string("Saw observation: $(o)", "\n"));

        end
        println("Total reward: $rtot \n")
        write(f, string("Total reward: $rtot", "\n"));
    end
end
close(f)

numEmbryos = 3
numViable = 0
Starting state: EmbryoState(3,0,[0,0,0],[1,1,1],1)
Time step 1
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[1,1,1],2)

Time step 2
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[2,0,1],3)

Time step 3
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[2,0,1],4)

Time step 4
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[2,1,0],5)

Time step 5
Taking action: EmbryoAction("D",0), got reward: 0
Saw observation: EmbryoObservation(3,[2,1,0],6)

Total reward: -4.0 

Starting state: EmbryoState(3,0,[0,0,0],[0,2,1],1)
Time step 1
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[1,2,0],2)

Time step 2
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[2,1,0],3)

Time step 3
Taking action: EmbryoAction("CC",0), got reward

In [35]:
# simulate actions based on policys

nInit = nEmbryos;
println("numEmbryos = $nInit")
f = open(string("SARSOP_sim", nInit, "_4_R1.txt"), "w")
write(f, "CC cost: 1 \n");
write(f, "nonviable and viable starting morphology SAME (.9,.9)")

for k in 0:nEmbryos
    vInit = k;
    println("numViable = $vInit");
    write(f, string("NumViable: ", vInit, "\n"));

    for j in 1:10

        # start with two viable embryos at day 1
        s = EmbryoState(nInit, vInit, generateMorphology(vInit,1,[.9;.9]), generateMorphology(nInit-vInit,0,[.9;.9]));
        o = EmbryoObservation(mdp.n, [3;0;0], 1);
        b = initial_belief(mdp)
        action_map = domain(actions(mdp)) # create a mapping array
        updater = DiscreteUpdater(mdp) # this comes from POMDPToolbox

        println("Starting state: $(s)")
        write(f, string("Starting state: ", s, "\n"));

        rtot = 0.0
        # run the simulation for 5 days max
        while s.day < 6
            # get the action from our SARSOP policy
            #a = action(qmdp_policy, b) # the QMDP action function returns the POMDP action not its index like the SARSOP action function

            ai = action(policy,b);
            a = action_map[ai];
            # compute the reward
            r = reward(mdp, s, a)
            rtot += r

            println("Time step $(s.day)")
            println("Taking action: $(a), got reward: $(r)")
            write(f, string("Time step $(s.day)", "\n"));
            write(f, string("Taking action: $(a), got reward: $(r)", "\n"));

            # transition the system state
            trans_dist = transition(mdp, s, a)
            rand!(trans_dist, s)
            # sample a new observation
            obs_dist = observation(mdp, s, a)
            rand!(obs_dist, o)

            # update the belief
            b = update(updater, b, a, o)

            println("Saw observation: $(o)\n")
            write(f, string("Saw observation: $(o)", "\n"));

        end
        println("Total reward: $rtot \n")
        write(f, string("Total reward: $rtot", "\n"));
    end
end
close(f)

numEmbryos = 3
numViable = 0
Starting state: EmbryoState(3,0,[0,0,0],[0,0,3],1)
Time step 1
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[0,1,2],2)

Time step 2
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[0,2,1],3)

Time step 3
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[2,1,0],4)

Time step 4
Taking action: EmbryoAction("D",0), got reward: 0
Saw observation: EmbryoObservation(3,[2,1,0],6)

Total reward: -3.0 

Starting state: EmbryoState(3,0,[0,0,0],[0,0,3],1)
Time step 1
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[0,2,1],2)

Time step 2
Taking action: EmbryoAction("CC",0), got reward: -1
Saw observation: EmbryoObservation(3,[1,1,1],3)

Time step 3
Taking action: EmbryoAction("TR",2), got reward: -40
Saw observation: EmbryoObservation(3,[1,1,1],6)

Total reward: -42.0 

Starting state: EmbryoState(3,0,[0,0

In [36]:
d = Binomial(5,.9);
rand(d)

5