## RL agent Q-learning for TicTacToe env

The [Tic-Tac-Toe](https://github.com/MauroLuzzatto/OpenAI-Gym-TicTacToe-Environment) is a simple game environment that allows to train reinforcement learning agents. These notebook contains an implemetation of Q-learning with epsilon-greedy strategy for TicTacToe env.

In [4]:
# load the python modules
import time
import sys
import warnings

import gym
import numpy as np
from tqdm import tqdm
import gym_TicTacToe

from src.qagent import QLearningAgent
from src.play_tictactoe import play_tictactoe, play_tictactoe_with_random

from src.utils import (
    create_state_dictionary,
    reshape_state,
    save_qtable,
    load_qtable
)

# ignore warnings
if not sys.warnoptions:
    warnings.simplefilter("ignore")

In [5]:
class Player:
    def __init__(self, color, episodes: int):
        self.color = color
        self.reward_array = np.zeros(episodes)
        self.reset_reward()
        self.name = f"Player {color}"

    def reset_reward(self):
        self.reward = 0

In [6]:

# initialize the tictactoe environment
env = gym.envs.make("TTT-v0", small=-1, large=10)

In [7]:
state_dict = create_state_dictionary()
state_size = len(state_dict.keys())
action_size = env.action_space.n

Number of legal states: 12092


In [8]:
# set training parameters
episodes = 660_000
max_steps = 9

In [9]:
exploration_parameters = {
    "max_epsilon": 1.0,
    "min_epsilon": 0.0,
    "decay_rate": 0.00001,
}

In [10]:
qagent = QLearningAgent(exploration_parameters, state_size, action_size, learning_rate=0.1, gamma=0.99)

In [11]:

def play(qagent:QLearningAgent, player: Player, state: int, action_space: np.array, last_turn: bool) -> tuple:
    action = qagent.get_action(state, action_space)

    # remove action from the action space
    action_space = action_space[action_space != action]

    new_state, reward, done, _ = env.step((action, player.color))

    if (last_turn == True) and (done == False):
        reward += 7
        # print(f"New_state:{new_state}, Reward:{reward}, Done:{done}")
    #TODO: maybe should change a marker after this agent turn 
    new_state = np.append(new_state, player.color)
    new_state = state_dict[reshape_state(new_state)] 

    qagent.qtable[state, action] = qagent.update_qtable(
        state, new_state, action, reward, done
    )
    # new state
    state = new_state
    return state, action_space, done

In [12]:
def play_random(qagent:QLearningAgent, player: Player, state: int, action_space: np.array) -> tuple:
    action = np.random.choice(action_space)
    action_space = action_space[action_space != action]
    new_state, reward, done, _ = env.step((action, player.color))
    new_state = np.append(new_state, player.color)
    new_state = state_dict[reshape_state(new_state)]
    state = new_state
    return state, action_space, done

In [13]:
visited_states = np.zeros((state_size, 1))

In [14]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
import random

start_time = time.time()

player_1 = Player(color=1, episodes=episodes)
player_2 = Player(color=2, episodes=episodes)

track_progress = np.zeros(episodes)

win_history = []

rewards = []
lear_rate = 0.8
# lr = 0.4, gamma = 0.9, winrate = 0.64
# Learning rate: 0.5, Win rate: 0.72, Gamma: 0.9
# Learning rate: 0.6, Win rate: 0.5, Gamma: 0.9
# Learning rate: 0.7, Win rate: 0.66, Gamma: 0.9
# Learning rate: 0.65, Win rate: 0.76, Gamma: 0.8

# best 
# Learning rate: 0.8, Win rate: 0.8, Gamma: 0.9
gamma = 0.9


qagent = QLearningAgent(exploration_parameters, state_size, action_size, learning_rate=lear_rate, gamma=gamma)
# qagent_old = qagent
for episode in tqdm(range(episodes)):
    last_turn = False
    action_space = np.arange(9)
    player_1.reset_reward()
    player_2.reset_reward()

    # randomly change the order players
    start = np.random.choice([1,2])

    state, _ = env.reset()
    state = np.append(state, start)
    state = state_dict[reshape_state(state)]
    # if episode % 10_000 == 0:
    #     save_qtable(qagent.qtable, 'tables', "q_table_old")

    for _step in range(start, max_steps + start):
        if _step == max_steps + start - 1:
            last_turn = True
        # change a turn
        if _step % 2 == 0:
            #state, action_space, done = play_random(qagent, player_1, state, action_space)
            #qagent_old.qtable = load_qtable('tables', "q_table_old")
            state, action_space, done = play(qagent, player_1, state, action_space, last_turn)
        else:
            state, action_space, done = play(qagent, player_2, state, action_space, last_turn)
        visited_states[state] += 1
        if done == True:
            break

    # reduce epsilon for exporation-exploitation tradeoff
    qagent.update_epsilon(episode)

    #cur_win_rate, reward = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=100)

    #check how good is agent
    if episode % 25_000 == 0:
        num_games = 50
        cur_win_rate, reward = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=num_games)
        win_history.append(sum(cur_win_rate)/num_games)
        print("WinRate:", sum(cur_win_rate)/num_games)
        # rewards.append(reward)
        # clear_output(True)
        # # plt.title('eps = {:e}, mean reward = {:.1f}'.format(agent.epsilon, np.mean(rewards[-10:])))
        # plt.plot(rewards)
        # plt.show()
    if episode % 25_000 == 0:

        sum_q_table = np.sum(qagent.qtable)
        time_passed = round((time.time() - start_time) / 60.0, 2)

        print(
            f"episode: {episode}, \
            epsilon: {round(qagent.epsilon, 2)}, \
            sum q-table: {sum_q_table}, \
            elapsed time [min]: {time_passed},  \
            done [%]: {episode / episodes * 100} \
            "
        )


  0%|          | 496/660000 [00:00<04:23, 2504.97it/s]

WinRate: 0.56
episode: 0,             epsilon: 1.0,             sum q-table: 21.949572621298703,             elapsed time [min]: 0.0,              done [%]: 0.0             


  4%|▍         | 25435/660000 [00:08<03:51, 2740.26it/s]

WinRate: 0.28
episode: 25000,             epsilon: 0.78,             sum q-table: 172595.70168621192,             elapsed time [min]: 0.14,              done [%]: 3.787878787878788             


  8%|▊         | 50303/660000 [00:17<03:54, 2601.15it/s]

WinRate: 0.8
episode: 50000,             epsilon: 0.61,             sum q-table: 205213.750002212,             elapsed time [min]: 0.29,              done [%]: 7.575757575757576             


 11%|█▏        | 75486/660000 [00:27<03:54, 2494.43it/s]

WinRate: 0.88
episode: 75000,             epsilon: 0.47,             sum q-table: 209727.00811511706,             elapsed time [min]: 0.45,              done [%]: 11.363636363636363             


 15%|█▌        | 100356/660000 [00:37<03:45, 2484.60it/s]

WinRate: 0.68
episode: 100000,             epsilon: 0.37,             sum q-table: 210861.13928443642,             elapsed time [min]: 0.61,              done [%]: 15.151515151515152             


 19%|█▉        | 125458/660000 [00:46<03:40, 2427.65it/s]

WinRate: 0.64
episode: 125000,             epsilon: 0.29,             sum q-table: 211226.18138236215,             elapsed time [min]: 0.78,              done [%]: 18.939393939393938             


 23%|██▎       | 150353/660000 [00:56<03:29, 2429.99it/s]

WinRate: 0.28
episode: 150000,             epsilon: 0.22,             sum q-table: 211411.6950017766,             elapsed time [min]: 0.94,              done [%]: 22.727272727272727             


 27%|██▋       | 175383/660000 [01:06<03:20, 2420.98it/s]

WinRate: 0.6
episode: 175000,             epsilon: 0.17,             sum q-table: 211509.28419710314,             elapsed time [min]: 1.11,              done [%]: 26.515151515151516             


 30%|███       | 200370/660000 [01:16<03:07, 2449.40it/s]

WinRate: 0.76
episode: 200000,             epsilon: 0.14,             sum q-table: 211543.14668774858,             elapsed time [min]: 1.27,              done [%]: 30.303030303030305             


 34%|███▍      | 225315/660000 [01:26<02:57, 2453.75it/s]

WinRate: 0.28
episode: 225000,             epsilon: 0.11,             sum q-table: 211561.08788998582,             elapsed time [min]: 1.44,              done [%]: 34.090909090909086             


 38%|███▊      | 250467/660000 [01:36<02:48, 2434.15it/s]

WinRate: 0.68
episode: 250000,             epsilon: 0.08,             sum q-table: 211573.34717538272,             elapsed time [min]: 1.6,              done [%]: 37.878787878787875             


 42%|████▏     | 275339/660000 [01:46<02:39, 2404.39it/s]

WinRate: 0.8
episode: 275000,             epsilon: 0.06,             sum q-table: 211576.62945478418,             elapsed time [min]: 1.77,              done [%]: 41.66666666666667             


 45%|████▌     | 300299/660000 [01:56<02:27, 2440.29it/s]

WinRate: 0.72
episode: 300000,             epsilon: 0.05,             sum q-table: 211577.60171361954,             elapsed time [min]: 1.93,              done [%]: 45.45454545454545             


 49%|████▉     | 325346/660000 [02:06<02:16, 2447.78it/s]

WinRate: 0.64
episode: 325000,             epsilon: 0.04,             sum q-table: 211577.98585052762,             elapsed time [min]: 2.1,              done [%]: 49.24242424242424             


 53%|█████▎    | 350398/660000 [02:15<02:05, 2464.64it/s]

WinRate: 0.64
episode: 350000,             epsilon: 0.03,             sum q-table: 211578.33370800092,             elapsed time [min]: 2.26,              done [%]: 53.03030303030303             


 57%|█████▋    | 375251/660000 [02:25<02:00, 2368.12it/s]

WinRate: 0.76
episode: 375000,             epsilon: 0.02,             sum q-table: 211578.41968107535,             elapsed time [min]: 2.43,              done [%]: 56.81818181818182             


 61%|██████    | 400320/660000 [02:36<01:49, 2379.51it/s]

WinRate: 0.56
episode: 400000,             epsilon: 0.02,             sum q-table: 211578.46373153286,             elapsed time [min]: 2.6,              done [%]: 60.60606060606061             


 64%|██████▍   | 425418/660000 [02:46<01:38, 2378.72it/s]

WinRate: 0.4
episode: 425000,             epsilon: 0.01,             sum q-table: 211578.4780909603,             elapsed time [min]: 2.77,              done [%]: 64.39393939393939             


 68%|██████▊   | 450277/660000 [02:56<01:28, 2361.11it/s]

WinRate: 0.72
episode: 450000,             epsilon: 0.01,             sum q-table: 211578.52539201817,             elapsed time [min]: 2.94,              done [%]: 68.18181818181817             


 72%|███████▏  | 475380/660000 [03:06<01:15, 2451.13it/s]

WinRate: 0.48
episode: 475000,             epsilon: 0.01,             sum q-table: 211578.5254100886,             elapsed time [min]: 3.1,              done [%]: 71.96969696969697             


 76%|███████▌  | 500300/660000 [03:16<01:05, 2438.53it/s]

WinRate: 0.72
episode: 500000,             epsilon: 0.01,             sum q-table: 211578.5343368829,             elapsed time [min]: 3.27,              done [%]: 75.75757575757575             


 80%|███████▉  | 525322/660000 [03:26<00:55, 2438.39it/s]

WinRate: 0.6
episode: 525000,             epsilon: 0.01,             sum q-table: 211578.53475012069,             elapsed time [min]: 3.43,              done [%]: 79.54545454545455             


 83%|████████▎ | 550369/660000 [03:36<00:44, 2445.39it/s]

WinRate: 0.48
episode: 550000,             epsilon: 0.0,             sum q-table: 211578.5348173431,             elapsed time [min]: 3.6,              done [%]: 83.33333333333334             


 87%|████████▋ | 575396/660000 [03:46<00:34, 2427.79it/s]

WinRate: 0.8
episode: 575000,             epsilon: 0.0,             sum q-table: 211578.53481735097,             elapsed time [min]: 3.77,              done [%]: 87.12121212121212             


 91%|█████████ | 600401/660000 [03:56<00:24, 2415.43it/s]

WinRate: 0.48
episode: 600000,             epsilon: 0.0,             sum q-table: 211578.53481735097,             elapsed time [min]: 3.93,              done [%]: 90.9090909090909             


 95%|█████████▍| 625389/660000 [04:06<00:14, 2331.59it/s]

WinRate: 0.72
episode: 625000,             epsilon: 0.0,             sum q-table: 211578.5351303892,             elapsed time [min]: 4.1,              done [%]: 94.6969696969697             


 99%|█████████▊| 650461/660000 [04:16<00:04, 2330.66it/s]

WinRate: 0.64
episode: 650000,             epsilon: 0.0,             sum q-table: 211578.53513210372,             elapsed time [min]: 4.27,              done [%]: 98.48484848484848             


100%|██████████| 660000/660000 [04:20<00:00, 2534.42it/s]


In [16]:
visited_states.shape[0]
print("Percent:",100*np.sum(visited_states > 0)/visited_states.shape[0])

Percent: 90.58881905391995


In [19]:
num_games = 1000
cur_win_rate, _ = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=num_games)
win_history.append(sum(cur_win_rate)/num_games)
print("WinRate:", sum(cur_win_rate)/num_games)

WinRate: 0.646


In [77]:
qtable = qagent.qtable
save_qtable(qtable, 'tables', "q_table_best")

q_table_best.npy saved!


In [4]:
q_table = load_qtable('tables', "q_table_best")

In [38]:
#check how correct is q-table

state = np.random.choice(np.arange(env.observation_space.n))
# state_dict[state]
print(state)

key = list(filter(lambda x: state_dict[x] == state, state_dict))[0]
print(np.array(key[:-1]).reshape(3,3))
print("Turn was:", key[-1])
print(np.round(qagent.qtable[state].reshape(3,3),1))

6007
[[0 0 0]
 [0 2 0]
 [0 1 2]]
Turn was: 2
[[3.8 7.1 7.1]
 [6.7 0.  6.8]
 [7.1 0.  0. ]]


In [5]:
play_tictactoe(env, q_table, state_dict, num_test_games=1)

Agent beginns
--------------------
╒═══╤═══╤═══╕
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
--------------------
Move Human
Action: 4
-1


╒═══╤═══╤═══╕
│ - │ - │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
--------------------
move Agent
Action: 6


╒═══╤═══╤═══╕
│ - │ - │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛
--------------------
Move Human
Action: 0
-1


╒═══╤═══╤═══╕
│ X │ - │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛
--------------------
move Agent
Action: 8


╒═══╤═══╤═══╕
│ X │ - │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ O │ - │ O │
╘═══╧═══╧═══╛
--------------------
Move Human
Action: 1
-1


╒═══╤═══╤═══╕
│ X │ X │ - │
├───┼───┼───┤
│ - │ X │ - │
├───┼───┼───┤
│ O │ - │ O │
╘═══╧═══╧═══╛
--------------------
move Agent
Action: 5


╒═══╤═══╤═══╕
│ X │ X │ - │
├───┼───┼───┤
│ - │ X │ O │
├───┼───┼───┤
│ O │ - │ O │
╘═══╧═══╧═══╛
------