In [1]:
import numpy as np
from itertools import product
from collections import defaultdict


transformations = [
    lambda x: x,                         # Identity
    lambda x: np.rot90(x, 1),            # Rotate 90°
    lambda x: np.rot90(x, 2),            # Rotate 180°
    lambda x: np.rot90(x, 3),            # Rotate 270°
    lambda x: np.fliplr(x),              # Horizontal reflection
    lambda x: np.flipud(x),              # Vertical reflection
    lambda x: np.transpose(x),           # Diagonal reflection (TL-BR)
    # lambda x: np.fliplr(np.transpose(x)) # Diagonal reflection (TR-BL)
]

inverse_transformations = [
    lambda x: x,                         # Identity
    lambda x: np.rot90(x, 3),            # Rotate 90° inverse (rotate 270°)
    lambda x: np.rot90(x, 2),            # Rotate 180° inverse
    lambda x: np.rot90(x, 1),            # Rotate 270° inverse (rotate 90°)
    lambda x: np.fliplr(x),              # Horizontal reflection inverse
    lambda x: np.flipud(x),              # Vertical reflection inverse
    lambda x: np.transpose(x),           # Diagonal reflection (TL-BR) inverse
    # lambda x: np.transpose(np.fliplr(x))# Diagonal reflection (TR-BL) inverse
]

original_actions = np.array(range(9)).reshape(3, 3)
for i, transform in enumerate(transformations):
    assert inverse_transformations[i](transform(original_actions)).flatten().tolist() == original_actions.flatten().tolist()

transformed_actions = [transform(original_actions).flatten().tolist() for transform in transformations]
for i in range(len(transformed_actions)):
    for j in range(i + 1, len(transformed_actions)):
        assert transformed_actions[i] != transformed_actions[j]

def generate_all_valid_boards():
    symbols = [' ', 'X', 'O']
    all_boards = list(product(symbols, repeat=9))  # Generate all 3^9 combinations
    assert (3**9) == len(all_boards)
    all_valid_boards = []

    for board in all_boards:
        x_count = board.count('X')
        o_count = board.count('O')
        
        # Valid boards must satisfy these conditions:
        if x_count == o_count or x_count == o_count + 1:
            all_valid_boards.append(board)

    return all_valid_boards

def board_to_matrix(board):
    return np.array(board).reshape(3, 3)

def matrix_to_board(matrix):
    return matrix.flatten().tolist()

def generate_symmetries(board):
    matrix = board_to_matrix(board)
    symmetries = [transform(matrix) for transform in transformations]
    return [matrix_to_board(sym) for sym in symmetries]

def get_canonical_representation(board):
    symmetries = generate_symmetries(board)
    min_symmetry = min(symmetries)
    return tuple(min_symmetry), symmetries.index(min_symmetry)

# Generate all empty positions on the board
def get_empty_positions(board):
    return [i for i, cell in enumerate(board) if cell == ' ']

def display_board(board):
    print(f" {board[0]} | {board[1]} | {board[2]} ")
    print("---+---+---")
    print(f" {board[3]} | {board[4]} | {board[5]} ")
    print("---+---+---")
    print(f" {board[6]} | {board[7]} | {board[8]} ")
    print("\n")

def get_next_player(board):
    x_count = board.count('X')
    o_count = board.count('O')
    return 'X' if x_count == o_count else 'O'

def get_next_board(board, action):
    new_board = list(board)
    next_player = get_next_player(board)
    if next_player == 'X':
        new_board[action] = 'X'
    else:
        new_board[action] = 'O'
    
    return tuple(new_board)

# Generate all valid Tic-Tac-Toe boards
all_valid_boards = generate_all_valid_boards()
state_action_pairs = 0
for valid_board in all_valid_boards:
    empty_positions = get_empty_positions(valid_board)
    state_action_pairs += len(empty_positions)

# Generate all canonical Tic-Tac-Toe boards
all_canonical_boards = set()
get_canonical_boards = {}
get_transform = {}
get_inverse_transform = {}
get_canonical_actions = {}
get_inverse_canonical_actions = {}
for valid_board in all_valid_boards:
    canonical_board, transform_idx = get_canonical_representation(valid_board)
    all_canonical_boards.add(canonical_board)
    get_canonical_boards[valid_board] = canonical_board
    get_transform[valid_board] = transformations[transform_idx]
    get_inverse_transform[valid_board] = inverse_transformations[transform_idx]
    get_inverse_canonical_actions[valid_board] = get_transform[valid_board](original_actions).flatten().tolist()
    get_canonical_actions[valid_board] = get_inverse_transform[valid_board](original_actions).flatten().tolist()

all_canonical_boards = sorted(list(all_canonical_boards))
all_canonical_actions = {}
canonical_state_action_pairs = 0
for canonical_board in all_canonical_boards:
    empty_positions = get_empty_positions(canonical_board)
    all_canonical_actions[canonical_board] = empty_positions
    canonical_state_action_pairs += len(empty_positions)

def get_canonical_board(board):
    return get_canonical_boards[tuple(board)]

def get_canonical_action(board, action):
    return get_canonical_actions[board][action]

def get_inverse_canonical_action(board, canonical_action):
    return get_inverse_canonical_actions[board][canonical_action]

def canonicalize(board, action):
    canonical_board = get_canonical_board(board)
    canonical_action = get_canonical_action(board, action)
    return canonical_board, canonical_action

canonical_board_to_next_canonical_board = {}
all_next_canonical_boards = set()
for canonical_board in all_canonical_boards:
    canonical_actions_to_next_canonical_board = {}
    for canonical_action in all_canonical_actions[canonical_board]:
        next_board = get_next_board(canonical_board, canonical_action)
        next_canonical_board = get_canonical_board(next_board)
        all_next_canonical_boards.add(next_canonical_board)
        canonical_actions_to_next_canonical_board[canonical_action] = next_canonical_board

    canonical_board_to_next_canonical_board[canonical_board] = canonical_actions_to_next_canonical_board

qMatrix = defaultdict(lambda: -1)
for i, next_canonical_board in enumerate(all_next_canonical_boards):
    qMatrix[next_canonical_board] = i

def get(board=None, action=None):
    if board is None and action is None:
        return qMatrix
    if action is None and board is not None:
        actions = get_empty_positions(board)
        canonical_actions = [get_canonical_action(board, action) for action in actions]
        canonical_board = get_canonical_board(board)
        return [qMatrix[canonical_board_to_next_canonical_board[canonical_board][canonical_action]] for canonical_action in canonical_actions]
    if action is not None and board is not None:
        return qMatrix[canonical_board_to_next_canonical_board[get_canonical_board(board)][get_canonical_action(board, action)]]

def set(board, action, value):
    qMatrix[canonical_board_to_next_canonical_board[get_canonical_board(board)][get_canonical_action(board, action)]] = value

def displayQ(valid_board):
    valid_actions = get_empty_positions(valid_board)
    canonical_board = get_canonical_board(valid_board)
    Qvalid_board = list(valid_board)
    for valid_action in valid_actions:
        Qvalid_board[valid_action] = qMatrix[canonical_board_to_next_canonical_board[canonical_board][get_canonical_action(valid_board, valid_action)]]

    print("\n")
    print(np.array(Qvalid_board).reshape(3,3))

print(f"Number of boards:                               {(3**9)}")
print(f"Number of state-action pairs:                   {(3**9)*9}")
print(f"Number of valid boards:                         {len(all_valid_boards)}")    
print(f"Number of valid state-action pairs:             {state_action_pairs}")
print(f"Number of canonical boards:                     {len(all_canonical_boards)}")
print(f"Number of canonical state-action pairs:         {canonical_state_action_pairs}")
print(f"Number of next canonical boards:                {len(all_next_canonical_boards)}")
print(f"Number of reduced canonical state-action pairs: {len(all_next_canonical_boards)}")

displayQ(('O', ' ', ' ', ' ', ' ', ' ', ' ', ' ', 'X'))
displayQ((' ', ' ', 'O', ' ', ' ', ' ', 'X', ' ', ' '))


Number of boards:                               19683
Number of state-action pairs:                   177147
Number of valid boards:                         6046
Number of valid state-action pairs:             19107
Number of canonical boards:                     1520
Number of canonical state-action pairs:         4808
Number of next canonical boards:                1073
Number of reduced canonical state-action pairs: 1073


[['O' '905' '320']
 ['3' '281' '668']
 ['802' '595' 'X']]


[['802' '3' 'O']
 ['595' '281' '905']
 ['X' '668' '320']]


In [5]:
qMatrix = defaultdict(lambda: -1)
board = (' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
set(board, 0, 0.2)
displayQ(board)



[[ 0.2 -1.   0.2]
 [-1.  -1.  -1. ]
 [ 0.2 -1.   0.2]]


In [None]:
valid_board = all_valid_boards[5]
actions = get_empty_positions(valid_board)
canonical_actions = [get_canonical_action(valid_board, action) for action in actions]
print(f"Board:             {valid_board}")
print(f"Canonical board:   {get_canonical_board(valid_board)}")
print(f"Actions:           {actions}")
print(f"Canonical actions: {sorted(canonical_actions)}")
print(f"Canonical actions: {sorted(all_canonical_actions[get_canonical_board(valid_board)])}")

tot = 0
for i, valid_board in enumerate(all_valid_boards):
    actions = get_empty_positions(valid_board)
    canonical_board = get_canonical_board(valid_board)
    canonical_actions1 = sorted(all_canonical_actions[canonical_board])
    canonical_actions2 = sorted([get_canonical_action(valid_board, action) for action in actions])
    if not canonical_actions2 == canonical_actions1:
        tot += 1

print(f"Number of wrong canonical actions:      {tot}/{len(all_valid_boards)}")

In [None]:
for i in range(10):
    if i == 10 - 1:
        print(f"final i: {i}")
    else:
        print(f"{i}, {i + 1}")

for i in reversed(range(10)):
    if i == 10 - 1:
        print(f"final i: {i}")
    else:
        print(f"{i}, {i + 1}")