-
Notifications
You must be signed in to change notification settings - Fork 94
/
cross_check_eval.py
136 lines (112 loc) · 5.02 KB
/
cross_check_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import features
import serialize
import nnue_bin_dataset
import nnue_dataset
import subprocess
import re
from model import NNUE
def read_model(nnue_path, feature_set):
with open(nnue_path, 'rb') as f:
reader = serialize.NNUEReader(f, feature_set)
return reader.model
def make_data_reader(data_path, feature_set):
return nnue_bin_dataset.NNUEBinData(data_path, feature_set)
def eval_model_batch(model, batch):
us, them, white_indices, white_values, black_indices, black_values, outcome, score = batch.contents.get_tensors('cpu')
evals = [v.item() for v in model.forward(us, them, white_indices, white_values, black_indices, black_values) * 600.0]
for i in range(len(evals)):
if them[i] > 0.5:
evals[i] = -evals[i]
return evals
re_nnue_eval = re.compile(r'NNUE evaluation:\s*?(-?\d*?\.\d*)')
def compute_basic_eval_stats(evals):
min_engine_eval = min(evals)
max_engine_eval = max(evals)
avg_engine_eval = sum(evals) / len(evals)
avg_abs_engine_eval = sum(abs(v) for v in evals) / len(evals)
return min_engine_eval, max_engine_eval, avg_engine_eval, avg_abs_engine_eval
def compute_correlation(engine_evals, model_evals):
if len(engine_evals) != len(model_evals):
raise Exception("number of engine evals doesn't match the number of model evals")
min_engine_eval, max_engine_eval, avg_engine_eval, avg_abs_engine_eval = compute_basic_eval_stats(engine_evals)
min_model_eval, max_model_eval, avg_model_eval, avg_abs_model_eval = compute_basic_eval_stats(model_evals)
print('Min engine/model eval: {} / {}'.format(min_engine_eval, min_model_eval))
print('Max engine/model eval: {} / {}'.format(max_engine_eval, max_model_eval))
print('Avg engine/model eval: {} / {}'.format(avg_engine_eval, avg_model_eval))
print('Avg abs engine/model eval: {} / {}'.format(avg_abs_engine_eval, avg_abs_model_eval))
relative_model_error = sum(abs(model - engine) / (abs(engine)+0.001) for model, engine in zip(model_evals, engine_evals)) / len(engine_evals)
relative_engine_error = sum(abs(model - engine) / (abs(model)+0.001) for model, engine in zip(model_evals, engine_evals)) / len(engine_evals)
print('Relative engine error: {}'.format(relative_engine_error))
print('Relative model error: {}'.format(relative_model_error))
print('Avg abs difference: {}'.format(sum(abs(model - engine) for model, engine in zip(model_evals, engine_evals)) / len(engine_evals)))
def eval_engine_batch(engine_path, net_path, fens):
engine = subprocess.Popen([engine_path], stdin=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True)
parts = ['uci', 'setoption name EvalFile value {}'.format(net_path)]
for fen in fens:
parts.append('position fen {}'.format(fen))
parts.append('eval')
parts.append('quit')
query = '\n'.join(parts)
out = engine.communicate(input=query)[0]
evals = re.findall(re_nnue_eval, out)
return [int(float(v)*208) for v in evals]
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument("--net", type=str, help="path to a .nnue net")
parser.add_argument("--engine", type=str, help="path to stockfish")
parser.add_argument("--data", type=str, help="path to .bin dataset")
parser.add_argument("--checkpoint", type=str, help="Optional checkpoint (used instead of nnue for local eval)")
parser.add_argument("--count", type=int, default=100, help="number of datapoints to process")
features.add_argparse_args(parser)
args = parser.parse_args()
feature_set = features.get_feature_set_from_name(args.features)
if args.checkpoint:
model = NNUE.load_from_checkpoint(args.checkpoint, feature_set=feature_set)
else:
model = read_model(args.net, feature_set)
model.eval()
data_reader = make_data_reader(args.data, feature_set)
fens = []
results = []
scores = []
plies = []
model_evals = []
engine_evals = []
i = -1
def commit_batch():
nonlocal fens
nonlocal results
nonlocal scores
nonlocal plies
nonlocal model_evals
nonlocal engine_evals
if len(fens) == 0:
return
b = nnue_dataset.make_sparse_batch_from_fens(feature_set, fens, scores, plies, results)
model_evals += eval_model_batch(model, b)
nnue_dataset.destroy_sparse_batch(b)
engine_evals += eval_engine_batch(args.engine, args.net, fens)
fens = []
results = []
scores = []
plies = []
done = 0
while done < args.count:
i += 1
item = data_reader.get_raw(i)
board = item[0]
if board.is_check():
continue
fens.append(board.fen())
results.append(int(round(item[2] * 2 - 1)))
scores.append(int(item[3]))
plies.append(1)
done += 1
if done % 1024 == 0:
# don't do batches that are too big
commit_batch()
commit_batch()
compute_correlation(engine_evals, model_evals)
if __name__ == '__main__':
main()