# Monte Carlo Algorithms in Connect4

In this notebook, we will explore Monte Carlo Tree Search in the classic Connect4 game! If you are unfamiliar with the game, you can find the [rules](https://www.hasbro.com/common/instruct/ConnectFour.PDF) here and [watch](https://www.youtube.com/watch?v=ylZBRUJi3UQ) a live playing of the game here.

![picture](https://miro.medium.com/v2/resize:fit:2000/1*ZhiICBnWZLN4Mz93xozWkA.png)

## Game Model

First, we will construct our game state being the Connect4 board! Some key features we want to record in our board are:

* Current turn (P1 or P2)
* Terminal state
* All the positions of the pieces
  * Differentiating P1 and P2 piece
* Rewards of state!
  * (+1) for Player 1 win
  * (-1) for Player 2 win
  * (0) for tie! or non-terminal

One implementation feature to note is that we keep track of each columns height. This makes it easier for us to resolve where a new piece will go into a column as well as whether the column is full. The time complexity goes from O(Board_Height) to O(1).

Another item to note is that win conditions can only change around the most recently played piece. That is, a Connect4 state will only go terminal if the most recently played piece directly caused it to go terminal by a win/tie condition. Therefore, we are able to exploit this fact and whenever we want to check whether a player has won, we only need to look around the most recently played piece in all 4 directions (up-down, left-right, both diagonals). The time complexity goes from O(Board_Height * Board_Width) to O(1).

These optimizations help the game run faster and thus allow more time for the Monte Carlo Tree Search to learn later on! This is especially important if you decide to make the board much bigger than a simple 7x6.

In [2]:
import copy

class Board:
    """
    Connect 4 Game Board
    """

    P1symbol = "X"
    P2symbol = "O"
    WinLength = 4
    TerminalTurn = 2

    def __init__(self, width, height) -> None:
        self.width = width
        self.height = height
        self.board = [["+" for _ in range(width)] for _ in range(height)]
        self.column_heights = [0 for _ in range(width)]

        # 0 = P1, 1 = P2, 2 = Terminal
        self.turn = 0
        self.reward = 0

    def __str__(self) -> str:
        """
        Print the board
        """
        board = [" ".join([str(x) for x in row]) for row in reversed(self.board)]
        board.append("-" * (self.width*2-1))
        board.append(" ".join([str(x) for x in range(self.width)]))
        board.append("-" * (self.width*2-1))
        return "\n".join(board)

    def play(self, col) -> bool:
        """
        Play a piece in the column
        """
        height = self.column_heights[col]
        piece = Board.P1symbol if self.turn == 0 else Board.P2symbol

        # Out of bounds
        if col < 0 or col >= self.width:
            return False

        # Column is full
        if height >= self.height:
            return False

        # Game already over!
        if self.is_terminal():
            return False

        winner = self.check_win(col)

        # Game over! player won --> terminal
        if winner:
            points = 1 if self.turn == 0 else -1
            self.set_terminal(points)

        # Play turn!
        self.board[height][col] = piece
        self.column_heights[col] += 1
        self.update_turn()

        # Game over! board is full --> terminal
        if len(self.get_actions()) == 0:
            self.set_terminal(0)

        return True

    def set_terminal(self, reward) -> None:
        """
        Set the board to terminal
        """
        self.turn = Board.TerminalTurn
        self.reward = reward


    def get_turn(self) -> str:
        """
        Return the current turn
        """
        turn_map = {
            0 : Board.P1symbol,
            1 : Board.P2symbol,
            Board.TerminalTurn : "Terminal"
        }
        return turn_map[self.turn]

    def actor(self) -> int:
        """
        Return the current actor
        """
        return self.turn

    def get_actions(self) -> list:
        """
        Return a list of valid moves
        """
        return [col for col in range(self.width) if self.column_heights[col] < self.height]

    def successor(self, col):
        """
        Return a new board with the move played
        """
        new_board = copy.deepcopy(self)
        new_board.play(col)
        return new_board

    def is_terminal(self) -> bool:
        """
        Check if the game is over
        """
        return self.turn == 2 or len(self.get_actions()) == 0

    def update_turn(self) -> None:
        """
        Update the turn
        """
        if self.turn != Board.TerminalTurn:
            self.turn = (self.turn + 1) % 2

    def payoff(self) -> int:
        """
        Return the reward
        """
        return self.reward

    def check_win(self, col) -> bool:
        """
        Check if adding a piece to the column will win the game
        """

        direction_groups = [
            [(1,0), (-1,0)],    # Horizontal -
            [(0,1), (0,-1)],    # Vertical   |
            [(1,1), (-1,-1)],   # Diagonal   /
            [(1,-1), (-1,1)]    # Diagonal   \
        ]

        piece = Board.P1symbol if self.turn == 0 else Board.P2symbol
        row = self.column_heights[col]

        # Cannot win if the column is full
        if row >= self.height:
            return False

        for direction_group in direction_groups:

            # Count of number of continuous pieces in a line

            count = 1
            for dx, dy in direction_group:
                for i in range(1, 4):
                    x = col + i * dx
                    y = row + i * dy
                    if x < 0 or x >= self.width or y < 0 or y >= self.height:
                        break
                    if self.board[y][x] == piece:
                        count += 1
                    else:
                        break

            if count >= Board.WinLength:
                return True

        return False

    def check_opposing_win(self, col) -> bool:
        """
        Check if adding a piece to the column will cause the opponent to win
        """
        self.update_turn()
        result = self.check_win(col)
        self.update_turn()
        return result

## Agents

Next, we will define the standard agents for this game which will include:

1. **Human**: Allows people to interact in the CLI to play against any of the agents!

*Note: humans can not play on a default collab account, you must download the python files and run in your own terminal to play!*

2. **Random**: As you may guess, the random agent randomly places pieces into columns until the game is over. They pay no regard to the relative position of pieces.

3. **RandomGreedy**: Similar to the last agent, this one will play mostly randomly until there is a critical condition. There are 3 key conditions checked before each placement in the following priority: (1) Will I win if I place a piece in this column. (2) Will the opponent win if I don't place a piece in this column. (3) Will the opponent win if I do place a piece in this column.

4. **MonteCarlo**: This agent will use Monte Carlo methods of simulated numerous playouts of the game to figure out what move results in the best rewards for the agent!

In [3]:
import random as rand
from abc import ABC, abstractmethod

class InvalidMoveException(Exception):

    def __init__(self, move):
        self.move = move

class Agent(ABC):

    @abstractmethod
    def get_move(self, board: Board) -> int:
        pass

class HumanAgent(Agent):

    def get_move(self, board: Board) -> int:

        valid_moves = board.get_actions()

        while True:
            try:
                move = int(input("Enter a move: "))
                if move in valid_moves:
                    return move
                else:
                    raise InvalidMoveException(move)
            except InvalidMoveException:
                print("Invalid move! Must be within", valid_moves)
            except ValueError:
                print("Invalid move! Must be an integer.")

class RandomAgent(Agent):

    def get_move(self, board: Board) -> int:
        moves = board.get_actions()
        return rand.choice(moves)

class RandomGreedyAgent(Agent):

    def get_move(self, board: Board) -> int:

        moves = board.get_actions()

        # Prioritize important moves
        for move in moves:
            # Win the game!
            if board.check_win(move):
                return move
            # Block the opponent from winning!
            if board.check_opposing_win(move):
                return move

        # Remove moves that lead to a loss the next turn!
        good_moves = []
        for move in moves:
            new_board = board.successor(move)
            opponent_win = new_board.check_win(move)
            if not opponent_win:
                good_moves.append(move)

        if good_moves:
            return rand.choice(good_moves)
        else:
            return rand.choice(moves)


## Monte Carlo Tree Search

Monte Carlo Tree Search (MCTS) is a decision-making algorithm widely used in artificial intelligence, especially for games like Go and Chess. It strikes a balance between exploring new actions and exploiting known strategies. Here's how it works:

**Selection**: Starting from the root node (the current state), the algorithm selects the most promising child node based on a specific criterion, such as the Upper Confidence Bound applied to Trees (UCT). This balances the exploration of less visited nodes and the exploitation of nodes with a high win rate. The formula is as follows:

![picture](https://i.stack.imgur.com/lFPTK.png)

**Expansion**: Upon reaching a leaf node (a state where the outcome is not yet known), the algorithm expands the tree by adding one or more child nodes, representing possible future states.

**Simulation**: From the newly added nodes, the algorithm simulates random play-outs or games to the end, using a default or random policy to make decisions. This step estimates the value of the new node.

**Backpropagation**: Finally, the results of these simulations are propagated back up the tree. Each node is updated with the new information, such as the win/loss ratio, helping to refine the decision-making process for future iterations.

MCTS iteratively runs through these steps, building a tree of possibilities, until a termination condition is met such as a time limit or iteration count. This method allows AI to handle games with vast numbers of possible moves by focusing on the most promising strategies.

![picture](https://www.researchgate.net/publication/320742905/figure/fig1/AS:631642972504115@1527606828915/Diagram-representing-the-4-steps-of-MCTS-In-the-first-two-steps-the-tree-is-traversed.png)

In [4]:
import time
import math

# Note: Connect4 is set up as a zero-sum game where:
# P1 is trying to maximize the score (+1)
# P2 is trying to minimize the score (-1)

class MonteCarloNode:

    def __init__(self, state : Board, parent=None, parent_action=None):

        self.state = state
        self.parent = parent
        self.parent_action = parent_action

        self.total_visits = 0
        self.total_rewards = 0

        self.children = []

        if self.state.is_terminal():
            self.missing_child_actions = []
        else:
            self.missing_child_actions = self.state.get_actions()

    def is_fully_expanded(self):
        """
        Determines if the node is fully expanded.
        """
        return len(self.missing_child_actions) == 0

    def get_average_reward(self):
        """
        Returns the average reward of a node.
        """
        if self.total_visits == 0:
            return 0
        return self.total_rewards / self.total_visits

    def get_best_average_child(self):
        """
        Returns the child node with the best average reward.
        """

        if self.state.actor() == 0:
            node = max(self.children, key = lambda x: x.get_average_reward())
        else:
            node = min(self.children, key = lambda x: x.get_average_reward())
        return node

    def get_ucb_value(self, parent_total_visits, parent_actor):
        """
        Returns the UCB value of the node using parents visits and actor
        """

        average_reward = self.get_average_reward()
        exploration_term = math.sqrt(2 * math.log(parent_total_visits) / self.total_visits)
        if parent_actor == 0:
            value = average_reward + exploration_term
        else:
            # Note: Player 2 priotizes negative rewards
            # Therefore we must subtract the exploration term to encourage exploring!
            value = average_reward - exploration_term
        return value

    def get_best_ucb_child(self):
        """
        Returns the child node with the best UCB value.
        """

        if self.state.actor() == 0:
            node = max(self.children, key = lambda x: x.get_ucb_value(self.total_visits, self.state.actor()))
        else:
            node = min(self.children, key = lambda x: x.get_ucb_value(self.total_visits, self.state.actor()))
        return node

    def expand(self):
        """
        Expands the node by adding a new child node from an unexplored action.
        """
        action = self.missing_child_actions.pop()
        next_state = self.state.successor(action)
        child_node = MonteCarloNode(next_state, parent=self, parent_action=action)
        self.children.append(child_node)
        return child_node

    def find_leaf_node(self):
        """
        Returns the leaf node of the tree using UCB values. Expands node if necessary!
        """

        current_node = self
        while not current_node.state.is_terminal():

            # If node is within the tree! with missing children
            if not current_node.is_fully_expanded():
                return current_node.expand()
            else:
                # Traverse down the tree!
                current_node = current_node.get_best_ucb_child()

        # The MCTS leaf node belongs to a terminal state!
        return current_node

    def simulate(self):
        """
        Returns the terminal value of the node by randomly simulating game.
        """

        state = self.state
        while not state.is_terminal():
            state = state.successor(rand.choice(state.get_actions()))
        return state.payoff()

    def update_rewards(self, reward):
        """
        Updates the total reward and total visits of the node and all its parents.
        """

        parent_node = self
        while parent_node is not None:
            parent_node.total_rewards += reward
            parent_node.total_visits += 1
            parent_node = parent_node.parent

    def __str__(self) -> str:

        ucb_value = "Undefined"
        if self.parent:
            ucb_value = self.get_ucb_value(self.parent.total_visits, self.parent.state.actor())

        string_form = f"""
        Node: {self.state}
        Total Visits: {self.total_visits}
        Average Reward: {self.get_average_reward()}
        UCB Value: {ucb_value}
        Parent Action: {self.parent_action}
        Actor: {self.state.actor()}
        Children: {len(self.children)}
        """

        return string_form


def mcts_policy(time_duration):

    def fxn(initial_position: Board):

        start_time = time.time()
        root = MonteCarloNode(initial_position, None)

        while time.time() - start_time < time_duration:

            # Gets the leaf node in the tree (step 1 & 2)
            node = root.find_leaf_node()

            # Determines the random terminal value of the node (step 3)
            reward = node.simulate()

            # Updates the rewards of the parents (step 4)
            node.update_rewards(reward)

        node = root.get_best_average_child()

        return node.parent_action

    return fxn


Now that we have the Monte Carlo Tree Search algorithm built, we can configure the agent to use the function! If you are looking for more resources to learn MCTS, you can watch the following [video](https://www.youtube.com/watch?v=onBYsen2_eA) on YouTube!

In [5]:
class MonteCarloAgent(Agent):

    def __init__(self, duration, random_move_prob=0.05) -> None:
        super().__init__()
        self.random_move_prob = random_move_prob
        self.policy_fxn = mcts_policy(duration)

    def get_move(self, board: Board) -> int:

        # Occasional random move to introduce non-determinism
        if rand.random() < self.random_move_prob:
            return rand.choice(board.get_actions())
        else:
            return self.policy_fxn(board)

## Game Play!

In [6]:
import time

class Game:

    def __init__(self, p1_agent: Agent, p2_agent: Agent, width: int, height: int, turn_sleep: bool = False) -> None:

        self.width = width
        self.height = height
        self.p1_agent = p1_agent
        self.p2_agent = p2_agent
        self.turn_sleep = turn_sleep

    def play(self, extra_print=False) -> int:
        """
        Plays the game until a terminal state is reached:
            Returns 1 if P1 wins, -1 if P2 wins, 0 if tie
        """

        board = Board(self.width, self.height)

        while True:

            if self.turn_sleep:
                time.sleep(1)

            if extra_print:
                print(board)

            # Terminal Tie
            if len(board.get_actions()) == 0:
                if extra_print:
                    print("Tie!")
                return 0

            # Get move from agent
            current_turn = board.get_turn()
            agent = self.p1_agent if current_turn == Board.P1symbol else self.p2_agent
            move = agent.get_move(board)

            # Check if winning move & play!
            winner = board.check_win(move)
            valid = board.play(move)

            if winner:
                if extra_print:
                    print(board)
                    print("-"*10)
                    print(current_turn, "wins!")
                    print("-"*10)
                return 1 if current_turn == Board.P1symbol else -1

            if not valid:
                print("Warning: invalid move:", move)

def simulate_matchups(p1: Agent, p2: Agent, width: int, height: int, count: int) -> tuple:
    game = Game(p1, p2, width, height)
    p1_wins, p2_wins, ties = 0, 0, 0
    for _ in range(count):
        result = game.play()
        if result == 1:
            p1_wins += 1
        elif result == -1:
            p2_wins += 1
        else:
            ties += 1
    return p1_wins, p2_wins, ties

Below, we will simulate matchups between all of the agents!

In [7]:
width = 7
height = 6
count = 100
monte_carlo_time = 0.15
randomness = 0  # Option to handicap the monte carlo agent to randomly play sub-optimal!

human = HumanAgent()
random = RandomAgent()
greedy = RandomGreedyAgent()
monte = MonteCarloAgent(monte_carlo_time, randomness)

setups = [
    ((random, random), ("Random", "Random")),
    ((random, greedy), ("Random", "Greedy")),
    ((greedy, greedy), ("Greedy", "Greedy")),
    ((random, monte), ("Random", "Monte Carlo")),
    ((greedy, monte), ("Greedy", "Monte Carlo")),
    ((monte, monte), ("Monte Carlo", "Monte Carlo")),
]

for config, names in setups:
    p1_wins, p2_wins, ties = simulate_matchups(*config, width, height, count)
    print(f"{names[0]}: {round(p1_wins/count*100, 3)}%",
        f"{names[1]}: {round(p2_wins/count*100, 3)}%",
        f"Ties: {round(ties/count*100, 3)}%",
        sep="\t"
        )

Random: 51.0%	Random: 48.0%	Ties: 1.0%
Random: 2.0%	Greedy: 98.0%	Ties: 0.0%
Greedy: 41.0%	Greedy: 49.0%	Ties: 10.0%
Random: 0.0%	Monte Carlo: 100.0%	Ties: 0.0%
Greedy: 2.0%	Monte Carlo: 97.0%	Ties: 1.0%
Monte Carlo: 56.0%	Monte Carlo: 39.0%	Ties: 5.0%


If you don't have time to run the above code (can take up to 10 mins), the expected output looks something like the following:

\begin{array}{|c|c|} \hline
Random: 57.0\% & Random: 43.0\% & Ties: 0.0\% \\
Random: 6.0\% & Greedy: 94.0\% & Ties: 0.0\% \\
Greedy: 45.0\% & Greedy: 39.0\% & Ties: 16.0\% \\
Random: 0.0\% & Monte Carlo: 100.0\% & Ties: 0.0\% \\
Greedy: 6.0\% & Monte Carlo: 93.0\% & Ties: 1.0\% \\
Monte Carlo: 54.0\% & Monte Carlo: 39.0\% & Ties: 7.0\% \\ \hline
\end{array}

One interesting to note is that Player 1 tends to have an advantage over Player 2 when the same algoritms play against each other! This is due to the fact that the game is actually solved and under optimal play, Player 1 has a 100% chance of winning! Here is a [reference](https://connect4.gamesolver.org/en/) to an AI which can solve the game from any position!

In [8]:
# Greedy vs Greedy with extra iterations
# Goal: Test advantage for P1 in a larger sample size!
p1_wins, p2_wins, ties = simulate_matchups(greedy, greedy, width, height, count*100)
print(f"Greedy: {round(p1_wins/(count*100), 3)}%",
    f"Greedy: {round(p2_wins/(count*100), 3)}%",
    f"Ties: {round(ties/(count*100), 3)}%",
    sep="\t"
    )

Greedy: 0.464%	Greedy: 0.414%	Ties: 0.122%


The above code chunk runs the quicker greedy algorithms for 100 times more iterations than previously to get a better sense of how skewed the game is in favor of the first mover.

In [9]:
# Greedy vs MonteCarlo Agent with 66% less simulation time
# Goal: Test worse performance of MCTS under lower learning!

monte_carlo_time = monte_carlo_time/3
p1_wins, p2_wins, ties = simulate_matchups(greedy, MonteCarloAgent(monte_carlo_time, randomness), width, height, count)
print(f"Greedy: {round(p1_wins/count*100, 3)}%",
    f"MonteCarlo: {round(p2_wins/count*100, 3)}%",
    f"Ties: {round(ties/count*100, 3)}%",
    sep="\t"
    )

Greedy: 12.0%	MonteCarlo: 82.0%	Ties: 6.0%


Another important parameter with Monte Carlo Tree Search is the amount of time allowed for simulations. We may expect that as we give the model more time, it performs better. Similarly, if we reduce time, it will perform worse. The following is an example output from the above code chunk:

\begin{array}{|c|c|} \hline
Greedy: 29.0\% & MonteCarlo: 65.0\%	& Ties: 6.0\% \\ \hline
\end{array}

We observe that when we gave the algorithm 66% less time and the Greedy agent started winning almost 5 times more often! (6% to about 30%). This is consistent with our understanding of MCTS. If you have time to burn, you could give the MCTS Agent much longer time to train and see how well the model performs then!

## Conclusion

By following through this tutorial you will have learned how Monte Carlo Tree Search works by traversing through a tree using exploitation & exploration, then expanding the tree, simulating a playout of the game, and updating the rewards up the path. A big advantage of Monte Carlo Tree Search is that we the flexibility of changing the amount of time it learns depending on the desired use case! There is a tradeoff between time and performance as seen by the examples above. This tutorial also gives guidance on how to design a  state for any type of AI algorithm.

In recent years, Monte Carlo tree searcb has been combined with neural networks and used in multiple board games like Backgammon, Chess, Checkers, Contract Bridge, Go, and Scrabble to efficiently find optimal play!

## Future Extensions

Some interesting extensions to this baseline tutorial include:

1. **Expectimax or Neural Net**: I believe that it is possible for a different algorithm to truly find the optimal move for Connect4 in a reasonable amount of time. It has been concluded that the game is solved with a first mover advantage. That is, under optimal play the first player will always win! The MCTS algorithm was still close to a 50/50 split when playing against itself even though Player 1 should in theory win 100% of the time. Therefore some other algorithms like Expectimax to find the value of each state or possibly training a neural network would be better!

2. **Greedy Algorithm Improvements**: Currently, the greedy algorithm only considers whether its move is going to result in a terminal state in the current or next turn and takes action to maximize its own rewards. Otherwise, it randomly places pieces. One improvement that could be made to the greedy algorithm could involve prioritizing putting pieces where they are connecting to other of your own pieces since that is the win condition of the game after all.

3. **Connect4 Variants**: All of the examples in the tutorial are based on the standard Connect4 game size of 7 width and 6 height. It could be interesting to change some of the rules of the game and see how that influences the effectiveness of MCTS, changes the rate of ties, and changes the bias towards P1 winning. Factors that could be changed include the board size (both width and height), the win condition (Connect4 → Connect5), the turn mechanism (each player has 2 sequential turns instead), etc.

4. **MCTS Improvements**: Improve efficiency of MCTS by only allowing 1-1 mapping of state to MCTS node. Throughout all the possible different playouts of Connect4, there are numerous ways in which we can end up in a particular game state. Currently, these nodes are all separated throughout the tree even though they represent the same underlying game state. One way to improve the efficiency of the tree search, we could combine all these nodes so that they are better able to use the simulation data!
