# **AI Assignment: Connect 4 with MCTS and ID3**

### Assignment Done by:
- David Ventura Mendes de Sá (UP202303580)
- Samuel José Sousa Ventura da Silva (UP202305647)

## 0. Contents
1. Introduction

2. Connect Four  
    **2.1.** Game Implementation  
    **2.2.** Bitboard vs Matrix  

3. Algorithms    
    **3.1.** Monte Carlo Tree Search (MCTS)   
    **3.2.** Decision Trees (ID3)     
        **3.2.1.** Dataset Generation  

      

4. Algorithms Implementation  
    **4.1.** Libraries   


4. UI Game

6. Results
7. Conclusion

8. References

   

## **1. Introduction** ##

## **2. Connect Four** ##

Connect Four is a two-player game where players take turns dropping discs into a 7x6 grid, aiming to be the first to connect four of their own discs in a row-horizontally, vertically, or diagonally. If the board fills up without a winner, the game is a draw.

### **2.1. Game Implementation** ###

The Bitboard class provides an efficient, compact, and fast representation for the state of a Connect Four game. Instead of using a traditional 2D array, it uses bitboards-integers where each bit represents a cell on the board-to store the positions of each player's pieces. This approach enables rapid move generation, win detection, and easy state copying, which is particularly useful for AI algorithms such as Monte Carlo Tree Search.

Board Encoding:
- The Connect Four board is 7 columns by 6 rows (7x6 = 42 cells).
- Each player’s pieces are stored in a separate 64-bit integer (player1, player2), where each bit corresponds to a cell.
- The mapping from board coordinates to bit positions is column-major:
bit_position = col * 7 + row
- The height array keeps track of how many pieces are in each column, allowing for efficient move generation.


#### **Principal methods:** ####

- `make_move(col)`
    - **Purpose:** Places a piece for the current player in the specified column.
    - **How it works:**
        - Checks if the column is full.
        - Updates the bitboard for the current player using a bitwise OR operation.
        - Increments the column height.
        - Switches the turn to the other player.

- `check_player_win(player)`
    - **Purpose:** Checks if the player has achieved four in a row (win condition).
    - **How it works:**
        - Uses bitwise operations to efficiently check for four consecutive pieces in all directions (vertical, horizontal, and both diagonals).
        - For example, a horizontal win is detected by checking if there are three consecutive bits to the right of a piece using bit shifts and AND operations.
    
- `get_legal_moves()`
    - **Purpose:** Returns a list of columns where a move is possible (i.e., not full).
    - **How it works:**
        - Checks the height array for columns with less than 6 pieces.

- `is_over()`
    - **Purpose:** Determines if the game has ended, either by a win or a draw.
    - **How it works:**
        - Calls check_player_win for both players and checks if all columns are full.

- `matrix()`
    - **Purpose:** Converts the internal bitboard representation to a 2D matrix for visualization or further processing.
    - **How it works:**
        - Iterates through all possible bit positions and assigns values to the matrix based on which player's bitboard contains the bit.

- `__str__()`
    - **Purpose:** Provides a human-readable string representation of the board, useful for debugging and visualization.

**Why did we use bitboards?**  
    - **Speed:** In terms of speed bitwise operations are much faster than iterating over arrays.  
    - **Memory Efficiency:** The entire board state fits in two integers.  
    - **Convenience:** Copying and comparing board states is trivial.  





In [2]:
class Bitboard:
    def __init__(self):
        self.player1 = 0
        self.player2 = 0
        self.height = [0] * 7
        self.current_player = 1

    # 05 12 19 26 33 40 47
    # 04 11 18 25 32 39 46
    # 03 10 17 24 31 38 45
    # 02 09 16 23 30 37 44
    # 01 08 15 22 29 36 43 
    # 00 07 14 21 28 35 42

    def make_move(self, col):
        
        if col == -1: return

        if self.height[col] >= 6:
            return False

        # Get position
        row = self.height[col]
        bit_position = col * 7 + row

        # Update bitboard
        if self.current_player == 1:
            self.player1 |= (1 << bit_position)
        else:
            self.player2 |= (1 << bit_position)

        # Update heightmap
        self.height[col] += 1

        # Switch to other player1
        self.current_player = 3 - self.current_player
        return True

    def check_player_win(self, player):
        # Diagonal \
        if player == 1:
            board = self.player1
        else:
            board = self.player2

        y = board & (board >> 6)
        if (y & (y >> 2 * 6)):
            return True
        
        # Horizontal
        y = board & (board >> 7)
        if (y & (y >> 2 * 7)):
            return True

        # Diagonal /
        y = board & (board >> 8)
        if (y & (y >> 2 * 8)):
            return True

        # Vertical
        y = board & (board >> 1)
        if (y & (y >> 2)):      
            return True
        return False

    def get_legal_moves(self):
        return [col for col in range(7) if self.height[col] < 6]
    
    def is_over(self):
        return self.check_player_win(1) or self.check_player_win(2) or all(h == 6 for h in self.height)

    def copy(self): # returns deep copy of self
        new_bitboard = Bitboard()
        new_bitboard.player1 = self.player1
        new_bitboard.player2 = self.player2
        new_bitboard.height = self.height.copy()
        new_bitboard.current_player = self.current_player
        return new_bitboard

    def matrix(self):

        matrix = [[0] * 7 for _ in range(6)]

        for bit_position in range(48):
            row = bit_position // 7  
            col = bit_position % 7

            # Check if the bit is set in player1's bitboard
            if self.player1 & (1 << bit_position):
                matrix[col][row] = 1
            # Check if the bit is set in player2's bitboard
            elif self.player2 & (1 << bit_position):
                matrix[col][row] = 2

        return matrix

    def __str__(self):
        # Print the matrix in a readable format
        matrix = self.matrix()
        resul = ""
        for row in matrix:
            for cell in row:
                if cell == 0:
                    resul += "- "
                elif cell == 1:
                    resul += "X "
                elif cell == 2:
                    resul += "O "
            resul += "\n"
        return resul


### **2.3 Bitboard vs Matrix** ###

In [3]:
##exemplo de codigo que faça o connect4 com matriz ou array 

## **3. Algorithms Implementation** ##

### **3.1 Monte Carlo Tree Search (MCTS)** ###

MCTS is a heuristic search algorithm that combines random sampling with tree search to make optimal decisions in complex environments. It's particularly effective for games like Connect Four with large branching factors. The algorithm operates in four phases:  
- **Selection:** Traverse the tree using Upper Confidence Bound (UCB) to balance exploration/exploitation.

- **Expansion:** Add a new child node for an unexplored move.

- **Simulation:** Perform random playouts from new nodes to a terminal state.

- **Backpropagation:** Update node statistics with simulation results.

The UCB formula balances known good moves with unexplored possibilities:

$$
UCB = \frac{U}{N} + C*\sqrt{\frac{ln{(Parent_N)}}{N}}
$$

#### **3.1.1. Libraries** ####

In [4]:
from math import sqrt, log
import random

- We import __sqrt__ and __log__ from the math module for mathematical calculations used in the UCB formula, and __random__ for selecting random moves during the search process.

#### **3.1.2. Class Node** ####

In [5]:
class Node:
    __slots__ = ['parent', 'move', 'children', 'wins', 'visits']
        
    def __init__(self, parent, move):
        self.parent = parent  # Node
        self.move = move  # move that led to this state
        self.children = {}  # Nodes
        self.wins = 0
        self.visits = 0

    def ucb_score(self, exploration_weight=5):
        if self.visits == 0:
            return float('inf')

        return (self.wins / self.visits) + exploration_weight * sqrt(log(self.parent.visits) / self.visits)

    def expand(self, bitboard):
        children = {Node(self, move) for move in bitboard.get_legal_moves()}
        self.children = children
        return random.choice(list(children))


The Node class represents a single state in the search tree.

- Memory-efficient with `__slots__`

- Each node tracks its parent, the move that led to this state, its children, and statistics (wins and visits).

- The `ucb_score` method computes the Upper Confidence Bound score for balancing exploration and exploitation, driving to an intelligent node selection.

- The `expand` method generates all possible child nodes from the current state and returns a randomly selected child for simulation.

#### **3.1.3. Class MCTS** ####

In [6]:
class MCTS:

    def __init__(self, iterations):
        self.iterations = iterations

    def select(self, root, state):
        node = root
        while node.children: 
            node = max(node.children, key=lambda c: c.ucb_score())
            state.make_move(node.move)
        return node, state


    def simulate(self, state):
        moves = state.get_legal_moves()
        while moves:
            move = random.choice(moves)
            state.make_move(move)
            if state.is_over():
                break
            moves = state.get_legal_moves()
        if state.check_player_win(1): return 1
        if state.check_player_win(2): return 2
        return 0
        

    def backpropagate(self, winner, node, state):

        reward = 0 if state.current_player == winner else 1

        while node is not None:
            node.visits += 1
            if winner == 0:
                reward = 0
            else:
                node.wins += reward
                reward = 1 - reward
            node = node.parent


    def search(self, bitboard):
        root = Node(None, None)
        root.expand(bitboard);

        for _ in range(self.iterations):

            state = bitboard.copy()

            leaf, state = self.select(root, state)
            
            # only simulate if its not terminal state
            if not state.is_over():
                leaf = leaf.expand(state)
                state.make_move(leaf.move)
            
            winner = self.simulate(state.copy())
            
            self.backpropagate(winner, leaf, state)

        # stats for the display
        arr = [0] * 14
        for child in root.children:
            arr[child.move] = child.visits
            arr[7+child.move] = child.wins
    
        # return the child with MOST VISITS, we don't use winrate here
        return max(root.children, key=lambda c: c.visits).move, arr


The Principal methods of the class `MCTS` are:

- `__init__(self, iterations)` : The constructor only takes a single parameter, the number of iterations the algorithm will run and determines the depth search.


- `select(self, root, state)` : This method implements the __selection__ phase of MCTS
    - Starts at the root node and descends through the tree  
    - At each level, selects the child with the highest UCB score, from the class `Node` 
    - Updates the game state as it descends  
    - Returns the selected leaf node and its corresponding state


- `simulate(self, state)` : This method performs the __simulation__ phase of MCTS
    - Executes a random play from the current state
    - Continues making random moves until the game ends
    - Returns the result: 1 if player 1 wins, 2 if player 2 wins, 0 for a draw


- `backpropagate(self, winner, node, state)` : This method implements the __backpropagation__ phase of MCTS  
    - Updates statistics (visits and wins) on all nodes in the path back to the root
    - Alternates the reward (0/1) to handle zero-sum games
    - If the result was a draw (winner=0), no wins are added


- `search(self, bitboard)` : This is the __main__ method that manage the entire MCTS process
    - Creates a root node and expands it
    - For each iteration:  
        - Copies the current game state
        - Selects a leaf node using UCB
        - If the game isn't over, expands the node and makes a move  
        - Simulates the game to completion  
        - Propagates the results back up the tree
    - Collects statistics for visualization  
    - Returns the move with the most visits (considered the best) and the statistics


The MCTS algorithm is powerful because it doesn't require domain-specific knowledge beyond the game rules, and naturally balances exploration of new moves with exploitation of moves known to be good.





### **3.2 Decision Trees (ID3)** ###

#### **3.2.1. Dataset Generation Libraries** ####

In [None]:
import csv
import random
from tqdm import tqdm
from game import Bitboard
from mcts import MCTS
import multiprocessing as mp
from multiprocessing import Pool
import os

dizer o porque desta libraries

#### **3.2.2. Dataset Generator** ####

In [None]:
def encode_board(board):
    # Positions of unused bits (0-based index)
    unused_bits = {5, 13, 20, 27, 34, 41}
            
    # Initialize the result array (41 positions)
    positions = [0] * 42    
    
    # Current position in the output array (0-based)
    pos_idx = 0
    
    # Iterate through all 48 bits (1-based)
    for i in range(48):
        if i in unused_bits:
            continue  # Skip unused bits
            
        p1_bit = (board.player1 >> i) & 1
        p2_bit = (board.player2 >> i) & 1
        if p1_bit:
            positions[pos_idx] = 1
        elif p2_bit:
            positions[pos_idx] = -1
  
        pos_idx += 1
    
    return positions

def worker_process(args):
    """Worker function that generates a single game sample"""
    mcts_iterations, min_random_moves, max_random_moves, process_id = args
    # Create MCTS instance per process to avoid sharing issues
    mcts = MCTS(mcts_iterations)
    
    board = Bitboard()
    # Make 0 to k random moves
     
    for _ in range(random.randint(min_random_moves, max_random_moves)):
        legal = board.get_legal_moves()
        if not legal or board.is_over():
            return (None, True)
        board.make_move(random.choice(legal))
    
    if board.is_over(): return (None, True)

    encoded = encode_board(board)
    move, _ = mcts.search(board.copy())
    return (encoded + [move], False)

def generate_dataset_parallel(n_games=1000, batch_size=500, mcts_iterations=10000, min_random_moves=8, max_random_moves=25):
    # Determine number of CPU cores to use e
    num_processes = mp.cpu_count()
    completed = 0
    terminated_early = 0

    # Create a pool of worker processes e
    with Pool(processes=num_processes) as pool:
        # Prepare arguments for each worker
        worker_args = [(mcts_iterations, min_random_moves, max_random_moves, i) for i in range(n_games)]
        
        # Open the output file  
        with open("dataset.csv", 'a', newline='') as f:
            writer = csv.writer(f)  
            if f.tell() == 0:  # File is empty
                writer.writerow([f"pos_{i}" for i in range(42)] + ["move"])
            
            buffer = []
            buffer_size = 0
            
            # Use imap_unordered for faster results as they become available
            with tqdm(total=n_games, desc="Generating Examples") as pbar:
                for result in pool.imap_unordered(worker_process, worker_args):
                    if result[1]: terminated_early += 1
                    elif result[0] is not None:

                        completed += 1

                        buffer.append(result[0])
                        buffer_size += 1
                        # Write in batches
                        if buffer_size >= batch_size:
                            writer.writerows(buffer)
                            buffer.clear()
                            buffer_size = 0


                    pbar.update(1)

                # Write remaining samples
                if buffer:
                    writer.writerows(buffer)

    print(f"Completed games: {completed} ({completed/n_games:.1%})")
    print(f"Terminated Early: {terminated_early} ({terminated_early/n_games:.1%})")

if __name__ == '__main__':
    # Required for Windows multiprocessing support
    generate_dataset_parallel(n_games=180000, batch_size=1000, min_random_moves=10, max_random_moves=30)


Explicar a geraçao do dataset e como a tornamos mais eficiente

#### **3.2.3. ID3 Libraries** ####

In [None]:
import numpy as np

__NumPy__ is the only external library used in this implementation. It provides essential functionality for numerical operations, including:

- Efficient array manipulation and mathematical functions
- Statistical tools for counting and probability calculations
- Mathematical operations like logarithms used in entropy calculations


#### **3.2.4. Auxiliary Functions** ####

In [None]:
def entropy(y):
    counts = np.bincount(y)
    probs = counts / len(y)
    return -np.sum([p * np.log2(p) for p in probs if p > 0])

def gini(y):
    counts = np.bincount(y)
    ps = counts / len(y)
    return 1 - np.sum(ps ** 2)

def information_gain(parent, left, right):
    weight_l = len(left) / len(parent)
    weight_r = len(right) / len(parent)
    return entropy(parent) - (weight_l * entropy(left) + weight_r * entropy(right))

def gini_gain(parent, left, right):
    weight_l = len(left) / len(parent)
    weight_r = len(right) / len(parent)
    return gini(parent) - (weight_l * gini(left) + weight_r * gini(right))  
 

The following function calculate:

- `entropy(y)`  
    - Calculates Shannon entropy, a measure of impurity in a set of labels
    - Uses np.bincount() to count occurrences of each class
    - Computes probability distribution by dividing counts by total number of samples
    - Returns the negative sum of p * log2(p) for each probability p
    - Higher entropy indicates more mixed classes (more impurity)


- `giny(y)`  
    - Calculates Gini impurity, an alternative measure of node impurity
    - Like entropy, starts by counting class occurrences and calculating probabilities
    - Returns 1 minus the sum of squared probabilities
    - A value of 0 indicates a pure node (all samples belong to the same class)
    - A higher value indicates more impurity


- `information_gain(parent, left, right)`
    - Measures the reduction in entropy achieved by splitting a parent node
    - Calculates weighted average entropy of child nodes (left and right)
    - Subtracts this from the parent entropy to determine information gain
    - Higher values indicate more informative splits


- `gini_gain(parent, left, right)`
    - Similar to information gain but uses Gini impurity instead of entropy
    - Measures the reduction in Gini impurity achieved by a split
    - Higher values indicate more effective splits



#### **3.2.5. ID3 Class Node** ####

In [None]:
class Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature   
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value
        

The Node class represents elements in the decision tree structure:
- __Internal Nodes (decision nodes)__
    - `feature`: Index of the feature to test
    - `threshold`: Value to compare against
    - `left`: Child node for samples where feature ≤ threshold
    - `right`: Child node for samples where feature > threshold
    - `value`: None
    
- __Leaf nodes (prediction nodes)__
    - `feature`, `threshold`, `left`, `right`:None
    - `value`: The predicted class

This flexible design allows the same class to represent both decision points and final predictions.


#### **3.2.6. Class ID3DecisionTree** ####

In [None]:
class ID3DecisionTree:

    def __init__(self, max_depth=3, criterion='gini'):
        self.criterion = gini_gain if criterion == 'gini' else information_gain
        self.max_depth = max_depth
        self.root = None


    def fit(self, X, y):
        print(self.criterion.__name__)
        self.root = self.grow(X, y)
    
    def grow(self, X, y, depth=0):
        if len(set(y)) == 1 or depth >= self.max_depth: 
            return Node(value=np.bincount(y).argmax())  

        best_gain = -1
        best_feature, best_threshold = None, None
        best_left_mask, best_right_mask = None, None

        for feature in range(X.shape[1]):
            thresholds = np.unique(X[:, feature])
            if len(thresholds) > 10:
                thresholds = np.quantile(X[:, feature], [0.25, 0.5, 0.75])
            for t in thresholds:
                left_mask = X[:, feature] <= t   
                right_mask = ~left_mask

                if left_mask.sum() == 0 or right_mask.sum() == 0:
                    continue

                gain = self.criterion(y, y[left_mask], y[right_mask])

                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_threshold = t
                    best_left_mask = left_mask
                    best_right_mask = right_mask

        if best_gain < 1e-6 or best_feature is None:  
            return Node(value=np.bincount(y).argmax())

        left = self.grow(X[best_left_mask], y[best_left_mask], depth + 1)
        right = self.grow(X[best_right_mask], y[best_right_mask], depth + 1)
        return Node(feature=best_feature, threshold=best_threshold, left=left, right=right)

    def predict(self, X):
        return np.array([self._predict(inputs, self.root) for inputs in X])

    def _predict(self, inputs, node):
        if node.value is not None:
            return node.value
        if inputs[node.feature] <= node.threshold:
            return self._predict(inputs, node.left)
        else:
            return self._predict(inputs, node.right)
        

por explicar

## **4. Connect Four Algorithms Implementation** ##

### **4.1. Libraries** ###


In [None]:
import game
import mcts
from pygame import gfxdraw
import pygame
from os import environ
import time
environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1'

### **4.2. Nao sei** ###


## **5. User Interface Game** ##


### **5.1. Human vs Human** ###


### **5.2. Human vs MCTS** ###


### **5.3. Human vs ID3** ###


### **5.4. MCTS vs ID3** ###


## **6. Results** ##


## **7. Conclusion** ##
