Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Added NNUE support via egbbdll a la CFish.
  • Loading branch information
dshawul committed Oct 15, 2020
1 parent 7b80a83 commit 258bb61
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 3 deletions.
13 changes: 11 additions & 2 deletions bin/scorpio.ini
Expand Up @@ -101,7 +101,8 @@ egbb_cache_size 32
egbb_load_type 3
egbb_depth_limit 6
#######################################
# NN
# NN settings
#------------------------
# use_nn - turn off/on neural network
# nn_cache_size - neural network cache size in MBs
# nn_path - path to neural network
Expand Down Expand Up @@ -129,7 +130,7 @@ draw_weight 100
loss_weight 100
min_policy_value 0
################################################################
# Multiple nets settings
# Multiple NNs settings
#------------------------
# The opening is played by default network specified above
# nn_man_m/e -- threshold piece counts for middle/end game
Expand Down Expand Up @@ -163,3 +164,11 @@ wdl_head_m 0
wdl_head_e 0
ensemble 0
ensemble_type 0
#######################################
# NNUE settings
#------------------------
# use_nnue - turn off/on NNUE
# nnue_path - path to NNUE
#######################################
use_nnue 0
nnue_path ../nets-scorpio/nn.nnue
8 changes: 7 additions & 1 deletion src/eval.cpp
Expand Up @@ -64,13 +64,19 @@ int SEARCHER::eval(bool skip_nn_l) {
/*number of evaluation calls*/
ecalls++;

/* neural network evaluation */
#ifdef EGBB
/* neural network evaluation */
if(use_nn && !skip_nn && !skip_nn_l) {
pstack->actual_score = probe_neural();
record_eval_hash(hash_key,pstack->actual_score);
return pstack->actual_score;
}
/*nnue evaluation*/
if(use_nnue) {
pstack->actual_score = probe_nnue();
record_eval_hash(hash_key,pstack->actual_score);
return pstack->actual_score;
}
#endif

#endif
Expand Down
28 changes: 28 additions & 0 deletions src/probe.cpp
Expand Up @@ -42,9 +42,16 @@ typedef void (CDECL *PLOAD_NN)(
typedef void (CDECL *PSET_NUM_ACTIVE_SEARCHERS) (
int n_searchers);

typedef void (CDECL *PNNUE_INIT) (
const char * evalFile);
typedef int (CDECL *PNNUE_EVALUATE) (
int player, int* pieces, int* squares);

static PPROBE_EGBB probe_egbb;
static PPROBE_NN probe_nn;
static PSET_NUM_ACTIVE_SEARCHERS set_num_active_searchers = 0;
static PNNUE_INIT nnue_init;
static PNNUE_EVALUATE nnue_evaluate;

int SEARCHER::egbb_is_loaded = 0;
int SEARCHER::egbb_load_type = LOAD_4MEN;
Expand All @@ -57,8 +64,10 @@ char SEARCHER::egbb_files_path[MAX_STR] = "egbb/";
char SEARCHER::nn_path[MAX_STR] = "../nets-scorpio/net-6x64.pb";
char SEARCHER::nn_path_e[MAX_STR] = "../nets-scorpio/net-6x64.pb";
char SEARCHER::nn_path_m[MAX_STR] = "../nets-scorpio/net-6x64.pb";
char SEARCHER::nnue_path[MAX_STR] = "../nets-scorpio/nnue.bin";
int SEARCHER::nn_cache_size = 16;
int SEARCHER::use_nn = 0;
int SEARCHER::use_nnue = 0;
int SEARCHER::save_use_nn = 0;
int SEARCHER::n_devices = 1;
int SEARCHER::device_type = CPU;
Expand Down Expand Up @@ -220,6 +229,8 @@ void LoadEgbbLibrary(char* main_path,int egbb_cache_size,int nn_cache_size) {
probe_nn = (PPROBE_NN) GetProcAddress(hmod,"probe_neural_network");
set_num_active_searchers =
(PSET_NUM_ACTIVE_SEARCHERS) GetProcAddress(hmod,"set_num_active_searchers");
nnue_init = (PNNUE_INIT) GetProcAddress(hmod,"nnue_init");
nnue_evaluate = (PNNUE_EVALUATE) GetProcAddress(hmod,"nnue_evaluate");

if(load_egbb) {
clean_path(SEARCHER::egbb_files_path);
Expand All @@ -245,6 +256,12 @@ void LoadEgbbLibrary(char* main_path,int egbb_cache_size,int nn_cache_size) {
init_input_planes();
} else
SEARCHER::use_nn = 0;

if(nnue_init && SEARCHER::use_nnue) {
nnue_init(SEARCHER::nnue_path);
} else {
SEARCHER::use_nnue = 0;
}
} else {
print("EgbbProbe not Loaded!\n");
}
Expand Down Expand Up @@ -958,4 +975,15 @@ int SEARCHER::compress_input_planes(float** iplanes, char* buffer) {
bcount += sprintf(&buffer[bcount], "%d", cnt);

return bcount;
}
/*
NNUE
*/
int SEARCHER::probe_nnue() {
#ifdef EGBB
int piece[33],square[33],count = 0;
fill_list(count,piece,square);
return nnue_evaluate(player,piece,square);
#endif
return 0;
}
13 changes: 13 additions & 0 deletions src/scorpio.cpp
Expand Up @@ -407,9 +407,12 @@ static void print_options() {
print_spin("pht",pht,1,256);
print_path("egbb_path",SEARCHER::egbb_path);
print_path("egbb_files_path",SEARCHER::egbb_files_path);
print_check("use_nn",SEARCHER::use_nn);
print_check("use_nnue",SEARCHER::use_nnue);
print_path("nn_path",SEARCHER::nn_path);
print_path("nn_path_e",SEARCHER::nn_path_e);
print_path("nn_path_m",SEARCHER::nn_path_m);
print_path("nnue_path",SEARCHER::nnue_path);
print_spin("egbb_cache_size",SEARCHER::egbb_cache_size,1,16384);
print_spin("egbb_load_type",SEARCHER::egbb_load_type,0,3);
print_spin("egbb_depth_limit",SEARCHER::egbb_depth_limit,0,MAX_PLY);
Expand Down Expand Up @@ -520,6 +523,13 @@ bool internal_commands(char** commands,char* command,int& command_num) {
SEARCHER::use_nn = false;
SEARCHER::save_use_nn = SEARCHER::use_nn;
command_num++;
} else if (!strcmp(command, "use_nnue")) {
if(!strcmp(commands[command_num],"on") ||
!strcmp(commands[command_num],"1"))
SEARCHER::use_nnue = true;
else
SEARCHER::use_nnue = false;
command_num++;
} else if(!strcmp(command, "nn_path")) {
strcpy(SEARCHER::nn_path,commands[command_num]);
command_num++;
Expand All @@ -529,6 +539,9 @@ bool internal_commands(char** commands,char* command,int& command_num) {
} else if(!strcmp(command, "nn_path_m")) {
strcpy(SEARCHER::nn_path_m,commands[command_num]);
command_num++;
} else if(!strcmp(command, "nnue_path")) {
strcpy(SEARCHER::nnue_path,commands[command_num]);
command_num++;
} else if(!strcmp(command, "n_devices")) {
SEARCHER::n_devices = atoi(commands[command_num]);
command_num++;
Expand Down
3 changes: 3 additions & 0 deletions src/scorpio.h
Expand Up @@ -867,6 +867,7 @@ typedef struct SEARCHER{
void ensemble_net(int,int,int,float&);
float probe_neural_(bool,float*,int,int,int);
int probe_neural(bool=false);
int probe_nnue();
void handle_terminal(Node*,bool);
void self_play_thread();
void self_play_thread_all(FILE*,FILE*,int);
Expand All @@ -890,7 +891,9 @@ typedef struct SEARCHER{
static char nn_path[MAX_STR];
static char nn_path_e[MAX_STR];
static char nn_path_m[MAX_STR];
static char nnue_path[MAX_STR];
static int nn_cache_size;
static int use_nnue;
static int use_nn;
static int save_use_nn;
static int n_devices;
Expand Down

0 comments on commit 258bb61

Please sign in to comment.