1212from mathy .agents .densenet import DenseNetStack
1313
1414
15+ class BahdanauAttention (tf .keras .layers .Layer ):
16+ """Bahdanau Attention from:
17+ https://www.tensorflow.org/tutorials/text/nmt_with_attention
18+
19+ Used rather than the built-in tf.keras Attention because we want
20+ to get the weights for visualization.
21+ """
22+
23+ def __init__ (self , units ):
24+ super (BahdanauAttention , self ).__init__ ()
25+ self .W1 = tf .keras .layers .Dense (units )
26+ self .W2 = tf .keras .layers .Dense (units )
27+ self .V = tf .keras .layers .Dense (1 )
28+
29+ def call (self , query , values ):
30+ # query hidden state shape == (batch_size, hidden size)
31+ # query_with_time_axis shape == (batch_size, 1, hidden size)
32+ # values shape == (batch_size, max_len, hidden size)
33+ # we are doing this to broadcast addition along the time axis to calculate the score
34+ query_with_time_axis = tf .expand_dims (query , 1 )
35+
36+ # score shape == (batch_size, max_length, 1)
37+ # we get 1 at the last axis because we are applying score to self.V
38+ # the shape of the tensor before applying self.V is (batch_size, max_length, units)
39+ score = self .V (tf .nn .tanh (self .W1 (query_with_time_axis ) + self .W2 (values )))
40+
41+ # attention_weights shape == (batch_size, max_length, 1)
42+ attention_weights = tf .nn .softmax (score , axis = 1 )
43+
44+ # context_vector shape after sum == (batch_size, hidden_size)
45+ context_vector = attention_weights * values
46+ context_vector = tf .reduce_sum (context_vector , axis = 1 )
47+
48+ return context_vector , attention_weights
49+
50+
1551class MathyEmbedding (tf .keras .Model ):
1652 def __init__ (self , config : BaseConfig , ** kwargs ):
1753 super (MathyEmbedding , self ).__init__ (** kwargs )
@@ -26,7 +62,9 @@ def __init__(self, config: BaseConfig, **kwargs):
2662 # +1 for the time - removed for ablation
2763 # +2 for the problem type hashes
2864 self .concat_size = 3
29- self .values_dense = tf .keras .layers .Dense (self .config .units , name = "values_input" )
65+ self .values_dense = tf .keras .layers .Dense (
66+ self .config .units , name = "values_input"
67+ )
3068 # self.time_dense = tf.keras.layers.Dense(self.config.units, name="time_input")
3169 self .type_dense = tf .keras .layers .Dense (self .config .units , name = "type_input" )
3270 self .in_dense = tf .keras .layers .Dense (
@@ -64,7 +102,7 @@ def __init__(self, config: BaseConfig, **kwargs):
64102 time_major = True ,
65103 return_sequences = True ,
66104 )
67- self .lstm_attention = tf . keras . layers . Attention ( )
105+ self .lstm_attention = BahdanauAttention ( self . config . lstm_units )
68106
69107 def call (self , features : MathyInputsType , train : tf .Tensor = None ) -> tf .Tensor :
70108 output = tf .concat (
@@ -84,5 +122,5 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
84122 output = self .time_lstm_norm (output )
85123 output , state_h , state_c = self .lstm_nodes (output )
86124 output = self .nodes_lstm_norm (output )
87- output = self .lstm_attention ([ output , state_h ] )
88- return self .out_dense_norm (self .output_dense (output ))
125+ output , attention = self .lstm_attention (output , output )
126+ return self .out_dense_norm (self .output_dense (output )), attention
0 commit comments