In [1]:
import torch
import torch.nn as nn

# Self Attention

In [12]:
class SelfAttention(nn.Module):
  def __init__(self,embed_size,heads):
    super(SelfAttention,self).__init__()
    '''
    heads : how many parts are we going to split the word embedding
    embed_size : size of word embedding
    '''
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size // heads # integer division

    assert (self.head_dim * heads == embed_size), "embed_size needs to be divisible by heads"

    self.values = nn.Linear(self.head_dim,self.head_dim,bias=False)
    self.keys = nn.Linear(self.head_dim,self.head_dim,bias=False)
    self.queries = nn.Linear(self.head_dim,self.head_dim,bias=False)
    self.fc_out = nn.Linear(heads*self.head_dim,embed_size)

  def forward(self, values, keys, queries, mask):
    '''
    N : no. of training examples
    '''
    N = queries.shape[0]
    value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

    # Split embedding into self.head pieces
    values = values.reshape(N, value_len, self.heads, self.head_dim)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)
    queries = queries.reshape(N, query_len, self.heads, self.head_dim)
  
    '''
    queries shape: (N, query_len, heads, heads_dim)
    keys shape: (N, key_len, heads, heads_dim)
    score shape: (N, heads, query_len, key_len)
    '''
    score = torch.einsum("nqhd,nkhd -> nhqk", [queries,keys]) 
    # Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.

    if mask is not None:
      score = score.masked_fill(mask == 0, float("-1e20"))
    
    attention = torch.softmax(score / (self.embed_size ** (1/2)), dim=3)

    '''
    attention shape: (N, heads, query_len, key_len)
    values shape: (N, value_len, heads, heads_dim)
    out shape: (N, query_len, heads, heads_dim)
    key_len and value_len are always same
    '''
    attention = torch.einsum("nhql,nlhd -> nqhd", [attention,values]).reshape(
        N,query_len,self.heads*self.head_dim
    )
    
    out = self.fc_out(attention)
    return out
