#Replication paper Attention is all you need

https://arxiv.org/pdf/1706.03762.pdf

In [None]:
# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[1]) >= 12, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

torch version: 1.13.1+cu116
torchvision version: 0.14.1+cu116


In [None]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
    !git clone https://github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular .
    !mv pytorch-deep-learning/helper_functions.py . # get the helper_functions.py script
    !rm -rf pytorch-deep-learning
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves

[INFO] Couldn't find torchinfo... installing it.
[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.
Cloning into 'pytorch-deep-learning'...
remote: Enumerating objects: 3435, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 3435 (delta 55), reused 97 (delta 41), pack-reused 3302[K
Receiving objects: 100% (3435/3435), 643.58 MiB | 24.70 MiB/s, done.
Resolving deltas: 100% (1962/1962), done.
Updating files: 100% (222/222), done.


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

#Data

#SelfAttention

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, embedding_size, heads_num):
    super().__init__()
    self.embedding_size = embedding_size
    self.heads_num = heads_num
    self.head_dim = embedding_size//heads_num

    assert (self.head_dim * heads_num == embedding_size), "Embedding size needs to divisible by head number"

    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

    self.fully_connected = nn.Linear(heads_num*self.head_dim, embedding_size)

  def forward(self, values, keys, query, mask):
    N = query.shape[0]
    value_len, keys_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

    values = values.reshape(N, value_len, self.heads_num, self.head_dim)
    keys = keys.reshape(N, keys_len, self.heads_num, self.head_dim)
    queries = query.reshape(N, query_len, self.heads_num, self.head_dim)

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(queries)


    energy = torch.einsum("nqhd,nkhd->nhqk",[queries,keys])

    if mask is not None:
      energy = energy.masked_fill(mask == 0,  float("-1e20"))
    
    attention = torch.softmax(energy / (self.embedding_size ** (1/2)), dim=3)

    out = torch.einsum("nhql,nlhd->nqhd",[attention, values]).reshape(
        N, query_len, self.heads_num*self.head_dim
    )

    out = self.fully_connected(out)
    return out


In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout, forward_expansion):
    super().__init__()
    self.attention = SelfAttention(embedding_size=embed_size, heads_num=heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)

    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size, embed_size)
    )

    self.dropout = nn.Dropout(dropout)

  def forward(self, value, key, query, mask):
    attention = self.attention(value, key, query, mask)
    x = self.dropout(self.norm1(attention + query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward + x))
    return out

In [None]:
class Encoder(nn.Module):
  def __init__(self,
               src_vocab_size,
               embed_size,
               num_layers,
               heads,
               device,
               forward_expansion,
               dropout,
               max_length):
    super().__init__()
    self.embed_size = embed_size
    self.device = device,
    self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
    self.position_embedding = nn.Embedding(max_length, embed_size)

    self.layers = nn.ModuleList(
        [
            TransformerBlock(embed_size=embed_size,heads=heads, dropout=dropout, forward_expansion=forward_expansion)
            for _ in range(num_layers)
        
        ]
    )

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    N, seq_lenght = x.shape
    self.device= "cpu"
    positions = torch.arange(0, seq_lenght).expand(N, seq_lenght).to(self.device)

    out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

    for layer in self.layers:
      out = layer(out,out,out,mask)
    return out

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self,
               embed_size,
               heads,
               forward_expansion,
               dropout,
               device):
    super().__init__()
    self.attention = SelfAttention(embed_size, heads)
    self.norm = nn.LayerNorm(embed_size)
    self.transformer_block = TransformerBlock(
          embed_size, heads, dropout, forward_expansion
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, value, key, src_mask, trg_mask):
    attention = self.attention(x,x,x,trg_mask)
    query = self.dropout(self.norm(attention + x))
    out = self.transformer_block(value, key, query, src_mask)
    return out

In [None]:

class Decoder(nn.Module):
  def __init__(self,
               trg_vocab_size,
               embed_size,
               num_layers,
               heads,
               forward_expansion,
               dropout,
               device,
               max_length):
    super().__init__()
    self.device = device
    self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
    self.position_embedding = nn.Embedding(max_length, embed_size)

    self.layers = nn.ModuleList(
        [DecoderBlock(embed_size=embed_size,heads=heads, forward_expansion=forward_expansion,dropout=dropout,device=device)
        for _ in range(num_layers)]
    )

    self.fc_out = nn.Linear(embed_size, trg_vocab_size)
    self.dropout = nn.Dropout(dropout)
  def forward(self, x, enc_out, src_mask, trg_mask):
    N, seq_lenght = x.shape
    positions = torch.arange(0, seq_lenght).expand(N, seq_lenght).to(self.device)
    x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

    for layer in self.layers:
      x = layer(x,enc_out,enc_out, src_mask, trg_mask)

    out = self.fc_out(x)
    return out


In [None]:
class Transformer(nn.Module):
  def __init__(self,
               src_vocab_size,
               trg_vocab_size,
               src_pad_idx,
               trg_pad_idx,
               embed_size=256,
               num_layers=6,
               forward_expansion=4,
               heads=8,
               dropout=0,
               device="cpu",
               max_length=100):
    super().__init__()

    self.encoder = Encoder(
        src_vocab_size=src_vocab_size, embed_size=embed_size,
         num_layers=num_layers, heads=heads, device=device,
         forward_expansion=forward_expansion,
        dropout=dropout, max_length=max_length
    )

    self.decoder = Decoder(
        trg_vocab_size=trg_vocab_size,
        embed_size=embed_size, num_layers=num_layers,
        heads=heads,forward_expansion=forward_expansion,
        dropout=dropout, device=device, max_length=max_length
    )

    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device

  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask.to(self.device)
  
  def make_trg_mask(self, trg):
    N, trg_len = trg.shape
    trg_mask = torch.tril(torch.ones((trg_len,trg_len))).expand(
        N, 1, trg_len, trg_len
    )
    return trg_mask.to(self.device)

  def forward(self, src, trg):
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)

    enc_src = self.encoder(src,src_mask)
    out = self.decoder(trg, enc_src, src_mask, trg_mask)
    return out


In [None]:
device = "cpu"

x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)

trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0

src_vocab_size = 10
trg_vocab_size = 10

model = Transformer(src_vocab_size,trg_vocab_size,src_pad_idx,trg_vocab_size).to(device)

out = model(x, trg[:, :-1])
print(out.shape)


torch.Size([2, 7, 10])


#My try

## Multihead Self-Attention

In [None]:
class MultiHeadSelfAttentionAndNormBlock(nn.Module):
  def __init__(self,
               embed_dim:int = 512,
               num_heads:int = 6,
               ):
    super().__init__()

    self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                                num_heads=num_heads,
                                                batch_first=False)
    self.layer_norm = nn.LayerNorm(normalized_shape=embed_dim)

  def forward(self, x):
    attention_out, _ = self.multihead_attn(query=x,
                                           key=x,
                                           value=x,
                                           need_weights=False)
    
    residual_conn = attention_out + x
    block_out = self.layer_norm(residual_conn)
    return block_out

In [None]:
class FeedForwardAndNorm(nn.Module):

  def __init__(self,
               embed_dim:int = 512,
               ffn_size:int = 2048):
    super().__init__()

    self.ffn = nn.Sequential(
        nn.Linear(in_features=embed_dim,
                  out_features=ffn_size),
        nn.ReLU(),
        nn.Linear(in_features=ffn_size,
                  out_features=embed_dim)
    )

    self.layer_norm = nn.LayerNorm(normalized_shape=embed_dim)

  def forward(self, x):
    ffn_out = self.ffn(x)
    residual_conn = ffn_out + x
    block_out = self.layer_norm(residual_conn)
    return block_out


##Encoder

In [None]:
class EncoderBlock(nn.Module):
  def _init__(self,
              embed_dim:int = 512,
              num_heads:int = 6,
              ffn_size = 512):
    super().__init__()

    self.mhsan_block = MultiHeadSelfAttentionAndNormBlock(embed_dim=embed_dim,
                                                num_heads=num_heads)
    
    self.ffn_block = FeedForwardAndNorm(embed_dim = embed_dim,
                                        ffn_size = ffn_size)
    
  def forward(self, x):
    x = self.mhsan_block(x)
    x = self.ffn_block(x)
    return x


##Decoder

In [None]:
class DecoderBlock(nn.Module):
  def __init_(self,
              encoder_out,
              embed_dim = 512,
              num_heads = 6,
              ffn_size = 512
              ):
    super().__init__()

    self.mhsan_block1 = MultiHeadSelfAttentionAndNormBlock(embed_dim=embed_dim,
                                                num_heads=num_heads)
    self.mhsan_block2 = MultiHeadSelfAttentionAndNormBlock(embed_dim=embed_dim,
                                                num_heads=num_heads)
    
    self.ffn_block = FeedForwardAndNorm(embed_dim = embed_dim,
                                        ffn_size = ffn_size)
    
  def forward(self, x):    
    x = self.mhsan_block(x)
    x = self.mhsan_block(x)
    x = self.ffn_block(x)
    return x

#Transformer