In [None]:
import torch as th
import torch.nn as nn

In [None]:
class SelfAttention(nn.Module):

  def __init__(self, embed_size, heads):
      super(SelfAttention, self).__init__()

      self.embed_size = embed_size
      self.heads = heads
      self.head_dim = embed_size // heads # 128

      assert (self.head_dim * heads == embed_size), "Embed size needs to div by heads"            

      self.values = nn.Linear(self.embed_size, self.embed_size, bias = False)
      self.keys = nn.Linear(self.embed_size, self.embed_size, bias = False)
      self.queries = nn.Linear(self.embed_size, self.embed_size, bias = False)

      self.fc_out = nn.Linear(heads * self.head_dim, embed_size) 

      
  def forward(self, values, keys, query, mask): 
      N = query.shape[0] # query.shape = (batch_size, seq_len, embed_size) -> (N, T, D)
      
      value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] 

      values = self.values(values) # Wv # (N, value_len, D) -> (N, value_len, D) # heads
      keys = self.keys(keys) # Wk # (N, key_len, D) -> (N, key_len, D)
      query = self.queries(query) # Wq # (N, query_len, D) -> (N, query_len, D)


      values = values.reshape(N, value_len, self.heads, self.head_dim) # (N, T, D)-> (N, T, heads, head_dim)
      keys = values.reshape(N, key_len, self.heads, self.head_dim)
      query = query.reshape(N, query_len, self.heads, self.head_dim)

      energy = th.einsum("nqhd,nkhd->nhqk", [query, keys]) # 
      

      if mask is not None:
          energy = energy.masked_fill(mask == 0, float("-1e20")) # decodder part masked multhead attention
      
      attention = th.softmax(energy / (self.embed_size ** 0.5), dim = 3)


      out = th.einsum("nhql,nlhd->nqhd", [attention, values])


      out = out.reshape(N, query_len, self.embed_size)


      out = self.fc_out(out)
      return out




In [None]:
embed_size = 256
batch_size = 64 
seq_len = 5
heads = 8 # defualt eight is just fine
model = SelfAttention(embed_size, heads)
x = th.randn((batch_size, seq_len, embed_size))
out = model(x, x, x, mask = None)
out.shape

torch.Size([64, 5, 256])

In [None]:
x = th.randn((1, 8, 5, 5))
score = th.softmax(x, dim = 3)
th.sum(score, dim = 3)

tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])

In [None]:
class TransformerBlock(nn.Module):


    def __init__(self, embed_size, heads, dropout, forward_expansion):

        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.layernorm1 = nn.LayerNorm(embed_size)
        self.layernorm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            
            nn.Linear(embed_size, embed_size * forward_expansion),
            nn.ReLU(),
            nn.Linear(embed_size * forward_expansion, embed_size)
 
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        
        attention = self.attention(value, key, query, mask)

        #print(attention.shape, query.shape)
        out = self.dropout(self.layernorm1(attention + query)) # query eklendi fakat burada diger herhangi biri de eklenebilirdi 
                                                                # fakat eklenemez cunku bu taraf decoder tarafinda da kullaniliyor ve 
                                                                # cok meshur olan o sekil incelirse query eklenmesi lazim!!


        forward = self.feed_forward(self.feed_forward(out))
        
        return self.dropout(self.layernorm2(forward + out))
        

        return out


In [None]:
embed_size = 256
heads = 8
dropout = 0.4
forward_expansion = 4
batch_size = 16
seq_len = 5
model = TransformerBlock(embed_size, heads, dropout, forward_expansion)

x = th.randn((batch_size, seq_len, embed_size))
score = model(x, x, x, mask = None)
score.shape

torch.Size([16, 5, 256])

In [None]:
class Encoder(nn.Module):


  def __init__(self, 
               src_vocab_size,
               embed_size, 
               num_layers,
               heads,
               device,
               forward_expansion,
               dropout,
               max_len
               ):
    
      super(Encoder, self).__init__()

      self.embed_size = embed_size
      self.device = device
      self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
      self.pos_embedding = nn.Embedding(max_len, embed_size)

      self.layers = nn.ModuleList(
          
          [
           TransformerBlock(
               embed_size,
               heads,
               dropout,
               forward_expansion
           )
           for _ in range(num_layers)
          ]
      )
      self.dropout = nn.Dropout(dropout)
  def forward(self, x, mask):

      N, seq_len = x.shape[0], x.shape[1]
      pos = th.arange(0, seq_len).expand(N, seq_len).to(self.device)
      
      out = self.dropout(self.word_embedding(x) + self.pos_embedding(pos))

      for layer in self.layers:

        out = layer(out, out, out, mask)
      return out

  

In [None]:
th.arange(0, 5).expand(4, 5)


tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])

In [None]:
embed_size = 256
heads = 8
dropout = 0.4
forward_expansion = 4
batch_size = 16
seq_len = 5
device =  "cpu"
num_layers = 6
src_vocab_size = 10090
max_len = 100
model = Encoder(src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_len)

x = th.randint(0, max_len, size = (batch_size, seq_len))
score = model(x, mask = None)
score.shape

torch.Size([16, 5, 256])

In [None]:
class DecoderBlock(nn.Module):


    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock,self).__init__()

        self.attention = SelfAttention(embed_size, heads)
        self.layernorm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask): # src_mask padding yapilan ifadelerin islenmemesini sagliyacak!
        
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.layernorm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out


In [None]:
model = DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
x = th.randn((batch_size, seq_len, embed_size))
value = th.randn((batch_size, seq_len, embed_size))
key =  th.randn((batch_size, seq_len, embed_size))

score = model(x, value, key, src_mask = None, trg_mask = None)
score.shape

torch.Size([16, 5, 256])

In [None]:
class Decoder(nn.Module):


    def __init__(self, 
                 trg_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 device,
                 max_len
                 ):
      
        super(Decoder, self).__init__()
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.pos_embedding = nn.Embedding(max_len, embed_size)
        
        self.layers = nn.ModuleList(
            [
             DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
            for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.device = device
        self.dropout = nn.Dropout(dropout)


    def forward(self, x, encoder_out, src_mask, trg_mask):

        N, seq_len = x.shape
        pos = th.arange(0, seq_len).expand(N, seq_len).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.pos_embedding(pos))

        
        for layer in self.layers:
            out = layer(out, encoder_out, encoder_out, src_mask, trg_mask)

        out = self.fc_out(out)
        return out


In [None]:
trg_vocab_size = 10000
embed_size = 256
num_layers = 6
heads = 8
forward_expansion = 4
dropout = 0.4
device = "cpu"
max_len = 100
decoder = Decoder(trg_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 device,
                 max_len
                 )
x = th.randint(0, trg_vocab_size, size = (batch_size, seq_len))
encoder_out = th.randn(batch_size, seq_len, embed_size)
score = decoder(x, encoder_out, src_mask = None, trg_mask = None)
score.shape

torch.Size([16, 5, 10000])

In [None]:
class Transformer(nn.Module):


    def __init__(self, 
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size = 256,
                 num_layer = 6,
                 forward_expansion = 4,
                 heads = 8,
                 dropout = 0,
                 device = "cpu",
                 max_len = 100
                 ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(src_vocab_size,
                               embed_size, 
                               num_layers,
                               heads,
                               device,
                               forward_expansion,
                               dropout,
                               max_len
                               )
        self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, heads,
                               forward_expansion, dropout, device, 
                               max_len)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src !=self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = th.tril(th.ones((trg_len, trg_len)))
        return trg_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        encoder_out = self.encoder(src, src_mask)
        out = self.decoder(trg, encoder_out, src_mask, trg_mask)
        return out

In [None]:
th.tril(th.ones((5, 5)))

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [None]:
src = th.randint(0, 10, size = (16, 100))
trg = th.randint(0, 10, size = (16, 8))

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx)
out = model(src, trg)

In [None]:
out.shape

torch.Size([16, 8, 10])