Skip to content

Commit

Permalink
Allow for NUMA memory replication for NNUE weights. Bind threads to e…
Browse files Browse the repository at this point in the history
…nsure execution on a specific NUMA node.
  • Loading branch information
Sopel97 committed May 25, 2024
1 parent 4759764 commit 91a511d
Show file tree
Hide file tree
Showing 18 changed files with 1,297 additions and 243 deletions.
73 changes: 60 additions & 13 deletions src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ constexpr auto StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq -

Engine::Engine(std::string path) :
binaryDirectory(CommandLine::get_binary_directory(path)),
numaContext(NumaConfig::from_system()),
states(new std::deque<StateInfo>(1)),
networks(NN::Networks(
threads(),
networks(numaContext, NN::Networks(
NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG),
NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))) {
pos.set(StartFEN, false, &states->back());
Expand All @@ -74,7 +76,7 @@ void Engine::stop() { threads.stop = true; }
void Engine::search_clear() {
wait_for_search_finished();

tt.clear(options["Threads"]);
tt.clear(threads);
threads.clear();

// @TODO wont work with multiple instances
Expand Down Expand Up @@ -124,40 +126,68 @@ void Engine::set_position(const std::string& fen, const std::vector<std::string>

// modifiers

void Engine::resize_threads() { threads.set({options, threads, tt, networks}, updateContext); }
void Engine::set_numa_config_from_option(const std::string& o) {
if (o == "auto" || o == "system"){
numaContext.set_numa_config(NumaConfig::from_system());
} else if (o == "none") {
numaContext.set_numa_config(NumaConfig{});
} else {
numaContext.set_numa_config(NumaConfig::from_string(o));
}

// Force reallocation of threads in case affinites need to change.
resize_threads();
}

void Engine::resize_threads() {
threads.wait_for_search_finished();
threads.set(numaContext.get_numa_config(), {options, threads, tt, networks}, updateContext);

// Reallocate the hash with the new threadpool size
set_tt_size(options["Hash"]);
}

void Engine::set_tt_size(size_t mb) {
wait_for_search_finished();
tt.resize(mb, options["Threads"]);
tt.resize(mb, threads);
}

void Engine::set_ponderhit(bool b) { threads.main_manager()->ponder = b; }

// network related

void Engine::verify_networks() const {
networks.big.verify(options["EvalFile"]);
networks.small.verify(options["EvalFileSmall"]);
networks->big.verify(options["EvalFile"]);
networks->small.verify(options["EvalFileSmall"]);
}

void Engine::load_networks() {
load_big_network(options["EvalFile"]);
load_small_network(options["EvalFileSmall"]);
networks.modify_and_replicate([this](NN::Networks& networks_) {
networks_.big.load(binaryDirectory, options["EvalFile"]);
networks_.small.load(binaryDirectory, options["EvalFileSmall"]);
});
threads.clear();
}

void Engine::load_big_network(const std::string& file) {
networks.big.load(binaryDirectory, file);
networks.modify_and_replicate([this, &file](NN::Networks& networks_) {
networks_.big.load(binaryDirectory, file);
});
threads.clear();
}

void Engine::load_small_network(const std::string& file) {
networks.small.load(binaryDirectory, file);
networks.modify_and_replicate([this, &file](NN::Networks& networks_) {
networks_.small.load(binaryDirectory, file);
});
threads.clear();
}

void Engine::save_network(const std::pair<std::optional<std::string>, std::string> files[2]) {
networks.big.save(files[0].first);
networks.small.save(files[1].first);
networks.modify_and_replicate([this, &files](NN::Networks& networks_) {
networks_.big.save(files[0].first);
networks_.small.save(files[1].first);
});
}

// utility functions
Expand All @@ -169,7 +199,7 @@ void Engine::trace_eval() const {

verify_networks();

sync_cout << "\n" << Eval::trace(p, networks) << sync_endl;
sync_cout << "\n" << Eval::trace(p, *networks) << sync_endl;
}

OptionsMap& Engine::get_options() { return options; }
Expand All @@ -184,4 +214,21 @@ std::string Engine::visualize() const {
return ss.str();
}

std::vector<std::pair<size_t, size_t>> Engine::get_bound_thread_count_by_numa_node() const {
auto counts = threads.get_bound_thread_count_by_numa_node();
const NumaConfig& cfg = numaContext.get_numa_config();
std::vector<std::pair<size_t, size_t>> ratios;
NumaIndex n = 0;
for (; n < counts.size(); ++n)
ratios.emplace_back(counts[n], cfg.num_cpus_in_numa_node(n));
if (!counts.empty())
for (; n < cfg.num_numa_nodes(); ++n)
ratios.emplace_back(0, cfg.num_cpus_in_numa_node(n));
return ratios;
}

std::string Engine::get_numa_config_as_string() const {
return numaContext.get_numa_config().to_string();
}

}
15 changes: 14 additions & 1 deletion src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "thread.h"
#include "tt.h"
#include "ucioption.h"
#include "numa.h"

namespace Stockfish {

Expand All @@ -47,6 +48,13 @@ class Engine {
using InfoIter = Search::InfoIteration;

Engine(std::string path = "");

// Can't be movable due to components holding backreferences to fields
Engine(const Engine&) = delete;
Engine(Engine&&) = delete;
Engine& operator=(const Engine&) = delete;
Engine& operator=(Engine&&) = delete;

~Engine() { wait_for_search_finished(); }

std::uint64_t perft(const std::string& fen, Depth depth, bool isChess960);
Expand All @@ -63,6 +71,7 @@ class Engine {

// modifiers

void set_numa_config_from_option(const std::string& o);
void resize_threads();
void set_tt_size(size_t mb);
void set_ponderhit(bool);
Expand All @@ -88,18 +97,22 @@ class Engine {
std::string fen() const;
void flip();
std::string visualize() const;
std::vector<std::pair<size_t, size_t>> get_bound_thread_count_by_numa_node() const;
std::string get_numa_config_as_string() const;

private:
const std::string binaryDirectory;

NumaReplicationContext numaContext;

Position pos;
StateListPtr states;
Square capSq;

OptionsMap options;
ThreadPool threads;
TranspositionTable tt;
Eval::NNUE::Networks networks;
NumaReplicated<Eval::NNUE::Networks> networks;

Search::SearchManager::UpdateContext updateContext;
};
Expand Down
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
#include "types.h"
#include "uci.h"
#include "tune.h"
#include "numa.h"

using namespace Stockfish;

int main(int argc, char* argv[]) {

std::cout << engine_info() << std::endl;

Bitboards::init();
Expand Down
123 changes: 0 additions & 123 deletions src/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,129 +592,6 @@ void aligned_large_pages_free(void* mem) { std_aligned_free(mem); }
#endif


namespace WinProcGroup {

#ifndef _WIN32

void bind_this_thread(size_t) {}

#else

namespace {
// Retrieves logical processor information using Windows-specific
// API and returns the best node id for the thread with index idx. Original
// code from Texel by Peter Österlund.
int best_node(size_t idx) {

int threads = 0;
int nodes = 0;
int cores = 0;
DWORD returnLength = 0;
DWORD byteOffset = 0;

// Early exit if the needed API is not available at runtime
HMODULE k32 = GetModuleHandle(TEXT("Kernel32.dll"));
auto fun1 = (fun1_t) (void (*)()) GetProcAddress(k32, "GetLogicalProcessorInformationEx");
if (!fun1)
return -1;

// First call to GetLogicalProcessorInformationEx() to get returnLength.
// We expect the call to fail due to null buffer.
if (fun1(RelationAll, nullptr, &returnLength))
return -1;

// Once we know returnLength, allocate the buffer
SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX *buffer, *ptr;
ptr = buffer = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*) malloc(returnLength);

// Second call to GetLogicalProcessorInformationEx(), now we expect to succeed
if (!fun1(RelationAll, buffer, &returnLength))
{
free(buffer);
return -1;
}

while (byteOffset < returnLength)
{
if (ptr->Relationship == RelationNumaNode)
nodes++;

else if (ptr->Relationship == RelationProcessorCore)
{
cores++;
threads += (ptr->Processor.Flags == LTP_PC_SMT) ? 2 : 1;
}

assert(ptr->Size);
byteOffset += ptr->Size;
ptr = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*) (((char*) ptr) + ptr->Size);
}

free(buffer);

std::vector<int> groups;

// Run as many threads as possible on the same node until the core limit is
// reached, then move on to filling the next node.
for (int n = 0; n < nodes; n++)
for (int i = 0; i < cores / nodes; i++)
groups.push_back(n);

// In case a core has more than one logical processor (we assume 2) and we
// still have threads to allocate, spread them evenly across available nodes.
for (int t = 0; t < threads - cores; t++)
groups.push_back(t % nodes);

// If we still have more threads than the total number of logical processors
// then return -1 and let the OS to decide what to do.
return idx < groups.size() ? groups[idx] : -1;
}
}


// Sets the group affinity of the current thread
void bind_this_thread(size_t idx) {

// Use only local variables to be thread-safe
int node = best_node(idx);

if (node == -1)
return;

// Early exit if the needed API are not available at runtime
HMODULE k32 = GetModuleHandle(TEXT("Kernel32.dll"));
auto fun2 = fun2_t((void (*)()) GetProcAddress(k32, "GetNumaNodeProcessorMaskEx"));
auto fun3 = fun3_t((void (*)()) GetProcAddress(k32, "SetThreadGroupAffinity"));
auto fun4 = fun4_t((void (*)()) GetProcAddress(k32, "GetNumaNodeProcessorMask2"));
auto fun5 = fun5_t((void (*)()) GetProcAddress(k32, "GetMaximumProcessorGroupCount"));

if (!fun2 || !fun3)
return;

if (!fun4 || !fun5)
{
GROUP_AFFINITY affinity;
if (fun2(node, &affinity)) // GetNumaNodeProcessorMaskEx
fun3(GetCurrentThread(), &affinity, nullptr); // SetThreadGroupAffinity
}
else
{
// If a numa node has more than one processor group, we assume they are
// sized equal and we spread threads evenly across the groups.
USHORT elements, returnedElements;
elements = fun5(); // GetMaximumProcessorGroupCount
GROUP_AFFINITY* affinity = (GROUP_AFFINITY*) malloc(elements * sizeof(GROUP_AFFINITY));
if (fun4(node, affinity, elements, &returnedElements)) // GetNumaNodeProcessorMask2
fun3(GetCurrentThread(), &affinity[idx % returnedElements],
nullptr); // SetThreadGroupAffinity
free(affinity);
}
}

#endif

} // namespace WinProcGroup

#ifdef _WIN32
#include <direct.h>
#define GETCWD _getcwd
Expand Down
Loading

0 comments on commit 91a511d

Please sign in to comment.