In [1]:
# first import the POMDPs.jl interface
using POMDPs
using POMDPDistributions
using Distributions
using SARSOP
using POMDPModels
using Iterators

In [48]:
################ Functions and Structures definitions #######################

# EmbryoState definition
type EmbryoState <: State
    v::Array{Int64,1} # viable or not? 
    n::Array{Int64,1} # developmental day achieved {1,2,3,4,5}
    done::Int64; # done? (only if TR or D at last step)
    day::Int64; # culture day
end

# initialize day1 embryos with known viability
EmbryoState(v::Array{Int64,1}) = EmbryoState(v,ones(Int64,v),0,1);
EmbryoState(v::Array{Int64,1}, n::Array{Int64,1}) = EmbryoState(v,n,0,1);

function ==(s1::EmbryoState, s2::EmbryoState)
    return (s1.v == s2.v) && (s1.n == s2.n) && (s1.done == s2.done) && (s1.day == s2.day);
end

# convert EmbryoState to string for hashing purposes
function embryoStateToString(es::EmbryoState)
    sOut = "";
    for i in 1:length(es.v)
        sOut = string(sOut, es.v[i]);
    end
    for i in 1:length(es.n)
        sOut = string(sOut, es.n[i]);
    end
    sOut = string(sOut, es.done);
    sOut = string(sOut, es.day);
    return sOut;
end

# convert string to EmbryoState for hashing purposes
function stringToEmbryoState(s::ASCIIString)
   
    m = convert(Int64, length(s)/2) - 1;
    v = zeros(Int64, m);
    n = zeros(Int64, m);
    
    for i in 1:m
        v[i] = parse(Int64,s[i]);
    end
    for i in 1:m
        n[i] = parse(Int64,s[m+i]);
    end
    
    done = parse(Int64, s[end-1]);
    day = parse(Int64, s[end]);
        
    return EmbryoState(v,n,done,day);
end

# convert Array{string} to Array{EmbryoState} for hashing purposes
function stringToEmbryoState(sa::Array{ASCIIString,1})
   
    nStates = length(sa);
    ea = cell(nStates);
    
    for j in 1:nStates
        s = sa[j];
    
        m = convert(Int64, length(s)/2) - 1;
        v = zeros(Int64, m);
        n = zeros(Int64, m);

        for i in 1:m
            v[i] = parse(Int64,s[i]);
        end
        for i in 1:m
            n[i] = parse(Int64,s[m+i]);
        end
        d = parse(Int64, s[end]);
        
        done = parse(Int64, s[end-1]);
        day = parse(Int64, s[end]);
        
        ea[j] = EmbryoState(v,n,done,day);
    end
    
    return ea;
end

function Base.copy!(s1::EmbryoState, s2::EmbryoState)
    s1.v = s2.v;
    s1.n = s2.n;
    s1.done = s2.done;
    s1.day = s2.day;
    s1
end

# EmbryoAction definition
type EmbryoAction <: Action
    a::String # {CC = continue culture, TL = measure time lapse params, B = biopsy (genetic testing)}
    # if we're transferring, TR = transfer to patient}
    tr::Array{Int64,1}; # vector of embryo nums to transfer (if a == TR)
end

# Define embryo culture MDP scenario
type EmbryoCulture <: POMDP
    #rf::Function; # calculates reward based on input S,A
    #tf::Function; #
    m::Int64; # number of embryos starting in culture
    vStart::Array{Int64,1}; # starting viability info
    nStart::Array{Int64,1}; # starting cell #
    dayStart::Int64;
    discount::Float64; # = 1; # discount_factor
    pv::Float64; # = 0.9; # probability that a viable embryo will continue to be viable
    pn::Array{Float64,1}; # = [.7; .95]; # probability that a [nonviable, viable] embryo will have normal cell division
end

EmbryoCulture(m::Int64, vStart::Array{Int64,1}, 
    nStart::Array{Int64,1}) = EmbryoCulture(m,vStart,nStart,1,1,0.9,[0.7;0.95]);
EmbryoCulture(m::Int64, vStart::Array{Int64,1}, nStart::Array{Int64,1}, 
    dayStart::Int64) = EmbryoCulture(m,vStart,nStart,dayStart,1,0.9,[0.7;0.95]);

# given number of actual viable embryos and culture day #,
# return expected # of viable embryos at the time of transfer
function expectedViable(nViable::Int64, nDay::Int64)
    return nViable * (2 ^ (-(5 - nDay)/4))
end


expectedViable (generic function with 1 method)

In [3]:
# state space functions and structures
# define iterator over entire EmbryoState space
type EmbryoStateIterator
    m::Int64;
    startState::EmbryoState;
    minVals::Vector{Int64};
    maxVals::Vector{Int64};
end

# can iterate over all states in embryo culture with m embryos
EmbryoStateIterator(m::Int64) = EmbryoStateIterator(m,EmbryoState(ones(m),ones(m),0,1),
                                            [zeros(m);ones(m)],[ones(m);5*ones(m)]);
# can also iterate over possible next states
EmbryoStateIterator(s::EmbryoState) = EmbryoStateIterator(length(s.v),s,[zeros(length(s.v));s.n],
    [s.v;5*ones(length(s.v))]);

function Base.start(e::EmbryoStateIterator)
    return e.startState;
end

function Base.done(e::EmbryoStateIterator, state)
    return (state.v[1] == -1);
    #return (state.v == e.minVals[1:m]) && (state.n == e.maxVals[(m+1):(2*m)]);
end

function Base.next(e::EmbryoStateIterator, state)
    iterVector = [state.v ; state.n];
    m = e.m;
    i = 2*m;
    
    # increment development days, decrement viability
    while (i > m) && (iterVector[i] == e.maxVals[i])
        iterVector[i] = e.minVals[i];
        i -= 1;
    end
    
    while (i < (m+1)) && (i > 1) && (iterVector[i] == e.minVals[i])
        iterVector[i] = e.maxVals[i];
        i -= 1;
    end
    
    if (i > m)
        iterVector[i] += 1;  
    else
        iterVector[i] -= 1;
    end
    
    return (state,EmbryoState(iterVector[1:m], iterVector[(m+1):(2*m)],state.done,state.day))
end

type EmbryoStateSpace <: AbstractSpace
    # only variable is state iterator 
    states::EmbryoStateIterator; 
end

function numRemainingStates(space::EmbryoStateSpace)
    minVals = space.states.minVals;
    maxVals = space.states.maxVals;
    nums = maxVals - minVals + ones(maxVals);
    
    stateProd = 1;
    for i in 1:length(nums)
        stateProd *= nums[i];
    end
    return stateProd;
end

numRemainingStates (generic function with 1 method)

In [4]:
# returns EmbryoStateSpace
function POMDPs.states(mdp::EmbryoCulture)
    return EmbryoStateSpace(EmbryoStateIterator(EmbryoState(mdp.vStart,mdp.nStart,0,1)));
end;

# returns EmbryoStateIterator
function POMDPs.domain(space::EmbryoStateSpace)
    return space.states;
end

# define function to uniformly sample state space
# max num states is 10^m
function POMDPs.rand!(es::EmbryoState, space::EmbryoStateSpace)
    maxNumStates = numRemainingStates(space); #10^(space.states.m);
    randStateNum = convert(Int64,ceil(rand(1)*maxNumStates)[1]);
    sp = start(takenth(space.states,randStateNum));
    copy!(es, sp)
    es
end

# for now, assume nStart <= dayStart
function POMDPs.n_states(mdp::EmbryoCulture)
    n = 2^(sum(mdp.vStart));
    
    for i in 1:length(mdp.nStart)
        n *= (5 - mdp.dayStart + 1); # min(5-mdp.dayStart+1,5-mdp.nStart[i]+1); 
    end
    
    n*= (5 - mdp.dayStart + 1); # extra dim for day
    return n;
end

function POMDPs.n_actions(mdp::EmbryoCulture)
    return 4 + 2^(mdp.m);
end

function POMDPs.discount(mdp::EmbryoCulture)
    return mdp.discount; 
end

function POMDPs.index(mdp::EmbryoCulture, s::EmbryoState)
    dimSizes = [[mdp.vStart + 1];[5 - mdp.dayStart*ones(Int64, mdp.m) + 1];[5 - mdp.dayStart + 1]];
    currInds = [[mdp.vStart - s.v + 1];[s.n - mdp.nStart + 1];[s.day - mdp.dayStart + 1]];
    return sub2ind(dimSizes, currInds);
end

index (generic function with 6 methods)

In [5]:
# test cell only
tic()
mdp = EmbryoCulture(2, [1;0], [4;4]);
stateSpace = states(mdp)

startingState = EmbryoState(mdp.vStart,mdp.nStart);
randState = rand!(startingState, stateSpace)

display(randState)

i = 1;
for es in domain(stateSpace)
    i += 1;
    display(es)
    if i > 60
        break;
    end
end
toc()

EmbryoState([0,0],[5,5],0,1)

EmbryoState([1,0],[4,4],0,1)

EmbryoState([1,0],[4,5],0,1)

EmbryoState([1,0],[5,4],0,1)

EmbryoState([1,0],[5,5],0,1)

EmbryoState([0,0],[4,4],0,1)

EmbryoState([0,0],[4,5],0,1)

EmbryoState([0,0],[5,4],0,1)

EmbryoState([0,0],[5,5],0,1)

elapsed time: 0.755795647 seconds


0.755795647

In [6]:
# EmbryoActionSpace definitions and functions
type EmbryoActionSpace <: AbstractSpace
    actions::Array{EmbryoAction,1}
end

function POMDPs.actions(mdp::EmbryoCulture)
    
    acts = Array(EmbryoAction,(2^(mdp.m))+4);
    acts[1] = EmbryoAction("CC", []); # continue culture only
    acts[2] = EmbryoAction("TL", []); # collect cell cycle params
    acts[3] = EmbryoAction("B", []); # biopsy all
    acts[4] = EmbryoAction("D", []); # discard all
    
    for i in 1:(2^(mdp.m))
        transferBits = bits(i-1)[end-mdp.m+1:end];
        transferVector = Array(Bool,mdp.m);
        
        for j in 1:mdp.m
            transferVector[j] = parse(Bool,transferBits[j]);
        end
        
        acts[i+4] = EmbryoAction("TR", transferVector);
    end
    
    return EmbryoActionSpace(acts)
end

POMDPs.actions(mdp::EmbryoCulture, s::EmbryoState, as::EmbryoActionSpace=actions(mdp)) = as;

function POMDPs.domain(space::EmbryoActionSpace)
    return space.actions; 
end

function POMDPs.rand!(ea::EmbryoAction, space::EmbryoActionSpace)
    ap = space.actions[rand(1:end)];
    ea.a = ap.a;
    ea.tr = ap.tr;
    ea
end


rand! (generic function with 27 methods)

In [7]:
# transition distribution stuff
# define transition distribution
type EmbryoCultureDistribution <: AbstractDistribution
    tD::Dict{ASCIIString, Float64}; # s' ==> p(s')
end

# initialize transition distribution (for preallocation only)
function POMDPs.create_transition_distribution(mdp::EmbryoCulture)
    # initialize dict with neighbors => probabilities 
    # init with only min v, max n to begin with
    # possibly define iterator if space becomes too large
    
    #tD = Dict{ASCIIString, Float64}; # transition distribution
    tD = [embryoStateToString(EmbryoState(zeros(Int64, mdp.m), 
        5*ones(Int64, mdp.m))) => 1.0];
    
    return EmbryoCultureDistribution(tD);
end

# get all possible next states
function POMDPs.domain(d::EmbryoCultureDistribution)
    return stringToEmbryoState(collect(keys(d.tD))); 
end

# get p(s')
function POMDPs.pdf(d::EmbryoCultureDistribution, s::EmbryoState)
    sk = embryoStateToString(s);
    if haskey(d.tD, sk)
        return d.tD[sk];
    else
        return 0.0;
    end
end

# randomly sample next state according to transition probs
function POMDPs.rand!(d::EmbryoCultureDistribution, s::EmbryoState)
    keyArray = collect(keys(d.tD));
    pArray = collect(values(d.tD));
    
    cat = Categorical(pArray);
    randState = stringToEmbryoState(keyArray[rand(cat)]);
    copy!(s, randState);
    return s;
end



rand! (generic function with 28 methods)

In [8]:
# transition model
function POMDPs.transition(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction,
    d::EmbryoCultureDistribution=create_transition_distribution(mdp))

    # if we're already "done" or if action leads to terminal state
    # (action = TR or D, day = 5) then can't change state
    if (s.day > 4) || (s.done > 0) || (a.a == "TR") || (a.a == "D")
        # can not transition out of done state
        d.tD = [embryoStateToString(EmbryoState(s.v, s.n, 1, s.day + 1)) => 1.0];
        return d;
    end
    
    pv = mdp.pv; # p(v(t+1) = 1 | v(t) = 1)
    pn = mdp.pn; # p(n(t+1) = n(t) + 1) for [nonviable, viable] embryo 
    
    # biopsy slightly harms embryos
    if a.a == "B" 
        pv = 0.9*pv; 
        pn = 0.9*pn;
    end 
    
    # calc number of possible next states (<= 4^m)
    nChange = sum(s.v .> 0) + sum(s.n .< 5);
    numNeighbors = 2^nChange;
    indChange = [s.v .> 0; s.n .< 5];

    # re-init dictionary w/ s' => p(s')
    d.tD = [embryoStateToString(EmbryoState(zeros(Int64,mdp.m), 
        5*(ones(Int64,mdp.m)), s.done, s.day+1)) => 0.0];
    
    for i in 1:numNeighbors

        vNext = deepcopy(s.v);
        nNext = deepcopy(s.n);
        pNext = 1;
        
        # which indices to change in indChange (length is nChange)
        currNum = bits(i-1)[end-nChange+1:end];
        cInd = 1;
        
        for j in 1:(2*mdp.m)
            if (j < (mdp.m + 1))
                
                if indChange[j] # if we can change the current value

                    if parse(Bool,currNum[cInd]) # v(t) = 1 -> v(t+1) = 0
                        vNext[j] -= 1; # decrement viability of embryo j
                        pNext *= (1-pv); # p(viable -> nonviable)
                    else # v(t) = 1 -> v(t+1) = 1
                        pNext *= pv; # p(viable -> viable)
                    end
                    
                    cInd += 1;
                end
            elseif (j > mdp.m)

                if indChange[j] # if we can change current value

                    if parse(Bool,currNum[cInd]) # n(t+1) = n(t) + 1
                        nNext[j - mdp.m] += 1;
                        pNext *= (pn[s.v[j - mdp.m] + 1]);
                    else # n(t+1) = n(t)
                        pNext *= (1 - pn[s.v[j - mdp.m] + 1]); 
                    end
                    
                    cInd += 1;
                end
            end
        end
        
        d.tD[embryoStateToString(EmbryoState(vNext,nNext,s.done,s.day+1))] = pNext;
        
    end
    
    return d;
    
end


transition (generic function with 9 methods)

In [9]:
# test cell only
mdp = EmbryoCulture(2, [1;0], [3;3], 3);
d = create_transition_distribution(mdp);
s = EmbryoState(mdp.vStart, mdp.nStart, 0, mdp.dayStart);
a = EmbryoAction("B", [1;0])

d = transition(mdp,s,a,d);

display(rand!(d,s))
display(rand!(d,s))
display(rand!(d,s))
display(d.tD)
display(pdf(d,EmbryoState([1;0],[4;3],0,4)))
#display(d.tD[embryoStateToString(EmbryoState([0,0], [1,2], 0, 2))])


EmbryoState([1,0],[3,3],0,4)

EmbryoState([1,0],[4,3],0,4)

EmbryoState([0,0],[4,4],0,4)

Dict{ASCIIString,Float64} with 9 entries:
  "103404" => 0.07399350000000002
  "004304" => 0.06010649999999998
  "004404" => 0.10234349999999998
  "104304" => 0.2562435
  "103304" => 0.04345650000000001
  "005504" => 0.0
  "104404" => 0.4363065
  "003304" => 0.010193499999999998
  "003404" => 0.017356499999999997

0.2562435

In [10]:
# reward model
# take day into account
# output reward given state,action and culture day
function POMDPs.reward(mdp::EmbryoCulture, s::EmbryoState, a::EmbryoAction)
    
    nDay = s.day;
    
    # if we are in a terminal state, no reward
    if (nDay > 5) || (s.done == 1)
        return 0;
    # if we are transferring some of the embryos
    elseif a.a == "TR"

        nViable = sum(s.v .* a.tr);
        ev = expectedViable(nViable, nDay);
        
        # big negative reward if no viable embryos transferred
        # big positive reward if one viable embryo transferred
        # diminishing positive rewards for >1 viable embryo transferred
        if ev < 1
            return -20; # transfer 0 viable embryos
        else 
            return 28 - 8*floor(ev); # transfer >= 1 viable embryo
        end
    elseif a.a == "D"
        
        nViable = sum(s.v);
        ev = expectedViable(nViable, nDay);
        
        # slight negative reward for discarding all (meaning no viable embryos available to transfer)
        # big negative reward for discarding viable embryos
        if ev < 1
            return -5; # none viable, discard all
        else 
            return -20; # discard >= 1 viable embryo
        end
    elseif a.a == "B"
        
        # if we're doing the biopsy at day 5, we would know for sure if an embryo is viable 
        if nDay == 5
            nViable = sum(s.v);
            if nViable == 0
                return -5; # small negative reward for having to discard all
            else
                return 20; # assume transfer 1 viable embryo
            end
        else 
            return -2; # slight negative reward for doing test
        end
    elseif a.a == "TL"
        return -2; # slight negative reward for doing test
    else
        # do nothing / continue culture
        return -1; # slight negative reward for continuing culture
    end
end

reward (generic function with 6 methods)

In [74]:
# now ... test simple MDP implementation

using DiscreteValueIteration
mdp = EmbryoCulture(4, [1;1;1;1], [1;1;1;1], 1);

solver = ValueIterationSolver(max_iterations=5, belres=1e-3)
policy = ValueIterationPolicy(mdp) 
solve(solver, mdp, policy, verbose=true);



Iteration : 1, residual: 20.0, iteration run-time: 36.788593373, total run-time: 36.788593373
Iteration : 2, residual: 0.0, iteration run-time: 41.270112356, total run-time: 78.058705729


In [80]:
s = EmbryoState([0;0;0;0],[1;1;1;1],0,1)
a = action(mdp, policy, s)

EmbryoAction("TR",[1,1,1,1])

In [51]:
display(s.v)
display(reward(mdp,s,EmbryoAction("TR",[1;0;0;0])))

2-element Array{Int64,1}:
 1
 1

LoadError: arrays could not be broadcast to a common size
while loading In[51], in expression starting on line 2

In [73]:
display(policy)

ValueIterationPolicy(50000x20 Array{Float64,2}:
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0  …  -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0  …  -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0  …  -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0     -20.0  -20.0  -20.0  -20.0  20.0
 -1.0  -2.0  -2.0  -20.0  -20.0  -20.0  

In [None]:
using SARSOP
using POMDPModels

policy = POMDPPolicy("tiger.policy")
pomdp = TigerPOMDP()

In [None]:
#using POMDPXFile

pomdpfile = POMDPFile(pomdp, "tiger.pomdpx")

#pomdpx = POMDPX("tiger.pomdpx")
#write(pomdp, pomdpx)

In [None]:
solver = SARSOPSolver()
solve(solver, pomdpfile, policy)

In [None]:
4^5
