# Lero's Quest
## Value Iteration Algorithm

In [112]:
import numpy as np
from copy import deepcopy
from functools import reduce
from operator import add

In [113]:
HEALTH_RANGE = 5
ARROWS_RANGE = 4
STAMINA_RANGE = 3

HEALTH_VALUES = tuple(range(HEALTH_RANGE))
ARROWS_VALUES = tuple(range(ARROWS_RANGE))
STAMINA_VALUES = tuple(range(STAMINA_RANGE))

HEALTH_FACTOR = 25 # 0, 25, 50, 75, 100
ARROWS_FACTOR = 1 # 0, 1, 2, 3
STAMINA_FACTOR = 50 # 0, 50, 100

NUM_ACTIONS = 3
ACTION_SHOOT = 0
ACTION_DODGE = 1
ACTION_RECHARGE = 2

PRIZE = 10
COST = -10/85

GAMMA = 0.99
DELTA = 1e-10

In [114]:
class State:
    def __init__(self, enemy_health, num_arrows, stamina):
        if (enemy_health not in HEALTH_VALUES) or (num_arrows not in ARROWS_VALUES) or (stamina not in STAMINA_VALUES):
            raise ValueError
        
        self.health = enemy_health 
        self.arrows = num_arrows 
        self.stamina = stamina 

    def show(self):
        return (self.health, self.arrows, self.stamina)

    def get_index(self):
        return ((ARROWS_RANGE + STAMINA_RANGE) * self.health +
                STAMINA_RANGE * self.arrows +
                self.stamina)
    
    def is_final(self):
        return (self.health == 0)

    def __str__(self):
        return f'({self.health * HEALTH_FACTOR}, {self.arrows * ARROWS_FACTOR}, {self.stamina * STAMINA_FACTOR})'

    @classmethod
    def from_index(cls, index):
        if index not in range(60):
            raise ValueError

        enemy_health = index // (ARROWS_RANGE + STAMINA_RANGE)
        index = index % (ARROWS_RANGE + STAMINA_RANGE)

        num_arrows = index // STAMINA_RANGE
        index = index % STAMINA_RANGE

        stamina = index

        return State(enemy_health, num_arrows, stamina)


In [115]:
REWARD = np.zeros((HEALTH_RANGE, ARROWS_RANGE, STAMINA_RANGE))
REWARD[0, :, :] = PRIZE

In [116]:
def action(action_type, state):
    # returns cost, array of tuple of (probability, state)
    print(state)
    state = State(*state)

    if action_type == ACTION_SHOOT:
        if state.arrows == 0 or state.stamina == 0:
            return None, None

        new_arrows = state.arrows - 1
        new_stamina = state.stamina - 1

        choices = []
        choices.append((0.5, State(max(HEALTH_VALUES[0],state.health-1), new_arrows, new_stamina)))
        choices.append((0.5, State(state.health, new_arrows, new_stamina)))

    elif action_type == ACTION_RECHARGE:
        choices = []
        choices.append((0.8, State(state.health, state.arrows, min(STAMINA_VALUES[-1], state.stamina+1))))
        choices.append((0.2, State(state.health, state.arrows, state.stamina)))

    elif action_type == ACTION_DODGE:
        if state.stamina == 0:
            return None, None

        if state.stamina == 2: #if stamina is 100  
            choices = []
            choices.append((0.64, State(state.health,min(ARROWS_VALUES[-1], state.arrows),state.stamina -1)))
            choices.append((0.16, State(state.health,state.arrows,STAMINA_VALUES[1])))
            choices.append((0.04, State(state.health, state.arrows,STAMINA_VALUES[0])))
            choices.append((0.16, State(state.health, min(ARROWS_VALUES[-1] ,state.arrows), STAMINA_VALUES[0])))

        elif state.stamina == 1: # if stamina is 50
            choices = []
            choices.append((0.2, State(state.health,state.arrows,0)))
            choices.append((0.8, State(state.health,min(state.arrows+1,ARROWS_VALUES[-1]),0)))

    cost = 0
    for choice in choices:
        cost += choice[0] * (COST + REWARD[choice[1].show()])
        
    return cost, choices


In [117]:
def show(i, utilities, policies):
    print(f'iteration={i}')

    for state, util in np.ndenumerate(utilities):
        if policies[state] == ACTION_SHOOT:
            act_str = 'SHOOT'
        elif policies[state] == ACTION_DODGE:
            act_str = 'DODGE'
        elif policies[state] == ACTION_RECHARGE:
            act_str = 'RECHARGE'
        
        util_str = '{:.3f}'.format(util)
        print(f'{state}:{act_str}=[{util_str}]')
        print('')

In [118]:
def value_iteration():
    utilities = np.zeros((HEALTH_RANGE, ARROWS_RANGE, STAMINA_RANGE), dtype='double')
    policies = np.full((HEALTH_RANGE, ARROWS_RANGE, STAMINA_RANGE), -1, dtype='int')

    index = 0
    while True: # one iteration of value iteration
        temp = np.zeros(utilities.shape, dtype='double')
        delta = 0
        
        for state, util in np.ndenumerate(utilities):
            new_util = np.NINF

            for act_index in range(NUM_ACTIONS):
                cost, states = action(act_index, state)
                
                if cost is None:
                    continue

                expected_util = reduce(add, map(lambda x: x[0]*utilities[x[1].show()], states))
                new_util = max(new_util, cost + GAMMA * expected_util)
            
            temp[state] = new_util
            delta = max(delta, abs(util - new_util))
        
        utilities = deepcopy(temp)

        for state, _ in np.ndenumerate(utilities):
            best_util = np.NINF
            best_action = None

            for act_index in range(NUM_ACTIONS):
                states = action(act_index, state)[1]

                if states is None:
                    continue

                action_util = reduce(add, map(lambda x: x[0]*utilities[x[1].show()], states))

                if action_util > best_util:
                    best_action = act_index
                    best_util = action_util

            policies[state] = best_action

        show(index, utilities, policies)
        index +=1
        if delta < DELTA:
            break
        
    return utilities, policies

In [119]:
value_iteration()

, 1, 1)
(3, 1, 2)
(3, 1, 2)
(3, 1, 2)
(3, 2, 0)
(3, 2, 0)
(3, 2, 0)
(3, 2, 1)
(3, 2, 1)
(3, 2, 1)
(3, 2, 2)
(3, 2, 2)
(3, 2, 2)
(3, 3, 0)
(3, 3, 0)
(3, 3, 0)
(3, 3, 1)
(3, 3, 1)
(3, 3, 1)
(3, 3, 2)
(3, 3, 2)
(3, 3, 2)
(4, 0, 0)
(4, 0, 0)
(4, 0, 0)
(4, 0, 1)
(4, 0, 1)
(4, 0, 1)
(4, 0, 2)
(4, 0, 2)
(4, 0, 2)
(4, 1, 0)
(4, 1, 0)
(4, 1, 0)
(4, 1, 1)
(4, 1, 1)
(4, 1, 1)
(4, 1, 2)
(4, 1, 2)
(4, 1, 2)
(4, 2, 0)
(4, 2, 0)
(4, 2, 0)
(4, 2, 1)
(4, 2, 1)
(4, 2, 1)
(4, 2, 2)
(4, 2, 2)
(4, 2, 2)
(4, 3, 0)
(4, 3, 0)
(4, 3, 0)
(4, 3, 1)
(4, 3, 1)
(4, 3, 1)
(4, 3, 2)
(4, 3, 2)
(4, 3, 2)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 1)
(0, 0, 1)
(0, 0, 1)
(0, 0, 2)
(0, 0, 2)
(0, 0, 2)
(0, 1, 0)
(0, 1, 0)
(0, 1, 0)
(0, 1, 1)
(0, 1, 1)
(0, 1, 1)
(0, 1, 2)
(0, 1, 2)
(0, 1, 2)
(0, 2, 0)
(0, 2, 0)
(0, 2, 0)
(0, 2, 1)
(0, 2, 1)
(0, 2, 1)
(0, 2, 2)
(0, 2, 2)
(0, 2, 2)
(0, 3, 0)
(0, 3, 0)
(0, 3, 0)
(0, 3, 1)
(0, 3, 1)
(0, 3, 1)
(0, 3, 2)
(0, 3, 2)
(0, 3, 2)
(1, 0, 0)
(1, 0, 0)
(1, 0, 0)
(1, 0, 1)
(1, 0, 1)
(1, 0, 1)
(1

(array([[[988.23529411, 988.23529411, 988.23529411],
         [988.23529411, 988.23529411, 988.23529411],
         [988.23529411, 988.23529411, 988.23529411],
         [988.23529411, 988.23529411, 988.23529411]],
 
        [[903.05548989, 914.60624994, 903.05548989],
         [929.19054293, 941.07129102, 946.78891725],
         [941.96608681, 954.00814227, 959.88911258],
         [948.21112948, 960.3320365 , 966.29285395]],
 
        [[816.76232482, 827.2235247 , 816.76232482],
         [840.43211041, 851.19217122, 862.08809139],
         [864.77810659, 875.84556634, 887.05276675],
         [882.92415218, 894.22072868, 905.65993871]],
 
        [[738.60902195, 748.08343771, 738.60902195],
         [760.04608386, 769.79116959, 779.65929933],
         [782.09557076, 792.11905911, 802.26910664],
         [804.77497873, 815.08482323, 825.52484254]],
 
        [[667.82776432, 676.40847733, 667.82776432],
         [687.24271092, 696.06856214, 705.00585088],
         [707.21231369, 716.290306