import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class BahdanauAttention(nn.Module):
Bahdanau Attention (
Implementation is very similar to tf.contrib.seq2seq.BahdanauAttention
def __init__(self, query_size, key_size, num_units, normalize=False,
batch_first=False, init_weight=0.1):
Constructor for the BahdanauAttention.
:param query_size: feature dimension for query
:param key_size: feature dimension for keys
:param num_units: internal feature dimension
:param normalize: whether to normalize energy term
:param batch_first: if True batch size is the 1st dimension, if False
the sequence is first and batch size is second
:param init_weight: range for uniform initializer used to initialize
Linear key and query transform layers and linear_att vector
super(BahdanauAttention, self).__init__()
self.normalize = normalize
self.batch_first = batch_first
self.num_units = num_units
self.linear_q = nn.Linear(query_size, num_units, bias=False)
self.linear_k = nn.Linear(key_size, num_units, bias=False)
nn.init.uniform_(, -init_weight, init_weight)
nn.init.uniform_(, -init_weight, init_weight)
self.linear_att = Parameter(torch.Tensor(num_units))
self.mask = None
if self.normalize:
self.normalize_scalar = Parameter(torch.Tensor(1))
self.normalize_bias = Parameter(torch.Tensor(num_units))
self.register_parameter('normalize_scalar', None)
self.register_parameter('normalize_bias', None)
def reset_parameters(self, init_weight):
Sets initial random values for trainable parameters.
stdv = 1. / math.sqrt(self.num_units), init_weight)
if self.normalize:
def set_mask(self, context_len, context):
sets self.mask which is applied before softmax
ones for inactive context fields, zeros for active context fields
:param context_len: b
:param context: if batch_first: (b x t_k x n) else: (t_k x b x n)
self.mask: (b x t_k)
if self.batch_first:
max_len = context.size(1)
max_len = context.size(0)
indices = torch.arange(0, max_len, dtype=torch.int64,
self.mask = indices >= (context_len.unsqueeze(1))
def calc_score(self, att_query, att_keys):
Calculate Bahdanau score
:param att_query: b x t_q x n
:param att_keys: b x t_k x n
returns: b x t_q x t_k scores
b, t_k, n = att_keys.size()
t_q = att_query.size(1)
att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
sum_qk = att_query + att_keys
if self.normalize:
sum_qk = sum_qk + self.normalize_bias
linear_att = self.linear_att / self.linear_att.norm()
linear_att = linear_att * self.normalize_scalar
linear_att = self.linear_att
out = torch.tanh(sum_qk).matmul(linear_att)
return out
def forward(self, query, keys):
:param query: if batch_first: (b x t_q x n) else: (t_q x b x n)
:param keys: if batch_first: (b x t_k x n) else (t_k x b x n)
:returns: (context, scores_normalized)
context: if batch_first: (b x t_q x n) else (t_q x b x n)
scores_normalized: if batch_first (b x t_q x t_k) else (t_q x b x t_k)
# first dim of keys and query has to be 'batch', it's needed for bmm
if not self.batch_first:
keys = keys.transpose(0, 1)
if query.dim() == 3:
query = query.transpose(0, 1)
if query.dim() == 2:
single_query = True
query = query.unsqueeze(1)
single_query = False
b = query.size(0)
t_k = keys.size(1)
t_q = query.size(1)
# FC layers to transform query and key
processed_query = self.linear_q(query)
processed_key = self.linear_k(keys)
# scores: (b x t_q x t_k)
scores = self.calc_score(processed_query, processed_key)
if self.mask is not None:
mask = self.mask.unsqueeze(1).expand(b, t_q, t_k)
# I can't use -INF because of overflow check in pytorch, -65504.0)
# Normalize the scores, softmax over t_k
scores_normalized = F.softmax(scores, dim=-1)
# Calculate the weighted average of the attention inputs according to
# the scores
# context: (b x t_q x n)
context = torch.bmm(scores_normalized, keys)
if single_query:
context = context.squeeze(1)
scores_normalized = scores_normalized.squeeze(1)
elif not self.batch_first:
context = context.transpose(0, 1)
scores_normalized = scores_normalized.transpose(0, 1)
return context, scores_normalized
