Skip to content

Commit 11095ab

Browse files
refactor(model): remove episode long RNN state tracking
R2D2 and friends use this to great effect, but either I don't train long enough or I messed up the implementation. Remove it for now to simplify model usage, and let LSTM only operate on the current batch of observations (which are a rolling window of timesteps) BREAKING CHANGE: this removes long-term RNN state tracking across episodes. Tracking the state was a significant amount of code and it wasn't clear that it made the model substantially better at any given task. The overhead associated with keeping lots of hidden states in memory and calculating state histories was not insignificant on CPU training setups as well.
1 parent 28afa2c commit 11095ab

33 files changed

+63
-642
lines changed

libraries/mathy_python/mathy/agents/a3c/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, args: A3CConfig, env_extra: dict = None):
3737
self.action_size = env.action_space.n
3838
self.log_dir = os.path.join(self.args.model_dir, "tensorboard")
3939
self.writer = tf.summary.create_file_writer(self.log_dir)
40-
init_window = env.initial_window(self.args.lstm_units)
40+
init_window = env.initial_window()
4141
self.global_model = get_or_create_policy_model(
4242
args=args, predictions=self.action_size, is_main=True, env=env.mathy
4343
)

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -221,53 +221,16 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
221221
ep_steps = 1
222222
time_count = 0
223223
done = False
224-
last_observation: MathyObservation = env.reset(rnn_size=self.args.lstm_units)
224+
last_observation: MathyObservation = env.reset()
225225
last_text = env.state.agent.problem
226226
last_action = -1
227227
last_reward = -1
228228

229229
selector = self.build_episode_selector(env)
230230

231-
# Set RNN to 0 state for start of episode
232-
selector.model.embedding.reset_rnn_state()
233-
234-
# Start with the "init" sequence [n] times
235-
for i in range(self.args.num_thinking_steps_begin):
236-
rnn_state_h = tf.squeeze(selector.model.embedding.state_h.numpy())
237-
rnn_state_c = tf.squeeze(selector.model.embedding.state_c.numpy())
238-
seq_start = env.state.to_start_observation(rnn_state_h, rnn_state_c)
239-
try:
240-
window = observations_to_window([seq_start, last_observation])
241-
selector.model.call(window.to_inputs())
242-
except BaseException as err:
243-
print_error(
244-
err, f"Episode begin thinking steps prediction failed.",
245-
)
246-
continue
247-
248231
while not done and A3CWorker.request_quit is False:
249232
if self.args.print_training and self.worker_idx == 0:
250233
env.render(self.args.print_mode, None)
251-
# store rnn state for replay training
252-
rnn_state_h = tf.squeeze(selector.model.embedding.state_h.numpy())
253-
rnn_state_c = tf.squeeze(selector.model.embedding.state_c.numpy())
254-
rnn_history_h = episode_memory.rnn_weighted_history(self.args.lstm_units)[0]
255-
last_rnn_state = [rnn_state_h, rnn_state_c]
256-
257-
# named tuples are read-only, so add rnn state to a new copy
258-
last_observation = MathyObservation(
259-
nodes=last_observation.nodes,
260-
mask=last_observation.mask,
261-
values=last_observation.values,
262-
type=last_observation.type,
263-
time=last_observation.time,
264-
rnn_state_h=tf.squeeze(rnn_state_h),
265-
rnn_state_c=tf.squeeze(rnn_state_c),
266-
rnn_history_h=rnn_history_h,
267-
)
268-
# before_rnn_state_h = selector.model.embedding.state_h.numpy()
269-
# before_rnn_state_c = selector.model.embedding.state_c.numpy()
270-
271234
window = episode_memory.to_window_observation(
272235
last_observation, window_size=self.args.prediction_window_size
273236
)
@@ -277,33 +240,20 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
277240
last_window=window,
278241
last_action=last_action,
279242
last_reward=last_reward,
280-
last_rnn_state=last_rnn_state,
281243
)
282244
except BaseException as err:
283245
print_error(err, "failed to select an action during an episode step")
284246
continue
285247

286248
# Take an env step
287249
observation, reward, done, _ = env.step(action)
288-
rnn_state_h = tf.squeeze(selector.model.embedding.state_h.numpy())
289-
rnn_state_c = tf.squeeze(selector.model.embedding.state_c.numpy())
290-
291-
# TODO: make this a unit test, check that EpisodeMemory states are not equal
292-
# across time steps.
293-
# compare_states_h = tf.math.equal(before_rnn_state_h,rnn_state_h)
294-
# compare_states_c = tf.math.equal(before_rnn_state_h,rnn_state_h)
295-
# assert before_rnn_state_h != rnn_state_h
296-
# assert before_rnn_state_c != rnn_state_c
297250

298251
observation = MathyObservation(
299252
nodes=observation.nodes,
300253
mask=observation.mask,
301254
values=observation.values,
302255
type=observation.type,
303256
time=observation.time,
304-
rnn_state_h=rnn_state_h,
305-
rnn_state_c=rnn_state_c,
306-
rnn_history_h=rnn_history_h,
307257
)
308258

309259
new_text = env.state.agent.problem
@@ -315,38 +265,12 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
315265
observation=last_observation,
316266
action=action,
317267
reward=reward,
318-
grouping_change=grouping_change,
319268
value=value,
320269
)
321270
if time_count == self.args.update_gradients_every or done:
322271
if done and self.args.print_training and self.worker_idx == 0:
323272
env.render(self.args.print_mode, None)
324273

325-
# TODO: Make this a unit test?
326-
# Check that the LSTM h/c states changed over time in the episode.
327-
#
328-
# NOTE: in practice it seems every once in a while the state doesn't
329-
# change, and I suppose this makes sense if the LSTM thought the
330-
# existing state was... fine?
331-
#
332-
# check_rnn = None
333-
# for obs in episode_memory.observations:
334-
# if check_rnn is not None:
335-
# h_equal_indices = (
336-
# tf.squeeze(tf.math.equal(obs.rnn_state_h, check_rnn[0]))
337-
# .numpy()
338-
# .tolist()
339-
# )
340-
# c_equal_indices = (
341-
# tf.squeeze(tf.math.equal(obs.rnn_state_c, check_rnn[1]))
342-
# .numpy()
343-
# .tolist()
344-
# )
345-
# assert False in h_equal_indices
346-
# assert False in c_equal_indices
347-
348-
# check_rnn = [obs.rnn_state_h, obs.rnn_state_c]
349-
350274
self.update_global_network(done, observation, episode_memory)
351275
self.maybe_write_histograms()
352276
time_count = 0
@@ -406,17 +330,6 @@ def maybe_write_histograms(self) -> None:
406330
tf.summary.histogram(
407331
var.name, var, step=self.global_model.optimizer.iterations,
408332
)
409-
# Write out current LSTM hidden/cell states
410-
tf.summary.histogram(
411-
"memory/lstm_c",
412-
self.local_model.embedding.state_c,
413-
step=self.global_model.optimizer.iterations,
414-
)
415-
tf.summary.histogram(
416-
"memory/lstm_h",
417-
self.local_model.embedding.state_h,
418-
step=self.global_model.optimizer.iterations,
419-
)
420333

421334
def update_global_network(
422335
self, done: bool, observation: MathyObservation, episode_memory: EpisodeMemory,

libraries/mathy_python/mathy/agents/action_selectors.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def select(
2626
last_window: MathyWindowObservation,
2727
last_action: int,
2828
last_reward: float,
29-
last_rnn_state: List[float],
3029
) -> Tuple[int, float]:
3130
raise NotImplementedError(self.select)
3231

@@ -46,7 +45,6 @@ def select(
4645
last_window: MathyWindowObservation,
4746
last_action: int,
4847
last_reward: float,
49-
last_rnn_state: List[float],
5048
) -> Tuple[int, float]:
5149

5250
probs, value = self.model.predict_next(last_window.to_inputs())
@@ -85,7 +83,6 @@ def select(
8583
last_window: MathyWindowObservation,
8684
last_action: int,
8785
last_reward: float,
88-
last_rnn_state: List[float],
8986
) -> Tuple[int, float]:
90-
probs, value = self.mcts.estimate_policy(last_state, last_rnn_state)
87+
probs, value = self.mcts.estimate_policy(last_state)
9188
return np.argmax(probs), value

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,9 @@
1313

1414

1515
class MathyEmbedding(tf.keras.Model):
16-
def __init__(
17-
self,
18-
config: BaseConfig,
19-
episode_reset_state_h: Optional[bool] = True,
20-
episode_reset_state_c: Optional[bool] = True,
21-
**kwargs,
22-
):
16+
def __init__(self, config: BaseConfig, **kwargs):
2317
super(MathyEmbedding, self).__init__(**kwargs)
2418
self.config = config
25-
self.episode_reset_state_h = episode_reset_state_h
26-
self.episode_reset_state_c = episode_reset_state_c
27-
self.init_rnn_state()
2819
self.token_embedding = tf.keras.layers.Embedding(
2920
input_dim=MathTypeKeysMax,
3021
output_dim=self.config.embedding_units,
@@ -69,24 +60,15 @@ def __init__(
6960
if self.config.normalization_style == "batch":
7061
NormalizeClass = tf.keras.layers.BatchNormalization
7162
self.out_dense_norm = NormalizeClass(name="out_dense_norm")
72-
if self.config.use_lstm:
73-
self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm")
74-
self.lstm_nodes = tf.keras.layers.LSTM(
75-
self.config.lstm_units,
76-
name="nodes_lstm",
77-
time_major=False,
78-
return_sequences=True,
79-
return_state=True,
80-
)
81-
self.lstm_attention = tf.keras.layers.Attention()
82-
else:
83-
self.densenet = DenseNetStack(
84-
units=self.config.units,
85-
num_layers=6,
86-
output_transform=self.output_dense,
87-
normalization_style=self.config.normalization_style,
88-
)
89-
self.dense_attention = tf.keras.layers.Attention()
63+
self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm")
64+
self.lstm_nodes = tf.keras.layers.LSTM(
65+
self.config.lstm_units,
66+
name="nodes_lstm",
67+
time_major=False,
68+
return_sequences=True,
69+
return_state=True,
70+
)
71+
self.lstm_attention = tf.keras.layers.Attention()
9072

9173
def compute_output_shape(self, input_shapes: List[tf.TensorShape]) -> Any:
9274
nodes_shape: tf.TensorShape = input_shapes[0]
@@ -98,27 +80,6 @@ def compute_output_shape(self, input_shapes: List[tf.TensorShape]) -> Any:
9880
)
9981
)
10082

101-
def init_rnn_state(self):
102-
"""Track RNN states with variables in the graph"""
103-
self.state_c = tf.Variable(
104-
tf.zeros([1, self.config.lstm_units]),
105-
trainable=False,
106-
name="embedding/rnn/agent_state_c",
107-
)
108-
self.state_h = tf.Variable(
109-
tf.zeros([1, self.config.lstm_units]),
110-
trainable=False,
111-
name="embedding/rnn/agent_state_h",
112-
)
113-
114-
def reset_rnn_state(self, force: bool = False):
115-
"""Zero out the RNN state for a new episode"""
116-
if self.episode_reset_state_h or force is True:
117-
self.state_h.assign(tf.zeros([1, self.config.lstm_units]))
118-
119-
if self.episode_reset_state_c or force is True:
120-
self.state_c.assign(tf.zeros([1, self.config.lstm_units]))
121-
12283
def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
12384
nodes = features[ObservationFeatureIndices.nodes]
12485
values = features[ObservationFeatureIndices.values]
@@ -128,10 +89,6 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
12889
batch_size = nodes_shape[0] # noqa
12990
sequence_length = nodes_shape[1]
13091

131-
in_rnn_state_h = features[ObservationFeatureIndices.rnn_state_h]
132-
in_rnn_state_c = features[ObservationFeatureIndices.rnn_state_c]
133-
in_rnn_history_h = features[ObservationFeatureIndices.rnn_history_h]
134-
13592
with tf.name_scope("prepare_inputs"):
13693
values = tf.expand_dims(values, axis=-1, name="values_input")
13794
query = self.token_embedding(nodes)
@@ -159,8 +116,6 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
159116
output, state_h, state_c = self.lstm_nodes(query)
160117
output = self.lstm_attention([output, state_h])
161118
output = self.nodes_lstm_norm(output)
162-
self.state_h.assign(state_h[-1:])
163-
self.state_c.assign(state_c[-1:])
164119
else:
165120
# Non-recurrent model
166121
output = self.densenet(query)

libraries/mathy_python/mathy/agents/episode_memory.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,9 @@
77
MathyObservation,
88
MathyWindowObservation,
99
observations_to_window,
10-
rnn_placeholder_state,
1110
)
1211

1312

14-
def rnn_weighted_history(
15-
observations: List[MathyObservation], rnn_size: int = 128,
16-
):
17-
"""Build a historical LSTM state: https://arxiv.org/pdf/1810.04437.pdf
18-
19-
Take the mean of the previous LSTM states for this episode. """
20-
21-
import tensorflow as tf
22-
23-
if len(observations) > 0:
24-
in_h = []
25-
in_c = []
26-
for obs in observations:
27-
in_h.append(obs.rnn_state_h)
28-
in_c.append(obs.rnn_state_c)
29-
# Take the mean of the historical states:
30-
memory_context_h = tf.reduce_mean(in_h, axis=0)
31-
memory_context_c = tf.reduce_mean(in_c, axis=0)
32-
else:
33-
memory_context_h = rnn_placeholder_state(rnn_size)
34-
memory_context_c = memory_context_h
35-
return [memory_context_h, memory_context_c]
36-
37-
3813
class EpisodeMemory(object):
3914
# Observation from the environment
4015
observations: List[MathyObservation]
@@ -44,8 +19,6 @@ class EpisodeMemory(object):
4419
rewards: List[float]
4520
# Estimated value from the model
4621
values: List[float]
47-
# Grouping Control error from the environment
48-
grouping_changes: List[float]
4922

5023
def __init__(self):
5124
self.clear()
@@ -55,7 +28,6 @@ def clear(self):
5528
self.actions = []
5629
self.rewards = []
5730
self.values = []
58-
self.grouping_changes = []
5931

6032
def to_window_observation(
6133
self, observation: MathyObservation, window_size: int = 3
@@ -74,14 +46,9 @@ def store(
7446
observation: MathyObservation,
7547
action: int,
7648
reward: float,
77-
grouping_change: float,
7849
value: float,
7950
):
8051
self.observations.append(observation)
8152
self.actions.append(action)
8253
self.rewards.append(reward)
8354
self.values.append(value)
84-
self.grouping_changes.append(grouping_change)
85-
86-
def rnn_weighted_history(self, rnn_size):
87-
return rnn_weighted_history(self.observations, rnn_size)

libraries/mathy_python/mathy/agents/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def example() -> MathyWindowObservation:
1010
passed forward through a Mathy model. """
1111
env = PolySimplify()
1212
state = env.get_initial_state()[0]
13-
observation = env.state_to_observation(state, rnn_size=BaseConfig().lstm_units)
13+
observation = env.state_to_observation(state)
1414
return observations_to_window([observation])

0 commit comments

Comments
 (0)