Skip to content

Commit daba776

Browse files
feat(a3c): add bahdanau attention layer
- the keras layer won't give you its weights (greedy!) - from: https://www.tensorflow.org/tutorials/text/nmt_with_attention
1 parent f5740ad commit daba776

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class Config:
3636
init_model_from: Optional[str] = None
3737
train: bool = False
3838
verbose: bool = False
39-
lr: float = 2e-4
39+
# Initial learning rate that decays over time.
40+
lr: float = 0.01
4041
max_eps: int = 15000
4142
# How often to write histograms to tensorboard (in training steps)
4243
summary_interval: int = 100

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,42 @@
1212
from mathy.agents.densenet import DenseNetStack
1313

1414

15+
class BahdanauAttention(tf.keras.layers.Layer):
16+
"""Bahdanau Attention from:
17+
https://www.tensorflow.org/tutorials/text/nmt_with_attention
18+
19+
Used rather than the built-in tf.keras Attention because we want
20+
to get the weights for visualization.
21+
"""
22+
23+
def __init__(self, units):
24+
super(BahdanauAttention, self).__init__()
25+
self.W1 = tf.keras.layers.Dense(units)
26+
self.W2 = tf.keras.layers.Dense(units)
27+
self.V = tf.keras.layers.Dense(1)
28+
29+
def call(self, query, values):
30+
# query hidden state shape == (batch_size, hidden size)
31+
# query_with_time_axis shape == (batch_size, 1, hidden size)
32+
# values shape == (batch_size, max_len, hidden size)
33+
# we are doing this to broadcast addition along the time axis to calculate the score
34+
query_with_time_axis = tf.expand_dims(query, 1)
35+
36+
# score shape == (batch_size, max_length, 1)
37+
# we get 1 at the last axis because we are applying score to self.V
38+
# the shape of the tensor before applying self.V is (batch_size, max_length, units)
39+
score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
40+
41+
# attention_weights shape == (batch_size, max_length, 1)
42+
attention_weights = tf.nn.softmax(score, axis=1)
43+
44+
# context_vector shape after sum == (batch_size, hidden_size)
45+
context_vector = attention_weights * values
46+
context_vector = tf.reduce_sum(context_vector, axis=1)
47+
48+
return context_vector, attention_weights
49+
50+
1551
class MathyEmbedding(tf.keras.Model):
1652
def __init__(self, config: BaseConfig, **kwargs):
1753
super(MathyEmbedding, self).__init__(**kwargs)
@@ -26,7 +62,9 @@ def __init__(self, config: BaseConfig, **kwargs):
2662
# +1 for the time - removed for ablation
2763
# +2 for the problem type hashes
2864
self.concat_size = 3
29-
self.values_dense = tf.keras.layers.Dense(self.config.units, name="values_input")
65+
self.values_dense = tf.keras.layers.Dense(
66+
self.config.units, name="values_input"
67+
)
3068
# self.time_dense = tf.keras.layers.Dense(self.config.units, name="time_input")
3169
self.type_dense = tf.keras.layers.Dense(self.config.units, name="type_input")
3270
self.in_dense = tf.keras.layers.Dense(
@@ -64,7 +102,7 @@ def __init__(self, config: BaseConfig, **kwargs):
64102
time_major=True,
65103
return_sequences=True,
66104
)
67-
self.lstm_attention = tf.keras.layers.Attention()
105+
self.lstm_attention = BahdanauAttention(self.config.lstm_units)
68106

69107
def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
70108
output = tf.concat(
@@ -84,5 +122,5 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
84122
output = self.time_lstm_norm(output)
85123
output, state_h, state_c = self.lstm_nodes(output)
86124
output = self.nodes_lstm_norm(output)
87-
output = self.lstm_attention([output, state_h])
88-
return self.out_dense_norm(self.output_dense(output))
125+
output, attention = self.lstm_attention(output, output)
126+
return self.out_dense_norm(self.output_dense(output)), attention

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def call(
147147
start = time.time()
148148
inputs = features_window
149149
# Extract features into contextual inputs, sequence inputs.
150-
sequence_inputs = self.embedding(inputs)
150+
sequence_inputs, attention_weights = self.embedding(inputs)
151151
sequence_mean = tf.reduce_mean(sequence_inputs, axis=1)
152152
values = self.value_net(sequence_mean)
153153
reward_logits = self.reward_net(sequence_mean)
@@ -163,7 +163,7 @@ def call(
163163
time.time() - start, batch_size
164164
)
165165
)
166-
return logits, values, mask_logits, reward_logits, grouping
166+
return logits, values, mask_logits, reward_logits, attention_weights
167167

168168
def apply_pi_mask(
169169
self, logits: tf.Tensor, features_window: MathyInputsType,

0 commit comments

Comments
 (0)