<a href="https://colab.research.google.com/github/hamednasr/transformers/blob/main/transformers_in_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset
import numpy as np
import matplotlib.pyplot as plt

In [12]:
F.softmax(torch.tensor([.9,0.05,0.05]), dim=-1)

tensor([0.5391, 0.2304, 0.2304])

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_v, d_model, n_heads):
    super().__init__()

    self.d_k = d_k
    self.d_v = d_v
    self.n_heads = n_heads

    self.W_q = nn.Linear(d_model,d_k*n_heads)
    self.W_k = nn.Linear(d_model,d_k*n_heads)
    self.W_v = nn.Linear(d_model,d_v*n_heads)
    self.fc = nn.Linear(d_v*n_heads, d_model)

  def forward(self, X, mask=None): # X could be q,k,v that are different
    Q = self.W_q(X)  # N × T × h*d_k
    K = self.W_k(X)  # N × T × h*d_k
    V = self.W_v(X)  # N × T × h*d_v

    N = Q.shape[0]
    T = Q.shape[1]

    Q = Q.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k
    K = K.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k
    V = V.view(N, T, self.n_heads, self.d_k).transpose(1,2) # N × T × h*d_k -->> N × h × T × d_k

    AttentionScores = Q @ K.transpose(2,3)/torch.sqrt(self.d_k) #  N × h × T × T

    if ~mask:
      mask= torch.unsqueeze(mask, 1)
      mask= torch.unsqueeze(mask, 1)
      AttentionScores = AttentionScores.masked_fill(mask == 0, float('-inf'))

    AttentionWeights = F.softmax(AttentionScores, dim=-1) #  N × h × T × T

    A = AttentionWeights @ V #  N × h × T × d_v
    A = A.transpose(1,2).view(N, T, self.n_heads*self.d_v ) #  N × T × h*d_v

    return self.fc(A)


In [70]:
x = torch.tensor([[3., 4.],[1.,6.]])
print(x.shape)
x =torch.unsqueeze(x, 1)
print(x.shape)
x= torch.unsqueeze(x, 1)
x.shape

torch.Size([2, 2])
torch.Size([2, 1, 2])


torch.Size([2, 1, 1, 2])

In [74]:
F.softmax(x, dim=-1)

tensor([[[[0.2689, 0.7311]]],


        [[[0.0067, 0.9933]]]])

In [54]:
a= torch.rand(2,5,3,4)
b= torch.rand(2,5,4,3)
a.shape

torch.Size([2, 5, 3, 4])

In [56]:
a.view(2,5,4,3).shape

torch.Size([2, 5, 4, 3])

In [52]:
a@b

tensor([[[[0.9868, 1.1049, 1.2843],
          [1.2438, 0.9676, 0.9669],
          [0.3115, 0.1481, 0.3951]],

         [[1.3265, 0.6358, 1.6556],
          [0.4531, 0.6239, 0.8036],
          [1.5561, 0.8614, 2.1342]],

         [[0.5977, 0.6377, 1.1061],
          [0.8970, 0.9109, 1.5910],
          [0.4210, 0.5384, 0.3385]],

         [[1.7268, 1.1087, 1.2237],
          [1.7726, 0.6309, 0.9878],
          [1.6617, 0.6330, 0.8079]],

         [[1.3154, 1.7169, 0.5657],
          [0.6774, 1.0459, 0.6286],
          [1.3178, 1.1842, 0.2263]]],


        [[[2.0023, 0.5348, 1.9632],
          [2.0712, 0.9349, 1.8422],
          [1.1662, 0.7024, 1.1342]],

         [[0.8181, 1.1855, 1.2031],
          [1.0903, 1.5603, 1.6206],
          [1.2615, 1.2045, 1.0205]],

         [[0.9675, 1.2135, 1.5604],
          [0.9938, 1.1461, 0.9769],
          [1.0019, 0.8673, 1.4044]],

         [[0.7846, 1.6117, 0.9175],
          [0.5791, 1.0150, 0.6483],
          [0.5590, 1.0311, 0.5906]],

        

In [51]:
(a@b).shape

torch.Size([2, 5, 3, 3])