forked from Zeta36/chess-alpha-zero
-
Notifications
You must be signed in to change notification settings - Fork 7
/
model_chess.py
107 lines (90 loc) · 4.18 KB
/
model_chess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import hashlib
import json
import urllib.request
import ftplib
import os
from logging import getLogger
# noinspection PyPep8Naming
import keras.backend as K
from keras.engine.topology import Input
from keras.engine.training import Model
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.merge import Add
from keras.layers.normalization import BatchNormalization
from keras.losses import mean_squared_error
from keras.regularizers import l2
from chess_zero.config import Config
logger = getLogger(__name__)
class ChessModel:
def __init__(self, config: Config):
self.config = config
self.model = None # type: Model
self.digest = None
def build(self):
mc = self.config.model
in_x = x = Input((mc.input_stack_height, 8, 8))
# (batch, channels, height, width)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", kernel_regularizer=l2(mc.l2_reg), input_shape=(mc.input_stack_height, 8, 8))(x)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
for _ in range(mc.res_layer_num):
x = self._build_residual_block(x)
res_out = x
# for policy output
x = Conv2D(filters=2, kernel_size=1, data_format="channels_first", kernel_regularizer=l2(mc.l2_reg))(res_out)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
x = Flatten()(x)
# no output for 'pass'
policy_out = Dense(self.config.n_labels, kernel_regularizer=l2(mc.l2_reg), activation="softmax", name="policy_out")(x)
# for value output
x = Conv2D(filters=1, kernel_size=1, data_format="channels_first", kernel_regularizer=l2(mc.l2_reg))(res_out)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
x = Flatten()(x)
x = Dense(mc.value_fc_size, kernel_regularizer=l2(mc.l2_reg), activation="relu")(x)
value_out = Dense(1, kernel_regularizer=l2(mc.l2_reg), activation="tanh", name="value_out")(x)
self.model = Model(in_x, [policy_out, value_out], name="chess_model")
def _build_residual_block(self, x):
mc = self.config.model
in_x = x
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", kernel_regularizer=l2(mc.l2_reg))(x)
x = BatchNormalization(axis=1)(x)
x = Activation("relu")(x)
x = Conv2D(filters=mc.cnn_filter_num, kernel_size=mc.cnn_filter_size, padding="same", data_format="channels_first", kernel_regularizer=l2(mc.l2_reg))(x)
x = BatchNormalization(axis=1)(x)
x = Add()([in_x, x])
x = Activation("relu")(x)
return x
@staticmethod
def fetch_digest(weight_path):
if os.path.exists(weight_path):
m = hashlib.sha256()
with open(weight_path, "rb") as f:
m.update(f.read())
return m.hexdigest()
def load(self, config_path, weight_path):
if os.path.exists(config_path) and os.path.exists(weight_path):
logger.debug(f"loading model from {config_path}")
with open(config_path, "rt") as f:
self.model = Model.from_config(json.load(f))
self.model.load_weights(weight_path)
self.digest = self.fetch_digest(weight_path)
logger.debug(f"loaded model digest = {self.digest}")
return True
else:
logger.debug(f"model files do not exist at {config_path} and {weight_path}")
return False
def save(self, config_path, weight_path):
logger.debug(f"save model to {config_path}")
with open(config_path, "wt") as f:
json.dump(self.model.get_config(), f)
self.model.save_weights(weight_path)
self.digest = self.fetch_digest(weight_path)
logger.debug(f"saved model digest {self.digest}")
def objective_function_for_policy(y_true, y_pred):
# can use categorical_crossentropy??
return K.sum(-y_true * K.log(y_pred + K.epsilon()), axis=-1)
def objective_function_for_value(y_true, y_pred):
return mean_squared_error(y_true, y_pred)