<a href="https://colab.research.google.com/github/mirpouya/Transformer_EDU/blob/main/einsum_and_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Think about the situation that you want to merge d dimsnsion of a 4D matrix

In [1]:
import torch
from torch import nn

In [2]:
rand_tensor = torch.rand((2, 5))
rand_tensor

tensor([[0.9049, 0.7497, 0.7549, 0.6756, 0.6255],
        [0.5269, 0.8023, 0.4166, 0.6086, 0.3092]])

In [3]:
rand_tensor = torch.rand((3,4,5,6))
rand_tensor

tensor([[[[0.4272, 0.5442, 0.2517, 0.7889, 0.8619, 0.8654],
          [0.0666, 0.8560, 0.9473, 0.5588, 0.7690, 0.0586],
          [0.2043, 0.5678, 0.0451, 0.5862, 0.7480, 0.0267],
          [0.4652, 0.9349, 0.4918, 0.2418, 0.1082, 0.6770],
          [0.5552, 0.8100, 0.4213, 0.9764, 0.0273, 0.5012]],

         [[0.3132, 0.1171, 0.7283, 0.0282, 0.3125, 0.9177],
          [0.4310, 0.5009, 0.7681, 0.8635, 0.2355, 0.2121],
          [0.4530, 0.6935, 0.4342, 0.6764, 0.0459, 0.6705],
          [0.2524, 0.1069, 0.0038, 0.3601, 0.7893, 0.1103],
          [0.5747, 0.5341, 0.8096, 0.3766, 0.9147, 0.8954]],

         [[0.4420, 0.8605, 0.9480, 0.3037, 0.4664, 0.0353],
          [0.5814, 0.1024, 0.8893, 0.5554, 0.9763, 0.5051],
          [0.2044, 0.1190, 0.9412, 0.3974, 0.9323, 0.9973],
          [0.8388, 0.6731, 0.2284, 0.9694, 0.1819, 0.4629],
          [0.3343, 0.2152, 0.6220, 0.7947, 0.6908, 0.8189]],

         [[0.8629, 0.9964, 0.0746, 0.0200, 0.2734, 0.8139],
          [0.2242, 0.7166, 0.0622,

In [4]:
!pip install einops



In [5]:
from einops import rearrange, reduce, repeat

In [6]:
rand_tensor.size()

torch.Size([3, 4, 5, 6])

In [7]:
rand_tensor = rearrange(rand_tensor, "b c h w -> (b w) c h")

# b = 3, c = 4, h = 5, w = 6
# to (3 * 6, 4, 5)

In [8]:
rand_tensor.size()

torch.Size([18, 4, 5])

### <b> einsum </b>

In [9]:
a = torch.rand((10, 10, 30))
a.size()

torch.Size([10, 10, 30])

In [10]:
a = torch.rand(10, 20, 30)
a.size()

torch.Size([10, 20, 30])

In [11]:
a = torch.randn(10, 20, 30)  # b = 10, i = 20, k = 30
b = torch.randn(10, 50, 30)  # b = 10, i = 50, k = 30

In [12]:
y = torch.einsum("b i k, b j k -> b i j", a, b)

In [13]:
y.size()

torch.Size([10, 20, 50])

In [14]:
import numpy as np

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

In [15]:
np.einsum('i,ij->i', A, B)

array([ 0, 22, 76])

# <b> Transformer's blocks implementation </b>

## <b> Scaled dot product self-attention </b>

In the article, after embedding and positional embedding, embeddings are multiplied by weight matrices  <b> W_Q </b>,  <b> W_K </b>,  <b> W_V </b>.

the input <b> X </b> to the attention block is of shape: (batch, sequence_len, embedding_dim)

The matrix multiplication happens in the <b> embedding_dim </b> dimension, regardless of batch_size and sequence_len

In [16]:
# linear projection before attention block
embed_dim = 512
input = torch.randn(10, 12, 512)   # for example

qkv_weights = nn.Linear(embed_dim, embed_dim * 3, bias = False)  # we concat q, k, v into one matrix, then multiply by qkv_weights

qkv = qkv_weights(input)

q, k, v = tuple(rearrange(qkv, "b t (d k) -> k b t d", k = 3))

In [17]:
q.size(), k.size(), v.size()

(torch.Size([10, 12, 512]),
 torch.Size([10, 12, 512]),
 torch.Size([10, 12, 512]))

<b> Step 2 </b>

calculate scaled dot product, apply mask, and finally compute softmax in d last dimension.

In [18]:
# scaled dot product
scale_factor = 512 / 8   # embed_dim / num_heads
scaled_dot_product = torch.einsum("b i d, b j d -> b i j",q, k) * scale_factor
# resulting shape: (batch_size, embed_dim, embed_dim)

# masking if needed (decoder)
mask = None
if mask is not None:
  assert mask.shape == scaled_dot_product.shape[1:]
  scaled_dot_product = scaled_dot_product.masked_fill(mask, -np.inf)

attention = torch.softmax(scaled_dot_product, dim = -1)

# multiply attention scores with V
print(f"attention size: {attention.size()}")
print(f"value size: {v.size()}")

attention_final = torch.einsum("b i d, b d j -> b i j", attention, v)

attention size: torch.Size([10, 12, 12])
value size: torch.Size([10, 12, 512])


## <b> Implementation of Scaled dot Prodcut Self-Attention </b>

In [19]:
import torch
from torch import nn

import numpy
from einops import rearrange

In [20]:
class SelfAttention(nn.Module):

  def __init__(self, embed_dim):
    """
    Args:
        embed_dim: embedding dimension, with 512 as default
        the last dimension size that is provided in forward(x), where x is a 3D tensor
    """
    super().__init__()

    # wieght matrices for query, key, and value
    self.qkv_weights = nn.Linear(embed_dim, embed_dim * 3, bias = False)
    self.scale_factor = embed_dim ** (-0.5)

  # forward method
  def forward(self, x, mask = None):
    assert x.dim() == 3  # x must be a 3D tensor (batch_size, sequence_len, embed_dim)

    # step 1
    qkv = self.qkv_weights(x)

    # step 2: decomposing to q, k, v
    # rearranging to [3, batch_size, sequence_len, embed_dim]
    q, k , v = tuple(rearrange(qkv, "b t (d k) -> k b t d", k = 3))

    # scaled_dot_product
    scaled_dot_product = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale_factor

    # mask attention
    if mask is not None:
      assert mask.shape == x.shape[1:]
      scaled_dot_product = scaled_dot_product.masked_fill(mask, -np.inf)

    attention = torch.softmax(scaled_dot_product, dim = -1)
    # you have to multiply V in the dimension you apply the softmax. Be careful of that.
    attention_final = torch.einsum("b i d, b d k -> b i k")

    return attention_final

## <b> Implementing Multi-Head_Self-Attention </b>

in a single head case, we project the embedded and positioonal encoded input to weight matrix of size (embed_dim, embed_dim * 3)

in multi-head case, we project the inuput matrix to weigh matrix of size (embed_dim, head_dim * n_heads * 3), which is the same, but in the `rearrange()` step it is easy to separate heads.

In [21]:
head_dim = 512 // 8
n_heads = 8
x = torch.randn(10, 12, 512)

qkv_weights = nn.Linear(embed_dim, head_dim * n_heads * 3, bias=False)
qkv = qkv_weights(x)

In [22]:
# decompose qkv to q, k, v
q, k, v = tuple(rearrange(qkv, "b s (h n k) -> k b n s h", k = 3, n = n_heads))

In [23]:
qkv.size()

torch.Size([10, 12, 1536])

In [24]:
q.size(), k.size(), v.size()

(torch.Size([10, 8, 12, 64]),
 torch.Size([10, 8, 12, 64]),
 torch.Size([10, 8, 12, 64]))

the next step is to calculate `scaled-dot-product`, apply the mask, and finally compute the `softmax` in `dim_head`

In [25]:
# matrix multiplication of q and v in heads
# q, k -> (batch_size, n_heads, seq_len, dim_head)
scaled_dot_product = torch.einsum("b n s d, b n t d -> b n s t", q, k) * scale_factor   # (batch_size, n_heads, seq_len, tokens) "seq_len = tokens"

if mask is not None:
  # check mask shape
  assert mask.shape == scaled_dot_product.shape[2:]
  scaled_dot_product = scaled_dot_product.masked_fill(mask, -np.inf)

attention = torch.softmax(scaled_dot_product, dim = -1)

now that the attention is computed, we must multiply the attention score of each head with the corresponding value of each head

In [26]:
# attention shape : (batch_size, n_heads, sentence_words, sentence_words)
# value shape : (batch_size, n_heads, sentence_words, head_dim)

out = torch.einsum("b n i j, b n j d -> b n i d", attention, v)

it's time to merge heads into one matrix and multiply it by W_O

In [27]:
# concatenating heads
out = rearrange(out, "b n i d -> b i (n d)")

MultiHead(Q, K, V) = Concat(head1, ..., head8) <br>
head_i = Attention(QW_Q,i, KW_K,i, VW_V,i)

In [28]:
# apply linear transformation W_O
W_O = nn.Linear(embed_dim, embed_dim, bias = False)

final_output = W_O(out)

## <b> Implementation of Multi-Head-Attention </b>

In [29]:
import torch
from torch import nn
from einops import rearrange
import numpy as np

class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, n_heads):
    super().__init__()

    self.n_heads = n_heads
    self.dim_head = embed_dim // self.n_heads
    self.embed_dim = self.dim_head * n_heads
    self.scale_factor = self.dim_head ** (-0.5)

    # weight matrices
    self.qkv_weights = nn.Linear(self.embed_dim, n_heads * self.dim_head * 3, bias = False)
    self.W_O = nn.Linear(self.embed_dim, self.embed_dim, bias = False)

  # forward method
  def forward(self, x, mask = None):
    # check x has all 3 dimensions -> (batch_size, sentence_length, embedding_dim)
    assert x.shape == 3

    # step 1: compute query, key, value
    qkv = self.qdv_weights(x)   # (batch_size, sentence_length, dim_head * n_heads * 3)

    # step 2: decompose to q, k, v
    # resulting shape befor tuple():
    # [3, n_heads, batch_size, sentence_len, head_dim]
    q, k, v = tuple(rearrange(qkv, "b s (d n k) -> k b n s d", k = 3, h = self.n_heads))

    # step 3: compute scaled_dot_product
    scaled_dot_product = torch.einsum("b n s d, h n t d -> b n s t", q, v) * self.scale_factor

    # mask if needed
    if mask is not None:
      assert mask.shape == scaled_dot_product[2:]
      scaled_dot_product = scaled_dot_product.masked_fill(mask, -np.inf)

    attention = torch.softmax(scaled_dot_product, dim = -1)

    # step 4: calculate output
    out = torch.einsum("b n s j, b n j d -> b n s d", attention, v)

    # step 5: merge heads
    out = rearrange(out, "b n i d -> b s (n d)")

    # step 6: apply W_O
    output = self.W_O(out)

    return output

## <b> Transformer Encoder Block </b>

In [29]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_dim, n_heads = 8, dim_head = None, dim_linear_block = 1024, droupout_rate = 0.1):
    super().__init__()

    self.multi_head_att = MultiHeadAttention(embed_dim=)