<a href="https://colab.research.google.com/github/hamednasr/transformers/blob/main/transformers_in_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset
import numpy as np
import matplotlib.pyplot as plt

In [3]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_v, d_model, n_heads):
    super().__init__()

    self.d_k = d_k
    self.d_v = d_v
    self.n_heads = n_heads

    self.W_q = nn.Linear(d_model,d_k*n_heads)
    self.W_k = nn.Linear(d_model,d_k*n_heads)
    self.W_v = nn.Linear(d_model,d_v*n_heads)
    self.fc = nn.Linear(d_v*n_heads, d_model)

  def forward(self, X, mask=None): # X could be q,k,v that are different
    Q = self.W_q(X)  # N × T × h*d_k
    K = self.W_k(X)  # N × T × h*d_k
    V = self.W_v(X)  # N × T × h*d_v

    N = Q.shape[0]
    T = Q.shape[1]

    Q = Q.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k
    K = K.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k
    V = V.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k

    AttentionScores = Q @ K.transpose(2,3)/torch.sqrt(self.d_k) #  N × h × T × T

    if ~mask:
      mask= torch.unsqueeze(mask, 1)
      mask= torch.unsqueeze(mask, 1)
      AttentionScores = AttentionScores.masked_fill(mask == 0, float('-inf'))

    AttentionWeights = F.softmax(AttentionScores, dim=-1) #  N × h × T × T

    A = AttentionWeights @ V #  N × h × T × d_v
    A = A.transpose(1,2).contiguous().view(N, T, self.n_heads*self.d_v ) #  N × T × h*d_v

    return self.fc(A)


In [93]:
class TransformerBlock(nn.Module):
  def __init__(self, d_k, d_v, d_model, n_heads, dropout_prob=0.2):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = MultiHeadAttention(d_k, d_v, d_model, n_heads)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model*3),
        nn.GELU(),
        nn.Linear(d_model*3, d_model),
        nn.Dropout(dropout_prob)
    )
    self.dropout = nn.Dropout(dropout_prob)

  def forward(self, x, mask= None):
    x = self.ln1(x + self.mha(x,x,x,mask))
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)
    return x

In [94]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 2048, dropout_prob=0.2):
    super().__init__()
    self.dropout = nn.Dropout(dropout_prob)

    position = torch.arange(max_len).unsqueeze(1)
    exp_term = torch.arange(0, d_model, 2)
    div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe',pe)

  def forward(self, x):
    x = x+self.pe[:, :x.size(1), :]
    return self.dropout(x)

In [95]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, max_len,d_model,d_k,d_v,
               n_heads,n_layers,n_classes,dropout_prob):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size,d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [TransformerBlock(d_k, d_v,
                                           d_model,
                                           n_heads,
                                           dropout_prob) for i in range(n_layers)]

    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, n_classes)

  def forward(self, x, mask=None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, mask)


    x = x[:,0,:] # get one of the hidden vectors if it is a one-classfification problem


    x = self.ln(x)
    x = self.fc(x)

    return x


In [96]:
model = Encoder(20000,1024,64,16 ,16,4,2,5,.1)

In [97]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device)

Encoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (W_q): Linear(in_features=64, out_features=64, bias=True)
        (W_k): Linear(in_features=64, out_features=64, bias=True)
        (W_v): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=192, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=192, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elem

In [98]:
x = np.random.randint(0, 20000, size=(8,512))
x_t = torch.tensor(x).to(device)

In [99]:
mask = np.ones((8,512))
mask[:,256:] = 0
mask_t = torch.tensor(mask).to(device)

In [100]:
y = model(x_t, mask_t)

TypeError: ignored