# Hex world problem

The hex world problem is a tuple of attributes:
1. hexes: contain all tiles of the game, each tile is represent as pair of integer (coordinate in the board)
2. reward hexes: is a subset of hexes, contain tiles which has reward (also these tiles lead to terminal state) and its correspond reward.
3. reward border: when moving, if agent hit the wall, he will receive this "reward border" (usually negative as a penalty).
4. p_intended: when moving at direction $i$, there are p_intended chance to move as intended and (1 - p_intended)/2 chance to move to two side direction $(i - 1) % 6 + 1$ and $(i + 1) % 6 + 1$.
5. $\gamma$: discount factor for lookahead function.

In [4]:
struct HexWorld
    # each state is coordinate (x, y) of current hex tile
    hexes::Vector{Tuple{Int,Int}}
    # some state has reward 
    reward_hexes::Dict{Tuple{Int, Int}, Float64}
    # reward if hit the wall
    reward_border::Float64
    # probability of successfully moving chosen direction
    p_intended::Float64
    # discount factor
    γ::Float64
end

## Example Hex World

Here we have two example of hex world from algorithm book

In [5]:
const REWARD_BORDER = -1.0 # Reward for falling off hex map
const P_INTENDED = 0.7 # Probability of going intended direction
const DISCOUNT_FACTOR = 0.9

const general_hexworld = HexWorld(
    [(1,0), (2,0), (3,0), (4,0), (5,0), (6,0), (7,0), (8,0), (9,0), (10,0),
    (1,1), (2,1), (3,1), (5,1), (8,1), (9,1),
    (0,2), (1,2), (2,2), (3,2), (5,2), (6,2), (7,2), (9,2)],
    Dict{Tuple{Int,Int}, Float64}(
        (1,1) =>  5.0,
        (2,2) => -10.0,
        (9,2) =>  10.0,
    ),
    REWARD_BORDER,
    P_INTENDED,
    DISCOUNT_FACTOR
)

const straight_line_hexworld = HexWorld(
    [(0,0), (1,0), (2,0), (3,0), (4,0), (5,0), (6,0)],
    Dict{Tuple{Int,Int}, Float64}(
        (6,0) => 10.0, # right side reward
    ),
    REWARD_BORDER,
    P_INTENDED,
    DISCOUNT_FACTOR
)



HexWorld([(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)], Dict((6, 0) => 10.0), -1.0, 0.7, 0.9)

## State Space

Each hex tile is consider as a state, plus one addition terminal state.

## Action Space

Each hex tile can move to six direction, numbered from 1 to 6, store in `hex_direction` array

In [6]:
# array represent valid direction for hex tile
const hex_direction = [(+1, 0), (+1, -1), (0, -1), (-1, 0), (-1, +1), (0, +1)]
const direction2char = [">", "^>", "<^", "<", "<v", "v>"]



6-element Vector{String}:
 ">"
 "^>"
 "<^"
 "<"
 "<v"
 "v>"

## Reward Function

For state that is assigned reward, $R(s)$ is that assigned value, for other state, reward is $0$.

## Transition Function

## Value Iterative

Value iteration is an alternative to policy iteration that is often used because of its simplicity. Unlike policy improvement, value iteration updates the value function directly.

The value function can be improved by applying the Bellman equation:
$$U_{k+1} = \max_a\left(R(s,a) + \gamma \sum_{s'}T(s'|s,a)U_k(s')\right)$$

In [7]:
using Base
Base.:+(x::Tuple{Int, Int}, y::Tuple{Int, Int}) = (x[1] + y[1], x[2] + y[2])

function get_Qvalue(hex::Tuple{Int,Int}, 
                    a::Int, 
                    p::Float64,
                    hex_world::HexWorld,
                    hex2state::Dict{Tuple{Int, Int}, Int},
                    U::Vector{Float64})
    next_hex = hex + hex_direction[a]
    if !haskey(hex2state, next_hex)
        return p * hex_world.reward_border
    else
        new_s = hex2state[next_hex]
        return hex_world.γ * p * U[new_s]
    end
end

function bellman_equation(hex_world::HexWorld, 
                            U::Vector{Float64}, 
                            hex2state::Dict{Tuple{Int, Int}, Int})
    nS = length(hex_world.hexes)
    nA = length(hex_direction)
    
    Q = zeros(nS, nA)
    p_veer = (1.0 - hex_world.p_intended) / 2
    p_intended = hex_world.p_intended
    for hex in hex_world.hexes
        s = hex2state[hex]
        for a = 1:nA
            if haskey(hex_world.reward_hexes, hex)
                Q[s, a] = hex_world.reward_hexes[hex]
            else
                Q[s,a] += get_Qvalue(hex, a, p_intended, hex_world, hex2state, U)
                Q[s,a] += get_Qvalue(hex, mod1(a - 1, nA), p_veer, hex_world, hex2state, U)
                Q[s,a] += get_Qvalue(hex, mod1(a + 1, nA), p_veer, hex_world, hex2state, U)
            end
        end
    end
    
    new_U, new_policy = findmax(Q, dims = 2)

    return vec(new_U), map(x -> x[2], new_policy)
end

bellman_equation (generic function with 1 method)

In [8]:
using LinearAlgebra

function is_terminal(U::Vector{Float64}, new_U::Vector{Float64}, threshold::Float64)
    # println(LinearAlgebra.norm(U - new_U, 1))
    return LinearAlgebra.norm(U - new_U, 1) <= threshold
end

is_terminal (generic function with 1 method)

In [9]:
function get_optimal_policy(hex_world::HexWorld, max_step::Int, threshold::Float64)
    nS = length(hex_world.hexes)
    U = zeros(nS)
    policy = zeros(nS)
    hex2state = Dict{Tuple{Int, Int}, Int}()
    for (s, hex) in enumerate(hex_world.hexes)
        hex2state[hex] = s
    end
    
    for step = 1 : max_step
        new_U, policy = bellman_equation(hex_world, U, hex2state)
        # println(step, ": ", map(x -> direction2char[x], policy))

        if is_terminal(U, new_U, threshold)
            println(step)
            break
        end

        U = new_U
    end

    return policy
end

get_optimal_policy (generic function with 1 method)

In [10]:
optimal_policy = get_optimal_policy(general_hexworld, 10000, 0.0)

for (i, direction) in enumerate(optimal_policy)
    println(general_hexworld.hexes[i], " ", direction2char[direction])
end

# optimal_policy = get_optimal_policy(straight_line_hexworld, 100000)
# print(optimal_policy)
# for x in optimal_p olicy
#     println(straight_line_hexworld.hexes[x[1]], " ", direction2char[x[2]])
# end

45
(1, 0) v>
(2, 0) <v
(3, 0) <v
(4, 0) <
(5, 0) <
(6, 0) >
(7, 0) >
(8, 0) >
(9, 0) v>
(10, 0) <v
(1, 1) >
(2, 1) <
(3, 1) <^
(5, 1) <^
(8, 1) ^>
(9, 1) v>
(0, 2) ^>
(1, 2) <^
(2, 2) >
(3, 2) <^
(5, 2) >
(6, 2) >
(7, 2) ^>
(9, 2) >
