## RL agent Q-learning for TicTacToe env

The [Tic-Tac-Toe](https://github.com/MauroLuzzatto/OpenAI-Gym-TicTacToe-Environment) is a simple game environment that allows to train reinforcement learning agents. These notebook contains an implemetation of Q-learning with epsilon-greedy strategy for TicTacToe env.

In [1]:
# load the python modules
import time
import sys
import warnings

import gym
import numpy as np
from tqdm import tqdm
import gym_TicTacToe

from src.qagent import QLearningAgent
from src.play_tictactoe import play_tictactoe, play_tictactoe_with_random

from src.utils import (
    create_state_dictionary,
    reshape_state,
    save_qtable,
    load_qtable
)

# ignore warnings
if not sys.warnoptions:
    warnings.simplefilter("ignore")

In [2]:
class Player:
    def __init__(self, color, episodes: int):
        self.color = color
        self.reward_array = np.zeros(episodes)
        self.reset_reward()
        self.name = f"Player {color}"

    def reset_reward(self):
        self.reward = 0

In [3]:

# initialize the tictactoe environment
env = gym.envs.make("TTT-v0", small=-1, large=10)

In [4]:
state_dict = create_state_dictionary()
state_size = len(state_dict.keys())
action_size = env.action_space.n

Number of legal states: 12092


In [5]:
# set training parameters
episodes = 960_000
max_steps = 9

In [6]:
exploration_parameters = {
    "max_epsilon": 1.0,
    "min_epsilon": 0.0,
    "decay_rate": 0.00001,
}

In [7]:
qagent = QLearningAgent(exploration_parameters, state_size, action_size, learning_rate=0.1, gamma=0.99)

In [8]:
def check_for_potential_lose(state, color: int) -> bool:
        """check if after agent's move there is a lose position

        Args:
            color (int): of the player's enemy

        Returns:
            bool: indicating if this was a crucial move
        """
        state_check = np.copy(state)
        lose = False
        col = np.array([1,2])
        #enemy color
        enemy_color = color
        player_color = col[col != enemy_color][0]
        state_check[state_check == player_color] = -1
        state_check[state_check == enemy_color] = 1
        state_check = state_check.reshape(3,3)
        for ii in range(3):
            if (
                # check columns
                np.sum(state_check[:, ii]) == 2
                # check rows
                or np.sum(state_check[ii, :]) == 2
                # check diagonal
                or np.sum([state_check[0, 0], state_check[1, 1], state_check[2, 2]])
                == 2
                or np.sum([state_check[0, 2], state_check[1, 1], state_check[2, 0]])
                == 2
            ):
                lose = True
                break
        return lose

In [9]:
state = np.array([0,1,2,
                  0,1,2,
                  2,0,1])

check_for_potential_lose(state, color=1)

True

In [10]:
def play(qagent:QLearningAgent, player_color, state: int, action_space: np.array, last_turn: bool) -> tuple:
    action = qagent.get_action(state, action_space)

    # remove action from the action space
    action_space = action_space[action_space != action]

    new_state, reward, done, _ = env.step((action, player_color))

    col = np.array([1,2])
    # if done:
    #     print(new)
    #     reward -= 5
    if (done == False):
        if check_for_potential_lose(new_state, col[col != player_color][0]):
            reward -= 7

    # if (last_turn == True) and (done == False):
    #     reward += 3    # Reward for draw

    # elif (last_turn == True) and (done == True):
    #     reward += 10
        # print(f"New_state:{new_state}, Reward:{reward}, Done:{done}")
    #TODO: maybe should change a marker after this agent turn 
    new_state = np.append(new_state, player_color)
    new_state = state_dict[reshape_state(new_state)] 

    qagent.qtable[state, action] = qagent.update_qtable(
        state, new_state, action, reward, state_dict
    )
    # new state
    state = new_state
    return state, action_space, done

In [11]:
def play_random(qagent:QLearningAgent, player_color, state: int, action_space: np.array) -> tuple:
    action = np.random.choice(action_space)
    action_space = action_space[action_space != action]
    new_state, reward, done, _ = env.step((action, player_color))
    new_state = np.append(new_state, player_color)
    new_state = state_dict[reshape_state(new_state)]
    state = new_state
    return state, action_space, done

In [12]:
visited_states = np.zeros((state_size, 1))

In [13]:
lear_rate = 0.8
gamma = 0.9
qagent = QLearningAgent(exploration_parameters, state_size, action_size, learning_rate=lear_rate, gamma=gamma)

In [14]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
import random

start_time = time.time()

player_1 = Player(color=1, episodes=episodes)
player_2 = Player(color=2, episodes=episodes)

win_history = []

rewards = []
# lr = 0.4, gamma = 0.9, winrate = 0.64
# Learning rate: 0.5, Win rate: 0.72, Gamma: 0.9
# Learning rate: 0.6, Win rate: 0.5, Gamma: 0.9
# Learning rate: 0.7, Win rate: 0.66, Gamma: 0.9
# Learning rate: 0.65, Win rate: 0.76, Gamma: 0.8

# best 
# Learning rate: 0.8, Win rate: 0.8, Gamma: 0.9

# qagent_old = qagent
for episode in tqdm(range(episodes)):
    last_turn = False
    action_space = np.arange(9)
    player_1.reset_reward()
    player_2.reset_reward()

    # randomly change the order players
    start = np.random.choice([1,2])

    state, _ = env.reset()
    state = np.append(state, start)
    state = state_dict[reshape_state(state)]
    # if episode % 10_000 == 0:
    #     save_qtable(qagent.qtable, 'tables', "q_table_old")

    for _step in range(start, max_steps + start):
        if _step == max_steps + start - 1:
            last_turn = True
        # change a turn
        if _step % 2 == 0:
            #state, action_space, done = play_random(qagent, player_1.color, state, action_space)
            #qagent_old.qtable = load_qtable('tables', "q_table_old")
            state, action_space, done = play(qagent, player_1.color, state, action_space, last_turn)
        else:
            state, action_space, done = play(qagent, player_2.color, state, action_space, last_turn)
        visited_states[state] += 1
        if done == True:
            break

    # reduce epsilon for exporation-exploitation tradeoff
    qagent.update_epsilon(episode)

    #cur_win_rate, reward = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=100)

    #check how good is agent
    if episode % 25_000 == 0:
        num_games = 50
        cur_win_rate, reward = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=num_games)
        win_history.append(sum(cur_win_rate)/num_games)
        print("WinRate:", sum(cur_win_rate)/num_games)
        # rewards.append(reward)
        # clear_output(True)
        # # plt.title('eps = {:e}, mean reward = {:.1f}'.format(agent.epsilon, np.mean(rewards[-10:])))
        # plt.plot(rewards)
        # plt.show()
    if episode % 25_000 == 0:

        sum_q_table = np.sum(qagent.qtable)
        time_passed = round((time.time() - start_time) / 60.0, 2)

        print(
            f"episode: {episode}, \
            epsilon: {round(qagent.epsilon, 2)}, \
            sum q-table: {sum_q_table}, \
            elapsed time [min]: {time_passed},  \
            done [%]: {episode / episodes * 100} \
            "
        )


  0%|          | 0/960000 [00:00<?, ?it/s]

  0%|          | 135/960000 [00:00<32:42, 489.11it/s] 

WinRate: 0.32
episode: 0,             epsilon: 1.0,             sum q-table: 32.8772405069821,             elapsed time [min]: 0.0,              done [%]: 0.0             


  3%|▎         | 25181/960000 [00:42<25:58, 599.85it/s]

WinRate: 0.56
episode: 25000,             epsilon: 0.78,             sum q-table: 10053.245437175632,             elapsed time [min]: 0.7,              done [%]: 2.604166666666667             


  5%|▌         | 50142/960000 [01:20<24:08, 627.99it/s]

WinRate: 0.52
episode: 50000,             epsilon: 0.61,             sum q-table: 24304.781875876615,             elapsed time [min]: 1.33,              done [%]: 5.208333333333334             


  8%|▊         | 75157/960000 [01:55<21:02, 700.73it/s]

WinRate: 0.52
episode: 75000,             epsilon: 0.47,             sum q-table: 27630.54760220734,             elapsed time [min]: 1.93,              done [%]: 7.8125             


 10%|█         | 100094/960000 [02:29<20:46, 689.65it/s]

WinRate: 0.6
episode: 100000,             epsilon: 0.37,             sum q-table: 28593.042713084873,             elapsed time [min]: 2.48,              done [%]: 10.416666666666668             


 13%|█▎        | 125149/960000 [03:01<18:43, 743.18it/s]

WinRate: 0.6
episode: 125000,             epsilon: 0.29,             sum q-table: 28895.677182157342,             elapsed time [min]: 3.01,              done [%]: 13.020833333333334             


 16%|█▌        | 150143/960000 [03:32<17:24, 775.58it/s]

WinRate: 0.56
episode: 150000,             epsilon: 0.22,             sum q-table: 28994.798990368774,             elapsed time [min]: 3.53,              done [%]: 15.625             


 18%|█▊        | 175200/960000 [04:01<16:24, 797.16it/s]

WinRate: 0.64
episode: 175000,             epsilon: 0.17,             sum q-table: 29038.99814879319,             elapsed time [min]: 4.03,              done [%]: 18.229166666666664             


 21%|██        | 200199/960000 [04:31<15:25, 821.31it/s]

WinRate: 0.56
episode: 200000,             epsilon: 0.14,             sum q-table: 29050.92567414544,             elapsed time [min]: 4.51,              done [%]: 20.833333333333336             


 23%|██▎       | 225137/960000 [04:59<15:15, 802.60it/s]

WinRate: 0.52
episode: 225000,             epsilon: 0.11,             sum q-table: 29055.029605458054,             elapsed time [min]: 4.99,              done [%]: 23.4375             


 26%|██▌       | 250113/960000 [05:27<14:08, 836.63it/s]

WinRate: 0.72
episode: 250000,             epsilon: 0.08,             sum q-table: 29058.250216776894,             elapsed time [min]: 5.46,              done [%]: 26.041666666666668             


 29%|██▊       | 275209/960000 [05:55<13:39, 835.87it/s]

WinRate: 0.64
episode: 275000,             epsilon: 0.06,             sum q-table: 29058.46418815173,             elapsed time [min]: 5.92,              done [%]: 28.645833333333332             


 31%|███▏      | 300121/960000 [06:22<13:23, 821.41it/s]

WinRate: 0.8
episode: 300000,             epsilon: 0.05,             sum q-table: 29058.558424565363,             elapsed time [min]: 6.38,              done [%]: 31.25             


 34%|███▍      | 325136/960000 [06:50<12:02, 879.23it/s]

WinRate: 0.68
episode: 325000,             epsilon: 0.04,             sum q-table: 29058.55846648893,             elapsed time [min]: 6.83,              done [%]: 33.85416666666667             


 36%|███▋      | 350214/960000 [07:17<11:24, 891.20it/s]

WinRate: 0.68
episode: 350000,             epsilon: 0.03,             sum q-table: 29058.559094734206,             elapsed time [min]: 7.29,              done [%]: 36.45833333333333             


 39%|███▉      | 375261/960000 [07:46<10:35, 920.63it/s]

WinRate: 0.52
episode: 375000,             epsilon: 0.02,             sum q-table: 29058.55909499938,             elapsed time [min]: 7.77,              done [%]: 39.0625             


 42%|████▏     | 400136/960000 [08:15<13:51, 673.51it/s]

WinRate: 0.76
episode: 400000,             epsilon: 0.02,             sum q-table: 29063.787855371997,             elapsed time [min]: 8.26,              done [%]: 41.66666666666667             


 44%|████▍     | 425154/960000 [08:43<10:24, 857.08it/s]

WinRate: 0.6
episode: 425000,             epsilon: 0.01,             sum q-table: 29063.787855371997,             elapsed time [min]: 8.73,              done [%]: 44.27083333333333             


 47%|████▋     | 450201/960000 [09:10<09:35, 885.70it/s]

WinRate: 0.52
episode: 450000,             epsilon: 0.01,             sum q-table: 29063.787855371997,             elapsed time [min]: 9.17,              done [%]: 46.875             


 50%|████▉     | 475286/960000 [09:37<08:59, 898.59it/s]

WinRate: 0.6
episode: 475000,             epsilon: 0.01,             sum q-table: 29063.787855371997,             elapsed time [min]: 9.62,              done [%]: 49.47916666666667             


 52%|█████▏    | 500127/960000 [10:04<09:13, 830.31it/s] 

WinRate: 0.56
episode: 500000,             epsilon: 0.01,             sum q-table: 29063.787855371997,             elapsed time [min]: 10.07,              done [%]: 52.083333333333336             


 55%|█████▍    | 525158/960000 [10:30<08:16, 876.01it/s] 

WinRate: 0.8
episode: 525000,             epsilon: 0.01,             sum q-table: 29063.787855371997,             elapsed time [min]: 10.51,              done [%]: 54.6875             


 57%|█████▋    | 550283/960000 [10:57<07:22, 926.31it/s] 

WinRate: 0.6
episode: 550000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 10.95,              done [%]: 57.291666666666664             


 60%|█████▉    | 575235/960000 [11:24<07:16, 880.54it/s]

WinRate: 0.56
episode: 575000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 11.4,              done [%]: 59.895833333333336             


 63%|██████▎   | 600260/960000 [11:50<06:49, 877.99it/s] 

WinRate: 0.56
episode: 600000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 11.84,              done [%]: 62.5             


 65%|██████▌   | 625219/960000 [12:17<06:09, 904.84it/s]

WinRate: 0.32
episode: 625000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 12.29,              done [%]: 65.10416666666666             


 68%|██████▊   | 650133/960000 [12:43<05:50, 884.53it/s] 

WinRate: 0.76
episode: 650000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 12.73,              done [%]: 67.70833333333334             


 70%|███████   | 675185/960000 [13:10<05:23, 880.56it/s] 

WinRate: 0.72
episode: 675000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 13.16,              done [%]: 70.3125             


 73%|███████▎  | 700247/960000 [13:37<05:04, 853.62it/s]

WinRate: 0.84
episode: 700000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 13.61,              done [%]: 72.91666666666666             


 76%|███████▌  | 725239/960000 [14:00<03:36, 1083.16it/s]

WinRate: 0.6
episode: 725000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 14.0,              done [%]: 75.52083333333334             


 78%|███████▊  | 750201/960000 [14:23<03:24, 1027.52it/s]

WinRate: 0.64
episode: 750000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 14.39,              done [%]: 78.125             


 81%|████████  | 775287/960000 [14:46<02:57, 1040.37it/s]

WinRate: 0.6
episode: 775000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 14.77,              done [%]: 80.72916666666666             


 83%|████████▎ | 800221/960000 [15:09<02:35, 1027.11it/s]

WinRate: 0.68
episode: 800000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 15.15,              done [%]: 83.33333333333334             


 86%|████████▌ | 825268/960000 [15:32<02:05, 1076.89it/s]

WinRate: 0.52
episode: 825000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 15.53,              done [%]: 85.9375             


 89%|████████▊ | 850217/960000 [15:55<01:47, 1018.89it/s]

WinRate: 0.68
episode: 850000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 15.91,              done [%]: 88.54166666666666             


 91%|█████████ | 875270/960000 [16:17<01:21, 1037.35it/s]

WinRate: 0.68
episode: 875000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 16.3,              done [%]: 91.14583333333334             


 94%|█████████▍| 900231/960000 [16:40<00:56, 1050.89it/s]

WinRate: 0.32
episode: 900000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 16.67,              done [%]: 93.75             


 96%|█████████▋| 925255/960000 [17:03<00:32, 1072.73it/s]

WinRate: 0.76
episode: 925000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 17.05,              done [%]: 96.35416666666666             


 99%|█████████▉| 950328/960000 [17:26<00:09, 1067.81it/s]

WinRate: 0.56
episode: 950000,             epsilon: 0.0,             sum q-table: 29063.787855371997,             elapsed time [min]: 17.43,              done [%]: 98.95833333333334             


100%|██████████| 960000/960000 [17:35<00:00, 909.92it/s] 


In [15]:
visited_states.shape[0]
print("Percent:",100*np.sum(visited_states > 0)/visited_states.shape[0])

Percent: 90.58881905391995


In [18]:
num_games = 1000
cur_win_rate, _ = play_tictactoe_with_random(env, qagent.qtable, state_dict, num_test_games=num_games)
win_history.append(sum(cur_win_rate)/num_games)
print("WinRate:", sum(cur_win_rate)/num_games)

WinRate: 0.57


In [None]:
qtable = qagent.qtable
save_qtable(qtable, 'tables', "q_table_best2")

q_table_best2.npy saved!


In [None]:
qtable = load_qtable('tables', "q_table_best2")

In [27]:
#check how correct is q-table

state = np.random.choice(np.arange(env.observation_space.n))
# state_dict[state]
print(state)

key = list(filter(lambda x: state_dict[x] == state, state_dict))[0]
print(np.array(key[:-1]).reshape(3,3))
print("Turn was:", key[-1])
print(np.round(qagent.qtable[state].reshape(3,3),2))

# q = np.round(qtable[state],2)
# print("Action: ",np.argmax(q))

state_pure = np.array(key[:-1])
action_space = np.where(state_pure == 0)[0]

best_action = max(action_space, key=lambda action: qagent.qtable[state, action])
print(best_action)
# array = np.array(qtable[state, :])
# order = array.argsort()
# ranks = order.argsort()
# max_value_rank = np.min(ranks[action_space])
# action = np.where(ranks == max_value_rank)[0][0]
# action

6220
[[0 1 2]
 [0 1 0]
 [0 0 0]]
Turn was: 1
[[ 0.1  0.   0. ]
 [ 0.1  0.   0.1]
 [ 0.1 -1.   0.1]]
3


In [165]:
play_tictactoe(env, qtable, state_dict, num_test_games=1)

Agent beginns
--------------------
╒═══╤═══╤═══╕
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
--------------------
move Agent
Action: 1
╒═══╤═══╤═══╕
│ - │ O │ - │
├───┼───┼───┤
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛


--------------------
Move Human
Action: 0
╒═══╤═══╤═══╕
│ X │ O │ - │
├───┼───┼───┤
│ - │ - │ - │
├───┼───┼───┤
│ - │ - │ - │
╘═══╧═══╧═══╛
-1


--------------------
move Agent
Action: 6
╒═══╤═══╤═══╕
│ X │ O │ - │
├───┼───┼───┤
│ - │ - │ - │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛


--------------------
Move Human
Action: 5
╒═══╤═══╤═══╕
│ X │ O │ - │
├───┼───┼───┤
│ - │ - │ X │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛
-1


--------------------
move Agent
Action: 3
╒═══╤═══╤═══╕
│ X │ O │ - │
├───┼───┼───┤
│ O │ - │ X │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛


--------------------
Move Human
Action: 4
╒═══╤═══╤═══╕
│ X │ O │ - │
├───┼───┼───┤
│ O │ X │ X │
├───┼───┼───┤
│ O │ - │ - │
╘═══╧═══╧═══╛
-1


------