## RL agent Q-learning for TicTacToe env

The [Tic-Tac-Toe](https://github.com/MauroLuzzatto/OpenAI-Gym-TicTacToe-Environment) is a simple game environment that allows to train reinforcement learning agents. These notebook contains an implemetation of Q-learning with epsilon-greedy strategy for TicTacToe env.

In [140]:
# load the python modules
import time
import sys
import warnings

import gym
import numpy as np
from tqdm import tqdm
import gym_TicTacToe

from src.qagent import QLearningAgent
from src.play_tictactoe import play_tictactoe, play_tictactoe_with_random

from src.utils import (
    create_state_dictionary,
    reshape_state,
    save_qtable,
    load_qtable
)

# ignore warnings
if not sys.warnoptions:
    warnings.simplefilter("ignore")

In [141]:
class Player:
    def __init__(self, color, episodes: int):
        self.color = color
        self.reward_array = np.zeros(episodes)
        self.reset_reward()
        self.name = f"Player {color}"

    def reset_reward(self):
        self.reward = 0

In [142]:

# initialize the tictactoe environment
env = gym.envs.make("TTT-v0", small=-1, large=10)

In [143]:
state_dict = create_state_dictionary()
state_size = len(state_dict.keys())
action_size = env.action_space.n

Number of legal states: 12092


In [144]:
# set training parameters
episodes = 660_000
max_steps = 9

In [145]:
exploration_parameters = {
    "max_epsilon": 1.0,
    "min_epsilon": 0.0,
    "decay_rate": 0.00001,
}

In [146]:
qagent = QLearningAgent(exploration_parameters, state_size, action_size, learning_rate=0.1, gamma=0.99)

In [147]:
def check_for_lose(state, color: int) -> bool:
        """check if after agent's move there is a lose position

        Args:
            color (int): of the player's enemy

        Returns:
            bool: indicating if this was a crucial move
        """

        lose = False
        #enemy color
        enemy_matrix = state == color
        zero_matrix = state == 0
        bool_matrix = (enemy_matrix + zero_matrix).reshape(3,3)
        for ii in range(3):
            if (
                # check columns
                np.sum(bool_matrix[:, ii]) == 3
                # check rows
                or np.sum(bool_matrix[ii, :]) == 3
                # check diagonal
                or np.sum([bool_matrix[0, 0], bool_matrix[1, 1], bool_matrix[2, 2]])
                == 3
                or np.sum([bool_matrix[0, 2], bool_matrix[1, 1], bool_matrix[2, 0]])
                == 3
            ):
                lose = True
                # print(state.reshape(3,3))
                # print(color)
                break
        return lose

In [148]:
state = np.array([1,2,0,1,2,1,0,1,2])
check_for_lose(state, color=2)

True

In [149]:

def play(qagent:QLearningAgent, player_color, state: int, action_space: np.array, last_turn: bool) -> tuple:
    action = qagent.get_action(state, action_space)

    # remove action from the action space
    action_space = action_space[action_space != action]

    new_state, reward, done, _ = env.step((action, player_color))

    color = [1,2]
    # if done:
    #     print(new)
    #     reward -= 5
    # if len(action_space) == 1 and (done == False):
    #     if check_for_lose(new_state, color[player_color != color]):
    #         reward -= 5

    if (last_turn == True) and (done == False):
        reward += 12     # Reward for draw

    # elif (last_turn == True) and (done == True):
    #     reward += 10
        # print(f"New_state:{new_state}, Reward:{reward}, Done:{done}")
    #TODO: maybe should change a marker after this agent turn 
    new_state = np.append(new_state, player_color)
    new_state = state_dict[reshape_state(new_state)] 

    qagent.qtable[state, action] = qagent.update_qtable(
        state, new_state, action, reward, done
    )
    # new state
    state = new_state
    return state, action_space, done

In [150]:
def play_random(qagent:QLearningAgent, player: Player, state: int, action_space: np.array) -> tuple:
    action = np.random.choice(action_space)
    action_space = action_space[action_space != action]
    new_state, reward, done, _ = env.step((action, player.color))
    new_state = np.append(new_state, player.color)
    new_state = state_dict[reshape_state(new_state)]
    state = new_state
    return state, action_space, done

In [151]:
visited_states = np.zeros((state_size, 1))

In [152]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
import random

start_time = time.time()

player_1 = Player(color=1, episodes=episodes)
player_2 = Player(color=2, episodes=episodes)

track_progress = np.zeros(episodes)

win_history = []

rewards = []
lear_rate = 0.8
# lr = 0.4, gamma = 0.9, winrate = 0.64
# Learning rate: 0.5, Win rate: 0.72, Gamma: 0.9
# Learning rate: 0.6, Win rate: 0.5, Gamma: 0.9
# Learning rate: 0.7, Win rate: 0.66, Gamma: 0.9
# Learning rate: 0.65, Win rate: 0.76, Gamma: 0.8

# best 
# Learning rate: 0.8, Win rate: 0.8, Gamma: 0.9
gamma = 0.9


qagent = QLearningAgent(exploration_parameters, state_size, action_size, learning_rate=lear_rate, gamma=gamma)
# qagent_old = qagent
for episode in tqdm(range(episodes)):
    last_turn = False
    action_space = np.arange(9)
    player_1.reset_reward()
    player_2.reset_reward()

    # randomly change the order players
    start = np.random.choice([1,2])

    state, _ = env.reset()
    state = np.append(state, start)
    state = state_dict[reshape_state(state)]
    # if episode % 10_000 == 0:
    #     save_qtable(qagent.qtable, 'tables', "q_table_old")

    for _step in range(start, max_steps + start):
        if _step == max_steps + start - 1:
            last_turn = True
        # change a turn
        if _step % 2 == 0:
            #state, action_space, done = play_random(qagent, player_1, state, action_space)
            #qagent_old.qtable = load_qtable('tables', "q_table_old")
            state, action_space, done = play(qagent, player_1.color, state, action_space, last_turn)
        else:
            state, action_space, done = play(qagent, player_2.color, state, action_space, last_turn)
        visited_states[state] += 1
        if done == True:
            break

    # reduce epsilon for exporation-exploitation tradeoff
    qagent.update_epsilon(episode)

    #cur_win_rate, reward = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=100)

    #check how good is agent
    if episode % 25_000 == 0:
        num_games = 50
        cur_win_rate, reward = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=num_games)
        win_history.append(sum(cur_win_rate)/num_games)
        print("WinRate:", sum(cur_win_rate)/num_games)
        # rewards.append(reward)
        # clear_output(True)
        # # plt.title('eps = {:e}, mean reward = {:.1f}'.format(agent.epsilon, np.mean(rewards[-10:])))
        # plt.plot(rewards)
        # plt.show()
    if episode % 25_000 == 0:

        sum_q_table = np.sum(qagent.qtable)
        time_passed = round((time.time() - start_time) / 60.0, 2)

        print(
            f"episode: {episode}, \
            epsilon: {round(qagent.epsilon, 2)}, \
            sum q-table: {sum_q_table}, \
            elapsed time [min]: {time_passed},  \
            done [%]: {episode / episodes * 100} \
            "
        )


  0%|          | 517/660000 [00:00<04:12, 2610.49it/s]

WinRate: 0.36
episode: 0,             epsilon: 1.0,             sum q-table: 40.230236245810815,             elapsed time [min]: 0.0,              done [%]: 0.0             


  4%|▍         | 25552/660000 [00:08<03:47, 2791.11it/s]

WinRate: 0.0
episode: 25000,             epsilon: 0.78,             sum q-table: 181341.2208335659,             elapsed time [min]: 0.14,              done [%]: 3.787878787878788             


  8%|▊         | 50445/660000 [00:17<03:50, 2646.40it/s]

WinRate: 0.12
episode: 50000,             epsilon: 0.61,             sum q-table: 213178.9779147902,             elapsed time [min]: 0.29,              done [%]: 7.575757575757576             


 11%|█▏        | 75525/660000 [00:26<03:46, 2582.57it/s]

WinRate: 0.16
episode: 75000,             epsilon: 0.47,             sum q-table: 217699.84511674,             elapsed time [min]: 0.44,              done [%]: 11.363636363636363             


 15%|█▌        | 100394/660000 [00:36<03:42, 2515.50it/s]

WinRate: 0.36
episode: 100000,             epsilon: 0.37,             sum q-table: 219155.50397632344,             elapsed time [min]: 0.6,              done [%]: 15.151515151515152             


 19%|█▉        | 125480/660000 [00:45<03:33, 2508.98it/s]

WinRate: 0.4
episode: 125000,             epsilon: 0.29,             sum q-table: 219647.11337683446,             elapsed time [min]: 0.76,              done [%]: 18.939393939393938             


 23%|██▎       | 150418/660000 [00:55<03:25, 2485.39it/s]

WinRate: 0.56
episode: 150000,             epsilon: 0.22,             sum q-table: 219815.4346443758,             elapsed time [min]: 0.92,              done [%]: 22.727272727272727             


 27%|██▋       | 175326/660000 [01:04<03:12, 2515.18it/s]

WinRate: 0.16
episode: 175000,             epsilon: 0.17,             sum q-table: 219913.12111874926,             elapsed time [min]: 1.08,              done [%]: 26.515151515151516             


 30%|███       | 200507/660000 [01:14<03:03, 2499.97it/s]

WinRate: 0.44
episode: 200000,             epsilon: 0.14,             sum q-table: 219953.89880188502,             elapsed time [min]: 1.24,              done [%]: 30.303030303030305             


 34%|███▍      | 225341/660000 [01:24<02:54, 2491.78it/s]

WinRate: 0.24
episode: 225000,             epsilon: 0.11,             sum q-table: 219984.69295190214,             elapsed time [min]: 1.4,              done [%]: 34.090909090909086             


 38%|███▊      | 250388/660000 [01:34<02:45, 2472.23it/s]

WinRate: 0.44
episode: 250000,             epsilon: 0.08,             sum q-table: 219995.5616273599,             elapsed time [min]: 1.57,              done [%]: 37.878787878787875             


 42%|████▏     | 275324/660000 [01:43<02:32, 2516.11it/s]

WinRate: 0.44
episode: 275000,             epsilon: 0.06,             sum q-table: 219998.46136382496,             elapsed time [min]: 1.73,              done [%]: 41.66666666666667             


 46%|████▌     | 300339/660000 [01:53<02:21, 2533.60it/s]

WinRate: 0.48
episode: 300000,             epsilon: 0.05,             sum q-table: 220000.00650442115,             elapsed time [min]: 1.89,              done [%]: 45.45454545454545             


 49%|████▉     | 325387/660000 [02:03<02:13, 2510.54it/s]

WinRate: 0.36
episode: 325000,             epsilon: 0.04,             sum q-table: 220000.9487200193,             elapsed time [min]: 2.05,              done [%]: 49.24242424242424             


 53%|█████▎    | 350347/660000 [02:12<02:01, 2545.08it/s]

WinRate: 0.2
episode: 350000,             epsilon: 0.03,             sum q-table: 220002.97142386838,             elapsed time [min]: 2.21,              done [%]: 53.03030303030303             


 57%|█████▋    | 375346/660000 [02:21<01:51, 2556.15it/s]

WinRate: 0.44
episode: 375000,             epsilon: 0.02,             sum q-table: 220003.14834267844,             elapsed time [min]: 2.36,              done [%]: 56.81818181818182             


 61%|██████    | 400438/660000 [02:31<01:42, 2529.06it/s]

WinRate: 0.28
episode: 400000,             epsilon: 0.02,             sum q-table: 220003.2473474869,             elapsed time [min]: 2.52,              done [%]: 60.60606060606061             


 64%|██████▍   | 425442/660000 [02:41<01:32, 2533.26it/s]

WinRate: 0.36
episode: 425000,             epsilon: 0.01,             sum q-table: 220003.26891585978,             elapsed time [min]: 2.68,              done [%]: 64.39393939393939             


 68%|██████▊   | 450407/660000 [02:50<01:22, 2541.66it/s]

WinRate: 0.2
episode: 450000,             epsilon: 0.01,             sum q-table: 220003.34787746833,             elapsed time [min]: 2.84,              done [%]: 68.18181818181817             


 72%|███████▏  | 475657/660000 [03:00<01:12, 2536.99it/s]

WinRate: 0.32
episode: 475000,             epsilon: 0.01,             sum q-table: 220003.36886956287,             elapsed time [min]: 3.0,              done [%]: 71.96969696969697             


 76%|███████▌  | 500328/660000 [03:09<01:03, 2503.00it/s]

WinRate: 0.08
episode: 500000,             epsilon: 0.01,             sum q-table: 220003.3690289154,             elapsed time [min]: 3.16,              done [%]: 75.75757575757575             


 80%|███████▉  | 525260/660000 [03:19<00:53, 2513.05it/s]

WinRate: 0.52
episode: 525000,             epsilon: 0.01,             sum q-table: 220003.38113007304,             elapsed time [min]: 3.32,              done [%]: 79.54545454545455             


 83%|████████▎ | 550492/660000 [03:29<00:43, 2501.09it/s]

WinRate: 0.44
episode: 550000,             epsilon: 0.0,             sum q-table: 220003.38114182587,             elapsed time [min]: 3.48,              done [%]: 83.33333333333334             


 87%|████████▋ | 575329/660000 [03:38<00:33, 2508.46it/s]

WinRate: 0.44
episode: 575000,             epsilon: 0.0,             sum q-table: 220003.38114182645,             elapsed time [min]: 3.64,              done [%]: 87.12121212121212             


 91%|█████████ | 600266/660000 [03:48<00:23, 2494.59it/s]

WinRate: 0.36
episode: 600000,             epsilon: 0.0,             sum q-table: 220003.38114182686,             elapsed time [min]: 3.8,              done [%]: 90.9090909090909             


 95%|█████████▍| 625267/660000 [03:58<00:13, 2496.60it/s]

WinRate: 0.12
episode: 625000,             epsilon: 0.0,             sum q-table: 220003.38114182686,             elapsed time [min]: 3.97,              done [%]: 94.6969696969697             


 99%|█████████▊| 650272/660000 [04:07<00:03, 2518.70it/s]

WinRate: 0.64
episode: 650000,             epsilon: 0.0,             sum q-table: 220003.38114182686,             elapsed time [min]: 4.13,              done [%]: 98.48484848484848             


100%|██████████| 660000/660000 [04:11<00:00, 2624.99it/s]


In [153]:
visited_states.shape[0]
print("Percent:",100*np.sum(visited_states > 0)/visited_states.shape[0])

Percent: 90.58881905391995


In [159]:
num_games = 1000
cur_win_rate, _ = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=num_games)
win_history.append(sum(cur_win_rate)/num_games)
print("WinRate:", sum(cur_win_rate)/num_games)

WinRate: 0.378


In [20]:
qtable = qagent.qtable
save_qtable(qtable, 'tables', "q_table_best")

q_table_best.npy saved!


In [6]:
qtable = load_qtable('tables', "q_table_best")

In [220]:
#check how correct is q-table

state = np.random.choice(np.arange(env.observation_space.n))
# state_dict[state]
print(state)

key = list(filter(lambda x: state_dict[x] == state, state_dict))[0]
print(np.array(key[:-1]).reshape(3,3))
print("Turn was:", key[-1])
print(np.round(qagent.qtable[state].reshape(3,3),1))


8603
[[2 0 0]
 [2 1 1]
 [1 0 0]]
Turn was: 1
[[0.  7.1 5.4]
 [0.  0.  0. ]
 [0.  7.1 7.1]]


In [10]:
play_tictactoe(env, qtable, state_dict, num_test_games=1)

Agent beginns
--------------------
╒═══╤═══╤═══╕
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
--------------------
Move Human
Action: 4
╒═══╤═══╤═══╕
│ - │ - │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
-1


--------------------
move Agent
Action: 0
╒═══╤═══╤═══╕
│ O │ - │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛


--------------------
Move Human
Action: 1
╒═══╤═══╤═══╕
│ O │ X │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
-1


--------------------
move Agent
Action: 6
╒═══╤═══╤═══╕
│ O │ X │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛


--------------------
Move Human
Action: 7
╒═══╤═══╤═══╕
│ O │ X │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ O │ X │ - │
╘═══╧═══╧═══╛
-1


--------------------
move Agent
Action: 3
╒═══╤═══╤═══╕
│ O │ X │ - │
├───┼───┼───┤
│ O │ X │ - │
├───┼───┼───┤
│ O │ X │ - │
╘═══╧═══╧═══╛
********