In [1]:
import numpy as np
from dataclasses import dataclass
import random
from collections import Counter
from collections import defaultdict

In [12]:
PLAYER_1 = 1
PLAYER_2 = 2
EMPTY = 0

@dataclass(frozen=True)
class Move:
    player: int
    row: int
    column: int

class BoardState:
    def __init__(self, board=None):
        if board is None:
            self.board = np.zeros((6, 7), int)
        else:
            self.board = board

        self.n_rows = 6
        self.n_columns = 7
        self.outcomes = np.zeros(4, int)
        self.win_frac = 0.0

    def _check_winner(self, player):
        # check rows
        for r in range(self.n_rows):
            if (self.board[r] == player).all():
                return True

        # check columns
        for c in range(self.n_columns):
            if (self.board[:, c] == player).all():
                return True

        # check diagonals
        if (np.diag(self.board) == player).all():
            return True

        if (np.diag(np.rot90(self.board)) == player).all():
            return True

        return False

    def determine_winner(self):
        if self._check_winner(PLAYER_1):
            return True, PLAYER_1

        if self._check_winner(PLAYER_2):
            return True, PLAYER_2

        return False, EMPTY

    def is_end_state(self):
        if self.is_initial_state():
            return False

        # If any cell is empty, then
        # there is a possible move to make
        has_empty_cells = (self.board == EMPTY).any()

        has_winner, _ = self.determine_winner()

        return has_winner or not has_empty_cells

    def is_initial_state(self):
        return (self.board == EMPTY).all()

    def __hash__(self):
        return hash(tuple(self.board.flatten()))

    def __eq__(self, other):
        return (self.board == other.board).all()

    def __repr__(self):
        return str(self.board)

    def determine_next_player(self):
        n_player1_pieces = np.count_nonzero(self.board.flatten() == PLAYER_1)
        n_player2_pieces = np.count_nonzero(self.board.flatten() == PLAYER_2)

        # turn 1, 0 pieces, player 1
        # turn 2, 1 player 1 piece, 0 player 2 pieces, player 2
        # turn 3, 1 player 1 piece, 1 player 1 piece, player 1
        # turn 3, 2 player 1 piece, 1 player 1 piece, player 2

        if n_player1_pieces == n_player2_pieces:
            next_player = PLAYER_1
        else:
            next_player = PLAYER_2

        return next_player

    def determine_previous_player(self):
        n_player1_pieces = np.count_nonzero(self.board.flatten() == PLAYER_1)
        n_player2_pieces = np.count_nonzero(self.board.flatten() == PLAYER_2)

        # turn 1, 0 pieces, none
        # turn 2, 1 player 1 piece, 0 player 2 pieces, player 1 prev
        # turn 3, 1 player 1 piece, 1 player 1 piece, player 2 prev
        # turn 3, 2 player 1 piece, 1 player 1 piece, player 1 prev

        if n_player1_pieces == 0 and n_player2_pieces == 0:
            player = EMPTY
        elif n_player1_pieces == n_player2_pieces:
            player = PLAYER_2
        else:
            player = PLAYER_1

        return player

    def enumerate_moves(self):
        if self.is_end_state():
            return []

        next_player = self.determine_next_player()

        moves = []
        for r in range(self.n_rows):
            for c in range(self.n_columns):
                if self.board[r, c] == EMPTY:
                    move = Move(next_player, r, c)
                    moves.append(move)

        return moves

    def play_move(self, move: Move):
        if self.board[move.row, move.column] != EMPTY:
            print(self.board, move)
            print(self.board[move.row, move.column], EMPTY)
            raise Exception("Cannot place piece at (r, c)".format(move.row, move.column))
        else:
            self.board[move.row, move.column] = move.player

    def copy(self):
        return BoardState(self.board.copy())

class MCTree:
    def __init__(self):
        self.root = BoardState()

        # BoardState -> Move -> BoardState
        self.children = defaultdict(lambda: dict())
        self.children[self.root] = dict()

    def reroot_tree(self, board_state):
        self.root = board_state.copy()

    def sample_paths(self, n_paths):
        for i in range(n_paths):
            current = self.root
            while True:
                available_moves = current.enumerate_moves()
                if len(available_moves) == 0:
                    break
                chosen_move = random.choice(available_moves)
                next_state = current.copy()
                next_state.play_move(chosen_move)
                self.children[current][chosen_move] = next_state
                current = next_state

    def propagate_outcomes(self):
        return self._propagate_outcomes(self.root)

    def _propagate_outcomes(self, root):
        # clear outcomes
        root.outcomes[:] = 0

        if root.is_end_state():
            _, winner = root.determine_winner()
            root.outcomes[winner] = 1
        else:
            for move, child in self.children[root].items():
                self._propagate_outcomes(child)
                root.outcomes += child.outcomes

        player = root.determine_previous_player()
        root.win_frac = root.outcomes[player] / sum(root.outcomes)

    def find_policy(self):
        policy_map = dict()

        # modified DFS
        stack = [self.root]
        while len(stack) != 0:
            current = stack.pop()

            if current.is_end_state():
                continue

            best_move = None
            best_score = 0.0
            for move, child_state in self.children[current].items():
                stack.append(child_state)
                # some states will be no-win scenarios
                # so the win_frac will be 0
                if child_state.win_frac >= best_score:
                    best_score = child_state.win_frac
                    best_move = move

            policy_map[current] = best_move

        return policy_map

    def size(self):
        # modified DFS
        stack = [self.root]
        size = 0
        while len(stack) != 0:
            current = stack.pop()
            size += 1
            stack.extend(self.children[current].values())

        return size

class PolicyAgent:
    def __init__(self, state_tree):
        self.state_tree = state_tree

    def make_move(self, board_state: BoardState):
        self.state_tree.reroot_tree(board_state)
        self.state_tree.sample_paths(10)
        self.state_tree.propagate_outcomes()
        policy = self.state_tree.find_policy()
        move = policy[board_state]
        board_state.play_move(move)

class RandomAgent:
    def make_move(self, board_state: BoardState):
        available_moves = board_state.enumerate_moves()
        chosen_move = random.choice(available_moves)
        board_state.play_move(chosen_move)

def play_game(agent1, agent2):
    board_state = BoardState()
    current_agent = agent1
    other_agent = agent2
    while not board_state.is_end_state():
        current_agent.make_move(board_state)
        tmp = current_agent
        current_agent = other_agent
        other_agent = tmp

    return board_state.determine_winner()[1]

In [13]:
winners = []
for i in range(1000):
    agent1 = PolicyAgent(MCTree())
    agent2 = RandomAgent()
    winner = play_game(agent1, agent2)
    winners.append(winner)

Counter(winners)

Counter({0: 562, 1: 370, 2: 68})