Skip to content

Commit 8b97e9d

Browse files
feat(embedding): add optional non-recurrent DenseNet+Attention architecture --use-lstm=False
- while trying to train the Zero agent on Colab it is _really slow_, and I thought it might be the LSTM mixing with hundreds of calls per episode step. - Use a DenseNet stack with a Luong style attention layer at the end when use-lstm is false
1 parent 8534239 commit 8b97e9d

File tree

8 files changed

+230
-71
lines changed

8 files changed

+230
-71
lines changed

libraries/mathy_python/mathy/agents/a3c/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class A3CConfig(BaseConfig):
1111
# syncing the latest model from the learner process.
1212
update_gradients_every: int = 64
1313

14+
normalization_style: str = "layer"
15+
1416
# How many times to think about the initial state before acting.
1517
# (intuition) is that the LSTM updates the state each time it processes
1618
# the init sequence meaning that it gets more time to fine-tune the hidden

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ class Config:
1818
url: str = about.__uri__
1919
mathy_version: str = f">={about.__version__},<1.0.0"
2020

21+
# One of "batch" or "layer"
22+
normalization_style = "batch"
23+
24+
# Whether to use the LSTM or non-reccurrent architecture
25+
use_lstm: bool = True
2126
units: int = 64
2227
embedding_units: int = 128
2328
lstm_units: int = 128
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from typing import List, Optional
2+
3+
import tensorflow as tf
4+
5+
6+
class DenseNetBlock(tf.keras.layers.Layer):
7+
"""DenseNet like block layer that concatenates inputs with extracted features
8+
and feeds them into all future tower layers."""
9+
10+
def __init__(
11+
self,
12+
units=128,
13+
num_layers=2,
14+
use_shared=False,
15+
activation="relu",
16+
normalization="batch",
17+
**kwargs,
18+
):
19+
super(DenseNetBlock, self).__init__(**kwargs)
20+
self.use_shared = use_shared
21+
self.activate = tf.keras.layers.Activation("relu", name="relu")
22+
if normalization == "batch":
23+
self.normalize = tf.keras.layers.BatchNormalization()
24+
elif normalization == "layer":
25+
self.normalize = tf.keras.layers.LayerNormalization()
26+
else:
27+
raise ValueError(f"unkknown layer normalization style: {normalization}")
28+
self.dense = tf.keras.layers.Dense(
29+
units, use_bias=False, name=f"{self.name}_dense"
30+
)
31+
self.concat = tf.keras.layers.Concatenate(name=f"{self.name}_concat")
32+
self.activate = tf.keras.layers.Activation(activation, name="activate")
33+
34+
def do_op(self, input_tensor: tf.Tensor, previous_tensors: List[tf.Tensor]):
35+
inputs = previous_tensors[:]
36+
inputs.append(input_tensor)
37+
# concatenate the input and previous layer inputs
38+
if (len(inputs)) > 1:
39+
input_tensor = self.concat(inputs)
40+
return self.normalize(self.activate(self.dense(input_tensor)))
41+
42+
def call(self, input_tensor: tf.Tensor, previous_tensors: List[tf.Tensor]):
43+
if self.use_shared:
44+
with tf.compat.v1.variable_scope(
45+
self.name, reuse=True, auxiliary_name_scope=False
46+
):
47+
return self.do_op(input_tensor, previous_tensors)
48+
else:
49+
with tf.compat.v1.variable_scope(self.name):
50+
return self.do_op(input_tensor, previous_tensors)
51+
52+
53+
class DenseNetStack(tf.keras.layers.Layer):
54+
"""DenseNet like stack of residually connected layers where the input and all
55+
of the previous outputs are provided as the input to each layer:
56+
57+
https://arxiv.org/pdf/1608.06993.pdf
58+
59+
From the paper: "Crucially, in contrast to ResNets, we never combine features
60+
through summation before they are passed into a layer; instead, we combine features
61+
by concatenating them"
62+
"""
63+
64+
def __init__(
65+
self,
66+
units=64,
67+
num_layers=4,
68+
layer_scaling_factor=0.75,
69+
share_weights=False,
70+
activation="relu",
71+
output_transform: Optional[tf.keras.layers.Layer] = None,
72+
normalization_style: str = "layer",
73+
**kwargs,
74+
):
75+
self.units = units
76+
self.layer_scaling_factor = layer_scaling_factor
77+
self.num_layers = num_layers
78+
self.output_transform = output_transform
79+
if activation is not None:
80+
self.activate = tf.keras.layers.Activation(activation)
81+
else:
82+
self.activate = None
83+
self.concat = tf.keras.layers.Concatenate()
84+
block_units = int(self.units)
85+
self.dense_stack = [
86+
DenseNetBlock(
87+
block_units,
88+
name="densenet_0",
89+
use_shared=share_weights,
90+
normalization=normalization_style,
91+
)
92+
]
93+
for i in range(self.num_layers - 1):
94+
block_units = int(block_units * self.layer_scaling_factor)
95+
self.dense_stack.append(
96+
DenseNetBlock(
97+
block_units,
98+
name=f"densenet_{i + 1}",
99+
use_shared=share_weights,
100+
normalization=normalization_style,
101+
)
102+
)
103+
super(DenseNetStack, self).__init__(**kwargs)
104+
105+
def call(self, input_tensor: tf.Tensor):
106+
stack_inputs: List[tf.Tensor] = []
107+
root_input = input_tensor
108+
# Iterate the stack and call each layer, also inputting a list
109+
# of all the previous stack layers outputs.
110+
for layer in self.dense_stack:
111+
# Save reference to input
112+
prev_tensor = input_tensor
113+
# Apply densnet block
114+
input_tensor = layer(input_tensor, stack_inputs)
115+
# Append the current input to the list of previous input tensors
116+
stack_inputs.append(prev_tensor)
117+
output = self.concat([input_tensor, root_input])
118+
if self.output_transform is not None:
119+
output = self.output_transform(output)
120+
return self.activate(output) if self.activate is not None else output

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 78 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
MathyWindowObservation,
1010
ObservationFeatureIndices,
1111
)
12+
from mathy.agents.densenet import DenseNetStack
1213

1314

1415
class MathyEmbedding(tf.keras.Model):
@@ -65,25 +66,36 @@ def __init__(
6566
activation="relu",
6667
kernel_initializer="he_normal",
6768
)
68-
self.out_dense_norm = tf.keras.layers.LayerNormalization(name="out_dense_norm")
69-
self.time_lstm_norm = tf.keras.layers.LayerNormalization(name="lstm_time_norm")
70-
self.nodes_lstm_norm = tf.keras.layers.LayerNormalization(
71-
name="lstm_nodes_norm"
72-
)
73-
self.time_lstm = tf.keras.layers.LSTM(
74-
self.config.lstm_units,
75-
name="timestep_lstm",
76-
return_sequences=True,
77-
time_major=True,
78-
return_state=True,
79-
)
80-
self.nodes_lstm = tf.keras.layers.LSTM(
81-
self.config.lstm_units,
82-
name="nodes_lstm",
83-
return_sequences=True,
84-
time_major=False,
85-
return_state=False,
86-
)
69+
70+
NormalizeClass = tf.keras.layers.LayerNormalization
71+
if self.config.normalization_style == "batch":
72+
NormalizeClass = tf.keras.layers.BatchNormalization
73+
self.out_dense_norm = NormalizeClass(name="out_dense_norm")
74+
if self.config.use_lstm:
75+
self.time_lstm_norm = NormalizeClass(name="lstm_time_norm")
76+
self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm")
77+
self.time_lstm = tf.keras.layers.LSTM(
78+
self.config.lstm_units,
79+
name="timestep_lstm",
80+
return_sequences=True,
81+
time_major=True,
82+
return_state=True,
83+
)
84+
self.nodes_lstm = tf.keras.layers.LSTM(
85+
self.config.lstm_units,
86+
name="nodes_lstm",
87+
return_sequences=True,
88+
time_major=False,
89+
return_state=False,
90+
)
91+
else:
92+
self.densenet = DenseNetStack(
93+
units=self.config.units,
94+
num_layers=6,
95+
output_transform=self.output_dense,
96+
normalization_style=self.config.normalization_style,
97+
)
98+
self.dense_attention = tf.keras.layers.Attention()
8799

88100
def compute_output_shape(self, input_shapes: List[tf.TensorShape]) -> Any:
89101
nodes_shape: tf.TensorShape = input_shapes[0]
@@ -170,42 +182,52 @@ def call(self, features: MathyInputsType) -> tf.Tensor:
170182
# Input dense transforms
171183
query = self.in_dense(query)
172184

173-
with tf.name_scope("prepare_initial_states"):
174-
in_time_h = in_rnn_state_h[-1:, :]
175-
in_time_c = in_rnn_state_c[-1:, :]
176-
time_initial_h = tf.tile(
177-
in_time_h, [sequence_length, 1], name="time_hidden",
178-
)
179-
time_initial_c = tf.tile(in_time_c, [sequence_length, 1], name="time_cell",)
185+
if self.config.use_lstm:
186+
with tf.name_scope("prepare_initial_states"):
187+
in_time_h = in_rnn_state_h[-1:, :]
188+
in_time_c = in_rnn_state_c[-1:, :]
189+
time_initial_h = tf.tile(
190+
in_time_h, [sequence_length, 1], name="time_hidden",
191+
)
192+
time_initial_c = tf.tile(
193+
in_time_c, [sequence_length, 1], name="time_cell",
194+
)
180195

181-
with tf.name_scope("rnn"):
182-
query = self.nodes_lstm(query)
183-
query = self.nodes_lstm_norm(query)
184-
query, state_h, state_c = self.time_lstm(
185-
query, initial_state=[time_initial_h, time_initial_c]
186-
)
187-
query = self.time_lstm_norm(query)
188-
# historical_state_h = tf.squeeze(
189-
# tf.concat(in_rnn_history_h[0], axis=0, name="average_history_hidden"),
190-
# axis=1,
191-
# )
192-
193-
self.state_h.assign(state_h[-1:])
194-
self.state_c.assign(state_c[-1:])
195-
196-
# Concatenate the RNN output state with our historical RNN state
197-
# See: https://arxiv.org/pdf/1810.04437.pdf
198-
with tf.name_scope("combine_outputs"):
199-
rnn_state_with_history = tf.concat(
200-
[state_h[-1:], in_rnn_history_h[-1:]], axis=-1,
201-
)
202-
self.state_h_with_history.assign(rnn_state_with_history)
203-
lstm_context = tf.tile(
204-
tf.expand_dims(rnn_state_with_history, axis=0),
205-
[batch_size, sequence_length, 1],
206-
)
207-
# Combine the output LSTM states with the historical average LSTM states
208-
# and concatenate it with the query
209-
output = tf.concat([query, lstm_context], axis=-1, name="combined_outputs")
196+
with tf.name_scope("rnn"):
197+
query = self.nodes_lstm(query)
198+
query = self.nodes_lstm_norm(query)
199+
query, state_h, state_c = self.time_lstm(
200+
query, initial_state=[time_initial_h, time_initial_c]
201+
)
202+
query = self.time_lstm_norm(query)
203+
# historical_state_h = tf.squeeze(
204+
# tf.concat(in_rnn_history_h[0], axis=0, name="average_history_hidden"),
205+
# axis=1,
206+
# )
207+
208+
self.state_h.assign(state_h[-1:])
209+
self.state_c.assign(state_c[-1:])
210+
211+
# Concatenate the RNN output state with our historical RNN state
212+
# See: https://arxiv.org/pdf/1810.04437.pdf
213+
with tf.name_scope("combine_outputs"):
214+
rnn_state_with_history = tf.concat(
215+
[state_h[-1:], in_rnn_history_h[-1:]], axis=-1,
216+
)
217+
self.state_h_with_history.assign(rnn_state_with_history)
218+
lstm_context = tf.tile(
219+
tf.expand_dims(rnn_state_with_history, axis=0),
220+
[batch_size, sequence_length, 1],
221+
)
222+
# Combine the output LSTM states with the historical average LSTM states
223+
# and concatenate it with the query
224+
output = tf.concat(
225+
[query, lstm_context], axis=-1, name="combined_outputs"
226+
)
227+
output = self.output_dense(output)
228+
else:
229+
# Non-recurrent model
230+
output = self.densenet(query)
231+
output = self.dense_attention([output, output])
210232

211-
return self.out_dense_norm(self.output_dense(output))
233+
return self.out_dense_norm(output)

libraries/mathy_python/mathy/agents/mcts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ class MCTS:
1818
env: MathyEnv
1919
# cpuct is a hyperparameter controlling the degree of exploration
2020
# (1.0 in Suragnair experiments.)
21-
cpuct: int
21+
cpuct: float
2222
num_mcts_sims: int
2323
# Set epsilon = 0 to disable dirichlet noise in root node.
2424
# e.g. for ExaminationRunner competitions
25-
epsilon: int
25+
epsilon: float
2626
dir_alpha: float
2727

2828
def __init__(
@@ -128,7 +128,7 @@ def search(self, env_state: MathyEnvState, rnn_state: List[Any], isRootNode=Fals
128128
rnn_history_h=rnn_state[0],
129129
)
130130
observations = observations_to_window([obs]).to_inputs()
131-
out_policy, state_v = self.model.predict_next(observations)
131+
out_policy, state_v = self.model.predict_next(observations, use_graph=False)
132132
out_rnn_state = [
133133
tf.squeeze(self.model.embedding.state_h).numpy(),
134134
tf.squeeze(self.model.embedding.state_c).numpy(),

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,7 @@ def call(
9292
mask_logits = self.apply_pi_mask(logits, features_window)
9393
mask_result = logits if not apply_mask else mask_logits
9494
if call_print is True:
95-
print(
96-
"call took : {0:03f} for batch of {1:03}".format(
97-
time.time() - start, batch_size
98-
)
99-
)
95+
print("call took : {0:03f}".format(time.time() - start))
10096
return logits, values, mask_result
10197

10298
def apply_pi_mask(
@@ -126,10 +122,15 @@ def call_graph(
126122
"""Autograph optimized function"""
127123
return self.call(inputs)
128124

129-
def predict_next(self, inputs: MathyInputsType) -> Tuple[tf.Tensor, tf.Tensor]:
125+
def predict_next(
126+
self, inputs: MathyInputsType, use_graph=False
127+
) -> Tuple[tf.Tensor, tf.Tensor]:
130128
"""Predict one probability distribution and value for the
131-
given sequence of inputs"""
132-
logits, values, masked = self.call(inputs)
129+
given sequence of inputs """
130+
if use_graph:
131+
logits, values, masked = self.call_graph(inputs)
132+
else:
133+
logits, values, masked = self.call(inputs)
133134
# take the last timestep
134135
masked = masked[-1][:]
135136
flat_logits = tf.reshape(tf.squeeze(masked), [-1])

libraries/mathy_python/mathy/agents/zero/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ class SelfPlayConfig(BaseConfig):
88
temperature_threshold: float = 0.5
99
self_play_problems: int = 64
1010
training_iterations: int = 100
11-
cpuct: float = 1.0
12-
13-
11+
cpuct: float = 1.0
12+
normalization_style: str = "batch"
1413
# When profile is true and workers == 1 the main thread will output worker_0.profile on exit
1514
profile: bool = False

0 commit comments

Comments
 (0)