In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
import time
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class TicTacToe:
    def __init__(self):
        # Représentation sous forme de liste de 9 cases (0: vide, 1: X, -1: O)
        self.board = [0] * 9  
        self.current_player = 1  # 1 pour X, -1 pour O

    def clone(self):
        env = TicTacToe()
        env.board = self.board[:]
        env.current_player = self.current_player
        return env

    def get_legal_moves(self):
        return [i for i, cell in enumerate(self.board) if cell == 0]

    def play_move(self, move):
        self.board[move] = self.current_player
        self.current_player *= -1  # change de joueur

    def is_terminal(self):
        return self.check_winner() is not None or not any(cell == 0 for cell in self.board)

    def check_winner(self):
        wins = [
            (0,1,2), (3,4,5), (6,7,8),
            (0,3,6), (1,4,7), (2,5,8),
            (0,4,8), (2,4,6)
        ]
        for (i,j,k) in wins:
            if self.board[i] != 0 and self.board[i] == self.board[j] == self.board[k]:
                return self.board[i]
        return None

    def get_result(self):
        winner = self.check_winner()
        if winner is not None:
            return winner
        return 0

    def to_tensor(self):
        return torch.tensor(self.board, dtype=torch.float32)

    def __str__(self):
        symbols = {1: 'X', -1: 'O', 0: ' '}
        lines = ["|".join([symbols[self.board[i*3+j]] for j in range(3)]) for i in range(3)]
        return "\n-----\n".join(lines)

In [3]:
class TicTacToeNet(nn.Module):
    def __init__(self):
        super(TicTacToeNet, self).__init__()
        self.fc1 = nn.Linear(9, 64)
        self.fc2 = nn.Linear(64, 64)
        self.policy_head = nn.Linear(64, 9)
        self.value_head = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy_logits = self.policy_head(x)
        policy = torch.softmax(policy_logits, dim=1)
        value = torch.tanh(self.value_head(x))
        return policy, value

In [4]:
class MCTSNode:
    def __init__(self, env):
        self.env = env.clone()
        self.children = {}  # coup -> MCTSNode
        self.N = {}         # nombre de visites N(s,a)
        self.W = {}         # somme des récompenses W(s,a)
        self.P = {}         # probabilité a priori P(s,a)

    def is_leaf(self):
        return len(self.children) == 0 and not self.env.is_terminal()

class MCTS:
    def __init__(self, net, c_puct=1.0):
        self.net = net
        self.c_puct = c_puct
        self.nodes = {}  # clé: string(board)+current_player -> MCTSNode

    def _state_key(self, env):
        return ''.join(map(str, env.board)) + str(env.current_player)

    def get_node(self, env):
        key = self._state_key(env)
        if key not in self.nodes:
            self.nodes[key] = MCTSNode(env)
        return self.nodes[key]

    def policy(self, env):
        x = env.to_tensor().unsqueeze(0)
        with torch.no_grad():
            p, v = self.net(x)
        p = p.squeeze(0).numpy()
        moves = env.get_legal_moves()
        policy = {m: p[m] for m in moves}
        return policy, v.item()

    def puct_select_move(self, node):
        best_move = None
        best_value = -float('inf')
        total_N = sum(node.N.get(m, 0) for m in node.env.get_legal_moves())
        for m in node.env.get_legal_moves():
            Nsa = node.N.get(m, 0)
            Wsa = node.W.get(m, 0)
            Qsa = Wsa / (Nsa + 1e-8)
            prior = node.P.get(m, 0)
            U = self.c_puct * prior * math.sqrt(total_N + 1e-8) / (1 + Nsa)
            value = Qsa + U
            if value > best_value:
                best_value = value
                best_move = m
        return best_move

    def expand(self, node):
        policy, value = self.policy(node.env)
        for move, prob in policy.items():
            node.P[move] = prob
        for move in node.env.get_legal_moves():
            env_copy = node.env.clone()
            env_copy.play_move(move)
            child = self.get_node(env_copy)
            node.children[move] = child
        return value

    def simulate(self, env):
        if env.is_terminal():
            return env.get_result()
        while not env.is_terminal():
            moves = env.get_legal_moves()
            move = random.choice(moves)
            env.play_move(move)
        return env.get_result()

    def backpropagate(self, path, value):
        for node, move, player in reversed(path):
            val = value if player == 1 else -value
            node.N[move] = node.N.get(move, 0) + 1
            node.W[move] = node.W.get(move, 0) + val

    def search(self, env, n_simulations=100):
        root = self.get_node(env)
        legal_moves = env.get_legal_moves()
        dir_noise = np.random.dirichlet([0.3] * len(legal_moves))
        for i, m in enumerate(legal_moves):
            root.P[m] = 0.75 * root.P.get(m, 1/len(legal_moves)) + 0.25 * dir_noise[i]

        for _ in range(n_simulations):
            node = root
            env_clone = env.clone()
            path = []
            while not env_clone.is_terminal() and not node.is_leaf():
                move = self.puct_select_move(node)
                path.append((node, move, env_clone.current_player))
                env_clone.play_move(move)
                if move not in node.children:
                    break
                node = node.children[move]
            if not env_clone.is_terminal() and node.is_leaf():
                value = self.expand(node)
            else:
                value = env_clone.get_result()
            self.backpropagate(path, value)
        visits = {m: root.N.get(m, 0) for m in legal_moves}
        return visits


In [5]:
def self_play_episode_with_visualization(net, n_mcts_simulations=100, delay=0.5):
    game = TicTacToe()
    states = []
    mcts_probs = []
    current_players = []
    mcts = MCTS(net, c_puct=1.0)

    print("Début de la partie :")
    print(game)
    print("\n========================\n")
    time.sleep(delay)
    
    move_counter = 1
    while not game.is_terminal():
        visits = mcts.search(game, n_simulations=n_mcts_simulations)
        total = sum(visits.values())
        probs = {m: v/total for m, v in visits.items()}
        states.append(deepcopy(game.board))
        mcts_probs.append(probs)
        current_players.append(game.current_player)
        best_move = max(probs, key=probs.get)
        game.play_move(best_move)
        
        # Afficher le plateau après chaque coup
        print(f"Coup {move_counter} joué par {'X' if game.current_player == -1 else 'O'} :")
        print(game)
        print("\n------------------------\n")
        move_counter += 1
        time.sleep(delay)
    
    result = game.get_result()
    print("Partie terminée. Résultat :", result)
    return states, mcts_probs, current_players, result

# Exécution d'un épisode avec visualisation
states, mcts_probs, players, result = self_play_episode_with_visualization(TicTacToeNet(), n_mcts_simulations=100, delay=0.5)


Début de la partie :
 | | 
-----
 | | 
-----
 | | 


Coup 1 joué par X :
 | | 
-----
 |X| 
-----
 | | 

------------------------

Coup 2 joué par O :
 | | 
-----
 |X|O
-----
 | | 

------------------------

Coup 3 joué par X :
 | | 
-----
 |X|O
-----
 |X| 

------------------------

Coup 4 joué par O :
 |O| 
-----
 |X|O
-----
 |X| 

------------------------

Coup 5 joué par X :
 |O| 
-----
 |X|O
-----
 |X|X

------------------------

Coup 6 joué par O :
 |O| 
-----
 |X|O
-----
O|X|X

------------------------

Coup 7 joué par X :
X|O| 
-----
 |X|O
-----
O|X|X

------------------------

Partie terminée. Résultat : 1


In [6]:
def train_net(net, optimizer, states, mcts_probs, results, epochs=1):
    net.train()
    loss_fn_policy = nn.CrossEntropyLoss()
    loss_fn_value = nn.MSELoss()
    for epoch in range(epochs):
        total_loss = 0
        for board, mcts_prob, result in zip(states, mcts_probs, results):
            board_tensor = torch.tensor(board, dtype=torch.float32).unsqueeze(0)
            target_move = torch.tensor([max(mcts_prob, key=mcts_prob.get)], dtype=torch.long)
            target_value = torch.tensor([[result]], dtype=torch.float32)
            optimizer.zero_grad()
            policy, value = net(board_tensor)
            loss_policy = loss_fn_policy(policy, target_move)
            loss_value = loss_fn_value(value, target_value)
            loss = loss_policy + loss_value
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(states):.4f}")

# Exemple d'entraînement
net = TicTacToeNet()
optimizer = optim.Adam(net.parameters(), lr=0.01)
episodes = 50
all_states, all_probs, all_results = [], [], []
for i in range(episodes):
    states, mcts_probs, players, result = self_play_episode_with_visualization(net, n_mcts_simulations=100, delay=0.2)
    adjusted_results = [result if p == 1 else -result for p in players]
    all_states.extend(states)
    all_probs.extend(mcts_probs)
    all_results.extend(adjusted_results)
    print(f"Episode {i+1} terminée, résultat: {result}")

train_net(net, optimizer, all_states, all_probs, all_results, epochs=5)

Début de la partie :
 | | 
-----
 | | 
-----
 | | 


Coup 1 joué par X :
 | | 
-----
X| | 
-----
 | | 

------------------------

Coup 2 joué par O :
 | | 
-----
X| | 
-----
 |O| 

------------------------

Coup 3 joué par X :
 | | 
-----
X|X| 
-----
 |O| 

------------------------

Coup 4 joué par O :
 | | 
-----
X|X|O
-----
 |O| 

------------------------

Coup 5 joué par X :
 | | 
-----
X|X|O
-----
X|O| 

------------------------

Coup 6 joué par O :
O| | 
-----
X|X|O
-----
X|O| 

------------------------

Coup 7 joué par X :
O| |X
-----
X|X|O
-----
X|O| 

------------------------

Partie terminée. Résultat : 1
Episode 1 terminée, résultat: 1
Début de la partie :
 | | 
-----
 | | 
-----
 | | 


Coup 1 joué par X :
 | | 
-----
X| | 
-----
 | | 

------------------------

Coup 2 joué par O :
 | | 
-----
X|O| 
-----
 | | 

------------------------

Coup 3 joué par X :
 | | 
-----
X|O| 
-----
X| | 

------------------------

Coup 4 joué par O :
O| | 
-----
X|O| 
-----
X| | 

-----------

In [7]:
# Exemple d'utilisation : jouer un coup dans un état initial avec visualisation
game = TicTacToe()
print("État initial:")
print(game)
print("\n========================\n")
mcts = MCTS(net, c_puct=1.0)
visits = mcts.search(game, n_simulations=1000)
print("Distribution des visites:", visits)
best_move = max(visits, key=visits.get)
print("Meilleur coup sélectionné par MCTS:", best_move)
game.play_move(best_move)
print("\nPlateau après le coup sélectionné:")
print(game)

État initial:
 | | 
-----
 | | 
-----
 | | 


Distribution des visites: {0: 0, 1: 0, 2: 0, 3: 998, 4: 0, 5: 0, 6: 1, 7: 0, 8: 0}
Meilleur coup sélectionné par MCTS: 3

Plateau après le coup sélectionné:
 | | 
-----
X| | 
-----
 | | 
