<a href="https://colab.research.google.com/github/constantin50/machine_learning/blob/master/transformer/multihead_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch

Multihead attention takes 3 sequences: Q, K and V

1) Gram matrix for queris and keys: $Q*K$ 

it gives us mesuare of revelence between each pair of words in Q and K


2) Apply mask to Gram matrix

$mask[i][j]$ = 0 if the model is allowed to count jth token when it predicts ith token

$mask[i][j]$ = $-inf$ else

3) Normalize relevance scores with softmax

4) Tensor product of normed revelance scores and V

$ AttScores \otimes Values $







In [0]:
def Multihead_Attention(Q, K, V, K_padding_mask, dependency_mask, is_training, weights_dropout):
  """
  params
  ---
  Q - BatchSize x QueriesLen x HeadN x KeySize
  K - BatchSize x KeysLen x HeadN x KeySize
  V - BatchSize x KeysLen x HeadN x ValueSize
  K_padding_mask - BatchSize x KeysLen
  dependency_mask - ValuesLen x KeysLen
  is_training - bool
  weights_dropout - float 

  returns
  ---

  tuples of two:
  1) BatchSize x QueriesLen x HeadN x ValueSize - features for each query for each head
  2) BatchSize x QueriesLen x KeysLen x HeadN - scores for each position of Q to each position of V
  """

  # calculate scores of revelances of pairs of words
  # BatchSize x ValuesLen x KeysLen x HeadN
  revelances = torch.einsum("bvhs,bkhs->bvkh", (Q, K))

  # apply mask to elements that are beyond of the length of K sequence
  padding_mask_expanded = K_padding_mask[:, None, :, None].expand_as(revelances)
  relevances.masked_fill_(padding_mask_expanded, float("-inf"))

  # apply mask to relevance scores
  relevances = relevances + dependency_mask[None, :, :, None].expand_as(relevances)

  # normalization on dimension of keys
  normed_rels = F.softmax(relevances, dim=2)

  # dropout over normed revelance scores in order to prevent dependency between in and out
  normed_rels = F.dropout(normed_rels, weights_dropout, it_training)
  
  # BatchSize x ValuesLen x KeysLen x HeadN x 1
  normed_rels_expanded = normed_rels.unsqueeze(-1)
      
  # BatchSize x 1 x KeysLen x HeadN x ValueSize
  V_expanded = V.unsqueeze(1)
    
  # Tensor product : BatchSize x ValuesLen x KeysLen x HeadN x ValueSize
  weighted_V = normed_rels_expanded * V_expanded

  # sum over K 
  # for each batch for each out position for each head - vector of features
  result = weighted_V.sum(2)  # BatchSize x ValuesLen x HeadN x ValueSize
    
  return result, normed_rels


In [0]:
class Multihead_SelfAttention(nn.Module):
    def __init__(self, model_size, n_heads, dropout=0):
        super().__init__()
        assert model_size % n_heads == 0, 'model size should be divided by number of head'
        self.n_heads = n_heads

        self.Q_proj = nn.Linear(model_size, model_size)
        self.K_proj = nn.Linear(model_size, model_size)
        self.V_proj = nn.Linear(model_size, model_size)
        
        self.dropout = dropout

        self.last_attention_map = None
    
    def forward(self, sequence, padding_mask, dependency_mask):
        """
        sequence : BatchSize x Len x ModelSize
          batch of texts
        padding_mask : BatchSize x Len
        dependency_mask - Len x Len
        
        result - BatchSize x Len x ModelSize
        """
        batch_size, max_len, model_size = sequence.shape
        
        # We apply 
        # Also, we reshape resulting tensor as follow: split ModelSize into two dimensions: number of heads and 
        # new number of features
        Q_flat = self.Q_proj(sequence)  # BatchSize x Len x ModelSize
        Q = Q_flat.view(batch_size, max_len, self.n_heads, -1)
        
        K_flat = self.K_proj(sequence)  # BatchSize x Len x ModelSize
        K = K_flat.view(batch_size, max_len, self.n_heads, -1)
        
        V_flat = self.V_proj(sequence)  # BatchSize x Len x ModelSize
        V = V_flat.view(batch_size, max_len, self.n_heads, -1)
        

        # BatchSize x Len x HeadsN x ValueSize
        result, att_map = multihead_attention(Q, K, V,
                                                 padding_mask, dependency_mask,
                                                 self.training, self.dropout)
        
        result_flat = result.view(batch_size, max_len, model_size)
        
        # delete references to previous tensors
        self.last_attention_map = att_map.detach()

        return result_flat

In [0]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, model_size, n_heads, dim_feedforward, dropout):
        super().__init__()
        self.self_attention = Multihead_SelfAttention(model_size,
                                                       n_heads,
                                                       dropout=dropout)
        self.first_dropout = nn.Dropout(dropout)
        self.first_norm = nn.LayerNorm(model_size)
        
        self.feedforward = nn.Sequential(
            nn.Linear(model_size, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, model_size),
            nn.Dropout(dropout)
        )
        self.second_norm = nn.LayerNorm(model_size)
    
    def forward(self, sequence, padding_mask, dependency_mask):
        """
        sequence : BatchSize x Len x ModelSize
          batch of texts
        padding_mask : BatchSize x Len
        dependency_mask - Len x Len
        
        result - BatchSize x Len x ModelSize
        """

        # aggregation of context
        att_features = self.self_attention(sequence, padding_mask, dependency_mask)

        # ResNet Block
        # skip connection + dropout
        sequence = sequence + self.first_dropout(att_features)
        sequence = self.first_norm(sequence)
        
        # ResNet Block
        # apply 2 layer perceptron to prevent linearity + skip connection
        sequence = sequence + self.feedforward(sequence)
        sequence = self.second_norm(sequence)
        
        return sequence

**Encoder**

1) Self-attention for evaluation of a global context  + skip connection

2) Layer Normalization

3) 2 layer perceptron + skip connection

4) Layer Normalization

In [0]:
class MyTransformerEncoder(nn.Module):
    def __init__(self, n_layers, **layer_kwargs):
        super().__init__()
        self.layers = nn.ModuleList([
            MyTransformerEncoderLayer(**layer_kwargs)
            for _ in range(n_layers)
        ])
        self.initialize_weights()

    def forward(self, sequence, mask, src_key_padding_mask):
        for layer in self.layers:
            sequence = layer(sequence, src_key_padding_mask, mask)
        return sequence

    def initialize_weights(self):
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)