In [52]:
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F

In [41]:
class ScaledDotProductAttention:
    def __init__(self) -> None:
        """
        Implementing scaled dot procduct attention.
        """

    def __call__(self, Qs: torch.tensor, Ks: torch.tensor, Vs: torch.tensor, mask: bool) -> torch.tensor:
        """
        Calling scaled dot procduct attention.

        Args:
            Qs: query matrix    [batch_size, n_heads, max_seq_len, d_k]
            Ks: key matrix      [batch_size, n_heads, max_seq_len, d_k]
            Vs: values  matrix  [batch_size, n_heads, max_seq_len, d_v]
            mask: whether to apply mask based on whether called during encoding/decoding.

        Returns:
            Calculated attention weights.
        """
        score_mat = Qs @ Ks.permute(0, 1, 3, 2)
        score_mat_scaled = score_mat / np.sqrt(len(Ks[0]))

        if mask:
            score_mat_scaled = score_mat_scaled.masked_fill(mask == 0, -1e9)

        attention = F.softmax(score_mat_scaled, -1)
        out =  attention @ Vs
        return out

In [55]:
# Combining all the concept to build a multihead attention class using numpy
class MultiheadAttention(nn.Module):
    def __init__(self, h: int, d_k: int, d_v: int, d_emb: int) -> None:
        """
        Initialize multihead attention mechanism.

        Args:
            h: number of heads
            d_k: dimension of key and query vectors
            d_v: dimension of value vector
            d_emb: embedding dimension of each token.
            max_seq_len: maximum sequence length allowed.
        """
        super().__init__()
        self.n_heads = h                    # No. of heads.
        self.d_k = d_k                      # key dimension.
        self.d_v = d_v                      # value dimension.
        self.d_model = int(h * d_v)         # Model dimension.
        self.d_model2 = int(h * d_k)        # Model dimension.
        
        self.W_Q = nn.Linear(d_emb, self.d_model2)  # Query matrix.
        self.W_K = nn.Linear(d_emb, self.d_model2)  # Key matrix.
        self.W_V = nn.Linear(d_emb, self.d_model)   # Value matrix.
        self.W_o = nn.Linear(self.d_model, self.d_model)  # Output matrix.
        self.scaled_dot_product_attention = ScaledDotProductAttention()

    def split_head(self, x: torch.tensor) -> torch.tensor:
        """
        Split the input tensor into multiple heads.

        Args:
            x: input tensor of shape (batch_size, max_seq_len, d_k*n_heads)

        Returns:
            Reshaped tensor of dimension (batch_size, n_heads, max_seq_len, d_k/d_v)
        """
        x = x.view(x.shape[0], x.shape[1], self.n_heads, -1)    # (batch_size, max_seq_len, n_heads, d_k/d_v)
        out = x.permute(0, 2, 1, 3)                             # (batch_size, n_heads, max_seq_len, d_k/d_v)
        return out
    
    def forward(self, x, mask: bool) -> torch.tensor:
        """
        Forward pass of multihead attention.

        Args:
            x: input tensor of shape (batch_size, max_seq_len, d_k*n_heads)
            mask: whether to apply mask based on whether called during encoding/decoding.

        Returns:
            Reshaped tensor of dimension (batch_size, max_seq_len, d_model)
        """
        Qs = self.W_Q(x)       # Queries, shape: (batch_size, max_seq_len, d_k*n_heads)
        Ks = self.W_K(x)       # Keys, shape: (batch_size, max_seq_len, d_k*n_heads)
        Vs = self.W_V(x)       # Values, shape: (batch_size, max_seq_len, d_v*n_heads)

        Qs = self.split_head(Qs)   # (batch_size, n_heads, max_seq_len, d_k)
        Ks = self.split_head(Ks)   # (batch_size, n_heads, max_seq_len, d_k)
        Vs = self.split_head(Vs)   # (batch_size, n_heads, max_seq_len, d_v)

        multihead_vals = self.scaled_dot_product_attention(Qs, Ks, Vs, mask)   # (batch_size, n_heads, max_seq_len, d_v)
        multihead_vals = multihead_vals.view(x.shape[0], x.shape[1], self.d_model)   # (batch_size, max_seq_len, d_model)
        out = self.W_o(multihead_vals)   # (batch_size, max_seq_len, d_model)
        return out

In [43]:
class LayerNormalization(nn.Module):
    def __init__(self, params_shape:list, esp:int=1e-5):
        super().__init__()
        self.params_shape = params_shape
        self.gamma = nn.Parameter(torch.ones(params_shape))
        self.beta = nn.Parameter(torch.zeros(params_shape))
        self.esp = esp

    def forward(self, input):
        dims = [-(i+1) for i in range(len(self.params_shape))]
        mean = input.mean(dim=dims, keepdim=True)
        var = ((input - mean) ** 2).mean(dim=dims, keepdim=True)
        std = var.sqrt()
        y = (input - mean) / (std + self.esp)
        out = self.gamma * y + self.beta
        return out

In [44]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model, n_hidden, drop_prob):
    super().__init__()
    self.linear1 = nn.Linear(d_model, n_hidden)
    self.linear2 = nn.Linear(n_hidden, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(drop_prob)

  def forward(self, x):
    x = self.linear1(x)
    print(f"x after first linear layer: {x.size()}")
    x = self.relu(x)
    print(f"x after activation: {x.size()}")
    x = self.dropout(x)
    print(f"x after dropout: {x.size()}")
    x = self.linear2(x)
    print(f"x after 2nd linear layer: {x.size()}")
    return x

In [59]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
    super().__init__()
    d_k = d_model/num_heads
    d_v = d_model/num_heads
    self.attention = MultiheadAttention(num_heads, d_k, d_v, d_model)
    self.norm1 = LayerNormalization(params_shape=[d_model])
    self.dropout1 = nn.Dropout(p=drop_prob)
    self.ffn = PositionwiseFeedForward(d_model=d_model, n_hidden=ffn_hidden, drop_prob=drop_prob)
    self.norm2 = LayerNormalization(params_shape=[d_model])
    self.dropout2 = nn.Dropout(p=drop_prob)

  def forward(self, x):
    residual_x = x
    print("------- ATTENTION 1 ------")
    x = self.attention(x, mask=None)
    print("------- DROPOUT 1 ------")
    x = self.dropout1(x)
    print("------- ADD AND LAYER NORMALIZATION 1 ------")
    x = self.norm1(x + residual_x)
    residual_x = x
    print("------- ATTENTION 2 ------")
    x = self.ffn(x)
    print("------- DROPOUT 2 ------")
    x = self.dropout2(x)
    print("------- ADD AND LAYER NORMALIZATION 2 ------")
    x = self.norm2(x + residual_x)
    return x

In [60]:
class Encoder:
  def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
    self.layers = nn.Sequential(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

  def forward(self, x):
    x = self.layers(x)
    return x

In [61]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5

In [62]:
encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)

In [63]:
x = torch.randn( (batch_size, max_sequence_length, d_model) ) # includes positional encoding
out = encoder.forward(x)

------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- ATTENTION 2 ------
x after first linear layer: torch.Size([30, 200, 2048])
x after activation: torch.Size([30, 200, 2048])
x after dropout: torch.Size([30, 200, 2048])
x after 2nd linear layer: torch.Size([30, 200, 512])
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------
------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- ATTENTION 2 ------
x after first linear layer: torch.Size([30, 200, 2048])
x after activation: torch.Size([30, 200, 2048])
x after dropout: torch.Size([30, 200, 2048])
x after 2nd linear layer: torch.Size([30, 200, 512])
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------
------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- ATTENTION 2 ------
x after first linear layer: torch.Size([30, 200, 2048])
x after activation: torch.Si