Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

# LAB10

Use reinforcement learning to devise a tic-tac-toe player.

### Deadlines:

* Submission: [Dies Natalis Solis Invicti](https://en.wikipedia.org/wiki/Sol_Invictus)
* Reviews: [Befana](https://en.wikipedia.org/wiki/Befana)

Notes:

* Reviews will be assigned  on Monday, December 4
* You need to commit in order to be selected as a reviewer (ie. better to commit an empty work than not to commit)

### In collaboration w/ [Giovanni Bordero s313010](https://github.com/Giobordi)

In [1]:
import pickle
import itertools
import numpy as np
from tqdm.auto import tqdm
from collections import defaultdict, namedtuple
from copy import deepcopy
from enum import Enum
import random

In [2]:
WIN_SCORE = 2
LOSE_SCORE = -5
DRAW_SCORE = -1

TRAINING_EPOCHS = 100_000

NUM_TEST_GAMES = 10_000

In [3]:
State = namedtuple('State', ['x', 'o'])

class Player(Enum):
    X = 1,
    O = 0

In [4]:
class TicTacToe:
    def __init__(self):
        self.magic = [4, 9, 2,
                      3, 5, 7,
                      8, 1, 6]
        ### player 1 is x and player 2 is o
        self.state = State(set(), set())
        self.move = 0
        self.my_player = Player.X

    def sum_magic(self, comb: tuple) -> int:
        """get the magic sum of the elements in comb"""
        return sum(self.magic[i] for i in comb)

    def check_win(self, player: Player) -> bool:
        """check if the player win the game"""
        if self.move < 5:
            return False
        if player == Player.X:
            return any(self.sum_magic(comb) == 15 for comb in itertools.combinations(self.state.x, 3))
        if player == Player.O:
            return any(self.sum_magic(comb) == 15 for comb in itertools.combinations(self.state.o, 3))

    def check_draw(self) -> bool:
        """check if the game is a draw"""
        if self.move == 9 and not self.check_win(Player.X) and not self.check_win(Player.O):
            return True
        return False

    def move_done(self, move: int, player: Player) -> None:
        """apply a move to the game"""
        self.move += 1
        if player == Player.X:
            self.state.x.add(move)
        elif player == Player.O:
            self.state.o.add(move)
        # self.good_print()

    def evaluate_match(self) -> int:
        if self.check_win(self.my_player):
            return 1
        elif self.check_win(Player.O if self.my_player == Player.X else Player.X):
            return -1
        elif self.check_draw():
            return 0

    def good_print(self):
        """print the board of the game in a human-readable format"""
        num = ['0️⃣', '1️⃣', '2️⃣', '3️⃣', '4️⃣', '5️⃣', '6️⃣', '7️⃣', "8️⃣"]
        counter = 0
        for r in range(3):
            print('|', end='')
            for c in range(3):
                val = r * 3 + c
                if val in self.state.x:
                    print('✖️|', end='')
                elif val in self.state.o:
                    print('⭕|', end='')
                else:
                    print(f'{num[counter]}|', end="")
                counter += 1
            print()
        print()

In [5]:
class ReinforcedPlayer2:
    def __init__(self):
        self.Q = defaultdict(float)
        self.epsilon = 0.01
        self.training_epochs = TRAINING_EPOCHS

    def get_Q_value(self, state, action):
        tmp = deepcopy(state)
        st = (frozenset(tmp.x), frozenset(tmp.o))
        if (st, action) not in self.Q:
            self.Q[(st, action)] = 0.0
        return self.Q[(st, action)]

    def update_Q_value(self, state, action: int, reward):
        st = (frozenset(state.x), frozenset(state.o))
        self.Q[(st, action)] = self.get_Q_value(state, action) + self.epsilon * (
                    reward - self.get_Q_value(state, action))

    def choose_action(self, state, available_moves):
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(available_moves)
        else:
            Q_values = [self.get_Q_value(state, action) for action in available_moves]
            max_Q = max(Q_values)
            if Q_values.count(max_Q) > 1:
                best_moves = [i for i in range(len(available_moves)) if Q_values[i] == max_Q]
                i = random.choice(best_moves)
            else:
                i = Q_values.index(max_Q)
            return available_moves[i]

    def training(self):
        for _ in tqdm(range(self.training_epochs)):
            new_game = TicTacToe()
            trajectory, reward = self.random_game(new_game)
            ## update the Q value for each tuple (state, action) in the trajectory
            for state_move in trajectory:
                self.update_Q_value(state=state_move[0], action=state_move[1], reward=reward)

        save_model(self, f"EPOCHS_{str(self.training_epochs)}-LOSE_{LOSE_SCORE}-WIN_{WIN_SCORE}-DRAW_{DRAW_SCORE}")

    def random_game(self, game: TicTacToe):
        """
        play a semi random game and return the trajectory of the chosen player (X or O) and the reward.
        The reward is 1 if the player win, -1 if the player lose and 0 if it is a draw
        """
        trajectory = list()
        available_moves = list(range(0, 9))
        ## a random player start
        turn = np.random.choice([0, 1])

        players = [Player.X, Player.O]
        game.my_player = random.choice(players)  ## train the model on random player
        # so it is possible to play with both X and O
        while len(available_moves) != 0 and not game.check_draw():
            turn = 1 - turn

            if game.my_player == players[turn]:
                move = self.choose_action(game.state, available_moves)
                trajectory.append((deepcopy(game.state), move))
            else:
                move = np.random.choice(available_moves)

            available_moves.remove(move)
            game.move_done(move, players[turn])

            if game.check_win(players[turn]):
                if game.my_player == players[turn]:
                    return trajectory, WIN_SCORE
                else:
                    return trajectory, LOSE_SCORE

        return trajectory, DRAW_SCORE

In [6]:
def save_model(model: ReinforcedPlayer2, text: str = None):
    # Serialize the object and write it to a file
    with open(f'models/agent-{text}.pkl', 'wb') as f:
        pickle.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)


def load_model(path: str) -> ReinforcedPlayer2:
    # Load the model from a file
    with open(path, 'rb') as f:
        loaded_instance = pickle.load(f)

    return loaded_instance

In [9]:
train = False

if train:
    rr = ReinforcedPlayer2()
    rr.training()
    rr2 = load_model('models/agent-EPOCHS_10000000.pkl')
else:
    rr = load_model('models/agent-EPOCHS_10000000-LOSE_-5-WIN_2-DRAW_-1.pkl')
    rr2 = load_model('models/agent-EPOCHS_10000000-LOSE_-5-WIN_2-DRAW_-1.pkl')

wins = 0
draws = 0
for _ in range(NUM_GAMES):
    game = TicTacToe()
    
    turn = np.random.choice([0, 1])
    players = [Player.X, Player.O]
    rr_player = random.choice(players)
    
    available_moves = list(range(0, 9))
    
    while len(available_moves) != 0 and not game.check_draw():
        turn = 1 - turn
        # game.good_print()
        if players[turn] == rr_player:
            move = rr.choose_action(game.state, available_moves)
        else:
            move = random.choice(available_moves)
            # move = int(input("Enter your move: "))
            # move = rr2.choose_action(game.state, available_moves)

        available_moves.remove(move)
        game.move_done(move, players[turn])
        
        if game.check_win(players[turn]):
            if players[turn] == rr_player:
                wins += 1
            break

    if game.check_draw():
        draws += 1

print(f'wins : {wins}, draws: {draws}, lost: {NUM_GAMES - wins - draws}')

wins : 9415, draws: 547, lost: 38
