In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TicTacToeNet(nn.Module):
    def __init__(self):
        super(TicTacToeNet, self).__init__()
        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(256 * 9 * 9, 1024)
        self.fc_policy = nn.Linear(1024, 81)
        self.fc_value = nn.Linear(1024, 1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 256 * 9 * 9)
        x = F.relu(self.fc1(x))
        policy = self.fc_policy(x)
        value = torch.tanh(self.fc_value(x))
        return policy, value


In [4]:
import numpy as np

class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

class MCTS:
    def __init__(self, model, cpuct=1.0):
        self.model = model
        self.cpuct = cpuct
    
    def search(self, state):
        root = MCTSNode(state)
        for _ in range(num_simulations):
            self._simulate(root)
        return self._get_best_move(root)
    
    def _simulate(self, node):
        if node.state.is_terminal():
            return -node.state.reward()
        if not node.children:
            policy, value = self.model(node.state.to_tensor())
            policy = F.softmax(policy, dim=-1).detach().numpy()
            for action in node.state.get_legal_actions():
                next_state = node.state.take_action(action)
                child_node = MCTSNode(next_state, parent=node)
                node.children.append((policy[action], child_node))
            return -value.item()
        else:
            _, best_child = max(node.children, key=lambda child: self._ucb1(node, child))
            value = self._simulate(best_child)
            node.visits += 1
            node.value += value
            return -value
    
    def _ucb1(self, parent, child):
        return child[0] + self.cpuct * np.sqrt(parent.visits) / (1 + child[1].visits)
    
    def _get_best_move(self, root):
        _, best_child = max(root.children, key=lambda child: child[1].visits)
        return best_child.state.last_move


In [5]:
def self_play(model, num_games=100):
    training_data = []
    for _ in range(num_games):
        state = GameState()
        mcts = MCTS(model)
        while not state.is_terminal():
            policy = mcts.search(state)
            move = np.random.choice(len(policy), p=policy)
            state = state.take_action(move)
            training_data.append((state.to_tensor(), policy, state.reward()))
    return training_data


In [None]:
def train_model(model, training_data, epochs=10, batch_size=32):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        for i in range(0, len(training_data), batch_size):
            batch = training_data[i:i+batch_size]
            states, policies, rewards = zip(*batch)
            states = torch.stack(states)
            policies = torch.tensor(policies)
            rewards = torch.tensor(rewards)
            optimizer.zero_grad()
            pred_policies, pred_values = model(states)
            policy_loss = F.cross_entropy(pred_policies, policies)
            value_loss = F.mse_loss(pred_values, rewards)
            loss = policy_loss + value_loss
            loss.backward()
            optimizer.step()
