## Quiz 06 Answers
1. True
2. False
3. False
4. True
5. True

## Lab 06 Objectives
1. Implement Attention(Q, K, V) function of Transformer
2. Implement MultiHead(Q, K, V) function of Transformer
3. Create a mask for the autoregressive characteristic of Transformer decoder


In [None]:
# Objective 1: Implement Attention(Q, K, V) function of Transformer
import torch
from torch import nn

def attention(Q, K, V, mask=None):
  softmax = nn.Softmax(2)
  scale = 1.0 / torch.sqrt(torch.tensor(Q.size(2)))
  logits = scale * Q @ K.transpose(1, 2)
  if mask is not None:
    logits += mask
  out = softmax(logits) @ V
  return out

torch.manual_seed(605)

batch_size = 8
seq_len = 16
hidden_size = 32

Q = torch.randn(batch_size, seq_len, hidden_size)
K = torch.randn(batch_size, seq_len, hidden_size)
V = torch.randn(batch_size, seq_len, hidden_size)
out = attention(Q, K, V)

print(out.size())
assert (out.max() - 1.204).abs() < 0.01

torch.Size([8, 16, 32])


In [None]:
# Objective 2: Implement MultiHead(Q, K, V) function of Transformer
from torch import nn

class MultiHead(nn.Module):
  def __init__(self, hidden_size, num_heads):
    super().__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.head_size = int(hidden_size / num_heads)
    self.q_layer = nn.Linear(hidden_size, hidden_size)
    self.k_layer = nn.Linear(hidden_size, hidden_size)
    self.v_layer = nn.Linear(hidden_size, hidden_size)
  
  def forward(self, Q, K, V):
    batch_size, seq_len, hidden_size = Q.size()
    
    def preprocess(M):
      Mi = M.reshape(batch_size, seq_len, self.num_heads, self.head_size)
      M_ = Mi.transpose(1, 2).reshape(batch_size * self.num_heads, seq_len, self.head_size)
      return M_

    Q_ = preprocess(self.q_layer(Q))
    K_ = preprocess(self.k_layer(K))
    V_ = preprocess(self.v_layer(V))

    out_ = attention(Q_, K_, V_)

    # postprocess
    outi = out_.reshape(batch_size, self.num_heads, seq_len, self.head_size).transpose(1, 2)
    out = outi.reshape(batch_size, seq_len, self.num_heads * self.head_size)

    return out


torch.manual_seed(605)

batch_size = 8
seq_len = 16
hidden_size = 32
num_heads = 4

Q = torch.randn(batch_size, seq_len, hidden_size)
K = torch.randn(batch_size, seq_len, hidden_size)
V = torch.randn(batch_size, seq_len, hidden_size)
multi_head = MultiHead(hidden_size, num_heads)
out = multi_head(Q, K, V)

print(out.size())
assert (out.max() - 0.5557).abs() < 0.01

torch.Size([8, 16, 32])


In [None]:
# Objective 3: Create a mask for the autoregressive characteristic of Transformer decoder

import torch
from torch import nn

def get_ar_mask(seq_len):
  binary_mask = torch.ones(seq_len, seq_len).tril()
  logit_mask = (1.0 - binary_mask) * -1e9
  return logit_mask

torch.manual_seed(605)

batch_size = 8
seq_len = 16
hidden_size = 32

Q = torch.randn(batch_size, seq_len, hidden_size)
K = torch.randn(batch_size, seq_len, hidden_size)
V = torch.randn(batch_size, seq_len, hidden_size)
ar_mask = get_ar_mask(seq_len)
out = attention(Q, K, V, mask=ar_mask)

print(ar_mask.size())
assert (out.max() - 3.8989).abs() < 0.01

torch.Size([16, 16])
