This repository has been archived by the owner on Dec 29, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
attention.py
75 lines (62 loc) · 2.61 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
""" Implementations of attention layers.
"""
import tensorflow as tf
from seq2seq import GraphModule
class AttentionLayer(GraphModule):
"""
Attention layer according to https://arxiv.org/abs/1409.0473.
Args:
num_units: Number of units used in the attention layer
name: Name for this graph module
"""
def __init__(self, num_units, name="attention"):
super(AttentionLayer, self).__init__(name=name)
self.num_units = num_units
def _build(self, state, inputs):
"""Computes attention scores and outputs.
Args:
state: The state based on which to calculate attention scores.
In seq2seq this is typically the current state of the decoder.
A tensor of shape `[B, ...]`
inputs: The elements to compute attention *over*. In seq2seq this is
typically the sequence of encoder outputs.
A tensor of shape `[B, T, input_dim]`
Returns:
A tuple `(scores, context)`.
`scores` is vector of length `T` where each element is the
normalized "score" of the corresponding `inputs` element.
`context` is the final attention layer output corresponding to
the weighted inputs.
A tensor fo shape `[B, input_dim]`.
"""
batch_size, inputs_timesteps, _ = tf.unpack(tf.shape(inputs))
inputs_dim = inputs.get_shape().as_list()[-1]
# Fully connected layers to transform both inputs and state
# into a tensor with `num_units` units
inputs_att = tf.contrib.layers.fully_connected(
inputs=inputs,
num_outputs=self.num_units,
activation_fn=None,
scope="inputs_att")
state_att = tf.contrib.layers.fully_connected(
inputs=state,
num_outputs=self.num_units,
activation_fn=None,
scope="state_att")
# Take the dot product of state for each time step in inputs
# Result: A tensor of shape [B, T]
inputs_att_flat = tf.reshape(inputs_att, [-1, self.num_units])
state_att_flat = tf.reshape(
tf.tile(state_att, [1, inputs_timesteps]),
[inputs_timesteps * batch_size, self.num_units])
scores = tf.batch_matmul(
tf.expand_dims(inputs_att_flat, 1), tf.expand_dims(state_att_flat, 2))
scores = tf.reshape(scores, [batch_size, inputs_timesteps], name="scores")
# Normalize the scores
scores_normalized = tf.nn.softmax(scores, name="scores_normalized")
# Calculate the weighted average of the attention inputs
# according to the scores
context = tf.expand_dims(scores_normalized, 2) * inputs
context = tf.reduce_sum(context, 1, name="context")
context.set_shape([None, inputs_dim])
return (scores_normalized, context)