# Monte Carlo Tree Search for ConnectX

### Helper functions for ConnectX

In [6]:
# Get new board with given piece dropped
def drop_piece(grid, col, mark, config):
    next_grid = grid.copy()
    for row in range(config.rows-1, -1, -1):
        if next_grid[row][col] == 0:
            break
    next_grid[row][col] = mark
    return next_grid

def legal_moves(grid, config):
    return [col for col in range(config.columns) if grid[0, col] == 0]

In [8]:
def is_terminal_window(window, config):
    return window.count(1) == config.inarow or window.count(2) == config.inarow

def is_terminal_grid(grid, config):
    return check_winner(grid, config) != 0

def check_window_winner(window, config):
    if window.count(1) == config.inarow:
        return 1
    if window.count(2) == config.inarow:
        return 2
    else:
        raise Exception("Winner not found")
    
def check_winner(grid, config):
    # Check for win: horizontal, vertical, or diagonal
    # horizontal 
    for row in range(config.rows):
        for col in range(config.columns-(config.inarow-1)):
            window = list(grid[row, col:col+config.inarow])
            if is_terminal_window(window, config):
                return check_window_winner(window, config)
    # vertical
    for row in range(config.rows-(config.inarow-1)):
        for col in range(config.columns):
            window = list(grid[row:row+config.inarow, col])
            if is_terminal_window(window, config):
                return check_window_winner(window, config)
    # positive diagonal
    for row in range(config.rows-(config.inarow-1)):
        for col in range(config.columns-(config.inarow-1)):
            window = list(grid[range(row, row+config.inarow), range(col, col+config.inarow)])
            if is_terminal_window(window, config):
                return check_window_winner(window, config)
    # negative diagonal
    for row in range(config.inarow-1, config.rows):
        for col in range(config.columns-(config.inarow-1)):
            window = list(grid[range(row, row-config.inarow, -1), range(col, col+config.inarow)])
            if is_terminal_window(window, config):
                return check_window_winner(window, config)
    # Check for draw 
    if list(grid[0, :]).count(0) == 0:
        return 3
    
    return 0

### Monte Carlo Tree Search

In [25]:
import math

class Node:

    UCT_CONSTANT = math.sqrt(2)

    def __init__(self, parent, grid, config, to_move):
        self.parent = parent
        self.grid = grid
        self.config = config
        self.to_move = to_move
        self.n_wins = 0
        self.n_rollouts = 0
        self.children = []

    def is_leaf(self):
        return len(self.children) == 0
    
    def has_parent(self):
        return self.parent is not None
    
    def is_terminal(self):
        return is_terminal_grid(self.grid, self.config)
    
    def is_fully_expanded(self):
        if len(self.children) == 0 :
            return False
        for child in self.children:
            if child.n_rollouts == 0:
                return False
        return True
    
    def uct_value(self):
        if self.n_rollouts == 0:
            raise Exception("Cannot compute UCT value of node with no rollouts")
        if not self.has_parent():
            raise Exception("Cannot compute UCT value of node with no parent")
        exploit_val = self.n_wins / self.n_rollouts
        explore_val = math.log(self.parent.n_rollouts / self.n_rollouts)
        return exploit_val + Node.UCT_CONSTANT * explore_val
    
    def update_stats(self, result):
        self.n_rollouts += 1
        if result == self.to_move:
            self.n_wins += 1
        if result == 3:    # draw
            self.n_wins += .5

In [5]:
import random
import numpy as np

# main function for the Monte Carlo Tree Search
def monte_carlo_tree_search(root, config):
	
	while resources_left(time, computational power):
		leaf = traverse(root) 
		simulation_result = rollout(leaf)
		backpropagate(leaf, simulation_result)
		
	return best_child(root)

# function for node traversal
def traverse(node):
	while node.is_fully_expanded():
		node = best_uct(node)
		
	# in case no children are present / node is terminal 
	return pick_unvisited(node.children) or node

# function for the result of the simulation
def rollout(node):
	while not node.is_terminal():
		node = rollout_policy(node)
	return check_winner(node)

# function for randomly selecting a child node
def rollout_policy(node):
	return random.choice(node.children)

# function for backpropagation
def backpropagate(node, result):
	if not node.has_parent():
		return
	node.update_stats(result)
	backpropagate(node.parent)

def best_uct(node):
	best_node = None
	best_value = 0
	for child in node.children:
		current_value = child.uct_value()
		if current_value >= best_value:
			best_value = current_value
			best_node = child
	return best_node

# best child based on num_rollouts
def best_child(node):
	best_node = None
	most_rollouts = 0
	for child in node.children:
		if child.n_rollouts >= most_rollouts:
			most_rollouts = child.n_rollouts
			best_node = child
	return best_node

SyntaxError: invalid syntax (3254252707.py, line 7)