Skip to content

Commit 0241070

Browse files
fix(model): remove second LSTM from recurent model
- previously I thought that having an LSTM to process the nodes in a sequence, and another to process the timesteps in a sequence was a good idea. It may still be a good idea, but I had to do weird RNN state tiling to make the initial_state shapes work out. - remove time LSTM and use only one LSTM with no tiled initial_state. This appears to train atleast as well on poly-easy
1 parent 02b11ee commit 0241070

File tree

2 files changed

+11
-35
lines changed

2 files changed

+11
-35
lines changed

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import gym
1111
import numpy as np
1212
import tensorflow as tf
13-
from memory_profiler import profile
1413
from wasabi import msg
1514

1615
from ...envs.gym.mathy_gym_env import MathyGymEnv
@@ -82,7 +81,10 @@ def __init__(
8281
self.reset_episode_loss()
8382
self.last_model_write = -1
8483
self.last_histogram_write = -1
85-
msg.good(f"Worker {worker_idx} started.")
84+
display_e = self.greedy_epsilon
85+
if self.worker_idx == 0 and self.args.main_worker_use_epsilon is False:
86+
display_e = 0.0
87+
msg.good(f"Worker {worker_idx} started. (e={display_e:.3f})")
8688

8789
@property
8890
def tb_prefix(self) -> str:
@@ -526,7 +528,7 @@ def compute_policy_value_loss(
526528
else:
527529
# Predict the reward using the local network
528530
_, values, _ = self.local_model.predict(
529-
observations_to_window([observation]).to_inputs()
531+
[observations_to_window([observation]).to_inputs()]
530532
)
531533
# Select the last timestep
532534
values = values[-1]

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,13 @@ def __init__(
7272
NormalizeClass = tf.keras.layers.BatchNormalization
7373
self.out_dense_norm = NormalizeClass(name="out_dense_norm")
7474
if self.config.use_lstm:
75-
self.time_lstm_norm = NormalizeClass(name="lstm_time_norm")
7675
self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm")
77-
self.time_lstm = tf.keras.layers.LSTM(
76+
self.lstm = tf.keras.layers.LSTM(
7877
self.config.lstm_units,
79-
name="timestep_lstm",
80-
return_sequences=True,
81-
time_major=True,
82-
return_state=True,
83-
)
84-
self.nodes_lstm = tf.keras.layers.LSTM(
85-
self.config.lstm_units,
86-
name="nodes_lstm",
78+
name="lstm",
8779
return_sequences=True,
8880
time_major=False,
89-
return_state=False,
81+
return_state=True,
9082
)
9183
else:
9284
self.densenet = DenseNetStack(
@@ -143,8 +135,6 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
143135
nodes_shape = tf.shape(features[ObservationFeatureIndices.nodes])
144136
batch_size = nodes_shape[0] # noqa
145137
sequence_length = nodes_shape[1]
146-
# batch_size = features[ObservationFeatureIndices.nodes].shape[0] # noqa
147-
# sequence_length = features[ObservationFeatureIndices.nodes].shape[1] # noqa
148138

149139
in_rnn_state_h = features[ObservationFeatureIndices.rnn_state_h]
150140
in_rnn_state_c = features[ObservationFeatureIndices.rnn_state_c]
@@ -178,27 +168,11 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
178168
query = self.in_dense(query)
179169

180170
if self.config.use_lstm:
181-
with tf.name_scope("prepare_initial_states"):
182-
in_time_h = in_rnn_state_h[-1:, :]
183-
in_time_c = in_rnn_state_c[-1:, :]
184-
time_initial_h = tf.tile(
185-
in_time_h, [sequence_length, 1], name="time_hidden",
186-
)
187-
time_initial_c = tf.tile(
188-
in_time_c, [sequence_length, 1], name="time_cell",
189-
)
190-
191171
with tf.name_scope("rnn"):
192-
query = self.nodes_lstm(query)
193-
query = self.nodes_lstm_norm(query)
194-
query, state_h, state_c = self.time_lstm(
195-
query, initial_state=[time_initial_h, time_initial_c]
172+
query, state_h, state_c = self.lstm(
173+
query, initial_state=[in_rnn_state_h, in_rnn_state_c]
196174
)
197-
query = self.time_lstm_norm(query)
198-
# historical_state_h = tf.squeeze(
199-
# tf.concat(in_rnn_history_h[0], axis=0, name="average_history_hidden"),
200-
# axis=1,
201-
# )
175+
query = self.nodes_lstm_norm(query)
202176

203177
self.state_h.assign(state_h[-1:])
204178
self.state_c.assign(state_c[-1:])

0 commit comments

Comments
 (0)