Skip to content

Commit

Permalink
mcts: Print stat
Browse files Browse the repository at this point in the history
Disabled by default.

Change-Id: I80b4d37959cc17fafc4ca114fd77d3d806b1feff
  • Loading branch information
calcitem committed May 3, 2023
1 parent 22d6ad8 commit 928395b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
82 changes: 72 additions & 10 deletions src/mcts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#include <algorithm>
#include <chrono>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <mutex>
#include <stack>
#include <thread>
#include <vector>

#include "mcts.h"
#include "movepick.h"
#include "option.h"
#include "position.h"
#include "search.h"
#include "types.h"
#include <algorithm>
#include <chrono>
#include <cmath>
#include <stack>
#include <vector>
#include <thread>
#include <mutex>
#include <atomic>
#include "uci.h"

Value MTDF(Position *pos, Sanmill::Stack<Position> &ss, Value firstguess,
Depth depth, Depth originDepth, Move &bestMove);
Expand All @@ -42,6 +45,7 @@ class ThreadSafeNodeVisits
public:
explicit ThreadSafeNodeVisits(size_t initial_size)
: node_visits_(initial_size, 0)
, node_wins_(initial_size, 0)
{ }

void increment_visits(int move_index, uint32_t visits)
Expand All @@ -50,14 +54,27 @@ class ThreadSafeNodeVisits
node_visits_[move_index] += visits;
}

void increment_wins(int move_index, uint32_t wins)
{
std::unique_lock<std::mutex> lock(mutex_);
node_wins_[move_index] += wins;
}

uint32_t visits(int move_index)
{
std::unique_lock<std::mutex> lock(mutex_);
return node_visits_[move_index];
}

uint32_t wins(int move_index)
{
std::unique_lock<std::mutex> lock(mutex_);
return node_wins_[move_index];
}

private:
std::vector<uint32_t> node_visits_;
std::vector<uint32_t> node_wins_;
std::mutex mutex_;
};

Expand Down Expand Up @@ -165,7 +182,8 @@ Node *expand(Node *node)

MovePicker mp(*pos);
mp.next_move(); // Sort moves
//const int moveCount = std::max(mp.move_count() / SEARCH_PRUNING_FACTOR, 1);
// const int moveCount = std::max(mp.move_count() / SEARCH_PRUNING_FACTOR,
// 1);
const int moveCount = mp.move_count();

// Add child nodes for each sorted legal move
Expand Down Expand Up @@ -206,6 +224,45 @@ void backpropagate(Node *node, bool win)
}
}

#ifdef MCTS_PRINT_STAT
void print_stats(const MovePicker &mp, ThreadSafeNodeVisits &shared_visits,
Move bestMove, Value best_value, double win_score)
{
uint32_t total_visits = 0;

// Iterate through all moves and print their statistics
std::cout << "\n";
std::cout << std::setw(5) << "Move" << " " << std::setw(9) << std::fixed << std::setprecision(6)
<< "Win Rate"
<< " " << std::setw(6) << "Wins"
<< " " << std::setw(6) << "Visits"
<< '\n';
std::cout << "----------------------------------------\n";
for (int i = 0; i < mp.move_count(); ++i) {
uint32_t visits = shared_visits.visits(i);
total_visits += visits;
uint32_t wins = shared_visits.wins(i);
double win_rate = static_cast<double>(wins) / visits;

std::string move_str = UCI::move(mp.moves[i].move);

std::cout << std::setw(5) << move_str << " " << std::setw(9)
<< std::fixed << std::setprecision(6) << win_rate << " "
<< std::setw(6) << wins << " " << std::setw(6) << visits
<< '\n';
}
std::cout << "----------------------------------------\n";
std::cout << "Best Move: " << UCI::move(bestMove) << '\n';
std::cout << "Best Move Win Score: " << std::fixed << std::setprecision(6)
<< win_score << "\n";
std::cout << "Best Move Value: " << (int)best_value << '\n';

std::cout << "-----------------------------\n";
std::cout << "Total visits: " << total_visits << '\n';
std::cout << "\n";
}
#endif // MCTS_PRINT_STAT

void mcts_worker(Position *pos, int max_iterations,
ThreadSafeNodeVisits &shared_visits)
{
Expand Down Expand Up @@ -242,6 +299,7 @@ void mcts_worker(Position *pos, int max_iterations,

for (Node *child : root->children) {
shared_visits.increment_visits(child->move_index, child->num_visits);
shared_visits.increment_wins(child->move_index, child->num_wins);
}

delete_tree(root);
Expand Down Expand Up @@ -285,7 +343,7 @@ Value monte_carlo_tree_search(Position *pos, Move &bestMove)
uint32_t max_visits = 0;

for (int i = 0; i < mp.move_count(); ++i) {
uint32_t visits = shared_visits.visits(i);
uint32_t visits = shared_visits.visits(i);
if (visits > max_visits) {
max_visits = visits;
best_move_index = i;
Expand All @@ -298,5 +356,9 @@ Value monte_carlo_tree_search(Position *pos, Move &bestMove)
(max_iterations / num_threads);
Value best_value = static_cast<Value>(win_score * 2.0 - 1.0);

#ifdef MCTS_PRINT_STAT
print_stats(mp, shared_visits, bestMove, best_value, win_score);
#endif // MCTS_PRINT_STAT

return best_value;
}
2 changes: 2 additions & 0 deletions src/mcts.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "position.h"
#include "types.h"

//#define MCTS_PRINT_STAT

// The role of exploration_parameter in Monte Carlo Tree Search (MCTS) is to
// balance exploration and utilization. During the search,
// MCTS needs to choose between nodes that have not been fully explored
Expand Down

0 comments on commit 928395b

Please sign in to comment.