<a href="https://colab.research.google.com/github/elangbijak4/LLM-SLM-Examples/blob/main/Demo_Rev1_Monte_Carlo_Tree_Search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

In [20]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(9, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

model = SimpleNN().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

def evaluate_position(state):
    board = torch.tensor(state.board, dtype=torch.float32).cuda().unsqueeze(0)
    with torch.no_grad():
        value = model(board)
    return value.item()

In [21]:
def mcts(root, iterations):
    for _ in range(iterations):
        node = root
        state = root.state

        # Selection
        while node.fully_expanded() and not state.is_terminal():
            node = node.best_child()
            state = state.play_move(node.state.get_possible_moves()[0])

        # Expansion
        if not state.is_terminal():
            move = random.choice(state.get_possible_moves())
            state = state.play_move(move)
            node = node.add_child(state)

        # Simulation with NN Evaluation
        while not state.is_terminal():
            state = state.play_move(random.choice(state.get_possible_moves()))

        result = evaluate_position(state)

        # Backpropagation
        while node:
            node.update(result)
            node = node.parent

    return root.best_child(c_param=0.0)

# Usage Example:
initial_state = TicTacToe()
root = Node(initial_state)
best_node = mcts(root, iterations=1000)

print("Best move board state:", best_node.state.board)

Best move board state: [0, 0, 0, 0, 0, 0, 0, 1, 0]
