In [1]:
import numpy as np
import random

In [2]:
def convert_to_b3(number):
    digits = " OX"
    result = ""

    base = 3
    while number > 0:
        result = digits[number % base] + result
        number //= base

    # Prepend leading spaces to make it 9 trigits.
    while len(result) < 9:
        result = digits[0] + result
    
    return result

def convert_from_b3(code):
    digits = " OX"
    result = 0
    base = 3
    for i in range(9):
        result += digits.index(code[i]) * (base ** (8 - i))
    return result

def print_state(state):
    def print_row(row):
        vert = ""
        for c in row:
            print(vert + c, end="")
            vert = "|"

    code = convert_to_b3(state)
    horiz = ""
    for i in range(3):
        if len(horiz) != 0:
            print(horiz)
        row = code[i*3:i*3+3]
        print_row(row)
        print()
        horiz = 5*"-"

def generate_next_states(state, player="X"):
    code = convert_to_b3(state)
    next_states = []
    for i in range(9):
        if code[i] == " ":
            new_code = code[:i] + player + code[i+1:]
            next_states.append(convert_from_b3(new_code))
    return next_states

def explore(state, player="X"):
    next_states = generate_next_states(state, player)

    return random.choice(next_states)

def guided_explore(state, p, player="X"):
    next_states = generate_next_states(state, player)

    # Sample from the next possible states according
    # to how likely they are to lead to a win, but
    # allow for randomness.
    if len(next_states) == 0:
        return state

    if player == "X":
        pp = p[next_states]
    else:
        pp = 1 - p[next_states]
    random_index = random.choices(range(len(next_states)), weights=pp)[0]
    random_state = next_states[random_index]
    
    return random_state

def exploit(state, p, player="X"):
    next_states = generate_next_states(state, player)
    
    if len(next_states) == 0:
        return state

    if player == "X":
        pp = p[next_states]
    else:
        pp = 1 - p[next_states]
    best_index = np.argmax(pp)
    best_state = next_states[best_index]
    
    return best_state
    
def is_win(state, player="X"):
    code = convert_to_b3(state)
    # Check rows
    for i in range(3):
        if code[i*3:i*3+3] == 3*player:
            return True
    # Check columns
    for i in range(3):
        if code[i] == code[i+3] == code[i+6] == player:
            return True
    # Check diagonals
    if code[0] == code[4] == code[8] == player:
        return True
    if code[2] == code[4] == code[6] == player:
        return True
    
    return None

In [3]:
state = 0
print("First state:", state)
print_state(state)

num_states = 3**9
print("Max state:", num_states - 1)
print_state(num_states - 1)

state = random.randint(0, num_states - 1)
print("Random state:", state)
print_state(state)

First state: 0
 | | 
-----
 | | 
-----
 | | 
Max state: 19682
X|X|X
-----
X|X|X
-----
X|X|X
Random state: 17610
X|X| 
-----
 |O|O
-----
 |X| 


In [4]:
# Initialize states
p = np.zeros(num_states)
for i in range(num_states):
    if is_win(i, "X"):
        p[i] = 1.0
    elif is_win(i, "O"):
        p[i] = 0.0
    else:
        p[i] = 0.5

initial_state = 0
visited_states = set()
visited_states.add(initial_state)
num_games = 200_000
p_explore = 0.1
alpha = 0.1
for i in range(num_games):
    if i % (num_games // 10) == 0:
        print("Game", i)
    player = "X"
    state = initial_state
    state_sequence = [state]
    for j in range(9):
        if np.random.rand() < p_explore:
            # Explore
            new_state = explore(state, player)
            #new_state = guided_explore(state, player)
        else:
            # Exploit
            new_state = exploit(state, p, player)
        
        state = new_state
        state_sequence.append(state)
        visited_states.add(state)

        if is_win(state, player):
            # Update probabilities. p always refers to the probability that "X" wins.
            if player == "X":
                pp = p[state_sequence]
            else:
                pp = 1 - p[state_sequence]
            
            # pp is the probability that the current player, no matter it is "X" or "O", wins.
            for k in range(len(state_sequence) - 2, -1, -1):
                pp[k] += alpha * (pp[k + 1] - pp[k])
            
            if player == "X":
                for k in range(len(state_sequence) - 1):
                    p[state_sequence[k]] = pp[k]
            else:
                for k in range(len(state_sequence) - 1):
                    p[state_sequence[k]] = 1 - pp[k]

            break
        
        if player == "X":
            player = "O"
        else:
            player = "X"

print("Number of visited states:", len(visited_states), "out of", num_states)


Game 0
Game 20000
Game 40000
Game 60000
Game 80000
Game 100000
Game 120000
Game 140000
Game 160000
Game 180000
Number of visited states: 4292 out of 19683


In [5]:
print("Greedy play")
player = "X"
state = 0
print_state(state)
for i in range(9):
    print("Player", player)
    possible_moves = generate_next_states(state, player)
    # If player is "X", make it smart; use only exploit.
    if player == "X":
        state = exploit(state, p, player)
    else:
        # If player is "O", make it dumb; use only explore.
        state = guided_explore(state, p, player)
    print("Player's move:")
    print_state(state)
    if is_win(state, player):
        print(player, "wins!")
        break
    if i == 8:
        print("Draw!")
        break

    if player == "X":
        player = "O"
    else:
        player = "X"

Greedy play
 | | 
-----
 | | 
-----
 | | 
Player X
Player's move:
 | | 
-----
 |X| 
-----
 | | 
Player O
Player's move:
 |O| 
-----
 |X| 
-----
 | | 
Player X
Player's move:
X|O| 
-----
 |X| 
-----
 | | 
Player O
Player's move:
X|O| 
-----
 |X| 
-----
 | |O
Player X
Player's move:
X|O| 
-----
X|X| 
-----
 | |O
Player O
Player's move:
X|O|O
-----
X|X| 
-----
 | |O
Player X
Player's move:
X|O|O
-----
X|X|X
-----
 | |O
X wins!


In [7]:

def print_state_value(code):
    state = convert_from_b3(code)
    print("Value for code " + str(p[state]) + "\t|" + code + "|")

code_array= [
    "  O X XO ",
    "O O X X  ",
    "OXO X XO ",
    "OX XXO O ",
    "OX XXO OX",
    "OXXXXO O ",
    "OX XXOXO ",
    "         ",
    "   XO    ",
    "   X O   "
]

for code in code_array:
    print_state_value(code)


Value for code 0.9986230167294833	|  O X XO |
Value for code 0.7332161770587183	|O O X X  |
Value for code 0.890678422635053	|OXO X XO |
Value for code 0.9999610454891611	|OX XXO O |
Value for code 0.5	|OX XXO OX|
Value for code 0.9999999999999991	|OXXXXO O |
Value for code 0.9603656337632881	|OX XXOXO |
Value for code 0.9302821570081439	|         |
Value for code 0.3981000346047587	|   XO    |
Value for code 0.592675327540317	|   X O   |
