<a href="https://colab.research.google.com/github/bkim9/Resume/blob/main/11_5_Multi_Head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install d2l==1.0.3
!pip install -U mxnet-cu112==1.9.1

Collecting d2l==1.0.3
  Downloading d2l-1.0.3-py3-none-any.whl (111 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/111.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.0/111.7 kB[0m [31m1.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.7/111.7 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jupyter==1.0.0 (from d2l==1.0.3)
  Downloading jupyter-1.0.0-py2.py3-none-any.whl (2.7 kB)
Collecting matplotlib==3.7.2 (from d2l==1.0.3)
  Downloading matplotlib-3.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m23.7 MB/s[0m eta [36m0:00:00[0m
Collecting pandas==2.0.3 (from d2l==1.0.3)
  Downloading pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

Collecting mxnet-cu112==1.9.1
  Downloading mxnet_cu112-1.9.1-py3-none-manylinux2014_x86_64.whl (499.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m499.4/499.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting graphviz<0.9.0,>=0.8.1 (from mxnet-cu112==1.9.1)
  Downloading graphviz-0.8.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: graphviz, mxnet-cu112
  Attempting uninstall: graphviz
    Found existing installation: graphviz 0.20.1
    Uninstalling graphviz-0.20.1:
      Successfully uninstalled graphviz-0.20.1
Successfully installed graphviz-0.8.4 mxnet-cu112-1.9.1


In [None]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
class MultiHeadAttention(d2l.Module):
  """Multi-head attention."""
  def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
    super().__init__()
    self.num_heads = num_heads
    self.attention = d2l.DotProductAttention(dropout)
    self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
    self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
    self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
    self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

  def forward(self, queries, keys, values, valid_lens):
    queries = self.transpose_qkv(self.W_q(queries))
    keys = self.transpose_qkv(self.W_k(keys))
    values = self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      valid_lens = torch.repeat_interleave(
          valid_lens, repeats=self.num_heads, dim=0
      )
    output = self.attention(queries, keys, values, valid_lens)
    output_concat = self.transpose_output(output)
    return self.W_o(output_concat)

In [None]:
@d2l.add_to_class(MultiHeadAttention)
def transpose_qkv(self, X):
  """Transposition for parallel computation of multiple attention heads."""
  # input  X: (batch_size, no. of queries or key-value pairs, numhiddens)
  # output X: (batch_size, #Q #K_V, num_heads, num_hiddens / num_heads)
  X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1);
  # X: (batch_size * #H, #Q | #K_V, #Hiddens/#H)
  X = X.permute(0,2,1,3)
  return X.reshape(-1, X.shape[2], X.shape[3])

@d2l.add_to_class(MultiHeadAttention)
def transpose_output(self, X):
  """Reverse the operation of transpose_qkv."""
  X=X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
  X=X.permute(0,2,1,3)
  return X.reshape(X.shape[0], X.shape[1], -1)

In [None]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3,2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))

