In [None]:
from cathedral_rl import cathedral_v0  
import random


import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from collections import deque
import matplotlib.pyplot as plt
from tqdm import tqdm 
import math
import copy
from cathedral_rl.game.board import Board 

from MCTS.MCTS_parallel import MCTSNode, initial_state, is_terminal, mcts_search, get_policy_from_mcts, get_legal_moves, next_state, evaluate_terminal

In [None]:
def evaluate_model(network, device, n_simulations, c_puct, temperature, n_actions, board_size, n_eval_games=10):
    # -----------------------------
    # Random
    # -----------------------------
    wins_random = 0
    draws_random = 0
    losses_random = 0

    for game in range(n_eval_games):
        state = initial_state(board_size=board_size)
        current_player = 1  # Le modèle (joueur 1) commence
        while not is_terminal(state):
            if current_player == 1:
                root = MCTSNode(state)
                mcts_search(root, network, device, n_simulations, c_puct, n_actions)
                pi = get_policy_from_mcts(root, n_actions, temperature)
                action = np.random.choice(n_actions, p=pi)
            else:
                legal_moves = get_legal_moves(state)
                action = random.choice(legal_moves)
            state = next_state(state, action)
            current_player = -current_player

        outcome = evaluate_terminal(state)
        
        if outcome > 0:
            wins_random += 1
        elif outcome == 0:
            draws_random += 1
        else:
            losses_random += 1

    win_rate_random = wins_random / n_eval_games
    print(f"VS random:  Victoires: {wins_random}, Nuls: {draws_random}, Défaites: {losses_random}, Taux de victoire: {win_rate_random * 100:.2f}%")

    # -----------------------------
    # Self-play 
    # -----------------------------
    wins_self = 0
    draws_self = 0
    losses_self = 0

    for game in range(n_eval_games):
        state = initial_state(board_size=board_size)
        current_player = 1  # On évalue le résultat pour le joueur 1
        while not is_terminal(state):
            root = MCTSNode(state)
            mcts_search(root, network, device, n_simulations, c_puct, n_actions)
            pi = get_policy_from_mcts(root, n_actions, temperature)
            action = np.random.choice(n_actions, p=pi)
            state = next_state(state, action)
            current_player = -current_player

        outcome = evaluate_terminal(state)
        if outcome > 0:
            wins_self += 1
        elif outcome == 0:
            draws_self += 1
        else:
            losses_self += 1

    win_rate_self = wins_self / n_eval_games
    print(f"  Victory (player 1): {wins_self}, Draws: {draws_self}, Defeat: {losses_self}")
    print(f"  Winrate: {win_rate_self * 100:.2f}%")
