<a href="https://colab.research.google.com/github/jprashant21/language-translation/blob/main/transformer%20encoder%20classes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## self attention dot product
# dimension query = [1,5,768]

def self_attn_dot_product(query,key,value):

  emb_dim = query.size(-1)
  score = torch.bmm(query,key.transpose(1,2)) / sqrt(emb_dim)
  weights = F.softmax(score, dim=-1)
  return torch.bmm(weights,value)


In [None]:
## attention head

class AttentionHead(nn.Module):

  def __init__(self,emd_dim,head_dim):
    super().__init__()
    self.q = nn.Linear(emb_dim,head_dim)
    self.k = nn.Linear(emb_dim,head_dim)
    self.v = nn.Linear(emb_dim,head_dim)

  def forward(self,hidden_state):
    return self_attn_dot_product(self.q(hidden_state), 
                                 self.k(hidden_state),
                                 self.v(hidden_state))


In [None]:
## multi headed attention

class MultiHeadedAttention(nn.Module):

  def __init__(self,config):
    super().__init__()
    self.emb_dim = config.emb_dim
    self.num_heads = config.num_heads
    self.head_dim = self.emb_dim // self.num_heads

    self.heads = nn.ModuleList(
        [AttentionHead(self.emb_dim,self.head_dim) for _ in self.num_heads]
    )
    self.output_linear = nn.Linear(self.emb_dim,self.emb_dim)

  def forward(hidden_state):
    x = torch.cat( [h(hidden_state) for h in self.heads], dim=-1)
    x = self.output_linear(x)
    return x

In [None]:
## feed forward

class FeedForward(nn.Module):

  def __init__(self,config):
    super().__init__()
    self.hidden_size = config.hidden_size
    self.intermediate_size = config.intermediate_size

    self.l1 = nn.Linear(self.hidden_size,config.intermediate_size)
    self.l2 = nn.Linear(config.intermediate_size,self.hidden_size)
    self.gleu = nn.GLEU()
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(x):
    x = self.l1(x)
    x = self.gleu(x)
    x = self.l2(x)
    x = self.dropout(x)
    return x



In [None]:
## layer norm and skip connection

class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(x):
      x = self.layer_norm_1(x)
      x = x + self.attention(x)
      x = self.layer_norm_2(x)
      x = x + self.feed_forward(x)
      return x


In [None]:
## positional embedding

class PostionalEmbeddings(nn.Module):

  def __init__(self,config):
    super().__init__()
    self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
    self.position_embedding = nn.Embedding(config.max_position_size, config.hidden_size)
    self.layer_norm = nn.LayerNorm(config.hidden_size)
    self.dropout = nn.DropOut(config.hidden_dropout_prob)

  def forward(self,input_ids):

    seq_len = input_ids.size(1)
    position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
    token_embeddings = self.token_embedding(input_ids)
    position_embeddings = self.position_embeddings(position_ids)

    embeddings = token_embeddings + position_embeddings
    x = self.layer_norm(embeddings)
    x = self.dropout(x)
    return x



In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = PostionalEmbeddings(config)
        self.encoderstack = nn.ModuleList([TransformerEncoderLayer(config)
                                     for _ in range(config.num_hidden_layers)])

    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.encoderstack:
            x = layer(x)
        return x

encoder = TransformerEncoder(config)
encoder(inputs.input_ids).size()
