Skip to content

Commit 38690c2

Browse files
fix(zero): default to non-recurrent architecture
- the batching breaks up the temporal link between examples that would normally be processed by the time_lstm in A3C. - I guess this could be fixed somehow, but I'm not sure how at the moment.
1 parent c9c409a commit 38690c2

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

libraries/mathy_python/mathy/agents/densenet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
**kwargs,
1818
):
1919
super(DenseNetBlock, self).__init__(**kwargs)
20+
self.units = units
2021
self.use_shared = use_shared
2122
self.activate = tf.keras.layers.Activation("relu", name="relu")
2223
if normalization == "batch":

libraries/mathy_python/mathy/agents/zero/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ class SelfPlayConfig(BaseConfig):
88
temperature_threshold: float = 0.5
99
self_play_problems: int = 64
1010
training_iterations: int = 100
11-
cpuct: float = 1.0
11+
cpuct: float = 1.0
1212
normalization_style: str = "batch"
1313
# When profile is true and workers == 1 the main thread will output worker_0.profile on exit
1414
profile: bool = False
15+
# Don't use the LSTM with Zero agents because the random sampling breaks
16+
# timestep correlations across the batch.
17+
use_lstm: bool = False
18+

libraries/mathy_python/mathy/agents/zero/self_play_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_predictor(self, game):
3939
)
4040

4141
model: PolicyValueModel = get_or_create_policy_model(
42-
args=config, env_actions=game.action_space.n,
42+
args=config, env_actions=game.action_space.n, is_main=True
4343
)
4444
return model
4545

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
import numpy as np
3+
from ..mathy.agents.densenet import DenseNetBlock, DenseNetStack
4+
import tensorflow as tf
5+
6+
7+
def test_densenet_construction():
8+
layer: DenseNetStack = DenseNetStack()
9+
assert layer is not None
10+
model = tf.keras.Sequential([layer])
11+
model(np.zeros(shape=(128, 1)))
12+
13+
14+
def test_densenet_errors():
15+
with pytest.raises(ValueError):
16+
DenseNetStack(normalization_style="invalid")
17+
18+
19+
@pytest.mark.parametrize("norm_type", ("batch", "layer"))
20+
def test_densenet_normalization_types(norm_type: str):
21+
model = tf.keras.Sequential(
22+
[DenseNetStack(units=24, num_layers=4, normalization_style=norm_type)]
23+
)
24+
y = model(np.zeros(shape=(10, 1)))
25+
26+
27+
def test_densenet_no_activation():
28+
model = tf.keras.Sequential(
29+
[DenseNetStack(units=24, num_layers=4, activation=None)]
30+
)
31+
y = model(np.zeros(shape=(10, 1)))
32+
assert y.shape[0] == 10
33+
34+
35+
def test_densenet_output_transform():
36+
layer: DenseNetStack = DenseNetStack(
37+
units=10,
38+
num_layers=4,
39+
activation=None,
40+
output_transform=tf.keras.layers.Dense(128),
41+
)
42+
assert layer is not None
43+
model = tf.keras.Sequential([layer])
44+
x = np.zeros(shape=(10, 1))
45+
y = model(x)
46+
assert y.shape[1] == 128
47+
48+
49+
def test_densenet_share_weights():
50+
layer: DenseNetStack = DenseNetStack(share_weights=True)
51+
layer_two: DenseNetStack = DenseNetStack(share_weights=True)
52+
assert layer.get_weights() == layer_two.get_weights()
53+
model = tf.keras.Sequential([layer, layer_two])
54+
x = np.zeros(shape=(10, 1))
55+
y = model(x)
56+

0 commit comments

Comments
 (0)