Skip to content

Commit

Permalink
Merge pull request #92 from Sopel97/better_eval_checking
Browse files Browse the repository at this point in the history
Make cross_eval_check script form batches from fens, delegated to the data loader.
  • Loading branch information
Sopel97 committed Apr 18, 2021
2 parents 26c8b95 + 3d47f97 commit fae9b9c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 19 deletions.
60 changes: 41 additions & 19 deletions cross_check_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import features
import serialize
import nnue_bin_dataset
import nnue_dataset
import subprocess
import re
from model import NNUE
Expand All @@ -14,21 +15,14 @@ def read_model(nnue_path, feature_set):
def make_data_reader(data_path, feature_set):
return nnue_bin_dataset.NNUEBinData(data_path, feature_set)

def eval_model(model, item):
us, them, white, black, outcome, score = item
us = us.unsqueeze(dim=0)
them = them.unsqueeze(dim=0)
white = white.unsqueeze(dim=0)
black = black.unsqueeze(dim=0)
def eval_model_batch(model, batch):
us, them, white, black, outcome, score = batch.contents.get_tensors('cpu')

eval = model.forward(us, them, white, black).item() * 600.0
if them[0] > 0.5:
return -eval
else:
return eval

def eval_engine(engine, fen):
pass
evals = [v.item() for v in model.forward(us, them, white, black) * 600.0]
for i in range(len(evals)):
if them[i] > 0.5:
evals[i] = -evals[i]
return evals

re_nnue_eval = re.compile(r'NNUE evaluation:\s*?(-?\d*?\.\d*)')

Expand Down Expand Up @@ -89,8 +83,31 @@ def main():
data_reader = make_data_reader(args.data, feature_set)

fens = []
results = []
scores = []
plies = []
model_evals = []
engine_evals = []
i = -1

def commit_batch():
nonlocal fens
nonlocal results
nonlocal scores
nonlocal plies
nonlocal model_evals
nonlocal engine_evals
if len(fens) == 0:
return
b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens, scores, plies, results)
model_evals += eval_model_batch(model, b)
nnue_dataset.destroy_sparse_batch(b)
engine_evals += eval_engine_batch(args.engine, args.net, fens)
fens = []
results = []
scores = []
plies = []

done = 0
while done < args.count:
i += 1
Expand All @@ -100,14 +117,19 @@ def main():
if board.is_check():
continue

fen = board.fen()
fens.append(fen)
eval = eval_model(model, data_reader.transform(item))
model_evals.append(eval)
fens.append(board.fen())
results.append(int(round(item[2] * 2 - 1)))
scores.append(int(item[3]))
plies.append(1)

done += 1

engine_evals = eval_engine_batch(args.engine, args.net, fens)
if done % 1024 == 0:
# don't do batches that are too big
commit_batch()

commit_batch()

compute_correlation(engine_evals, model_evals)

if __name__ == '__main__':
Expand Down
19 changes: 19 additions & 0 deletions nnue_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ def __del__(self):
fetch_next_sparse_batch.argtypes = [ctypes.c_void_p]
destroy_sparse_batch = dll.destroy_sparse_batch

get_sparse_batch_from_fens = dll.get_sparse_batch_from_fens
get_sparse_batch_from_fens.restype = SparseBatchPtr
get_sparse_batch_from_fens.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)]

def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
results_ = (ctypes.c_int*len(scores))()
scores_ = (ctypes.c_int*len(plies))()
plies_ = (ctypes.c_int*len(results))()
fens_ = (ctypes.c_char_p * len(fens))()
fens_[:] = [fen.encode('utf-8') for fen in fens]
for i, v in enumerate(scores):
scores_[i] = v
for i, v in enumerate(plies):
plies_[i] = v
for i, v in enumerate(results):
results_[i] = v
b = get_sparse_batch_from_fens(feature_set.name.encode('utf-8'), len(fens), fens_, scores_, plies_, results_)
return b

class SparseBatchProvider(TrainingDataProvider):
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, device='cpu'):
super(SparseBatchProvider, self).__init__(
Expand Down
42 changes: 42 additions & 0 deletions training_data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,48 @@ struct FeaturedBatchStream : Stream<StorageT>

extern "C" {

EXPORT SparseBatch* get_sparse_batch_from_fens(
const char* feature_set_c,
int num_fens,
const char* const* fens,
int* scores,
int* plies,
int* results
)
{
std::vector<TrainingDataEntry> entries;
entries.reserve(num_fens);
for (int i = 0; i < num_fens; ++i)
{
auto& e = entries.emplace_back();
e.pos = Position::fromFen(fens[i]);
movegen::forEachLegalMove(e.pos, [&](Move m){e.move = m;});
e.score = scores[i];
e.ply = plies[i];
e.result = results[i];
}

std::string_view feature_set(feature_set_c);
if (feature_set == "HalfKP")
{
return new SparseBatch(FeatureSet<HalfKP>{}, entries);
}
else if (feature_set == "HalfKP^")
{
return new SparseBatch(FeatureSet<HalfKPFactorized>{}, entries);
}
else if (feature_set == "HalfKA")
{
return new SparseBatch(FeatureSet<HalfKA>{}, entries);
}
else if (feature_set == "HalfKA^")
{
return new SparseBatch(FeatureSet<HalfKAFactorized>{}, entries);
}
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
return nullptr;
}

EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping)
{
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr;
Expand Down

0 comments on commit fae9b9c

Please sign in to comment.