Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Added incremental NNUE support.
  • Loading branch information
dshawul committed Oct 17, 2020
1 parent 258bb61 commit 04bfc4f
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/Makefile
Expand Up @@ -48,6 +48,7 @@ DEFINES += -DMAX_CPUS=640
#DEFINES += -DCLUSTER_TT_TYPE=1
#DEFINES += -DCUTECHESS_FIX
DEFINES += -DNODES_PRIOR
DEFINES += -DNNUE_INC
############################
# Compiler choice
############################
Expand Down Expand Up @@ -102,16 +103,16 @@ endif

ifeq ($(COMP),win)
LXXFLAGS += -static
CXXFLAGS = -Wall -fstrict-aliasing -fno-exceptions -fno-rtti $(UNUSED) -std=c++11
CXXFLAGS = -Wall -fstrict-aliasing -fno-exceptions -fno-rtti $(UNUSED) -std=c++17
else ifeq ($(COMP),pgcc)
CXXFLAGS = warn -Mvect=sse -c11
LXXFLAGS += -ldl
else ifeq ($(COMP),icpc)
CXXFLAGS = -wd128 -wd981 -wd869 -wd2259 -wd383 -wd1418
CXXFLAGS += -fstrict-aliasing -fno-exceptions -fno-rtti $(UNUSED) -std=c++11
CXXFLAGS += -fstrict-aliasing -fno-exceptions -fno-rtti $(UNUSED) -std=c++17
LXXFLAGS += -ldl
else
CXXFLAGS = -Wall -fstrict-aliasing -fno-exceptions -fno-rtti $(UNUSED) -std=c++11
CXXFLAGS = -Wall -fstrict-aliasing -fno-exceptions -fno-rtti $(UNUSED) -std=c++17
LXXFLAGS += -ldl
endif

Expand Down
54 changes: 51 additions & 3 deletions src/moves.cpp
Expand Up @@ -23,6 +23,15 @@ void SEARCHER::do_move(const MOVE& move) {
phstack->pawns_bb[white] = pawns_bb[white];
phstack->pawns_bb[black] = pawns_bb[black];

/*nnue*/
#ifdef NNUE_INC
DirtyPiece* dp = &(nnue[hply+1].dirtyPiece);
if(use_nnue) {
nnue[hply+1].accumulator.computedAccumulation = 0;
dp->dirtyNum = 1;
}
#endif

/*remove captured piece*/
if((pic = m_capture(move)) != 0) {
if(is_ep(move)) {
Expand All @@ -32,7 +41,14 @@ void SEARCHER::do_move(const MOVE& move) {
else sq = to;
pcRemove(pic,sq,phstack->pCapt);
board[sq] = blank;

#ifdef NNUE_INC
if(use_nnue) {
dp->dirtyNum = 2;
dp->pc[1] = pic;
dp->from[1] = SQ8864(sq);
dp->to[1] = 64;
}
#endif
hash_key ^= PC_HKEY(pic,sq);
if(PIECE(pic) == pawn) {
pawn_hash_key ^= PC_HKEY(pic,sq);
Expand All @@ -50,13 +66,28 @@ void SEARCHER::do_move(const MOVE& move) {
/*move piece*/
all_bb ^= BB(from);
all_bb |= BB(to);
#ifdef NNUE_INC
if(use_nnue) {
dp->pc[0] = m_piece(move);
dp->from[0] = SQ8864(from);
dp->to[0] = SQ8864(to);
}
#endif
if((pic = m_promote(move)) != 0) {
pic1 = COMBINE(player,pawn);
board[to] = pic;
board[from] = blank;
pcAdd(pic,to);
pcRemove(pic1,from,phstack->pProm);

#ifdef NNUE_INC
if(use_nnue) {
dp->to[0] = 64;
dp->pc[dp->dirtyNum] = pic;
dp->from[dp->dirtyNum] = 64;
dp->to[dp->dirtyNum] = SQ8864(to);
dp->dirtyNum++;
}
#endif
pawns_bb[player] ^= BB(from);
pieces_bb[player] ^= BB(to);
hash_key ^= PC_HKEY(pic1,from);
Expand Down Expand Up @@ -103,7 +134,14 @@ void SEARCHER::do_move(const MOVE& move) {
board[fromc] = blank;
pcSwap(fromc,toc);
pic = COMBINE(player,rook);

#ifdef NNUE_INC
if(use_nnue) {
dp->dirtyNum = 2;
dp->pc[1] = pic;
dp->from[1] = SQ8864(fromc);
dp->to[1] = SQ8864(toc);
}
#endif
pieces_bb[player] ^= (BB(fromc) | BB(toc));
all_bb ^= (BB(fromc) | BB(toc));
hash_key ^= PC_HKEY(pic,toc);
Expand Down Expand Up @@ -238,6 +276,16 @@ void SEARCHER::do_null() {
phstack->checks = 0;
phstack->hash_key = hash_key;

/*nnue*/
#ifdef NNUE_INC
if(use_nnue) {
memcpy(&nnue[hply+1].accumulator,
&nnue[hply].accumulator, sizeof(Accumulator));
DirtyPiece* dp = &(nnue[hply+1].dirtyPiece);
dp->dirtyNum = 0;
}
#endif

if(epsquare)
hash_key ^= EP_HKEY(epsquare);
epsquare = 0;
Expand Down
28 changes: 26 additions & 2 deletions src/probe.cpp
Expand Up @@ -47,12 +47,21 @@ typedef void (CDECL *PNNUE_INIT) (
typedef int (CDECL *PNNUE_EVALUATE) (
int player, int* pieces, int* squares);

#ifdef NNUE_INC
typedef int (CDECL *PNNUE_EVALUATE_INCREMENTAL) (
int player, int* pieces, int* squares, NNUEdata**);
#endif

static PPROBE_EGBB probe_egbb;
static PPROBE_NN probe_nn;
static PSET_NUM_ACTIVE_SEARCHERS set_num_active_searchers = 0;
static PNNUE_INIT nnue_init;
static PNNUE_EVALUATE nnue_evaluate;

#ifdef NNUE_INC
static PNNUE_EVALUATE_INCREMENTAL nnue_evaluate_incremental;
#endif

int SEARCHER::egbb_is_loaded = 0;
int SEARCHER::egbb_load_type = LOAD_4MEN;
int SEARCHER::egbb_depth_limit = 3;
Expand Down Expand Up @@ -231,6 +240,10 @@ void LoadEgbbLibrary(char* main_path,int egbb_cache_size,int nn_cache_size) {
(PSET_NUM_ACTIVE_SEARCHERS) GetProcAddress(hmod,"set_num_active_searchers");
nnue_init = (PNNUE_INIT) GetProcAddress(hmod,"nnue_init");
nnue_evaluate = (PNNUE_EVALUATE) GetProcAddress(hmod,"nnue_evaluate");
#ifdef NNUE_INC
nnue_evaluate_incremental =
(PNNUE_EVALUATE_INCREMENTAL) GetProcAddress(hmod,"nnue_evaluate_incremental");
#endif

if(load_egbb) {
clean_path(SEARCHER::egbb_files_path);
Expand Down Expand Up @@ -983,7 +996,18 @@ int SEARCHER::probe_nnue() {
#ifdef EGBB
int piece[33],square[33],count = 0;
fill_list(count,piece,square);
return nnue_evaluate(player,piece,square);

#ifdef NNUE_INC
NNUEdata* a_nnue[3] = {0, 0, 0};
for(int i = 0; i < 3 && hply >= i; i++)
a_nnue[i] = &nnue[hply - i];
return nnue_evaluate_incremental(
player,piece,square,&a_nnue[0]);
#else
return nnue_evaluate(
player,piece,square);
#endif

#endif
return 0;
}
}
26 changes: 26 additions & 0 deletions src/scorpio.h
Expand Up @@ -523,6 +523,29 @@ struct Node {
static float compute_fpu(Node*,bool);
};
/*
NNUE - data structures for incemental update
*/
#ifdef NNUE_INC

typedef struct DirtyPiece {
int dirtyNum;
int pc[3];
int from[3];
int to[3];
} DirtyPiece;

typedef struct Accumulator {
CACHE_ALIGN int16_t accumulation[2][256];
int computedAccumulation;
} Accumulator;

typedef struct NNUEdata {
Accumulator accumulator;
DirtyPiece dirtyPiece;
} NNUEdata;

#endif
/*
stacks
*/
typedef struct HIST_STACK{
Expand Down Expand Up @@ -680,6 +703,9 @@ typedef struct SEARCHER{
PLIST list[128];
PLIST plist[15];
HIST_STACK hstack[MAX_HSTACK];
#ifdef NNUE_INC
CACHE_ALIGN NNUEdata nnue[MAX_HSTACK];
#endif
STACK stack[MAX_PLY];
float prev_kld;
/*eval data*/
Expand Down
10 changes: 9 additions & 1 deletion src/util.cpp
Expand Up @@ -690,6 +690,11 @@ void SEARCHER::init_data() {
if(epsquare)
hash_key ^= EP_HKEY(epsquare);
hash_key ^= CAST_HKEY(castle);

#ifdef NNUE_INC
if(use_nnue)
nnue[hply].accumulator.computedAccumulation = 0;
#endif
}

void SEARCHER::set_board(const char* fen_str) {
Expand Down Expand Up @@ -964,7 +969,10 @@ void SEARCHER::COPY(SEARCHER* srcSearcher) {
all_man_c = srcSearcher->all_man_c;
root_node = srcSearcher->root_node;
root_key = srcSearcher->root_key;

#ifdef NNUE_INC
if(use_nnue)
memcpy(nnue, srcSearcher->nnue, (hply + 1) * sizeof(NNUEdata));
#endif
/*history stack*/
memcpy(&hstack[0],&srcSearcher->hstack[0], (hply + 1) * sizeof(HIST_STACK));

Expand Down

0 comments on commit 04bfc4f

Please sign in to comment.