In [5]:
#import necessary librarys
import math
from typing import Optional,List

import torch
from torch import nn
from labml import tracker

In [4]:
!pip install labml

Collecting labml
  Downloading labml-0.5.3-py3-none-any.whl.metadata (7.1 kB)
Downloading labml-0.5.3-py3-none-any.whl (94 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/94.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m92.2/94.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.6/94.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: labml
Successfully installed labml-0.5.3


##Prepare for multi-head attention

In [6]:
class prepareForMultiHeadAttention(nn.Module):
  def __init__(self,d_model:int,heads:int,d_k:int,bias:bool):
    super().__init__()
    #Linear layer for linear transform
    self.linear=nn.Linear(d_model,heads*d_k,bias=bias)
    #Number of heads
    self.heads=heads
    #Number of dimension in vectors in each head
    self.d_k=d_k

  def forward(self,x:torch.Tensor):
    head_shape=x.shape[:-1]
    #Linear transform
    x=self.linear(x)
    #split last dimension into heads
    x=x.view(*head_shape,self.heads,self.d_k)
    return x

In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self,heads:int,d_model:int,dropout_prob:float=0.1,bias:bool=True):
    super().__init__()
    #Number of feature per head
    self.d_k=d_model//heads
    #Number of heads
    self.heads=heads
    self.query=prepareForMultiHeadAttention(d_model,heads,self.d_k,bias=bias)
    self.key=prepareForMultiHeadAttention(d_model,heads,self.d_k,bias=bias)
    self.value=prepareForMultiHeadAttention(d_model,heads,self.d_k,bias=True)
    #softmax for attention along the time dimension of key
    self.softmax=nn.Softmax(dim=1)
    #output layer
    self.output=nn.Linear(d_model,d_model)
    #Dropout
    self.dropout=nn.Dropout(dropout_prob)
    #scaling factor before the softmax
    self.scale=1/math.sqrt(self.d_k)
    #we store attentions so that it can be used for logging ,or other computations if needed
    self.attn=None

    ##calculate scores between queries and keys
  def get_scores(self,query:torch.Tensor,key:torch.Tensor):
    return torch.einsum('ibhd,jbhd->ijbh',query,key)
  def prepare_mask(self,mask:torch.Tensor,query_shape:List[int],key_shape:List[int]):
    assert mask.shape[0]==1 or mask.shape[0]==query_shape[0]
    assert mask.shape[1]== key_shape[0]
    assert mask.shape[2]==1 or mask.shape[2]==query_shape[1]

    mask=mask.unsqueeze(-1)
    return mask

  def forward(self,*,
              query:torch.Tensor,
              key:torch.Tensor,
              value:torch.Tensor,
              mask:Optional[torch.Tensor]=None):
    seq_len,batch_size, _=query.shape

    if mask is not None:
      mask=self.prepare_mask(mask,query.shape,key.shape)

    query=self.query(query)
    key=self.key(key)
    value=self.value(value)

    scores=self.get_scores(query,key)

    scores*=self.scale

    #apply mask
    if mask is not None:
      scores=scores.masked_fill(mask==0,float('-inf'))

    attn=self.softmax(scores)
    tracker.debug('attn',attn)
    #apply dropout
    attn=self.dropout(attn)
    #multiply by values

    x=torch.einsum("ijbh,jbhd->ibhd",attn,value)
    self.attn=attn.detach()
    #Concatenate multiple heads
    x=x.reshape(seq_len,batch_size,-1)
    #Output layer
    return self.output(x)
