In [None]:
!pip install d2l==0.14.2

# **10.1 Attention Cues**

10.1.1 Attention Cues in Biology

10.1.2 Queries, Keys, and Values

10.1.3 Visualization of Attention

In [None]:
import torch
from d2l import torch as d2l

In [None]:
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """Show heatmaps of matrices."""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
    sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
      for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
        pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
        if i == num_rows - 1:
          ax.set_xlabel(xlabel)
        if j == 0:
          ax.set_ylabel(ylabel)
        if titles:
          ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

In [None]:
attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

# **10.2 Attention Pooling: Nadaraya-Watson Kernel Regression**

In [None]:
import torch
from torch import nn
from d2l import torch as d2l

10.2.1 Generating the Dataset

In [None]:
n_train = 50 # No. of training examples
x_train, _ = torch.sort(torch.rand(n_train) * 5) # Training inputs

In [None]:
def f(x):
  return 2 * torch.sin(x) + x**0.8
  
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # Training outputs
x_test = torch.arange(0, 5, 0.1) # Testing examples
y_truth = f(x_test) # Ground-truth outputs for the testing examples
n_test = len(x_test) # No. of testing examples
n_test

In [None]:
def plot_kernel_reg(y_hat):
  d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
          xlim=[0, 5], ylim=[-1, 5])
  d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

10.2.2 Average Pooling

In [None]:
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

10.2.3 Nonparametric Attention Pooling

In [None]:
# Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the
# same testing inputs (i.e., same queries)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# Note that `x_train` contains the keys. Shape of `attention_weights`:
# (`n_test`, `n_train`), where each row contains attention weights to be
# assigned among the values (`y_train`) given each query
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# Each element of `y_hat` is weighted average of values, where weights are
# attention weights
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

In [None]:
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

10.2.4 Parametric Attention Pooling

In [None]:
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

In [None]:
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))

In [None]:
class NWKernelRegression(nn.Module):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

  def forward(self, queries, keys, values):
    # Shape of the output `queries` and `attention_weights`:
    # (no. of queries, no. of key-value pairs)
    queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
    self.attention_weights = nn.functional.softmax(
        -((queries - keys) * self.w)**2 / 2, dim=1)
    # Shape of `values`: (no. of queries, no. of key-value pairs)
    return torch.bmm(self.attention_weights.unsqueeze(1),
            values.unsqueeze(-1)).reshape(-1)

In [None]:
# Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the
# same training inputs
X_tile = x_train.repeat((n_train, 1))
# Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the
# same training outputs
Y_tile = y_train.repeat((n_train, 1))
# Shape of `keys`: ('n_train', 'n_train' - 1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# Shape of `values`: ('n_train', 'n_train' - 1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

In [None]:
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
for epoch in range(5):
  trainer.zero_grad()
  l = loss(net(x_train, keys, values), y_train)
  l.sum().backward()
  trainer.step()
  print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
  animator.add(epoch + 1, float(l.sum()))

In [None]:
# Shape of `keys`: (`n_test`, `n_train`), where each column contains the same
# training inputs (i.e., same keys)
keys = x_train.repeat((n_test, 1))
# Shape of `value`: (`n_test`, `n_train`)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

In [None]:
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
      xlabel='Sorted training inputs',
      ylabel='Sorted testing inputs')

# **10.3 Attention Scoring Functions**

In [None]:
import math
import torch
from torch import nn
from d2l import torch as d2l

10.3.1 Masked Softmax Operation

In [None]:
#@save
def masked_softmax(X, valid_lens):
  """Perform softmax operation by masking elements on the last axis."""
  # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
  if valid_lens is None:
    return nn.functional.softmax(X, dim=-1)
  else:
    shape = X.shape
    if valid_lens.dim() == 1:
      valid_lens = torch.repeat_interleave(valid_lens, shape[1])
    else:
      valid_lens = valid_lens.reshape(-1)
    # On the last axis, replace masked elements with a very large negative
    # value, whose exponentiation outputs 0
    X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                    value=-1e6)
    return nn.functional.softmax(X.reshape(shape), dim=-1)

In [None]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

In [None]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

10.3.2 Additive Attention

In [None]:
#@save
class AdditiveAttention(nn.Module):
  """Additive attention."""
  def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
    super(AdditiveAttention, self).__init__(**kwargs)
    self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
    self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
    self.w_v = nn.Linear(num_hiddens, 1, bias=False)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, queries, keys, values, valid_lens):
    queries, keys = self.W_q(queries), self.W_k(keys)
    # After dimension expansion, shape of `queries`: (`batch_size`, no. of
    # queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
    # no. of key-value pairs, `num_hiddens`). Sum them up with
    # broadcasting
    features = queries.unsqueeze(2) + keys.unsqueeze(1)
    features = torch.tanh(features)
    # There is only one output of `self.w_v`, so we remove the last
    # one-dimensional entry from the shape. Shape of `scores`:
    # (`batch_size`, no. of queries, no. of key-value pairs)
    scores = self.w_v(features).squeeze(-1)
    self.attention_weights = masked_softmax(scores, valid_lens)
    # Shape of `values`: (`batch_size`, no. of key-value pairs, value
    # dimension)
    return torch.bmm(self.dropout(self.attention_weights), values)

In [None]:
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# The two value matrices in the `values` minibatch are identical
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
        2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
        dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)

In [None]:
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

10.3.3 Scaled Dot-Product Attention

In [None]:
#@save
class DotProductAttention(nn.Module):
  """Scaled dot product attention."""
  def __init__(self, dropout, **kwargs):
    super(DotProductAttention, self).__init__(**kwargs)
    self.dropout = nn.Dropout(dropout)
    # Shape of `queries`: (`batch_size`, no. of queries, `d`)
    # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
    # Shape of `values`: (`batch_size`, no. of key-value pairs, value
    # dimension)
    # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
  def forward(self, queries, keys, values, valid_lens=None):
    d = queries.shape[-1]
    # Set `transpose_b=True` to swap the last two dimensions of `keys`
    scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
    self.attention_weights = masked_softmax(scores, valid_lens)
    return torch.bmm(self.dropout(self.attention_weights), values)

In [None]:
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

In [None]:
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
        xlabel='Keys', ylabel='Queries')

# **10.4 Bahdanau Attention**

10.4.1 Model

10.4.2 Defining the Decoder with Attention

In [None]:
#@save
class AttentionDecoder(d2l.Decoder):
  """The base attention-based decoder interface."""
  def __init__(self, **kwargs):
    super(AttentionDecoder, self).__init__(**kwargs)
  
  @property
  def attention_weights(self):
    raise NotImplementedError

In [None]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
            dropout=0, **kwargs):
      super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
      self.attention = d2l.AdditiveAttention(
      num_hiddens, num_hiddens, num_hiddens, dropout)
      self.embedding = nn.Embedding(vocab_size, embed_size)
      self.rnn = nn.GRU(
          embed_size + num_hiddens, num_hiddens, num_layers,
          dropout=dropout)
      self.dense = nn.Linear(num_hiddens, vocab_size)

  def init_state(self, enc_outputs, enc_valid_lens, *args):
    # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`).
    # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`,
    # `num_hiddens`)
    outputs, hidden_state = enc_outputs
    return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

  def forward(self, X, state):
    # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`).
    # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`,
    # `num_hiddens`)
    enc_outputs, hidden_state, enc_valid_lens = state
    # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`)
    X = self.embedding(X).permute(1, 0, 2)
    outputs, self._attention_weights = [], []
    for x in X:
      # Shape of `query`: (`batch_size`, 1, `num_hiddens`)
      query = torch.unsqueeze(hidden_state[-1], dim=1)
      # Shape of `context`: (`batch_size`, 1, `num_hiddens`)
      context = self.attention(
            query, enc_outputs, enc_outputs, enc_valid_lens)
      # Concatenate on the feature dimension
      x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
      # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`)
      out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
      outputs.append(out)
      self._attention_weights.append(self.attention.attention_weights)
    # After fully-connected layer transformation, shape of `outputs`:
    # (`num_steps`, `batch_size`, `vocab_size`)
    outputs = self.dense(torch.cat(outputs, dim=0))
    return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,
        enc_valid_lens]
  @property
  def attention_weights(self):
    return self._attention_weights

In [None]:
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
    num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
    num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long) # (`batch_size`, `num_steps`)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

10.4.3 Training

In [None]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(
len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

In [None]:
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
  translation, dec_attention_weight_seq = d2l.predict_seq2seq(
      net, eng, src_vocab, tgt_vocab, num_steps, device, True)
  print(f'{eng} => {translation}, ',
        f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

In [None]:
# Plus one to include the end-of-sequence token
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')