# LAB10

Use reinforcement learning to devise a tic-tac-toe player.

### Deadlines:

* Submission: [Dies Natalis Solis Invicti](https://en.wikipedia.org/wiki/Sol_Invictus)
* Reviews: [Befana](https://en.wikipedia.org/wiki/Befana)

Notes:

* Reviews will be assigned  on Monday, December 4
* You need to commit in order to be selected as a reviewer (ie. better to commit an empty work than not to commit)

# !! WORK IN PROGRESS

In [1]:

import numpy as np

from itertools import combinations
from collections import defaultdict
from copy import copy
from tqdm.auto import tqdm
from random import random
from pprint import pprint

In [2]:
MAGIC = [2, 7, 6, 9, 5, 1, 4, 3, 8] # Magic square for 3x3 Tic-Tac-Toe

In [3]:
class TicTacToeState:
    '''
    Represents the state of a Tic-Tac-Toe game.
    
    Attributes:
    - x (list): The list of cells occupied by 'x'.
    - o (list): The list of cells occupied by 'o'.
    
    Methods:
    - get_board: Returns the current state of the board as a 3x3 numpy array.
    - get_actions: Returns the list of possible actions from the current state.
    - step: Returns the next state and reward after taking an action.
    - get_reward: Returns the reward for the current state.
    - is_terminal: Returns whether the current state is terminal.
    '''
    def __init__(self, x, o) -> None:
        self.x = x
        self.o = o
    
    def get_board(self):
        """
        Returns the current state of the board as a 3x3 numpy array.

        The board is represented by a list of length 9, where each element represents a cell on the board.
        The value 0 represents an empty cell, 1 represents a cell occupied by 'x', and 2 represents a cell occupied by 'o'.

        Returns:
        np.ndarray: The current state of the board as a 3x3 numpy array.
        """
        board = [0 for _ in range(9)]
        for e in self.x:
            board[MAGIC.index(e)] = 1
        for e in self.o:
            board[MAGIC.index(e)] = 2
        
        return board

    def available_actions(self):
        """
        Returns the list of possible actions from the current state.

        The list of possible actions is represented by a list of length 9, where each element represents a cell on the board.
        The value 0 represents an empty cell, 1 represents a cell occupied by 'x', and 2 represents a cell occupied by 'o'.

        Returns:
        list: The list of possible actions from the current state.
        """
        actions = []
        for i in range(9):
            if MAGIC[i] not in self.x and MAGIC[i] not in self.o:
                actions.append(i)
        return actions
    
    def not_available_actions(self):
        """
        Returns the list of impossible actions from the current state.

        The list of impossible actions is represented by a list of length 9, where each element represents a cell on the board.
        The value 0 represents an empty cell, 1 represents a cell occupied by 'x', and 2 represents a cell occupied by 'o'.

        Returns:
        list: The list of impossible actions from the current state.
        """
        actions = []
        for i in range(9):
            if MAGIC[i] in self.x or MAGIC[i] in self.o:
                actions.append(i)
        return actions
    
    def step(self, action):
        """
        Returns the next state and reward after taking an action.

        Args:
        action (int): The action to take.

        Returns:
        (TicTacToeState, int): The next state and reward after taking an action.
        """
        if action not in self.available_actions():
            raise ValueError("Invalid action.")
        
        x = copy(self.x)
        o = copy(self.o)
        if len(x) == len(o):
            x.append(MAGIC[action])
        else:
            o.append(MAGIC[action])
            
        next_state = TicTacToeState(x, o)
        
        return next_state, next_state.get_reward()
    
    def get_reward(self):
        """
        Returns the reward for the current state.

        Returns:
        int: The reward for the current state.
        """
        if self.is_terminal():
            if len(self.x) > len(self.o):
                return 10
            elif len(self.x) < len(self.o):
                return -10
            else:
                return 0
        else:
            return 0
    
    def is_terminal(self):
        """
        Returns whether the current state is terminal.

        Returns:
        bool: Whether the current state is terminal.
        """
        if len(self.x) + len(self.o) == 9:
            return True
        
        for c in combinations(self.x, 3):
            if sum(c) == 15:
                return True
        for c in combinations(self.o, 3):
            if sum(c) == 15:
                return True
        
        return False

    def print_board(self):
        """
        Prints the current state of the board.
        """
        board = np.array(self.get_board()).reshape(3, 3)
        for i in range(3):
            for j in range(3):
                if board[i][j] == 0:
                    print('.', end='')
                elif board[i][j] == 1:
                    print('X', end='')
                else:
                    print('O', end='')
            print()
    
    
class TicTacToeAgent:
    '''
    Represents a Tic-Tac-Toe agent.
    
    Attributes:
    - alpha (float): The learning rate.
    - gamma (float): The discount factor.
    - epsilon (float): The probability of taking a random action.
    
    Methods:
    - take_action: Returns the action to take given the current state.
    - train: Updates the Q table given the current state, action, next state, and reward.
    '''
    def __init__(self, alpha=0.1, gamma=0.5, epsilon=0.1) -> None:
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.Q = defaultdict(lambda: np.zeros(9))
           
    def take_action(self, state):
        """
        Returns the action to take given the current state.

        Args:
        state (TicTacToeState): The current state.

        Returns:
        int: The action to take given the current state.
        """
        
        self.Q[tuple(state.get_board())][state.not_available_actions()] = -100
        
        if random() < self.epsilon:
            return np.random.choice(state.available_actions())
        else:
            return np.argmax(self.Q[tuple(state.get_board())])
    
    def train(self, state, action, next_state, reward):
        """
        Updates the Q table given the current state, action, next state, and reward.

        Args:
        state (TicTacToeState): The current state.
        action (int): The action taken.
        next_state (TicTacToeState): The next state.
        reward (int): The reward received.
        """
        
        self.Q[tuple(state.get_board())][action] += self.alpha * (reward + self.gamma * np.max(self.Q[tuple(state.get_board())]) - self.Q[tuple(state.get_board())][action])

In [5]:
agent = TicTacToeAgent(epsilon=0.1)
agent_random = TicTacToeAgent(epsilon=1)

for _ in tqdm(range(100_000)):
    state = TicTacToeState([], [])
    while not state.is_terminal():
        action = agent.take_action(state)
        next_state, reward = state.step(action)
        agent.train(state, action, next_state, reward)
        state = next_state
        #state.print_board()

        if state.is_terminal():
            break
        
        action = agent_random.take_action(state)
        next_state, reward = state.step(action)
        agent.train(state, action, next_state, -reward)
        state = next_state
        #state.print_board()

pprint(agent.Q)

  0%|          | 0/100000 [00:00<?, ?it/s]

defaultdict(<function TicTacToeAgent.__init__.<locals>.<lambda> at 0x7faec059d1f0>,
            {(0, 0, 0, 0, 0, 0, 0, 0, 0): array([0., 0., 0., 0., 0., 0., 0., 0., 0.]),
             (0, 0, 0, 0, 0, 0, 0, 0, 1): array([0., 0., 0., 0., 0., 0., 0., 0., 0.]),
             (0, 0, 0, 0, 0, 0, 0, 1, 0): array([0., 0., 0., 0., 0., 0., 0., 0., 0.]),
             (0, 0, 0, 0, 0, 0, 0, 1, 2): array([   0.,    0.,    0.,    0.,    0.,    0.,    0., -100., -100.]),
             (0, 0, 0, 0, 0, 0, 0, 2, 1): array([   0.,    0.,    0.,    0.,    0.,    0.,    0., -100., -100.]),
             (0, 0, 0, 0, 0, 0, 1, 0, 0): array([0., 0., 0., 0., 0., 0., 0., 0., 0.]),
             (0, 0, 0, 0, 0, 0, 1, 0, 2): array([   0.,    0.,    0.,    0.,    0.,    0., -100.,    0., -100.]),
             (0, 0, 0, 0, 0, 0, 1, 1, 2): array([0., 0., 0., 0., 0., 0., 0., 0., 0.]),
             (0, 0, 0, 0, 0, 0, 1, 2, 0): array([   0.,    0.,    0.,    0.,    0.,    0., -100., -100.,    0.]),
             (0, 0, 0, 0,