Partie en cours de construction : la lecture de ce chapitre a été faite, 2 problèmes vont être implémentés avec la fonction valeur et la Q-value.

Or la résolution de ce problème est réalisable uniquement lorsque le chapitre 4 a été étudié, traitant de la programmation dynamique.

In [None]:
from typing import List, Tuple, Union

import torch

In [None]:
class GridWorld:
    def __init__(self, size: Union[int, Tuple[int, int]], reward_coordinates: List[Tuple[int, int]], reward_values: List[int], teleportation_coordinates: List[Tuple[int, int]]) -> None:
        assert len(reward_coordinates) == len(reward_values)
        assert len(reward_coordinates) == len(teleportation_coordinates)

        self.size = (size, size) if isinstance(size, int) else size
        self.reward_coordinates = reward_coordinates
        self.reward_values = reward_values
        self.teleportation_coordinates = teleportation_coordinates

        self.current_position = (0, 0)
        
        self.actions = {'top', 'bottom', 'left', 'right'}
        self.actions_proba = {'top': 1/4, 'bottom': 1/4, 'left': 1/4, 'right': 1/4}
        self._compute_actions = {
            'top': (1, 0),
            'bottom': (-1, 0),
            'left': (0, -1),
            'right': (0, 1)
        }
    
    def is_out_of_grid(self, position: Tuple[int, int]) -> bool:
        top_out = position[0] > self.size[0] - 1
        bottom_out = position[0] < -1
        right_out = position[1] > self.size[1] - 1
        left_out = position[1] < -1
        if top_out or bottom_out or right_out or left_out:
            return True
        return False
    
    def _is_teleportation_state(self, position: Tuple[int, int]) -> bool:
        return position in self.reward_coordinates
    
    def _get_teleport_state(self, position: Tuple[int, int]) -> Tuple[int, int]:
        teleport_idx = self.reward_coordinates.index(position)
        return self.teleportation_coordinates[teleport_idx]
    
    def choose_action_from_position(self, action: str, position: Tuple[int, int]) -> Tuple[int, int]:
        if action not in self.actions:
            raise ValueError(f'{action} is not in available actions : {self.actions}')
        
        # if self._is_teleportation_state(position):
        #     position = self._get_teleport_state(position)

        new_position = (
                position[0] + self._compute_actions[action][0],
                position[1] + self._compute_actions[action][1]
            )
        
        return new_position
    
    def choose_action(self, action: str) -> Tuple[int, int]:
        if action not in self.actions:
            raise ValueError(f'{action} is not in available actions : {self.actions}')

        new_position = (
                self.current_position[0] + self._compute_actions[action][0],
                self.current_position[1] + self._compute_actions[action][1]
            )
        
        return new_position
    
    def execute_action(self, action: str) -> None:
        new_position = self.choose_action(action)
        if self._is_teleportation_state(new_position):
            self.current_position = self._get_teleport_state(new_position)
        elif not self.is_out_of_grid(new_position):
            self.current_position = new_position

    def choose_action_get_reward(self, action: str) -> float:
        new_position = self.choose_action(action)
        return self.get_state_reward(new_position)
    
    def get_state_reward(self, position: Tuple[int, int]) -> float:
        if position in self.reward_coordinates:
            reward_idx = self.reward_coordinates.index(position)
            return self.reward_values[reward_idx]
        elif self.is_out_of_grid(position):
            return -1
        else:
            return 0

    def get_current_state(self) -> Tuple[int, int]:
        return self.current_position

In [None]:
class GridWorldAgent:
    ACCEPTED_METHODS = ['value function', 'q function']
    def __init__(self, gridworld: GridWorld, method: str, discount: float=0.1, threshold: float=0.01) -> None:
        self._check_method(method)

        self.gridworld = gridworld
        self.method = method
        self.states_values = torch.zeros(size=gridworld.size)
        self.discount = discount
        self.threshold = threshold
    
    def _check_method(self, method: str) -> None:
        if method not in self.ACCEPTED_METHODS:
            raise ValueError(f'{method} not in accepted methods : {self.ACCEPTED_METHODS}')
    
    def _continue_search(self, old_values: torch.Tensor, new_values: torch.Tensor) -> bool:
        return torch.linalg.norm((old_values - new_values).flatten(), ord=1) > self.threshold
    
    def _update_value_function(self, value_function: torch.Tensor, row_index: int, column_index: int) -> torch.Tensor:
        value = 0
        for action in self.gridworld.actions:
            p_a_s = self.gridworld.actions_proba[action]
            next_state_position = self.gridworld.choose_action_from_position(action, (row_index, column_index))
            next_state_reward = self.gridworld.get_state_reward(next_state_position)
            if self.gridworld.is_out_of_grid(next_state_position):
                old_value = 0
            else:
                old_value = value_function[next_state_position[0], next_state_position[1]]
                
            value += p_a_s * (next_state_reward + self.discount * old_value)
        
        return value
    
    def _build_policy_by_value_function(self) -> None:
        old_values = self.states_values.clone()
        old_values = torch.fill(old_values, float('inf'))
        new_values = self.states_values.clone()

        while self._continue_search(old_values, new_values):
            old_values = new_values.clone()
            for i in range(self.states_values.shape[0]):
                for j in range(self.states_values.shape[1]):
                    new_values[i, j] = self._update_value_function(new_values, i, j)

        self.states_values = new_values.clone()

    def _build_policy_by_q_value(self) -> None:
        ...

    def build_policy(self) -> None:
        if self.method == 'value function':
            self._build_policy_by_value_function()
        elif self.method == 'q function':
            self._build_policy_by_q_value()

In [None]:
environment = GridWorld(size=(5, 5), reward_coordinates=[(0, 1), (0, 3)], reward_values=[10, 5], teleportation_coordinates=[(4, 1), (2, 3)])
agent = GridWorldAgent(gridworld=environment, method='value function', discount=0.9, threshold=1e-7)

agent.build_policy()

In [None]:
agent.states_values