Skip to content

Commit ad23139

Browse files
feat(embedding): use bilstm for node sequences
1 parent f32600e commit ad23139

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ def __init__(self, config: BaseConfig, **kwargs):
8989
self.out_dense_norm = NormalizeClass(name="out_dense_norm")
9090
self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm")
9191
self.time_lstm_norm = NormalizeClass(name="time_lstm_norm")
92-
self.lstm_nodes = tf.keras.layers.LSTM(
93-
self.config.lstm_units,
94-
name="nodes_lstm",
95-
time_major=False,
96-
return_sequences=True,
97-
return_state=True,
92+
self.lstm_nodes = tf.keras.layers.Bidirectional(
93+
tf.keras.layers.LSTM(
94+
self.config.lstm_units,
95+
name="nodes_lstm",
96+
time_major=False,
97+
return_sequences=True,
98+
),
99+
merge_mode="sum",
98100
)
99101
self.lstm_time = tf.keras.layers.LSTM(
100102
self.config.lstm_units,
@@ -120,7 +122,7 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
120122
output = self.in_dense(output)
121123
output = self.lstm_time(output)
122124
output = self.time_lstm_norm(output)
123-
output, state_h, state_c = self.lstm_nodes(output)
125+
output = self.lstm_nodes(output)
124126
output = self.nodes_lstm_norm(output)
125127
output, attention = self.lstm_attention(output, output)
126128
return self.out_dense_norm(self.output_dense(output)), attention

0 commit comments

Comments
 (0)