|
9 | 9 | MathyWindowObservation, |
10 | 10 | ObservationFeatureIndices, |
11 | 11 | ) |
| 12 | +from mathy.agents.densenet import DenseNetStack |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class MathyEmbedding(tf.keras.Model): |
@@ -65,25 +66,36 @@ def __init__( |
65 | 66 | activation="relu", |
66 | 67 | kernel_initializer="he_normal", |
67 | 68 | ) |
68 | | - self.out_dense_norm = tf.keras.layers.LayerNormalization(name="out_dense_norm") |
69 | | - self.time_lstm_norm = tf.keras.layers.LayerNormalization(name="lstm_time_norm") |
70 | | - self.nodes_lstm_norm = tf.keras.layers.LayerNormalization( |
71 | | - name="lstm_nodes_norm" |
72 | | - ) |
73 | | - self.time_lstm = tf.keras.layers.LSTM( |
74 | | - self.config.lstm_units, |
75 | | - name="timestep_lstm", |
76 | | - return_sequences=True, |
77 | | - time_major=True, |
78 | | - return_state=True, |
79 | | - ) |
80 | | - self.nodes_lstm = tf.keras.layers.LSTM( |
81 | | - self.config.lstm_units, |
82 | | - name="nodes_lstm", |
83 | | - return_sequences=True, |
84 | | - time_major=False, |
85 | | - return_state=False, |
86 | | - ) |
| 69 | + |
| 70 | + NormalizeClass = tf.keras.layers.LayerNormalization |
| 71 | + if self.config.normalization_style == "batch": |
| 72 | + NormalizeClass = tf.keras.layers.BatchNormalization |
| 73 | + self.out_dense_norm = NormalizeClass(name="out_dense_norm") |
| 74 | + if self.config.use_lstm: |
| 75 | + self.time_lstm_norm = NormalizeClass(name="lstm_time_norm") |
| 76 | + self.nodes_lstm_norm = NormalizeClass(name="lstm_nodes_norm") |
| 77 | + self.time_lstm = tf.keras.layers.LSTM( |
| 78 | + self.config.lstm_units, |
| 79 | + name="timestep_lstm", |
| 80 | + return_sequences=True, |
| 81 | + time_major=True, |
| 82 | + return_state=True, |
| 83 | + ) |
| 84 | + self.nodes_lstm = tf.keras.layers.LSTM( |
| 85 | + self.config.lstm_units, |
| 86 | + name="nodes_lstm", |
| 87 | + return_sequences=True, |
| 88 | + time_major=False, |
| 89 | + return_state=False, |
| 90 | + ) |
| 91 | + else: |
| 92 | + self.densenet = DenseNetStack( |
| 93 | + units=self.config.units, |
| 94 | + num_layers=6, |
| 95 | + output_transform=self.output_dense, |
| 96 | + normalization_style=self.config.normalization_style, |
| 97 | + ) |
| 98 | + self.dense_attention = tf.keras.layers.Attention() |
87 | 99 |
|
88 | 100 | def compute_output_shape(self, input_shapes: List[tf.TensorShape]) -> Any: |
89 | 101 | nodes_shape: tf.TensorShape = input_shapes[0] |
@@ -170,42 +182,52 @@ def call(self, features: MathyInputsType) -> tf.Tensor: |
170 | 182 | # Input dense transforms |
171 | 183 | query = self.in_dense(query) |
172 | 184 |
|
173 | | - with tf.name_scope("prepare_initial_states"): |
174 | | - in_time_h = in_rnn_state_h[-1:, :] |
175 | | - in_time_c = in_rnn_state_c[-1:, :] |
176 | | - time_initial_h = tf.tile( |
177 | | - in_time_h, [sequence_length, 1], name="time_hidden", |
178 | | - ) |
179 | | - time_initial_c = tf.tile(in_time_c, [sequence_length, 1], name="time_cell",) |
| 185 | + if self.config.use_lstm: |
| 186 | + with tf.name_scope("prepare_initial_states"): |
| 187 | + in_time_h = in_rnn_state_h[-1:, :] |
| 188 | + in_time_c = in_rnn_state_c[-1:, :] |
| 189 | + time_initial_h = tf.tile( |
| 190 | + in_time_h, [sequence_length, 1], name="time_hidden", |
| 191 | + ) |
| 192 | + time_initial_c = tf.tile( |
| 193 | + in_time_c, [sequence_length, 1], name="time_cell", |
| 194 | + ) |
180 | 195 |
|
181 | | - with tf.name_scope("rnn"): |
182 | | - query = self.nodes_lstm(query) |
183 | | - query = self.nodes_lstm_norm(query) |
184 | | - query, state_h, state_c = self.time_lstm( |
185 | | - query, initial_state=[time_initial_h, time_initial_c] |
186 | | - ) |
187 | | - query = self.time_lstm_norm(query) |
188 | | - # historical_state_h = tf.squeeze( |
189 | | - # tf.concat(in_rnn_history_h[0], axis=0, name="average_history_hidden"), |
190 | | - # axis=1, |
191 | | - # ) |
192 | | - |
193 | | - self.state_h.assign(state_h[-1:]) |
194 | | - self.state_c.assign(state_c[-1:]) |
195 | | - |
196 | | - # Concatenate the RNN output state with our historical RNN state |
197 | | - # See: https://arxiv.org/pdf/1810.04437.pdf |
198 | | - with tf.name_scope("combine_outputs"): |
199 | | - rnn_state_with_history = tf.concat( |
200 | | - [state_h[-1:], in_rnn_history_h[-1:]], axis=-1, |
201 | | - ) |
202 | | - self.state_h_with_history.assign(rnn_state_with_history) |
203 | | - lstm_context = tf.tile( |
204 | | - tf.expand_dims(rnn_state_with_history, axis=0), |
205 | | - [batch_size, sequence_length, 1], |
206 | | - ) |
207 | | - # Combine the output LSTM states with the historical average LSTM states |
208 | | - # and concatenate it with the query |
209 | | - output = tf.concat([query, lstm_context], axis=-1, name="combined_outputs") |
| 196 | + with tf.name_scope("rnn"): |
| 197 | + query = self.nodes_lstm(query) |
| 198 | + query = self.nodes_lstm_norm(query) |
| 199 | + query, state_h, state_c = self.time_lstm( |
| 200 | + query, initial_state=[time_initial_h, time_initial_c] |
| 201 | + ) |
| 202 | + query = self.time_lstm_norm(query) |
| 203 | + # historical_state_h = tf.squeeze( |
| 204 | + # tf.concat(in_rnn_history_h[0], axis=0, name="average_history_hidden"), |
| 205 | + # axis=1, |
| 206 | + # ) |
| 207 | + |
| 208 | + self.state_h.assign(state_h[-1:]) |
| 209 | + self.state_c.assign(state_c[-1:]) |
| 210 | + |
| 211 | + # Concatenate the RNN output state with our historical RNN state |
| 212 | + # See: https://arxiv.org/pdf/1810.04437.pdf |
| 213 | + with tf.name_scope("combine_outputs"): |
| 214 | + rnn_state_with_history = tf.concat( |
| 215 | + [state_h[-1:], in_rnn_history_h[-1:]], axis=-1, |
| 216 | + ) |
| 217 | + self.state_h_with_history.assign(rnn_state_with_history) |
| 218 | + lstm_context = tf.tile( |
| 219 | + tf.expand_dims(rnn_state_with_history, axis=0), |
| 220 | + [batch_size, sequence_length, 1], |
| 221 | + ) |
| 222 | + # Combine the output LSTM states with the historical average LSTM states |
| 223 | + # and concatenate it with the query |
| 224 | + output = tf.concat( |
| 225 | + [query, lstm_context], axis=-1, name="combined_outputs" |
| 226 | + ) |
| 227 | + output = self.output_dense(output) |
| 228 | + else: |
| 229 | + # Non-recurrent model |
| 230 | + output = self.densenet(query) |
| 231 | + output = self.dense_attention([output, output]) |
210 | 232 |
|
211 | | - return self.out_dense_norm(self.output_dense(output)) |
| 233 | + return self.out_dense_norm(output) |
0 commit comments