# Multi Head Attention & Transformer

In [1]:
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
%matplotlib inline  
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(torch.__version__))

# Device Configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("This notebook use [%s]."%(device))

PyTorch version:[1.7.0].
This notebook use [cuda:0].


# Defining Model

$$

Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

$$

In [4]:
class ScaledDotProductAttention(nn.Module):

    def forward(self, query, key, value, mask=None):
        dk = key.size()[-1]
        scores = query.matmul(key.transpose(-2, -1)) / np.sqrt(dk)
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  
        attention = F.softmax(scores, dim=-1)
        out = attention.matmul(value)
        return out, attention

y = torch.rand(1, 60, 512)
out = ScaledDotProductAttention()(y, y, y)
out[0].shape, out[1].shape

(torch.Size([1, 60, 512]), torch.Size([1, 60, 60]))

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, features, num_heads, bias=True):
        super(MultiHeadAttention, self).__init__()
        assert features % num_heads == 0, f'"features"(features) should be divisible by "head_num"(num_heads)'
        
        self.features = features
        self.num_heads = num_heads
        self.bias = bias
        self.depth = features // num_heads

        self.wq = nn.Linear(features, features, bias=bias)
        self.wk = nn.Linear(features, features, bias=bias)
        self.wv = nn.Linear(features, features, bias=bias)

        self.fc = nn.Linear(features, features, bias=bias)

    def split_heads(self, x, batch_size):
        # batch_sie, num_heads, seq_len, depth
        x = x.reshape(batch_size, -1, self.num_heads, self.depth)
        return x.permute([0, 2, 1, 3])
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        # print(q.shape, k.shape, v.shape)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        # print(q.shape, k.shape, v.shape)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = ScaledDotProductAttention()(q, k, v, mask)
        # print(scaled_attention.shape, attention_weights.shape)

        scaled_attention = scaled_attention.permute([0, 2, 1, 3])

        concat_attention = scaled_attention.reshape(batch_size, -1, self.features)

        out = self.fc(concat_attention)

        return out, attention_weights

temp_mha = MultiHeadAttention(features=512, num_heads=8)
out, attn = temp_mha(q=torch.rand(1, 45, 512), k=y, v=y, mask=None)
print(out.shape, attn.shape)

torch.Size([1, 45, 512]) torch.Size([1, 8, 45, 60])


In [7]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, features, fffeatures):
        super(FeedForwardNetwork, self).__init__()

        layer_list = [
            nn.Linear(features, fffeatures),
            nn.ReLU(True),
            nn.Linear(fffeatures, features)
        ]

        self.net = nn.Sequential(*layer_list)
    def forward(self, x):
        return self.net(x)

sample_ffn = FeedForwardNetwork(512, 2048)
sample_ffn(torch.rand(64, 50, 512)).shape

torch.Size([64, 50, 512])

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, features, num_heads, fffeatures, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(features, num_heads)
        self.ffn = FeedForwardNetwork(features, fffeatures)

        self.layernorm1 = nn.LayerNorm(features)
        self.layernorm2 = nn.LayerNorm(features)

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)

    def forward(self, x, mask):

        attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output)
        out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

        return out2

sample_encoder_layer = EncoderLayer(512, 8, 2048)

sample_encoder_layer_output = sample_encoder_layer(
    torch.rand(64, 43, 512), None)

sample_encoder_layer_output.shape  # (batch_size, input_seq_len, d_model)

torch.Size([64, 43, 512])

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, features, num_heads, fffeatures, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(features, num_heads)
        self.mha2 = MultiHeadAttention(features, num_heads)
        self.ffn = FeedForwardNetwork(features, fffeatures)

        self.layernorm1 = nn.LayerNorm(features)
        self.layernorm2 = nn.LayerNorm(features)
        self.layernorm3 = nn.LayerNorm(features)

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)
        self.dropout3 = nn.Dropout(rate)

    def forward(self, x, enc_output, look_ahead_mask, padding_mask):
        
        # enc_output.shape == (batch_size, input_seq_len, d_model)
        # print(enc_output.shape)
        attn, attn_weights_block = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
        attn = self.dropout1(attn)
        out = self.layernorm1(attn + x)
        
        attn, attn_weights_block = self.mha2(out, enc_output, enc_output, padding_mask)  # (batch_size, target_seq_len, d_model)
        attn = self.dropout2(attn2)
        out = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

        ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
        ffn_output = self.dropout3(ffn_output)
        out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

        return out3, attn_weights_block1, attn_weights_block2

sample_encoder_layer = EncoderLayer(512, 8, 2048)

sample_encoder_layer_output = sample_encoder_layer(
    torch.rand(64, 50, 512), None)

sample_encoder_layer_output.shape  # (batch_size, input_seq_len, d_model)

torch.Size([64, 50, 512])