In [1]:
import numpy as np


class TicTacToe:

    def __init__(self, policy=None):
        self.policy = policy

        self.reset()

    def reset(self):
        self.board = np.zeros((3, 3), dtype=np.uint8)
        self.state = 0

        return self.state

    def check_board_is_full(self):
        return (self.board == 0).sum() == 0

    def calc_reward_and_done(self):
        reward = -0.1
        done = False

        for i in range(3):
            row = self.board[i].tolist()
            if row == [1, 1, 1]:
                reward += 10.1
                done = True
            elif row == [2, 2, 2]:
                reward -= 10.
                done = True

        for j in range(3):
            col = self.board[:, j].tolist()
            if col == [1, 1, 1]:
                reward += 10.1
                done = True
            elif col == [2, 2, 2]:
                reward -= 10.
                done = True

        diag_main = [self.board[k, k] for k in range(3)]
        diag_aux = [self.board[0, 2], self.board[1, 1], self.board[2, 0]]
        if diag_main == [1, 1, 1] or diag_aux == [1, 1, 1]:
            reward += 10.1
            done = True
        elif diag_main == [2, 2, 2] or diag_aux == [2, 2, 2]:
            reward -= 10.
            done = True

        filled = self.check_board_is_full()
        if filled and not done:
            done = filled
            reward -= 10.

        return reward, done

    def calc_state(self):
        self.state = np.sum([digit*pow(3, i) for i, digit in enumerate(self.board.flatten().tolist())])

        return self.state

    def take_action(self, action, act_as_O):
        self.board = self.board.flatten()
        self.board[action] = 1 if not act_as_O else 2
        self.board = self.board.reshape((3, 3))

        self.calc_state()

    def play_as_opponent(self, self_play, self_play_type):
        while self_play and (self.board == 0).sum() != 0:
            if self_play_type == "q_policy":
                opponent_action = self.policy.act_greedy_sampled(self.state)
            elif self_play_type == "random":
                opponent_action = np.random.randint(9)

            if self.board.flatten()[opponent_action] == 0:
                self.take_action(opponent_action, act_as_O=True)
                break
            else:
                self_play_type = "random"

    def step(self, action, self_play=True, self_play_type="random", act_as_O=False):    
        if self.board.flatten()[action] != 0: # 1 is X (as in the player) and 2 is O (as in the opponent)
            return self.state, -1.1, False, {"wrong input": True}

        self.take_action(action, act_as_O)

        reward, done = self.calc_reward_and_done()

        if not done:
            self.play_as_opponent(self_play, self_play_type)

        return self.state, reward, done, {"wrong input": False}

    def render(self):
        print('_'*25)
        for i in range(3):
            print('|'+' '*3, end="")
    
            for j in range(3):
                if self.board[i, j] == 0:
                    print(' ', end="")
                elif self.board[i, j] == 1:
                    print('X', end="")
                elif self.board[i, j] == 2:
                    print('O', end="")

                if j != 2:
                    print(' '*3+'|'+' '*3, end="")
            
            print(' '*3+'|')
            if i != 2:
                print('|'+'_'*7+'|'+'_'*7+'|'+'_'*7+'|')
        
        print('|'+'_'*7+'|'+'_'*7+'|'+'_'*7+'|')
        print()

    def play_against_opponent(self, policy_fn):
        state = self.reset()
        info = {"wrong input": False}

        while True:
            if not info["wrong input"]:
                action = policy_fn(state)
                _, _, done, _  = self.step(action, self_play=False)
                self.render()
                if done:
                    break
            else:
                print("Wrong choice! Choose again.")

            action_opponent = int(input("O's turn: "))
            state, _, done, info = self.step(action_opponent, self_play=False, act_as_O=True)
            self.render()
            if done:
                break

In [2]:
class Policy:

    def __init__(self, table_shape=(3**9, 9)):
        self.Q_table = np.zeros(table_shape)
    
    def save(self, path="./Policy.npy"):
        np.save(path, self.Q_table)

    def load(self, path="./Policy.npy"):
        self.Q_table = np.load(path)

    def train(self, env, n_epochs, max_steps=100, gamma=0.95, lr=0.1):
        for epoch in range(1, n_epochs+1):
            epsilon = max(0.01, 1 - epoch/(0.6*n_epochs))

            state = env.reset()
            rewards = 0.

            for step in range(max_steps):
                action = self.act_epsilon_greedy(state, epsilon)

                state_new, reward, done, info = env.step(action)

                self.Q_table[state][action] += lr * (reward + gamma * self.Q_table[state_new].max() - self.Q_table[state][action])

                state = state_new
                rewards += reward

                if done:
                    break
            
            print(f"\rEpoch: {epoch}, Epsilon: {epsilon:.4f}, Rewards: {rewards:.4f}", end="")

    def act_greedy(self, state):
        action = self.Q_table[state].argmax()

        return action

    def act_epsilon_greedy(self, state, epsilon=0.01):
        if np.random.rand() > epsilon:
            action = self.act_greedy(state)
        else:
            action = np.random.randint(9)

        return action
    
    def act_greedy_sampled(self, state, temperature=0.5):
        preds = self.Q_table[state]
        exp_preds = np.exp(np.log(preds + np.abs(preds.min()) + 1e-5) / 0.5)
        exp_preds /= np.sum(exp_preds) + 1e-7
        action = np.random.multinomial(1, exp_preds, 1).argmax()

        return action

In [3]:
q_policy = Policy()

In [4]:
env = TicTacToe(policy=q_policy)

In [5]:
q_policy.train(env, n_epochs=100_000)
q_policy.save()

Epoch: 100000, Epsilon: 0.0100, Rewards: 9.800000

In [6]:
q_policy.load()
q_policy.Q_table

array([[7.84111308, 8.01525486, 7.84942915, ..., 8.07641734, 7.75388114,
        8.03591443],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [7]:
env.play_against_opponent(policy_fn=q_policy.act_greedy_sampled)

_________________________
|       |       |   X   |
|_______|_______|_______|
|       |       |       |
|_______|_______|_______|
|       |       |       |
|_______|_______|_______|

_________________________
|   O   |       |   X   |
|_______|_______|_______|
|       |       |       |
|_______|_______|_______|
|       |       |       |
|_______|_______|_______|

_________________________
|   O   |       |   X   |
|_______|_______|_______|
|       |       |   X   |
|_______|_______|_______|
|       |       |       |
|_______|_______|_______|

_________________________
|   O   |       |   X   |
|_______|_______|_______|
|       |       |   X   |
|_______|_______|_______|
|       |       |   O   |
|_______|_______|_______|

_________________________
|   O   |       |   X   |
|_______|_______|_______|
|       |   X   |   X   |
|_______|_______|_______|
|       |       |   O   |
|_______|_______|_______|

_________________________
|   O   |       |   X   |
|_______|_______|_______|
|      