## CHEME 5660 Lab 8: Markov Decision Process (MDP) solution of the Branched Tiger Problem

<img src="./figs/Fig-Branched-MDP-Schematic-no-a-labels.png" style="margin:auto; width:60%"/>

## Introduction
Fill me in

## Lab 8 setup

In [1]:
import Pkg; Pkg.activate("."); Pkg.resolve(); Pkg.instantiate();

[32m[1m  Activating[22m[39m project at `~/Desktop/julia_work/CHEME-5660-Markets-Mayhem-Example-Notebooks/labs/lab-8-MDP-Tiger-Problem`
[32m[1m  No Changes[22m[39m to `~/Desktop/julia_work/CHEME-5660-Markets-Mayhem-Example-Notebooks/labs/lab-8-MDP-Tiger-Problem/Project.toml`
[32m[1m  No Changes[22m[39m to `~/Desktop/julia_work/CHEME-5660-Markets-Mayhem-Example-Notebooks/labs/lab-8-MDP-Tiger-Problem/Manifest.toml`


In [2]:
# load req packages -
using PrettyTables

In [3]:
include("CHEME-5660-Lab-8-CodeLib.jl");

In [4]:
# setup some global constants -
α = 0.95; # probability of moving the direction we are expect

In [5]:
# setup the states and actions -
safety = 1;
tiger = 15;
left_reward = -100.0;
right_reward = 1000.0;

# by pass costs -
top_hallway_cost = -20.0;
bottom_hallway_cost = -1.0;
blocked = -10000.0;

states = range(safety,stop=tiger, step=1) |> collect;
actions = [1,2,3,4]; # a₁ = move left, a₂ = move right, a₃ = move up, a₄ = move down
γ = 0.95;

# scenario flag -
bottom_hallway_cave_in_flag = false;

#### Configure the rewards array

In [6]:
# setup the rewards -
R = Array{Float64,2}(undef,length(states), length(actions));

# most of the rewards are zero -
fill!(R,0.0) # fill R w/zeros

# set the rewards for the ends -
R[safety + 1, 1] = left_reward; # if in state 2, and we take action 1 = we live, get married, our kids are all doctors, and we are generally content
R[tiger - 1, 2] = right_reward; # if in state N - 1, and we take action 2 = we get eaten. Bad.

# linear nodes are blocked -
list_linear_nodes = range(10,stop=(tiger-1),step = 1) |> collect
for i ∈ list_linear_nodes
    R[i,3] = blocked;
    R[i,4] = blocked;
end

# actions:
# a₁ = left
# a₂ = right
# a₃ = up
# a₄ = down

# rewards for the by-passes. 
R[2,3] = top_hallway_cost;
R[2,4] = bottom_hallway_cost;
R[2,2] = blocked;
R[9,1] = blocked;
R[9,3] = top_hallway_cost;
R[9,4] = bottom_hallway_cost;

# top -
R[3,1] = blocked;
R[3,2] = top_hallway_cost
R[3,3] = blocked;
R[4,1] = top_hallway_cost
R[4,2] = top_hallway_cost
R[4,3] = blocked;
R[4,4] = blocked;
R[5,1] = top_hallway_cost
R[5,2] = blocked;
R[5,3] = blocked;

# bottom -
R[6,1] = blocked;
R[6,2] = bottom_hallway_cost;
R[6,4] = blocked;
R[7,1] = bottom_hallway_cost;
R[7,2] = bottom_hallway_cost;
R[7,3] = blocked;
R[7,4] = blocked;
R[8,1] = bottom_hallway_cost;
R[8,2] = blocked;
R[8,4] = blocked;

#### Configure the transition array $T$

In [7]:
# Setup the transitions
T = Array{Float64,3}(undef, length(states), length(states), length(actions));
fill!(T,0.0);

# We need to put values into the transition array (these are probabilities, so eah row much sum to 1)
T[safety, 1, 1:length(actions)] .= 1.0; # if we are in state 1, we stay in state 1 ∀a ∈ 𝒜
T[tiger, tiger, 1:length(actions)] .= 1.0; # if we are in state 5, we stay in state 5 

# left actions -
for s ∈ 2:(tiger - 1)
    T[s,s-1,1] = α;
    T[s,s+1,1] = (1-α);
end

# right actions -
for s ∈ 2:(tiger - 1)
    T[s,s-1,2] = (1-α);
    T[s,s+1,2] = α; 
end

# Node 2 -
T[2,:,2] .= 0.0
T[2,3,3] = α;
T[2,6,3] = (1-α);
T[2,6,4] = α;
T[2,3,4] = (1-α);
T[2,2,2] = 1.0; # node 2 can't go right

# Node 3 -
T[3,:,:] .= 0.0
T[3,3,1] = 1.0;
T[3,4,2] = α;
T[3,3,2] = (1-α);
T[3,3,3] = 1.0;
T[3,2,4] = α;
T[3,3,4] = (1-α);

# Node 4 -
T[4,4,3] = 1.0;
T[4,4,4] = 1.0;

# Node 5 -
T[5,:,:] .= 0.0;
T[5,4,1] = α;
T[5,5,1] = (1-α);
T[5,5,2] = 1.0;
T[5,5,3] = 1.0;
T[5,9,4] = α;
T[5,5,4] = (1-α);

# Node 6 -
T[6,:,:] .= 0.0;
T[6,6,1] = 1.0;
T[6,7,2] = α;
T[6,6,2] = (1-α);
T[6,2,3] = α;
T[6,6,3] = (1-α);
T[6,6,4] = 1.0;

# Node 7 -
T[7,7,3] = 1.0;
T[7,7,4] = 1.0;

# Node 8 -
T[8,:,:] .= 0.0
T[8,7,1] = α;
T[8,8,1] = (1-α);
T[8,8,2] = 1.0;
T[8,9,3] = α;
T[8,8,3] = (1-α);
T[8,8,4] = 1.0;

# Node 9 -
T[9,:,:] .= 0.0;
T[9,9,1] = 1.0;
T[9,10,2] = α;
T[9,9,2] = (1-α);
T[9,5,3] = α;
T[9,9,3] = (1-α);
T[9,8,4] = α;
T[9,9,4] = (1-α);

# Nodes 10 -> end
for s = 10:(tiger-1)
    T[s,s,3] = 1.0
    T[s,s,4] = 1.0
end

In [8]:
# re-configure setup to include bttom hallway cave in
if (bottom_hallway_cave_in_flag == true)
    
    # configure rewards -
    R[6,2] = blocked;
    R[8,1] = blocked;

    # configure T array -
    # node 6 -
    T[6,:,:] .= 0.0;
    T[6,6,1] = 1.0; # we stay in 6 if we go left
    T[6,6,2] = 1.0; # we stay in 6 if we go right
    T[6,2,3] = α; 
    T[6,6,3] = (1-α);
    T[6,6,4] = 1.0; # we stay in 6 if we go down
    
    # node 8 -
    T[8,:,:] .= 0.0;
    T[8,9,3] = α; 
    T[8,8,3] = (1 - α); 
    T[8,8,1] = 1.0; # we stay in 8 if we go left
    T[8,8,2] = 1.0; # we stay in 8 if we go right
    T[8,8,4] = 1.0; # we stay in 8 if we go down
end

In [9]:
# row summation check -
T_array_check_table = Array{Any,2}(undef, length(states), length(actions)+1)

for s ∈ 1:length(states)
    T_array_check_table[s,1] = s;
end

for a ∈ 1:length(actions)
    
    # sum the action table -
    Z = sum(T[:,:,a], dims=2)
    
    for s ∈ 1:length(states)
        T_array_check_table[s,a+1] = Z[s]
    end
end

# header -
T_check_header = (["State s", "a₁ (left)", "a₂ (right)", "a₃ (up)", "a₄ (down)"]);

# display -
pretty_table(T_array_check_table; header=T_check_header)

┌─────────┬───────────┬────────────┬─────────┬───────────┐
│[1m State s [0m│[1m a₁ (left) [0m│[1m a₂ (right) [0m│[1m a₃ (up) [0m│[1m a₄ (down) [0m│
├─────────┼───────────┼────────────┼─────────┼───────────┤
│       1 │       1.0 │        1.0 │     1.0 │       1.0 │
│       2 │       1.0 │        1.0 │     1.0 │       1.0 │
│       3 │       1.0 │        1.0 │     1.0 │       1.0 │
│       4 │       1.0 │        1.0 │     1.0 │       1.0 │
│       5 │       1.0 │        1.0 │     1.0 │       1.0 │
│       6 │       1.0 │        1.0 │     1.0 │       1.0 │
│       7 │       1.0 │        1.0 │     1.0 │       1.0 │
│       8 │       1.0 │        1.0 │     1.0 │       1.0 │
│       9 │       1.0 │        1.0 │     1.0 │       1.0 │
│      10 │       1.0 │        1.0 │     1.0 │       1.0 │
│      11 │       1.0 │        1.0 │     1.0 │       1.0 │
│      12 │       1.0 │        1.0 │     1.0 │       1.0 │
│      13 │       1.0 │        1.0 │     1.0 │       1.0 │
│      14 │     

In [10]:
mdp_problem = build(MDPProblem; 𝒮 = states, 𝒜 = actions, T = T, R = R, γ = γ);

In [11]:
U = solve(mdp_problem,1000);

In [12]:
# compute the Q array -
Q_array = Q(mdp_problem, U)

# compute the policy -
policy = π(Q_array);

In [13]:
# make a Q-table -
Q_table_data_array = Array{Any,2}(undef, length(states), length(actions)+3)

for s ∈ 1:length(states)
    
    Q_table_data_array[s,1] = s;
    
    direction = "left"
    policy_index = policy[s];
    if policy_index == 2
        direction = "right" 
    elseif policy_index == 3
         direction = "up" 
    elseif policy_index == 4
         direction = "down" 
    elseif policy_index == 0
        direction = "stop"
    end
    
    Q_table_data_array[s,2] = direction;
    Q_table_data_array[s,3] = policy_index;
    
    
    for a ∈ 1:length(actions)
        Q_table_data_array[s,a+3] = Q_array[s,a];
    end
end

# header -
Q_table_header = (["State", "Direction", "π(s)", "U(a₁) l", "U(a₂) r", "U(a₃) u", "U(a₄) d"])

# show -
pretty_table(Q_table_data_array; header = Q_table_header)

┌───────┬───────────┬──────┬──────────┬──────────┬──────────┬──────────┐
│[1m State [0m│[1m Direction [0m│[1m π(s) [0m│[1m  U(a₁) l [0m│[1m  U(a₂) r [0m│[1m  U(a₃) u [0m│[1m  U(a₄) d [0m│
├───────┼───────────┼──────┼──────────┼──────────┼──────────┼──────────┤
│     1 │      stop │    0 │      0.0 │      0.0 │      0.0 │      0.0 │
│     2 │      down │    4 │ -47.5291 │ -8950.58 │  1029.47 │  1049.42 │
│     3 │      down │    4 │ -8950.58 │  1013.32 │ -8950.58 │  1049.42 │
│     4 │     right │    2 │  1029.58 │  1032.47 │ -8967.53 │ -8967.53 │
│     5 │      down │    4 │  1013.48 │ -8947.37 │ -8947.37 │  1052.63 │
│     6 │     right │    2 │ -8949.53 │  1050.47 │  1049.47 │ -8949.53 │
│     7 │     right │    2 │  1049.58 │  1051.52 │ -8948.48 │ -8948.48 │
│     8 │        up │    3 │  1050.58 │ -8947.37 │  1052.63 │ -8947.37 │
│     9 │     right │    2 │ -8947.37 │  1052.63 │  1032.63 │  1051.63 │
│    10 │      left │    1 │  1052.63 │  1052.63 │ -8947.37 │ -8947.

<img src="./figs/Fig-Branched-MDP-Schematic-no-a-labels.png" style="width:50%; align:left"/>