# Needle Insertion Markov Decision Process

In [None]:
using POMDPs
using Distributions
using POMDPToolbox
using PyPlot

## States

In [None]:
type NeedleState 
    x::Int64 # x position
    y::Int64 # y position
    psi::Int64 # orientation
    done::Bool # are we in a terminal state?
end

In [None]:
# initial state constructor
NeedleState(x::Int64, y::Int64, psi::Int64) = NeedleState(x,y,psi,false)
# checks if the position of two states are the same
posequal(s1::NeedleState, s2::NeedleState) = s1.x == s2.x && s1.y == s2.y && s1.psi == s2.psi
# copies state s2 to s1
function Base.copy!(s1::NeedleState, s2::NeedleState) 
    s1.x = s2.x
    s1.y = s2.y
    s1.psi = s2.psi
    s1.done = s2.done
    s1
end
# if you want to use Monte Carlo Tree Search, you will need to define the functions below
Base.hash(s::NeedleState, h::UInt64 = zero(UInt64)) = hash(s.x, hash(s.y, hash(s.psi, hash(s.done, h))))
Base.isequal(s1::NeedleState,s2::NeedleState) = s1.x == s2.x && s1.y == s2.y && s1.psi == s2.psi && s1.done == s2.done;

## MDP

In [None]:
# the needle mdp type
type Needle <: MDP{NeedleState, Symbol} # Note that our MDP is parametarized by the state and the action
    size_x::Int64 # x size of the grid
    size_y::Int64 # y size of the grid
    size_psi::Int64 # number of orientation bins
    reward_states::Vector{NeedleState} # target/obstacle states
    reward_values::Vector{Float64} # reward values for those states
    tprob::Array{Float64} # probability of transitioning to the desired state
    discount_factor::Float64 # disocunt factor
end

In [None]:
# we use key worded arguments so we can change any of the values we pass in 
function Needle(;sx::Int64 = 10, # size_x
                sy::Int64 = 10, # size_y
                spsi::Int64 = 8, # size_psi
                rs::Vector{NeedleState} = [[NeedleState(8,4,psi) for psi = 1:spsi]; # target states
                                            [NeedleState(4,6,psi) for psi = 1:spsi]; # obstacle states
                                            [NeedleState(1,y,psi) for y = 1:sy, psi = 3:7][:]; # boundary states
                                            [NeedleState(sx,y,psi) for y = 1:sy, psi = [1:3;7:spsi]][:];
                                            [NeedleState(x,1,psi) for x = 2:sx-1, psi = [5:spsi;1]][:];
                                            [NeedleState(x,sy,psi) for x = 2:sx-1, psi = 1:5][:]],
                rv::Vector{Float64} = [fill(100.0,spsi); fill(-20.0,spsi); fill(-5,(2*sx+2*sy-4)*5)],
                tp::Array{Float64} = [0.05, 0.9, 0.05, 0.0], # tprob
                discount_factor::Float64 = 0.9)
    return Needle(sx, sy, spsi, rs, rv, tp, discount_factor)
end

# we can now create a NeedleState mdp instance like this:
mdp = Needle()
# mdp.reward_states # mdp contains all the defualt values from the constructor

## Spaces

### State Space ($ \mathcal{S}$)

In [None]:
type StateSpace <: AbstractSpace
    states::Vector{NeedleState}
end

In [None]:
function POMDPs.states(mdp::Needle)
    s = NeedleState[] # initialize an array of NeedleStates
    # loop over all our states, remeber there is one binary variables: done (d)
    for d = 0:1, y = 1:mdp.size_y, x = 1:mdp.size_x, psi = 1:mdp.size_psi
        push!(s, NeedleState(x,y,psi,d))
    end
    return StateSpace(s)
end;

In [None]:
function POMDPs.iterator(space::StateSpace)
    return space.states 
end;

In [None]:
function POMDPs.rand(rng::AbstractRNG, space::StateSpace, s::NeedleState)
    sp = space.states[rand(rng, 1:end)]
    copy!(s, sp)
    s
end;

### Action Space ($\mathcal{A}$)

In [None]:
type ActionSpace <: AbstractSpace
    actions::Vector{Symbol}
end

In [None]:
function POMDPs.actions(mdp::Needle)
    acts = [:cw, :ccw]
    return ActionSpace(acts)
end;
POMDPs.actions(mdp::Needle, s::NeedleState, as::ActionSpace=actions(mdp)) = as;

In [None]:
function POMDPs.iterator(space::ActionSpace)
    return space.actions 
end;

In [None]:
function POMDPs.rand(rng::AbstractRNG, space::ActionSpace, a::Symbol)
    return space.actions[rand(rng, 1:end)]
end;
function POMDPs.rand(rng::AbstractRNG, space::ActionSpace)
    a = NeedleAction(:cw)
    return rand(rng, space, a)
end;

In [None]:
POMDPs.create_state(mdp::Needle) = NeedleState(1,1,1)
POMDPs.create_action(mdp::Needle) = :cw;

## Transition Distribution

In [None]:
type NeedleDistribution <: AbstractDistribution
    neighbors::Array{NeedleState} # the states s' in the distribution
    probs::Array{Float64} # the probability corresponding to each state s'
    cat::Categorical # this comes from Distributions.jl and is used for sampling
end

In [None]:
function POMDPs.create_transition_distribution(mdp::Needle)
    # can have at most five neighbors in grid world
    neighbors =  [NeedleState(i,i,1) for i = 1:5]
    probabilities = zeros(5) + 1.0/5.0
    cat = Categorical(5)
    return NeedleDistribution(neighbors, probabilities, cat)
end;

In [None]:
function POMDPs.iterator(d::NeedleDistribution)
    return d.neighbors
end;

In [None]:
function POMDPs.pdf(d::NeedleDistribution, s::NeedleState)
    for (i, sp) in enumerate(d.neighbors)
        if isequal(s,sp)
            return d.probs[i]
        end
    end   
    return 0.0
end;

In [None]:
function POMDPs.rand(rng::AbstractRNG, d::NeedleDistribution, s::NeedleState)
    d.cat = Categorical(d.probs) # init the categorical distribution
    ns = d.neighbors[rand(d.cat)] # sample a neighbor state according to the distribution c
    copy!(s, ns)
    return s # return the pointer to s
end;

## Transition Model (T)

In [None]:
# transition helpers
function inbounds(mdp::Needle,x::Int64,y::Int64,psi::Int64)
    if 1 <= x <= mdp.size_x && 1 <= y <= mdp.size_y && 1 <= psi <= mdp.size_psi
        return true
    else
        return false
    end
end

function inbounds(mdp::Needle,state::NeedleState)
    x = state.x
    y = state.y
    psi = state.psi
    return inbounds(mdp, x, y, psi)
end

###########################################################

function atbounds(mdp::Needle,x::Int64,y::Int64,psi::Int64)
    # at bounds if: at wall, facing outward or at corner
    if (x == 1 || x == mdp.size_x) && (y == 1 || y == mdp.size_y) # at corner
        return true
        elseif (x == 1 && 3 <= psi <= 7) || (x == mdp.size_x && (7 <= psi || psi <= 3) ) ||
        (y == 1 && (5 <= psi || psi <= 1) ) || (y == mdp.size_y && 1 <= psi <= 5) # at wall, facing outward
        return true
    else
        return false
    end
end

function atbounds(mdp::Needle,state::NeedleState)
    x = state.x
    y = state.y
    psi = state.psi
    return atbounds(mdp, x, y, psi)
end

###########################################################

function fill_probability!(p::Vector{Float64}, val::Float64, index::Int64)
    for i = 1:length(p)
        if i == index
            p[i] = val
        else
            p[i] = 0.0
        end
    end
end;


In [None]:
function POMDPs.transition(mdp::Needle,
                            state::NeedleState,
                            action::Symbol,
                            d::NeedleDistribution=create_transition_distribution(mdp))
    tp = mdp.tprob
    
    a = action
    x = state.x
    y = state.y
    psi = state.psi
    
    neighbors = d.neighbors
    probability = d.probs
    
    # let's handle the done case first
    if state.done
        # can only transition to the same done state
        fill!(probability, 0.0)
        probability[1] = 1.0
        copy!(neighbors[1], state)
        # when we sample d, we will only get the state in neighbors[1] - our done state
        return d
    end
    
    fill!(probability, 0.0)

    if a == :ccw
        if psi == 1
            neighbors[1].x = x+1; neighbors[1].y = y;   neighbors[1].psi = psi; 
            neighbors[2].x = x+1; neighbors[2].y = y;   neighbors[2].psi = psi+1;
            neighbors[3].x = x+1; neighbors[3].y = y+1; neighbors[3].psi = psi+1;
            neighbors[4].x = x+1; neighbors[4].y = y+1; neighbors[4].psi = psi+2;
        elseif psi == 2
            neighbors[1].x = x+1; neighbors[1].y = y+1; neighbors[1].psi = psi; 
            neighbors[2].x = x+1; neighbors[2].y = y+1; neighbors[2].psi = psi+1;
            neighbors[3].x = x;   neighbors[3].y = y+1; neighbors[3].psi = psi+1;
            neighbors[4].x = x;   neighbors[4].y = y+1; neighbors[4].psi = psi+2;
        elseif psi == 3
            neighbors[1].x = x;   neighbors[1].y = y+1; neighbors[1].psi = psi; 
            neighbors[2].x = x;   neighbors[2].y = y+1; neighbors[2].psi = psi+1;
            neighbors[3].x = x-1; neighbors[3].y = y+1; neighbors[3].psi = psi+1;
            neighbors[4].x = x-1; neighbors[4].y = y+1; neighbors[4].psi = psi+2;
        elseif psi == 4
            neighbors[1].x = x-1; neighbors[1].y = y+1; neighbors[1].psi = psi; 
            neighbors[2].x = x-1; neighbors[2].y = y+1; neighbors[2].psi = psi+1;
            neighbors[3].x = x-1; neighbors[3].y = y;   neighbors[3].psi = psi+1;
            neighbors[4].x = x-1; neighbors[4].y = y;   neighbors[4].psi = psi+2;
        elseif psi == 5
            neighbors[1].x = x-1; neighbors[1].y = y;   neighbors[1].psi = psi; 
            neighbors[2].x = x-1; neighbors[2].y = y;   neighbors[2].psi = psi+1;
            neighbors[3].x = x-1; neighbors[3].y = y-1; neighbors[3].psi = psi+1;
            neighbors[4].x = x-1; neighbors[4].y = y-1; neighbors[4].psi = psi+2;
        elseif psi == 6
            neighbors[1].x = x-1; neighbors[1].y = y-1; neighbors[1].psi = psi; 
            neighbors[2].x = x-1; neighbors[2].y = y-1; neighbors[2].psi = psi+1;
            neighbors[3].x = x;   neighbors[3].y = y-1; neighbors[3].psi = psi+1;
            neighbors[4].x = x;   neighbors[4].y = y-1; neighbors[4].psi = psi+2;
        elseif psi == 7
            neighbors[1].x = x;   neighbors[1].y = y-1; neighbors[1].psi = psi; 
            neighbors[2].x = x;   neighbors[2].y = y-1; neighbors[2].psi = psi+1;
            neighbors[3].x = x+1; neighbors[3].y = y-1; neighbors[3].psi = psi+1;
            neighbors[4].x = x+1; neighbors[4].y = y-1; neighbors[4].psi = psi+2;
        elseif psi == 8
            neighbors[1].x = x+1; neighbors[1].y = y-1; neighbors[1].psi = psi; 
            neighbors[2].x = x+1; neighbors[2].y = y-1; neighbors[2].psi = psi+1;
            neighbors[3].x = x+1; neighbors[3].y = y;   neighbors[3].psi = psi+1;
            neighbors[4].x = x+1; neighbors[4].y = y;   neighbors[4].psi = psi+2;
        end
    elseif a == :cw
        if psi == 1
            neighbors[1].x = x+1; neighbors[1].y = y;   neighbors[1].psi = psi; 
            neighbors[2].x = x+1; neighbors[2].y = y;   neighbors[2].psi = psi+7;
            neighbors[3].x = x+1; neighbors[3].y = y-1; neighbors[3].psi = psi+7;
            neighbors[4].x = x+1; neighbors[4].y = y-1; neighbors[4].psi = psi+6;
        elseif psi == 2
            neighbors[1].x = x+1; neighbors[1].y = y+1; neighbors[1].psi = psi; 
            neighbors[2].x = x+1; neighbors[2].y = y+1; neighbors[2].psi = psi+7;
            neighbors[3].x = x+1; neighbors[3].y = y;   neighbors[3].psi = psi+7;
            neighbors[4].x = x+1; neighbors[4].y = y;   neighbors[4].psi = psi+6;
        elseif psi == 3
            neighbors[1].x = x;   neighbors[1].y = y+1; neighbors[1].psi = psi; 
            neighbors[2].x = x;   neighbors[2].y = y+1; neighbors[2].psi = psi+7;
            neighbors[3].x = x+1; neighbors[3].y = y+1; neighbors[3].psi = psi+7;
            neighbors[4].x = x+1; neighbors[4].y = y+1; neighbors[4].psi = psi+6;
        elseif psi == 4
            neighbors[1].x = x-1; neighbors[1].y = y+1; neighbors[1].psi = psi; 
            neighbors[2].x = x-1; neighbors[2].y = y+1; neighbors[2].psi = psi+7;
            neighbors[3].x = x;   neighbors[3].y = y+1; neighbors[3].psi = psi+7;
            neighbors[4].x = x;   neighbors[4].y = y+1; neighbors[4].psi = psi+6;
        elseif psi == 5
            neighbors[1].x = x-1; neighbors[1].y = y;   neighbors[1].psi = psi; 
            neighbors[2].x = x-1; neighbors[2].y = y;   neighbors[2].psi = psi+7;
            neighbors[3].x = x-1; neighbors[3].y = y+1; neighbors[3].psi = psi+7;
            neighbors[4].x = x-1; neighbors[4].y = y+1; neighbors[4].psi = psi+6;
        elseif psi == 6
            neighbors[1].x = x-1; neighbors[1].y = y-1; neighbors[1].psi = psi; 
            neighbors[2].x = x-1; neighbors[2].y = y-1; neighbors[2].psi = psi+7;
            neighbors[3].x = x-1; neighbors[3].y = y;   neighbors[3].psi = psi+7;
            neighbors[4].x = x-1; neighbors[4].y = y;   neighbors[4].psi = psi+6;
        elseif psi == 7
            neighbors[1].x = x;   neighbors[1].y = y-1; neighbors[1].psi = psi; 
            neighbors[2].x = x;   neighbors[2].y = y-1; neighbors[2].psi = psi+7;
            neighbors[3].x = x-1; neighbors[3].y = y-1; neighbors[3].psi = psi+7;
            neighbors[4].x = x-1; neighbors[4].y = y-1; neighbors[4].psi = psi+6;
        elseif psi == 8
            neighbors[1].x = x+1; neighbors[1].y = y-1; neighbors[1].psi = psi; 
            neighbors[2].x = x+1; neighbors[2].y = y-1; neighbors[2].psi = psi+7;
            neighbors[3].x = x;   neighbors[3].y = y-1; neighbors[3].psi = psi+7;
            neighbors[4].x = x;   neighbors[4].y = y-1; neighbors[4].psi = psi+6;
        end
    end
    # make sure psi is between 1 and 8
    for i = 1:4
        neighbors[i].psi = mod(neighbors[i].psi,8)
        if neighbors[i].psi == 0
            neighbors[i].psi = 8;
        end
    end
    neighbors[5].x = x; neighbors[5].y = y; neighbors[5].psi = psi;
    
    # initialize done states 
    for i = 1:5 neighbors[i].done = false end
    reward_states = mdp.reward_states
    
    # detection of done states
    n = length(reward_states)
    for i = 1:n
        # terminate at target/obstacle
        if isequal(state, reward_states[i])
            fill_probability!(probability, 1.0, 5)
            neighbors[5].done = true
            return d
        end
        # terminate at boundary
        if atbounds(mdp, state)
            fill_probability!(probability, 1.0, 5)
            neighbors[5].done = true
            return d
        end
    end
    
    if !inbounds(mdp, neighbors[1]) || !inbounds(mdp, neighbors[2]) ||
        !inbounds(mdp, neighbors[3]) || !inbounds(mdp, neighbors[4]) # at least one of the neighbors is outside bounds
        fill_probability!(probability, 1.0, 5) # stuck in current state when terminated
    else # none of the neighbors is outside bounds
        probability[1:4] = tp
    end
    
    return d
end;

## Reward Model (R)

In [None]:
function POMDPs.reward(mdp::Needle, state::NeedleState, action::Symbol, statep::NeedleState) #deleted action
    if state.done
        return 0.0
    end
    r = 0.0
    reward_states = mdp.reward_states
    reward_values = mdp.reward_values
    n = length(reward_states)
    for i = 1:n
        if isequal(state, reward_states[i]) # reward, obstacle and wall states
            r += reward_values[i]
        end
    end
    r += -1; # penalty for every step taken    
    return r
end;


## Miscellaneous Functions

In [None]:
POMDPs.n_states(mdp::Needle) = 2*mdp.size_x*mdp.size_y*mdp.size_psi
POMDPs.n_actions(mdp::Needle) = 2
POMDPs.discount(mdp::Needle) = mdp.discount_factor;

In [None]:
function POMDPs.state_index(mdp::Needle, state::NeedleState)
    sd = Int(state.done + 1)
    return sub2ind((mdp.size_x, mdp.size_y, mdp.size_psi, 2, 2), state.x, state.y, state.psi, sd)
end;

In [None]:
function POMDPs.action_index(mdp::Needle, action::Symbol)
    if action == :cw
        return 1
    elseif action == :ccw
        return 2
    end
end;

In [None]:
function POMDPs.isterminal(mdp::Needle, s::NeedleState)
    s.done ? (return true) : (return false)
end;

## Value Iteration Solver

In [None]:
using DiscreteValueIteration

# initialize the problem
mdp_vi = Needle()

# initialize the solver
# max_iterations: maximum number of iterations value iteration runs for (default is 100)
# belres: the value of Bellman residual used in the solver (defualt is 1e-3)
solver = ValueIterationSolver(max_iterations=500, belres=1e-4)

# initialize the policy by passing in your problem
policy_vi = ValueIterationPolicy(mdp_vi)

# solve for an optimal policy
# if verbose=false, the text output will be supressed (false by default)
solve(solver, mdp_vi, policy_vi, verbose=true);

### Value iteration policy simulation

In [None]:
s = NeedleState(4,10,7)
hist_vi = HistoryRecorder()

r = simulate(hist_vi, mdp_vi, policy_vi, s)

println("Total discounted reward: $r")

if posequal(hist_vi.state_hist[end], mdp.reward_states[1])
    println("Target reached")
else
    println("Target missed")
end

# define tissue environment
plot([1 10 10 1 1]',[1 1 10 10 1]',linewidth=10,color="r") # tissue bounds
plot(mdp.reward_states[1].x,mdp.reward_states[1].y,marker="o",markersize=40,color="g",markeredgecolor="none")
plot(mdp.reward_states[9].x,mdp.reward_states[9].y,marker="o",markersize=40,color="r",markeredgecolor="none")

# needle trajectory
steps = length(hist_vi.state_hist)
for i = 1:steps-1 
    state = hist_vi.state_hist[i]
    action = hist_vi.action_hist[i]
    if action == :cw
        c = "y"
    else
        c = "b"
    end
    plot(state.x,state.y,color=c,marker="o",markersize=15)
    quiver(state.x,state.y,0.5*cos((state.psi-1)*pi/4),0.5*sin((state.psi-1)*pi/4))
end

title(@sprintf("Needle tip trajectory (value iteration): reward = %0.2f",r))
axis("equal")
axis([0, 11, 0, 11])
xlabel("x")
ylabel("y")
grid(true)

In [None]:
# N = 100000
# r_all = ones(1,N)
# for i = 1:N
#     r_all[i] = simulate(hist_vi, mdp, policy_vi, s)
# end
# println(value(policy_vi, s))
# mean(r_all)

## Monte-Carlo Tree Search Solver

In [None]:
using MCTS

# initialize the problem
mdp_MCTS = Needle()

# initialize the solver
# the hyper parameters in MCTS can be tricky to set properly
# n_iterations: the number of iterations that each search runs for
# depth: the depth of the tree (how far away from the current state the algorithm explores)
# exploration constant: this is how much weight to put into exploratory actions. 
# A good rule of thumb is to set the exploration constant to what you expect the upper bound on your average expected reward to be.
solver = MCTSSolver(n_iterations=100, depth=50, exploration_constant=1.0)

# initialize the policy by passing in your problem and the solver
policy_MCTS = MCTSPolicy(solver, mdp_MCTS);

### MCTS policy simulation

In [None]:
s = NeedleState(4,10,7)

hist_MCTS = HistoryRecorder()
r = simulate(hist_MCTS, mdp_MCTS, policy_MCTS, s)

println("Total discounted reward: $r")

if posequal(hist_MCTS.state_hist[end], mdp_MCTS.reward_states[1])
    println("Target reached")
else
    println("Target missed")
end

# define tissue environment
plot([1 10 10 1 1]',[1 1 10 10 1]',linewidth=10,color="r") # tissue bounds
plot(mdp.reward_states[1].x,mdp.reward_states[1].y,marker="o",markersize=40,color="g",markeredgecolor="none")
plot(mdp.reward_states[9].x,mdp.reward_states[9].y,marker="o",markersize=40,color="r",markeredgecolor="none")

# needle trajectory
steps = length(hist_MCTS.state_hist)
for i = 1:steps-1 
    state = hist_MCTS.state_hist[i]
    action = hist_MCTS.action_hist[i]
    if action == :cw
        c = "y"
    else
        c = "b"
    end
    plot(state.x,state.y,color=c,marker="o",markersize=15)
    quiver(state.x,state.y,0.5*cos((state.psi-1)*pi/4),0.5*sin((state.psi-1)*pi/4))
end

title(@sprintf("Needle tip trajectory (MCTS): reward = %0.2f",r))
axis("equal")
axis([0, 11, 0, 11])
xlabel("x")
ylabel("y")
grid(true)