# Building the Transformer from Scratch

In this notebook, we'll be implementing the famous Transformer architecture from scratch.

The code is based off of the following repos/blog posts:

- [attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch)
- [nlp-tutorial](https://github.com/graykode/nlp-tutorial)
- [The illustrated Transformer](https://nlpinkorean.github.io/illustrated-transformer/l)

Thanks so much to their authors!

In [1]:
import torch
import torch.nn as nn
import numpy as npdec_self_attn_mask

![image](https://camo.githubusercontent.com/88e8f36ce61dedfd2491885b8df2f68c4d1f92f5/687474703a2f2f696d6775722e636f6d2f316b72463252362e706e67)

In [2]:
sentences = ['기분이 저기압일 때에는 고기 앞으로 가라 P', 'S eat meat when you feel low ', 'eat meat when you feel low E']

In [3]:
src_vocab = {'P' : 0, '기분이' : 1, '저기압일' : 2, '때에는' : 3, '고기' : 4, '앞으로' : 5, '가라' : 6}
src_vocab_size = len(src_vocab)

In [4]:
src_vocab

{'P': 0, '기분이': 1, '저기압일': 2, '때에는': 3, '고기': 4, '앞으로': 5, '가라': 6}

In [5]:
tgt_vocab = {'P' : 0, 'eat' : 1, 'meat' : 2, 'when' : 3, 'you' : 4, 'feel' : 5, 'low' : 6, 'S' : 7, 'E' : 8}
tgt_vocab_size = len(tgt_vocab)

src_len = 7
tgt_len = 7

In [6]:
enc_input_batch = [[src_vocab[n] for n in sentences[0].split()]]
dec_input_batch = [[tgt_vocab[n] for n in sentences[1].split()]]
dec_output_batch = [[tgt_vocab[n] for n in sentences[2].split()]]

In [7]:
print(enc_input_batch)
print(dec_input_batch)
print(dec_output_batch)

[[1, 2, 3, 4, 5, 6, 0]]
[[7, 1, 2, 3, 4, 5, 6]]
[[1, 2, 3, 4, 5, 6, 8]]


In [8]:
from torch.autograd import Variable

In [9]:
enc_input_batch = Variable(torch.LongTensor(enc_input_batch))
dec_input_batch = Variable(torch.LongTensor(dec_input_batch))
dec_output_batch = Variable(torch.LongTensor(dec_output_batch))

In [10]:
print(enc_input_batch)
print(dec_input_batch)
print(dec_output_batch)

tensor([[1, 2, 3, 4, 5, 6, 0]])
tensor([[7, 1, 2, 3, 4, 5, 6]])
tensor([[1, 2, 3, 4, 5, 6, 8]])


In [11]:
import math

In [12]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)
        
    def forward(self, x):
        return self.weight[:, :x.size(1), :] # (1, Seq, Feature)

In [13]:
class WordPositionEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model)
        
    def forward(self, x, mask=None):
        return self.word_embedding(x) + self.position_embedding(x)

In [14]:
input_emb = WordPositionEmbedding(src_vocab_size)

In [15]:
input_emb

WordPositionEmbedding(
  (word_embedding): Embedding(7, 512)
  (position_embedding): PositionalEmbedding()
)

In [16]:
enc_emb = input_emb(enc_input_batch)

In [17]:
enc_emb.shape

torch.Size([1, 7, 512])

![image](https://camo.githubusercontent.com/88e8f36ce61dedfd2491885b8df2f68c4d1f92f5/687474703a2f2f696d6775722e636f6d2f316b72463252362e706e67)

In [18]:
class TransformerEncoder(nn.Module):
    def __init__(self, n_blocks=6, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        
        self.encoders = nn.ModuleList([
            EncoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout)
                                                                                                                    for _ in range(n_blocks)])
    
    def forward(self, x, mask=None):
        for encoder in self.encoders:
            x = encoder(x)
        return x

```python
encoder = TransformerEncoder()
enc_output_batch = encoder(enc_emb)
```

In [19]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.layer_norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        print('[Encoder Block]')
        print(x.shape, "Encoder block input")
        att = self.attn_head(x, x, x, mask=mask)
        print(att.shape, "Attention output")
        
        # Apply normalization and residual connection
        x = self.dropout(self.layer_norm1(x + att))
        
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        print(pos.shape, "Feedforward output")
        
        # Apply normalization and residual connection
        x = self.dropout(self.layer_norm2(x + pos))
        print(x.shape, "Encoder output\n")
        return x

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_feature = d_feature
        
        self.W_Q = nn.Linear(d_model, d_feature * n_heads)
        self.W_K = nn.Linear(d_model, d_feature * n_heads)
        self.W_V = nn.Linear(d_model, d_feature * n_heads)
        
        self.W_O = nn.Linear(n_heads * d_feature, d_model)
        
    def forward(self, x1, x2, x3, mask):
        print('\t[MULTIHEAD]')
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = x1, x1.size(0)
        
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(x1).view(batch_size, -1, self.n_heads, self.d_feature).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        K = self.W_K(x2).view(batch_size, -1, self.n_heads, self.d_feature).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        V = self.W_V(x3).view(batch_size, -1, self.n_heads, self.d_feature).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]
        print('\t# Q, K, V shape(batch, heads, length, d_model/heads) :')
        print('\t',Q.shape, K.shape, V.shape)
        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        Z = ScaledDotProductAttention()(Q, K, V, mask, self.d_feature)
        print('\t# Z shape : ', Z.shape)
        Z = Z.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_feature) # context: [batch_size x len_q x n_heads * d_v]
        print('\t# Z shape changed : ', Z.shape)
        output = self.W_O(Z)
        return output # output: [batch_size x len_q x d_model]

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

In [21]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
        
    def forward(self, Q, K, V, attn_mask, d_k):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        print('\t# scores, V : ', scores.shape, V.shape)
        if attn_mask is not None:
            scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context

In [22]:
encoder = TransformerEncoder()
enc_output_batch = encoder(enc_emb)

[Encoder Block]
torch.Size([1, 7, 512]) Encoder block input
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64])
	# scores, V :  torch.Size([1, 8, 7, 7]) torch.Size([1, 8, 7, 64])
	# Z shape :  torch.Size([1, 8, 7, 64])
	# Z shape changed :  torch.Size([1, 7, 512])
torch.Size([1, 7, 512]) Attention output
torch.Size([1, 7, 512]) Feedforward output
torch.Size([1, 7, 512]) Encoder output

[Encoder Block]
torch.Size([1, 7, 512]) Encoder block input
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64])
	# scores, V :  torch.Size([1, 8, 7, 7]) torch.Size([1, 8, 7, 64])
	# Z shape :  torch.Size([1, 8, 7, 64])
	# Z shape changed :  torch.Size([1, 7, 512])
torch.Size([1, 7, 512]) Attention output
torch.Size([1, 7, 512]) Feedforward output
torch.Size([1, 7, 512]) Encoder output

[Encoder Block]
torch.Size([

In [23]:
class TransformerDecoder(nn.Module):
    def __init__(self, n_blocks=6, d_model=512, d_feature=64,
                 d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.position_embedding = PositionalEmbedding(d_model)
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout)
            for _ in range(n_blocks)
        ])
        
    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, src_mask=src_mask, tgt_mask=tgt_mask)
        return x

In [24]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super().__init__()
        self.enc_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.dec_attn_head = MultiHeadAttention(d_model, d_feature, n_heads, dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        # Apply attention to inputs
        print('[Decoder Block]')
        print(x.shape, "Decoder block input")
        att = self.dec_attn_head(x, x, x, mask=src_mask)
        print(att.shape, "First Attention output")
        x = self.dropout(self.layer_norm1(x + att))
        
        # Apply attention to the encoder outputs and outputs of the previous layer
        att = self.dec_attn_head(x1=x, x2=enc_out, x3=enc_out, mask=tgt_mask)
        print(att.shape, "Second Attention output")
        x = self.dropout(self.layer_norm2(x + att))
        # Apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        print(pos.shape, "Feedforward output")
        x = self.dropout(self.layer_norm2(x + pos))
        print(x.shape, "Decoder output\n")
        return x

In [25]:
output_emb = WordPositionEmbedding(tgt_vocab_size)
decoder = TransformerDecoder()

In [26]:
dec_emb = output_emb(dec_input_batch)
result = decoder(dec_emb, enc_output_batch)

[Decoder Block]
torch.Size([1, 7, 512]) Decoder block input
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64])
	# scores, V :  torch.Size([1, 8, 7, 7]) torch.Size([1, 8, 7, 64])
	# Z shape :  torch.Size([1, 8, 7, 64])
	# Z shape changed :  torch.Size([1, 7, 512])
torch.Size([1, 7, 512]) First Attention output
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64])
	# scores, V :  torch.Size([1, 8, 7, 7]) torch.Size([1, 8, 7, 64])
	# Z shape :  torch.Size([1, 8, 7, 64])
	# Z shape changed :  torch.Size([1, 7, 512])
torch.Size([1, 7, 512]) Second Attention output
torch.Size([1, 7, 512]) Feedforward output
torch.Size([1, 7, 512]) Decoder output

[Decoder Block]
torch.Size([1, 7, 512]) Decoder block input
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) t

In [27]:
result.shape

torch.Size([1, 7, 512])

------------

In [28]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

In [29]:
enc_self_attn_mask = get_attn_pad_mask(enc_input_batch, enc_input_batch)
print(enc_self_attn_mask.shape)
print(enc_self_attn_mask)

torch.Size([1, 7, 7])
tensor([[[0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1]]], dtype=torch.uint8)


In [30]:
def get_attn_subsequent_mask(seq):
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequent_mask = np.triu(np.ones(attn_shape), k=1)
    subsequent_mask = torch.from_numpy(subsequent_mask).byte()
    return subsequent_mask

In [31]:
dec_self_attn_pad_mask = get_attn_pad_mask(dec_input_batch, dec_input_batch)
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_input_batch)
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)

In [32]:
print(dec_self_attn_pad_mask.shape)
print(dec_self_attn_pad_mask)

torch.Size([1, 7, 7])
tensor([[[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)


In [33]:
print(dec_self_attn_subsequent_mask.shape)
print(dec_self_attn_subsequent_mask)

torch.Size([1, 7, 7])
tensor([[[0, 1, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)


In [34]:
print(dec_self_attn_mask)

tensor([[[0, 1, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)


In [35]:
encoder = TransformerEncoder()
enc_output_batch = encoder(enc_emb, enc_self_attn_mask)

output_emb = WordPositionEmbedding(tgt_vocab_size)
decoder = TransformerDecoder()

dec_emb = output_emb(dec_input_batch)
result = decoder(dec_emb, enc_output_batch, src_mask=enc_self_attn_mask, tgt_mask= dec_self_attn_mask)

[Encoder Block]
torch.Size([1, 7, 512]) Encoder block input
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64])
	# scores, V :  torch.Size([1, 8, 7, 7]) torch.Size([1, 8, 7, 64])
	# Z shape :  torch.Size([1, 8, 7, 64])
	# Z shape changed :  torch.Size([1, 7, 512])
torch.Size([1, 7, 512]) Attention output
torch.Size([1, 7, 512]) Feedforward output
torch.Size([1, 7, 512]) Encoder output

[Encoder Block]
torch.Size([1, 7, 512]) Encoder block input
	[MULTIHEAD]
	# Q, K, V shape(batch, heads, length, d_model/heads) :
	 torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64]) torch.Size([1, 8, 7, 64])
	# scores, V :  torch.Size([1, 8, 7, 7]) torch.Size([1, 8, 7, 64])
	# Z shape :  torch.Size([1, 8, 7, 64])
	# Z shape changed :  torch.Size([1, 7, 512])
torch.Size([1, 7, 512]) Attention output
torch.Size([1, 7, 512]) Feedforward output
torch.Size([1, 7, 512]) Encoder output

[Encoder Block]
torch.Size([