In [0]:
import numpy as np
import pickle

class State:
    def __init__(self):
        self.data = np.zeros((3, 3))
        self.winner = None
        self.hash_val = None
        self.end = None

    def hash(self):
        if self.hash_val is None:
            self.hash_val = 0
            for i in np.nditer(self.data):
                self.hash_val = self.hash_val * 3 + i + 1
        return self.hash_val

    def is_end(self):
        if self.end is not None:
            return self.end
        results = []
        for i in range(3):
            results.append(np.sum(self.data[i, :]))
        for i in range(3):
            results.append(np.sum(self.data[:, i]))

        trace = 0
        reverse_trace = 0
        for i in range(3):
            trace += self.data[i, i]
            reverse_trace += self.data[i, 3 - 1 - i]
        results.append(trace)
        results.append(reverse_trace)

        for result in results:
            if result == 3:
                self.winner = 1
                self.end = True
                return self.end
            if result == -3:
                self.winner = -1
                self.end = True
                return self.end

        sum_values = np.sum(np.abs(self.data))
        if sum_values == 9:
            self.winner = 0
            self.end = True
            return self.end

        self.end = False
        return self.end

    def next_state(self, i, j, symbol):
        new_state = State()
        new_state.data = np.copy(self.data)
        new_state.data[i, j] = symbol
        return new_state

    def print_state(self):
        for i in range(3):
            print('-------------')
            out = '| '
            for j in range(3):
                if self.data[i, j] == 1:
                    token = 'X'
                elif self.data[i, j] == -1:
                    token = 'O'
                else:
                    token = ' '
                out += token + ' | '
            print(out)
        print('-------------')

In [0]:
def get_all_states_impl(current_state, current_symbol, all_states):
    for i in range(3):
        for j in range(3):
            if current_state.data[i][j] == 0:
                new_state = current_state.next_state(i, j, current_symbol)
                new_hash = new_state.hash()
                if new_hash not in all_states:
                    is_end = new_state.is_end()
                    all_states[new_hash] = (new_state, is_end)
                    if not is_end:
                        get_all_states_impl(new_state, -current_symbol, all_states)


def get_all_states():
    current_symbol = 1
    current_state = State()
    all_states = dict()
    all_states[current_state.hash()] = (current_state, current_state.is_end())
    get_all_states_impl(current_state, current_symbol, all_states)
    return all_states


all_states = get_all_states()

In [0]:
class Judger:
    def __init__(self, player1, player2):
        self.p1 = player1
        self.p2 = player2
        self.current_player = None
        self.p1_symbol = 1
        self.p2_symbol = -1
        self.p1.set_symbol(self.p1_symbol)
        self.p2.set_symbol(self.p2_symbol)
        self.current_state = State()

    def reset(self):
        self.p1.reset()
        self.p2.reset()

    def alternate(self):
        while True:
            yield self.p1
            yield self.p2

    def play(self, print_state=False):
        alternator = self.alternate()
        self.reset()
        current_state = State()
        self.p1.set_state(current_state)
        self.p2.set_state(current_state)
        while True:
            player = next(alternator)
            i, j, symbol = player.act()
            next_state_hash = current_state.next_state(i, j, symbol).hash()
            current_state, is_end = all_states[next_state_hash]
            self.p1.set_state(current_state)
            self.p2.set_state(current_state)
            if is_end:
                if print_state:
                    current_state.print_state()
                return current_state.winner

In [0]:
class Player:
    def __init__(self, step_size=0.1, epsilon=0.1):
        self.estimations = dict()
        self.step_size = step_size
        self.epsilon = epsilon
        self.states = []
        self.greedy = []
        self.symbol = 0

    def reset(self):
        self.states = []
        self.greedy = []

    def set_state(self, state):
        self.states.append(state)
        self.greedy.append(True)

    def set_symbol(self, symbol):
        self.symbol = symbol
        for hash_val in all_states:
            state, is_end = all_states[hash_val]
            if is_end:
                if state.winner == self.symbol:
                    self.estimations[hash_val] = 1.0
                elif state.winner == 0:
                    self.estimations[hash_val] = 0.5
                else:
                    self.estimations[hash_val] = 0
            else:
                self.estimations[hash_val] = 0.5

    def backup(self):
        states = [state.hash() for state in self.states]

        for i in reversed(range(len(states) - 1)):
            state = states[i]
            td_error = self.greedy[i] * (
                self.estimations[states[i + 1]] - self.estimations[state]
            )
            self.estimations[state] += self.step_size * td_error

    def act(self):
        state = self.states[-1]
        next_states = []
        next_positions = []
        for i in range(3):
            for j in range(3):
                if state.data[i, j] == 0:
                    next_positions.append([i, j])
                    next_states.append(state.next_state(
                        i, j, self.symbol).hash())

        if np.random.rand() < self.epsilon:
            action = next_positions[np.random.randint(len(next_positions))]
            action.append(self.symbol)
            self.greedy[-1] = False
            return action

        values = []
        for hash_val, pos in zip(next_states, next_positions):
            values.append((self.estimations[hash_val], pos))
        np.random.shuffle(values)
        values.sort(key=lambda x: x[0], reverse=True)
        action = values[0][1]
        action.append(self.symbol)
        return action

    def save_policy(self):
        with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'wb') as f:
            pickle.dump(self.estimations, f)

    def load_policy(self):
        with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'rb') as f:
            self.estimations = pickle.load(f)

In [0]:
class HumanPlayer:
    def __init__(self, **kwargs):
        self.symbol = None
        self.keys = [7, 8, 9, 4, 5, 6, 1, 2, 3]
        self.state = None

    def reset(self):
        pass

    def set_state(self, state):
        self.state = state

    def set_symbol(self, symbol):
        self.symbol = symbol

    def act(self):
        self.state.print_state()
        key = int(input("Input your position:"))
        data = self.keys.index(key)
        i = data // 3
        j = data % 3
        return i, j, self.symbol

In [0]:
def train(epochs, print_every_n=1000):
    print('Training is started.')
    player1 = Player(epsilon=0.01)
    player2 = Player(epsilon=0.01)
    judger = Judger(player1, player2)
    for i in range(1, epochs + 1):
        winner = judger.play()
        if i % print_every_n == 0:
            print(f'Epoch {i}')
        player1.backup()
        player2.backup()
        judger.reset()
    player1.save_policy()
    player2.save_policy()
    print('Training is done.')

train(int(1e5))

Training is started.
Epoch 1000
Epoch 2000
Epoch 3000
Epoch 4000
Epoch 5000
Epoch 6000
Epoch 7000
Epoch 8000
Epoch 9000
Epoch 10000
Epoch 11000
Epoch 12000
Epoch 13000
Epoch 14000
Epoch 15000
Epoch 16000
Epoch 17000
Epoch 18000
Epoch 19000
Epoch 20000
Epoch 21000
Epoch 22000
Epoch 23000
Epoch 24000
Epoch 25000
Epoch 26000
Epoch 27000
Epoch 28000
Epoch 29000
Epoch 30000
Epoch 31000
Epoch 32000
Epoch 33000
Epoch 34000
Epoch 35000
Epoch 36000
Epoch 37000
Epoch 38000
Epoch 39000
Epoch 40000
Epoch 41000
Epoch 42000
Epoch 43000
Epoch 44000
Epoch 45000
Epoch 46000
Epoch 47000
Epoch 48000
Epoch 49000
Epoch 50000
Epoch 51000
Epoch 52000
Epoch 53000
Epoch 54000
Epoch 55000
Epoch 56000
Epoch 57000
Epoch 58000
Epoch 59000
Epoch 60000
Epoch 61000
Epoch 62000
Epoch 63000
Epoch 64000
Epoch 65000
Epoch 66000
Epoch 67000
Epoch 68000
Epoch 69000
Epoch 70000
Epoch 71000
Epoch 72000
Epoch 73000
Epoch 74000
Epoch 75000
Epoch 76000
Epoch 77000
Epoch 78000
Epoch 79000
Epoch 80000
Epoch 81000
Epoch 82000
Epoc

In [0]:
def play():
    print('The game has started, all the best.')
    print('"X" is for you\n"O" is for computer')
    print('Lets the battle commence')
    again = True
    while again:
        print('\n')
        player1 = HumanPlayer()
        player2 = Player(epsilon=0)
        judger = Judger(player1, player2)
        player2.load_policy()
        winner = judger.play(True)
        if winner == player2.symbol:
            print("You lose!\nBetter luck next time")
        elif winner == player1.symbol:
            print("You win!")
        else:
            print("It is a tie!")
        again = input('Ýou wanna play again: ')
        again = True if again in ['y', 'yes'] else False

In [0]:
play()

The game has started, all the best.
"X" is for you
"O" is for computer
Lets the battle commence


-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
Input your position:5
-------------
|   |   |   | 
-------------
|   | X |   | 
-------------
|   |   | O | 
-------------
Input your position:1
-------------
|   |   | O | 
-------------
|   | X |   | 
-------------
| X |   | O | 
-------------
Input your position:7
-------------
| X |   | O | 
-------------
|   | X | O | 
-------------
| X |   | O | 
-------------
You lose!
Better luck next time
Ýou wanna play again: no
