<a href="https://colab.research.google.com/github/elangbijak4/LLM-SLM-Examples/blob/main/Demo_Monte_Carlo_Tree_Search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import math
import random

In [4]:
class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.wins = 0

    def add_child(self, child_state):
        child = Node(child_state, parent=self)
        self.children.append(child)
        return child

    def update(self, result):
        self.visits += 1
        self.wins += result

    def fully_expanded(self):
        return len(self.children) == len(self.state.get_possible_moves())

    def best_child(self, c_param=1.4):
        choices_weights = [
            (c.wins / c.visits) + c_param * math.sqrt((2 * math.log(self.visits) / c.visits))
            for c in self.children
        ]
        return self.children[choices_weights.index(max(choices_weights))]


class TicTacToe:
    def __init__(self):
        self.board = [0] * 9
        self.current_player = 1

    def get_possible_moves(self):
        return [i for i, cell in enumerate(self.board) if cell == 0]

    def play_move(self, move):
        new_state = TicTacToe()
        new_state.board = self.board[:]
        new_state.board[move] = self.current_player
        new_state.current_player = -self.current_player
        return new_state

    def is_terminal(self):
        for (x, y, z) in [(0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)]:
            if self.board[x] == self.board[y] == self.board[z] != 0:
                return True
        return 0 not in self.board

    def get_result(self):
        for (x, y, z) in [(0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)]:
            if self.board[x] == self.board[y] == self.board[z] != 0:
                return self.board[x]
        return 0


def mcts(root, iterations):
    for _ in range(iterations):
        node = root
        state = root.state

        # Selection
        while node.fully_expanded() and not state.is_terminal():
            node = node.best_child()
            state = state.play_move(node.state.get_possible_moves()[0])

        # Expansion
        if not state.is_terminal():
            move = random.choice(state.get_possible_moves())
            state = state.play_move(move)
            node = node.add_child(state)

        # Simulation
        while not state.is_terminal():
            state = state.play_move(random.choice(state.get_possible_moves()))

        # Backpropagation
        result = state.get_result()
        while node:
            node.update(result)
            node = node.parent

    return root.best_child(c_param=0.0)


# Usage Example:
initial_state = TicTacToe()
root = Node(initial_state)
best_node = mcts(root, iterations=1000)

print("Best move board state:", best_node.state.board)

Best move board state: [0, 0, 0, 0, 0, 0, 0, 0, 1]
