<a href="https://colab.research.google.com/github/hamednasr/transformers/blob/main/transformers_in_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset
import numpy as np
import matplotlib.pyplot as plt

In [12]:
F.softmax(torch.tensor([.9,0.05,0.05]), dim=-1)

tensor([0.5391, 0.2304, 0.2304])

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_v, d_model, n_heads):
    super().__init__()

    self.d_k = d_k
    self.d_v = d_v
    self.n_heads = n_heads

    self.W_q = nn.Linear(d_model,d_k*n_heads)
    self.W_k = nn.Linear(d_model,d_k*n_heads)
    self.W_v = nn.Linear(d_model,d_v*n_heads)
    self.fc = nn.Linear(d_v*n_heads, d_model)

  def forward(self, X, mask=None): # X could be q,k,v that are different
    Q = self.W_q(X)  # N × T × h*d_k
    K = self.W_k(X)  # N × T × h*d_k
    V = self.W_v(X)  # N × T × h*d_v

    N = Q.shape[0]
    T = Q.shape[1]

    Q = Q.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k
    K = K.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k
    V = V.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k

    AttentionScores = Q @ K.transpose(2,3)/torch.sqrt(self.d_k) #  N × h × T × T

    if ~mask:
      mask= torch.unsqueeze(mask, 1)
      mask= torch.unsqueeze(mask, 1)
      AttentionScores = AttentionScores.masked_fill(mask == 0, float('-inf'))

    AttentionWeights = F.softmax(AttentionScores, dim=-1) #  N × h × T × T

    A = AttentionWeights @ V #  N × h × T × d_v
    A = A.transpose(1,2).view(N, T, self.n_heads*self.d_v ) #  N × T × h*d_v

    return self.fc(A)


In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, d_k, d_v, d_model, n_heads, dropout_prob=0.2):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = MultiHeadAttention(d_k, d_v, d_model, n_heads)
    self.ann = nn.Sequential(
        nn.LayerNorm(d_model, d_model*3)
        nn.GELU()
        nn.LayerNorm(d_model*3, d_model)
        nn.Dropout(dropout_prob)
    )
    self.dropout = nn.Dropout(dropout_prob)

  def forward(self, x, mask= None):
    x = self.ln1(x + self.mha(x, x, x, mask))
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)
    return x