Skip to content

Commit b750bfc

Browse files
feat(a3c): add self-attention over sequences
- common math considerations require knowledge of terms w.r.t other terms, and self-attention provides that. - we can also inspect it after training to determine if the attention is finding useful things - add dropout to LSTM layers
1 parent 8f2008a commit b750bfc

File tree

4 files changed

+21
-39
lines changed

4 files changed

+21
-39
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""TensorFlow compatible import of SeqSelfAttention"""
2+
import os
3+
4+
# required to tell library to use TF backend
5+
os.environ["TF_KERAS"] = "1"
6+
from keras_self_attention import SeqSelfAttention

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class Config:
2424
# (n - 1) previous timesteps
2525
prediction_window_size: int = 16
2626

27+
# Dropout to apply to LSTMs
28+
dropout: float = 0.2
29+
2730
# Whether to use the LSTM or non-reccurrent architecture
2831
use_lstm: bool = True
2932
units: int = 64
@@ -38,7 +41,7 @@ class Config:
3841
verbose: bool = False
3942
# Initial learning rate that decays over time.
4043
lr_initial: float = 0.01
41-
lr_decay_steps: int = 1000
44+
lr_decay_steps: int = 100
4245
lr_decay_rate: float = 0.96
4346
lr_decay_staircase: bool = True
4447
max_eps: int = 15000

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Any, Dict, List, Optional, Tuple, Union
2-
32
import tensorflow as tf
43

54
from mathy.agents.base_config import BaseConfig
@@ -11,41 +10,7 @@
1110
)
1211
from mathy.agents.densenet import DenseNetStack
1312

14-
15-
class BahdanauAttention(tf.keras.layers.Layer):
16-
"""Bahdanau Attention from:
17-
https://www.tensorflow.org/tutorials/text/nmt_with_attention
18-
19-
Used rather than the built-in tf.keras Attention because we want
20-
to get the weights for visualization.
21-
"""
22-
23-
def __init__(self, units):
24-
super(BahdanauAttention, self).__init__()
25-
self.W1 = tf.keras.layers.Dense(units)
26-
self.W2 = tf.keras.layers.Dense(units)
27-
self.V = tf.keras.layers.Dense(1)
28-
29-
def call(self, query, values):
30-
# query hidden state shape == (batch_size, hidden size)
31-
# query_with_time_axis shape == (batch_size, 1, hidden size)
32-
# values shape == (batch_size, max_len, hidden size)
33-
# we are doing this to broadcast addition along the time axis to calculate the score
34-
query_with_time_axis = tf.expand_dims(query, 1)
35-
36-
# score shape == (batch_size, max_length, 1)
37-
# we get 1 at the last axis because we are applying score to self.V
38-
# the shape of the tensor before applying self.V is (batch_size, max_length, units)
39-
score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
40-
41-
# attention_weights shape == (batch_size, max_length, 1)
42-
attention_weights = tf.nn.softmax(score, axis=1)
43-
44-
# context_vector shape after sum == (batch_size, hidden_size)
45-
context_vector = attention_weights * values
46-
context_vector = tf.reduce_sum(context_vector, axis=1)
47-
48-
return context_vector, attention_weights
13+
from .attention import SeqSelfAttention
4914

5015

5116
class MathyEmbedding(tf.keras.Model):
@@ -95,6 +60,7 @@ def __init__(self, config: BaseConfig, **kwargs):
9560
name="nodes_lstm",
9661
time_major=False,
9762
return_sequences=True,
63+
dropout=self.config.dropout,
9864
),
9965
merge_mode="sum",
10066
)
@@ -103,8 +69,13 @@ def __init__(self, config: BaseConfig, **kwargs):
10369
name="time_lstm",
10470
time_major=True,
10571
return_sequences=True,
72+
dropout=self.config.dropout,
73+
)
74+
self.lstm_attention = SeqSelfAttention(
75+
attention_activation="sigmoid",
76+
name="self_attention",
77+
return_attention=True,
10678
)
107-
self.lstm_attention = BahdanauAttention(self.config.lstm_units)
10879

10980
def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
11081
output = tf.concat(
@@ -124,5 +95,5 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
12495
output = self.time_lstm_norm(output)
12596
output = self.lstm_nodes(output)
12697
output = self.nodes_lstm_norm(output)
127-
output, attention = self.lstm_attention(output, output)
98+
output, attention = self.lstm_attention(output)
12899
return self.out_dense_norm(self.output_dense(output)), attention

libraries/mathy_python/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ wasabi
99
gym<=0.12.5
1010
tensorflow_probability
1111
typing-extensions>=3.7.4.1
12+
13+
keras-self-attention

0 commit comments

Comments
 (0)