<a href="https://colab.research.google.com/github/huynhthaihoa/Dive-into-Deep-Learning-Practice/blob/master/Attention_mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Attention Mechanism

In [1]:
!pip install mxnet



In [2]:
import math
from mxnet import nd
from mxnet.gluon import nn

In [3]:
# Save to the d2l package.
def masked_softmax(X, valid_length):
  # X: 3-D tensor, valid_length: 1-D or 2-D tensor
  if valid_length is None:
    return X.softmax()
  else:
    shape = X.shape
    if valid_length.ndim == 1:
      valid_length = valid_length.repeat(shape[1], axis=0)
    else:
      valid_length = valid_length.reshape((-1,))
    # fill masked elements with a large negative, whose exp is 0
    X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True, axis=1, value=-1e6)
    return X.softmax().reshape(shape)

In [4]:
masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3]))


[[[0.488994   0.511006   0.         0.        ]
  [0.4365484  0.56345165 0.         0.        ]]

 [[0.288171   0.3519408  0.3598882  0.        ]
  [0.29034296 0.25239873 0.45725837 0.        ]]]
<NDArray 2x2x4 @cpu(0)>

In [5]:
nd.batch_dot(nd.ones((2,1,3)), nd.ones((2,3,2)))


[[[3. 3.]]

 [[3. 3.]]]
<NDArray 2x1x2 @cpu(0)>

## 1.1. Dot Product Attention 

In [6]:
# Save to the d2l package.
class DotProductAttention(nn.Block):
  def __init__(self, dropout, **kwargs):
    super(DotProductAttention, self).__init__(**kwargs)
    self.dropout = nn.Dropout(dropout)

  def forward(self, query, key, value, valid_length=None):
    '''
    Support a batch of queries & key-value pairs:
    @query: (batch_size, #queries, d)
    @key: (batch_size, #kv_pairs, d)
    @valid_length: either (batch_size, ) or (batch_size, xx)
    '''
    d = query.shape[-1]
    # set transpose_b=True to swap the last two dimensions of key
    scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
    # randomly drop some attention weights as a regularization
    attention_weights = self.dropout(masked_softmax(scores, valid_length))
    return nd.batch_dot(attention_weights, value)

In [7]:
atten = DotProductAttention(dropout=0.5)
atten.initialize()
keys = nd.ones((2,10,2))
values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))


[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

## 1.2. Multilayer Perceptron Attention

In [8]:
# Save to the d2l package.
class MLPAttention(nn.Block):
  def __init__(self, units, dropout, **kwargs):
    super(MLPAttention, self).__init__(**kwargs)
    # Use flatten=True to keep query's and key's 3-D shapes.
    self.W_k = nn.Dense(units, activation='tanh', use_bias=False, flatten=False)
    self.W_q = nn.Dense(units, activation='tanh', use_bias=False, flatten=False)
    self.v = nn.Dense(1, use_bias=False, flatten=False)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, query, key, value, valid_length):
    query, key = self.W_k(query), self.W_q(key)
    # expand query to (batch_size, #querys, 1, units), and key to
    # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
    features = query.expand_dims(axis=2) + key.expand_dims(axis=1)
    scores = self.v(features).squeeze(axis=-1)
    attention_weights = self.dropout(masked_softmax(scores, valid_length))
    return nd.batch_dot(attention_weights, value)

In [9]:
atten = MLPAttention(units=8, dropout=0.1)
atten.initialize()
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))


[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

# 2. Sequence to Sequence with Attention Mechanism

In [22]:
!pip install d2l==0.11.2

Collecting d2l==0.11.2
  Downloading https://files.pythonhosted.org/packages/76/a4/8744a4f55dd50752a503848d605c86ee49728f8311fa3eee6478ff0a6bec/d2l-0.11.2-py3-none-any.whl
Installing collected packages: d2l
  Found existing installation: d2l 0.15.0
    Uninstalling d2l-0.15.0:
      Successfully uninstalled d2l-0.15.0
Successfully installed d2l-0.11.2


In [10]:
import d2l
from mxnet import nd
from mxnet.gluon import rnn, nn

## 2.1. Decoder

In [11]:
class Seq2SeqAttentionDecoder(d2l.Decoder):
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
    super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
    self.attention_cell = d2l.MLPAttention(num_hiddens, dropout)
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout)
    self.dense = nn.Dense(vocab_size, flatten=False)
  
  def init_state(self, enc_outputs, enc_valid_len, *args):
    outputs, hidden_state = enc_outputs
    # Transpose outputs to (batch_size, seq_len, hidden_size)
    return (outputs.swapaxes(0,1), hidden_state, enc_valid_len)

  def forward(self, X, state):
    enc_outputs, hidden_state, enc_valid_len = state
    X = self.embedding(X).swapaxes(0, 1)
    outputs = []
    for x in X:
      # query shape: (batch_size, 1, hidden_size)
      query = hidden_state[0][-1].expand_dims(axis=1)
      # context has same shape as query
      context = self.attention_cell(query, enc_outputs, enc_outputs, enc_valid_len)
      # concatenate on the feature dimension
      x = nd.concat(context, x.expand_dims(axis=1), dim=-1)
      # reshape x to (1, batch_size, embed_size+hidden_size)
      out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)
    outputs.append(out)
    outputs = self.dense(nd.concat(*outputs, dim=0))
    return outputs.swapaxes(0, 1), [enc_outputs, hidden_state, enc_valid_len]

In [19]:
from mxnet import npx
npx.set_np()

In [24]:
import numpy as np

In [29]:
#npx.set_np()
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.initialize()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
decoder.initialize()
X = nd.zeros((4, 7))
state = decoder.init_state(encoder(X.as_np_ndarray()), None)
out, state = decoder(X, state)
out.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

TypeError: ignored