In [1]:
import numpy as np
from numpy.typing import NDArray
np.__version__

'2.2.3'

In [2]:
type State = NDArray[np.int8]

class TicTacToe:
    def __init__(self):
        self.row_count: int = 3
        self.col_count: int = 3
        self.action_size: int = self.row_count * self.col_count
        self.board: State = self.get_initial_state()

    def get_initial_state(self) -> State:
        return np.zeros((self.row_count, self.col_count), dtype=np.int8)

    def get_next_state(self, state: State, action: int, player: int) -> State:
        row = action // self.col_count
        col = action % self.col_count
        state[row, col] = player
        return state
    
    def get_valid_moves(self, state: State) -> NDArray[np.bool_]:
        return (state.reshape(-1) == 0).astype(np.bool_)
    
    def check_win(self, state: State, action: int) -> bool:
        row = action // self.col_count
        col = action % self.col_count
        player = state[row, col]
    
        return (
            np.sum(state[row, :]) == player * self.col_count
            or np.sum(state[:, col]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.fliplr(state))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state: State, action: int):
        if self.check_win(state, action):
            return 1, True
        elif not np.any(self.get_valid_moves(state)):
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player


In [3]:
ttt = TicTacToe()
player = 1

state = ttt.get_initial_state()

while True:
    print(state)
    valid_moves = ttt.get_valid_moves(state)
    print("valid_moves", [i for i in range(ttt.action_size) if valid_moves[i] == 1])
    try:
        action = int(input(f"{player}:"))
    except TypeError:
        print("numbers only")
        continue

    if valid_moves[action] == 0:
        print("action not valid")
        continue

    state = ttt.get_next_state(state, action, player)
    value, terminated = ttt.get_value_and_terminated(state, action)

    if terminated:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break
    
    player = ttt.get_opponent(player)


[[0 0 0]
 [0 0 0]
 [0 0 0]]
valid_moves [0, 1, 2, 3, 4, 5, 6, 7, 8]
[[0 0 0]
 [0 1 0]
 [0 0 0]]
valid_moves [0, 1, 2, 3, 5, 6, 7, 8]
action not valid
[[0 0 0]
 [0 1 0]
 [0 0 0]]
valid_moves [0, 1, 2, 3, 5, 6, 7, 8]
[[-1  0  0]
 [ 0  1  0]
 [ 0  0  0]]
valid_moves [1, 2, 3, 5, 6, 7, 8]
[[-1  1  0]
 [ 0  1  0]
 [ 0  0  0]]
valid_moves [2, 3, 5, 6, 7, 8]
[[-1  1 -1]
 [ 0  1  0]
 [ 0  0  0]]
valid_moves [3, 5, 6, 7, 8]
[[-1  1 -1]
 [ 0  1  0]
 [ 0  1  0]]
1 won
