11from typing import Any , Dict , List , Optional , Tuple , Union
2-
32import tensorflow as tf
43
54from mathy .agents .base_config import BaseConfig
1110)
1211from mathy .agents .densenet import DenseNetStack
1312
14-
15- class BahdanauAttention (tf .keras .layers .Layer ):
16- """Bahdanau Attention from:
17- https://www.tensorflow.org/tutorials/text/nmt_with_attention
18-
19- Used rather than the built-in tf.keras Attention because we want
20- to get the weights for visualization.
21- """
22-
23- def __init__ (self , units ):
24- super (BahdanauAttention , self ).__init__ ()
25- self .W1 = tf .keras .layers .Dense (units )
26- self .W2 = tf .keras .layers .Dense (units )
27- self .V = tf .keras .layers .Dense (1 )
28-
29- def call (self , query , values ):
30- # query hidden state shape == (batch_size, hidden size)
31- # query_with_time_axis shape == (batch_size, 1, hidden size)
32- # values shape == (batch_size, max_len, hidden size)
33- # we are doing this to broadcast addition along the time axis to calculate the score
34- query_with_time_axis = tf .expand_dims (query , 1 )
35-
36- # score shape == (batch_size, max_length, 1)
37- # we get 1 at the last axis because we are applying score to self.V
38- # the shape of the tensor before applying self.V is (batch_size, max_length, units)
39- score = self .V (tf .nn .tanh (self .W1 (query_with_time_axis ) + self .W2 (values )))
40-
41- # attention_weights shape == (batch_size, max_length, 1)
42- attention_weights = tf .nn .softmax (score , axis = 1 )
43-
44- # context_vector shape after sum == (batch_size, hidden_size)
45- context_vector = attention_weights * values
46- context_vector = tf .reduce_sum (context_vector , axis = 1 )
47-
48- return context_vector , attention_weights
13+ from .attention import SeqSelfAttention
4914
5015
5116class MathyEmbedding (tf .keras .Model ):
@@ -95,6 +60,7 @@ def __init__(self, config: BaseConfig, **kwargs):
9560 name = "nodes_lstm" ,
9661 time_major = False ,
9762 return_sequences = True ,
63+ dropout = self .config .dropout ,
9864 ),
9965 merge_mode = "sum" ,
10066 )
@@ -103,8 +69,13 @@ def __init__(self, config: BaseConfig, **kwargs):
10369 name = "time_lstm" ,
10470 time_major = True ,
10571 return_sequences = True ,
72+ dropout = self .config .dropout ,
73+ )
74+ self .lstm_attention = SeqSelfAttention (
75+ attention_activation = "sigmoid" ,
76+ name = "self_attention" ,
77+ return_attention = True ,
10678 )
107- self .lstm_attention = BahdanauAttention (self .config .lstm_units )
10879
10980 def call (self , features : MathyInputsType , train : tf .Tensor = None ) -> tf .Tensor :
11081 output = tf .concat (
@@ -124,5 +95,5 @@ def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
12495 output = self .time_lstm_norm (output )
12596 output = self .lstm_nodes (output )
12697 output = self .nodes_lstm_norm (output )
127- output , attention = self .lstm_attention (output , output )
98+ output , attention = self .lstm_attention (output )
12899 return self .out_dense_norm (self .output_dense (output )), attention
0 commit comments