Partie en cours de construction : la lecture de ce chapitre a été faite, 2 problèmes vont être implémentés avec la fonction valeur et la Q-value.

Or la résolution de ce problème est réalisable uniquement lorsque le chapitre 4 a été étudié, traitant de la programmation dynamique.

In [None]:
from typing import List, Tuple, Union

import torch

In [None]:
class GridWorld:
    def __init__(self, size: Union[int, Tuple[int, int]], reward_coordinates: List[Tuple[int, int]], reward_values: List[int], teleportation_coordinates: List[Tuple[int, int]]) -> None:
        assert len(reward_coordinates) == len(reward_values)
        assert len(reward_coordinates) == len(teleportation_coordinates)

        self.size = (size, size) if isinstance(size, int) else size
        self.reward_coordinates = reward_coordinates
        self.reward_values = reward_values
        self.teleportation_coordinates = teleportation_coordinates

        self.current_position = (0, 0)
        
        self.actions = {'top', 'bottom', 'left', 'right'}
        self.actions_proba = {'top': 1/4, 'bottom': 1/4, 'left': 1/4, 'right': 1/4}
        self._compute_actions = {
            'top': (-1, 0),
            'bottom': (1, 0),
            'left': (0, -1),
            'right': (0, 1)
        }
    
    def is_out_of_grid(self, position: Tuple[int, int]) -> bool:
        top_out = position[0] > self.size[0] - 1
        bottom_out = position[0] < 0
        right_out = position[1] > self.size[1] - 1
        left_out = position[1] < 0
        if top_out or bottom_out or right_out or left_out:
            return True
        return False
    
    def _is_teleportation_state(self, position: Tuple[int, int]) -> bool:
        return position in self.reward_coordinates
    
    def _get_teleport_state(self, position: Tuple[int, int]) -> Tuple[int, int]:
        teleport_idx = self.reward_coordinates.index(position)
        return self.teleportation_coordinates[teleport_idx]
    
    def choose_action_from_position(self, action: str, position: Tuple[int, int]) -> Tuple[int, int]:
        if action not in self.actions:
            raise ValueError(f'{action} is not in available actions : {self.actions}')
        
        if self._is_teleportation_state(position):
            new_position = self._get_teleport_state(position)
        else:
            new_position = (
                    position[0] + self._compute_actions[action][0],
                    position[1] + self._compute_actions[action][1]
                )
        
        return new_position
    
    def choose_action(self, action: str) -> Tuple[int, int]:
        if action not in self.actions:
            raise ValueError(f'{action} is not in available actions : {self.actions}')

        new_position = (
                self.current_position[0] + self._compute_actions[action][0],
                self.current_position[1] + self._compute_actions[action][1]
            )
        
        return new_position
    
    def execute_action(self, action: str) -> None:
        new_position = self.choose_action(action)
        if self._is_teleportation_state(new_position):
            self.current_position = self._get_teleport_state(new_position)
        elif not self.is_out_of_grid(new_position):
            self.current_position = new_position

    def choose_action_get_reward(self, action: str, from_position: Tuple[int, int]) -> float:
        new_position = self.choose_action(action)
        return self.get_state_reward(new_position, from_position)
    
    def get_state_reward(self, position: Tuple[int, int], from_position: Tuple[int, int]) -> float:
        if self.is_out_of_grid(position):
            return -1
        elif from_position in self.reward_coordinates:
            reward_idx = self.reward_coordinates.index(from_position)
            return self.reward_values[reward_idx]
        else:
            return 0

    def get_current_state(self) -> Tuple[int, int]:
        return self.current_position

In [None]:
class GridWorldAgent:
    def __init__(self, gridworld: GridWorld, method: str, discount: float=0.1, threshold: float=None, max_iterations: int=10000) -> None:
        self.gridworld = gridworld
        self.method = method
        self.states_values = torch.zeros(size=gridworld.size)
        # for reward_position in self.gridworld.reward_coordinates:
        #     reward_index = self.gridworld.reward_coordinates.index(reward_position)
        #     reward_value = self.gridworld.reward_values[reward_index]
        #     self.states_values[reward_position[0], reward_position[1]] = reward_value

        self.discount = discount
        self.threshold = threshold
        self.max_iterations = max_iterations
    
    def _continue_search(self, old_values: torch.Tensor, new_values: torch.Tensor, iteration: int) -> bool:
        if self.threshold is None:
            threshold_reached = False
        else:
            threshold_reached = torch.linalg.norm((old_values - new_values).flatten(), ord=1) > self.threshold
        return not threshold_reached and iteration < self.max_iterations
    
    def _update_value_function(self, value_function: torch.Tensor, row_index: int, column_index: int) -> float:
        value = 0
        for action in self.gridworld.actions:
            p_a_s = self.gridworld.actions_proba[action]
            next_state_position = self.gridworld.choose_action_from_position(action, (row_index, column_index))
            next_state_reward = self.gridworld.get_state_reward(next_state_position, (row_index, column_index))
            if self.gridworld.is_out_of_grid(next_state_position):
                old_value = 0
            else:
                old_value = value_function[next_state_position]
                
            value += p_a_s * (next_state_reward + self.discount * old_value)
        
        return value
    
    def _build_policy_by_value_function(self) -> None:
        old_values = self.states_values.clone()
        old_values = torch.fill(old_values, float('inf'))
        new_values = self.states_values.clone()
        
        iteration = 1
        while self._continue_search(old_values, new_values, iteration):
            old_values = new_values.clone()
            for i in range(self.states_values.shape[0]):
                for j in range(self.states_values.shape[1]):
                    new_values[i, j] = self._update_value_function(old_values, i, j)
            
            iteration += 1
        
        self.states_values = new_values.clone()

    def _compute_q_value(self, position: Tuple[int, int], action: str) -> float:
        next_position = self.gridworld.choose_action_from_position(action, position)
        return self.gridworld.get_state_reward(next_position, position) + self.discount * self.states_values[next_position]

    def choose_action_by_q_value(self, position: Tuple[int, int]) -> str:
        possible_actions = list(self.gridworld.actions)
        q_values = torch.zeros(size=(len(possible_actions),))

        for action_idx, action in enumerate(possible_actions):
            q_values[action_idx] = self._compute_q_value(position, action)

        best_action = possible_actions[torch.argmax(q_values).item()]
        return best_action
    
    def build_policy(self) -> None:
        self._build_policy_by_value_function()

In [None]:
environment = GridWorld(size=(5, 5), reward_coordinates=[(0, 1), (0, 3)], reward_values=[10, 5], teleportation_coordinates=[(4, 1), (2, 3)])
agent = GridWorldAgent(gridworld=environment, method='value function', discount=0.9, threshold=None, max_iterations=10000)

agent.build_policy()

In [None]:
torch.set_printoptions(precision=2, sci_mode=False)
agent.states_values

In [None]:
agent.choose_action_by_q_value((0, 2))