Skip to content

Commit a0edaa2

Browse files
committed
check tt earlier
1 parent ae582a0 commit a0edaa2

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

src/mcts/search.cc

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,32 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node,
18791879
// We don't need the mutex because other threads will see that N=0 and
18801880
// N-in-flight=1 and will not touch this node.
18811881
const auto& board = history->Last().GetBoard();
1882-
picked_node.legal_moves = board.GenerateLegalMoves();
1882+
1883+
// Check the transposition table first.
1884+
picked_node.hash = history->HashLast(params_.GetCacheHistoryLength() + 1);
1885+
std::shared_ptr<LowNode> tt_low_node;
1886+
auto tt_iter = search_->tt_->find(picked_node.hash);
1887+
if (tt_iter != search_->tt_->end()) {
1888+
tt_low_node = tt_iter->second.lock();
1889+
}
1890+
1891+
// Try to get the moves from the transposition table entry and verify them.
1892+
if (tt_low_node) {
1893+
picked_node.legal_moves.reserve(tt_low_node->GetNumEdges());
1894+
const KingAttackInfo king_attack_info = board.GenerateKingAttackInfo();
1895+
for (int ct = 0; ct < tt_low_node->GetNumEdges(); ct++) {
1896+
auto move = tt_low_node->GetEdges()[ct].GetMove();
1897+
if (!board.IsLegalMove(move, king_attack_info)) {
1898+
// It was a hash collision, forget it.
1899+
tt_low_node.reset();
1900+
break;
1901+
}
1902+
picked_node.legal_moves.emplace_back(move);
1903+
}
1904+
}
1905+
if (!tt_low_node) {
1906+
picked_node.legal_moves = board.GenerateLegalMoves();
1907+
}
18831908

18841909
// Check whether it's a draw/lose by position. Importantly, we must check
18851910
// these before doing the by-rule checks below.
@@ -1953,16 +1978,10 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node,
19531978

19541979
picked_node.nn_queried = true; // Node::SetLowNode() required.
19551980

1956-
// Check the transposition table first and NN cache second before asking for
1957-
// NN evaluation.
1958-
picked_node.hash = history->HashLast(params_.GetCacheHistoryLength() + 1);
1959-
auto tt_iter = search_->tt_->find(picked_node.hash);
1960-
if (tt_iter != search_->tt_->end()) {
1961-
// assert(!tt_iter->second.expired());
1962-
picked_node.tt_low_node = tt_iter->second.lock();
1963-
}
1964-
if (picked_node.tt_low_node) {
1981+
// Check the NN cache before asking for NN evaluation.
1982+
if (tt_low_node) {
19651983
assert(!tt_iter->second.expired());
1984+
picked_node.tt_low_node = std::move(tt_low_node);
19661985
picked_node.is_tt_hit = true;
19671986
} else {
19681987
picked_node.lock = NNCacheLock(search_->cache_, picked_node.hash);

0 commit comments

Comments
 (0)