Skip to content

Commit ee77ae5

Browse files
fix(policy_value_model): value head was not learning from hidden state
- Maybe it was a dumb idea to make the value head a single-layer projection of the hidden state. - Reduce the output sequence and use a dense layer activated with ReLu.
1 parent 460f80c commit ee77ae5

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ def __init__(
6363
self.args = args
6464
self.predictions = predictions
6565
self.embedding = MathyEmbedding(self.args)
66-
# self.embedding.compile(optimizer=self.optimizer, run_eagerly=True)
66+
self.value_head = tf.keras.layers.Dense(
67+
self.args.units,
68+
name="value_head",
69+
kernel_initializer="he_normal",
70+
activation="relu",
71+
)
6772
self.value_logits = tf.keras.layers.Dense(
6873
1, name="value_logits", kernel_initializer="he_normal", activation=None,
6974
)
@@ -114,7 +119,9 @@ def _call(
114119
inputs = features_window
115120
# Extract features into contextual inputs, sequence inputs.
116121
sequence_inputs = self.embedding(inputs)
117-
values = self.normalize_v(self.value_logits(self.embedding.state_h))
122+
values = self.value_head(tf.reduce_mean(sequence_inputs, axis=1))
123+
values = self.normalize_v(values)
124+
values = self.value_logits(values)
118125
logits = self.normalize_pi(self.policy_logits(sequence_inputs))
119126
mask_logits = self.apply_pi_mask(logits, features_window)
120127
if call_print is True:

0 commit comments

Comments
 (0)