Skip to content

Commit a8f0d54

Browse files
feat(embedding): use LSTMs for batch and time axes
- refactor inputs to have a dense transform for the node values coming in - the time LSTM contextualizes the episode timesteps in the batch - the nodes LSTM contextualizes each node of the timesteps - the combination works well for multitask training, and exceeds previous models on acc of the poly task
1 parent 3d2d78b commit a8f0d54

File tree

4 files changed

+27
-30
lines changed

4 files changed

+27
-30
lines changed

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,16 @@ def __init__(self, config: BaseConfig, **kwargs):
1919
self.token_embedding = tf.keras.layers.Embedding(
2020
input_dim=MathTypeKeysMax,
2121
output_dim=self.config.embedding_units,
22-
name="nodes_embedding",
22+
name="nodes_input",
2323
mask_zero=True,
2424
)
25-
2625
# +1 for the value
2726
# +1 for the time
2827
# +2 for the problem type hashes
29-
self.concat_size = (
30-
4
31-
if self.config.use_env_features
32-
else 1
33-
if self.config.use_node_values
34-
else 0
35-
)
36-
if self.config.use_env_features:
37-
self.time_dense = tf.keras.layers.Dense(
38-
self.config.units, name="time_input"
39-
)
40-
self.type_dense = tf.keras.layers.Dense(
41-
self.config.units, name="type_input"
42-
)
28+
self.concat_size = 4
29+
self.values_dense = tf.keras.layers.Dense(self.config.units, name="values_input")
30+
self.time_dense = tf.keras.layers.Dense(self.config.units, name="time_input")
31+
self.type_dense = tf.keras.layers.Dense(self.config.units, name="type_input")
4332
self.in_dense = tf.keras.layers.Dense(
4433
self.config.units,
4534
# In transform gets the embeddings concatenated with the
@@ -61,28 +50,39 @@ def __init__(self, config: BaseConfig, **kwargs):
6150
NormalizeClass = tf.keras.layers.BatchNormalization
6251
self.out_dense_norm = NormalizeClass(name="out_dense_norm")
6352
self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm")
53+
self.time_lstm_norm = NormalizeClass(name="time_lstm_norm")
6454
self.lstm_nodes = tf.keras.layers.LSTM(
6555
self.config.lstm_units,
6656
name="nodes_lstm",
6757
time_major=False,
6858
return_sequences=True,
6959
return_state=True,
7060
)
61+
self.lstm_time = tf.keras.layers.LSTM(
62+
self.config.lstm_units,
63+
name="time_lstm",
64+
time_major=True,
65+
return_sequences=True,
66+
)
7167
self.lstm_attention = tf.keras.layers.Attention()
7268

7369
def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
7470
output = tf.concat(
7571
[
7672
self.token_embedding(features[ObservationFeatureIndices.nodes]),
77-
tf.expand_dims(features[ObservationFeatureIndices.values], axis=-1),
73+
self.values_dense(
74+
tf.expand_dims(features[ObservationFeatureIndices.values], axis=-1)
75+
),
7876
self.type_dense(features[ObservationFeatureIndices.type]),
7977
self.time_dense(features[ObservationFeatureIndices.time]),
8078
],
8179
axis=-1,
8280
name="input_vectors",
8381
)
8482
output = self.in_dense(output)
83+
output = self.lstm_time(output)
84+
output = self.time_lstm_norm(output)
8585
output, state_h, state_c = self.lstm_nodes(output)
86-
output = self.lstm_attention([output, state_h])
8786
output = self.nodes_lstm_norm(output)
87+
output = self.lstm_attention([output, state_h])
8888
return self.out_dense_norm(self.output_dense(output))

libraries/mathy_python/mathy/agents/episode_memory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def clear(self):
2929
self.rewards = []
3030
self.values = []
3131

32+
def clear_except_window(self, window_size: int):
33+
"""Clear all data except for a window's worth of elements"""
34+
self.observations = self.observations[-window_size:]
35+
self.actions = self.actions[-window_size:]
36+
self.rewards = self.rewards[-window_size:]
37+
self.values = self.values[-window_size:]
38+
3239
def to_window_observation(
3340
self, observation: MathyObservation, window_size: int = 3
3441
) -> MathyWindowObservation:

libraries/mathy_python/mathy/envs/complex_simplify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_env_namespace(self) -> str:
2626
def max_moves_fn(
2727
self, problem: MathyEnvProblem, config: MathyEnvProblemArgs
2828
) -> int:
29-
return problem.complexity * 5
29+
return problem.complexity * 8
3030

3131
def problem_fn(self, params: MathyEnvProblemArgs) -> MathyEnvProblem:
3232
"""Given a set of parameters to control term generation, produce

libraries/mathy_python/mathy/envs/poly_simplify.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,7 @@ def get_env_namespace(self) -> str:
3131
def max_moves_fn(
3232
self, problem: MathyEnvProblem, config: MathyEnvProblemArgs
3333
) -> int:
34-
if config.difficulty == MathyEnvDifficulty.easy:
35-
multiplier = 4
36-
elif problem.complexity < 5:
37-
multiplier = 2
38-
elif problem.complexity < 7:
39-
multiplier = 3
40-
elif problem.complexity < 12:
41-
multiplier = 4
42-
else:
43-
multiplier = 3
44-
return problem.complexity * multiplier
34+
return problem.complexity * 6
4535

4636
def transition_fn(
4737
self,

0 commit comments

Comments
 (0)