In [1]:
import torch
from torch import nn
import torch.nn.functional as f
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nn_softargmax = nn.Softmax

# Multi head attention

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, p, d_input=None):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        if d_input is None:
            d_xq = d_xk = d_xv = d_model
        else:
            d_xq,d_xk,d_xv = d_input
        # make sure the embedding dimension of model is a multiple of number of heads
        assert d_model % self.num_heads == 0
        
        self.d_k = d_model//self.num_heads
        
        self.W_q = nn.Linear(d_xq, d_model, bias=False)
        self.W_k = nn.Linear(d_xk, d_model, bias=False)
        self.W_v = nn.Linear(d_xv, d_model, bias=False)
        
        # outputs of all sub-layers
        self.W_h = nn.Linear(d_model,d_model)
        
    def scaled_dot_product_attention(self,Q,K,V):
        batch_size = Q.size(0)
        k_length = K.size(-2)
        # scaling by d_k so the softmaxarg dont saturate
        Q = Q/np.sqrt(self.d_k)
        scores = torch.matmul(Q,K.transpose(2,3))
        A = nn_softargmax(dim=-1)(scores)
        # get the weigthed average of the values
        H = torch.matmul(A,V)
        return H,A
    
    def split_heads(self,x,batch_size):
        """ Split the last dimension into (heads X depth)
        return after transpose to put in shape """
        return x.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)
    
    def group_heads(self,x,batch_size):
        """ Combine the heads to get batch_size X seq_length X num_heads X d_k """
        return x.transpose(1,2).contiguous().view(batch_size,-1,self.num_heads*self.d_k)
    
    def forward(self,X_q,X_k,X_v):
        batch_size,seq_length,dim = X_q.size()
        # after transforming split into num_heads
        Q = self.split_heads(self.W_q(X_q), batch_size)
        K = self.split_heads(self.W_k(X_k), batch_size)
        V = self.split_heads(self.W_v(X_v), batch_size)
        # calcultate the attention weights for each head
        H_cat,A = self.scaled_dot_product_attention(Q,K,V)
        # put all heads back together by concat
        H_cat = self.group_heads(H_cat,batch_size)
        # final linear layer
        H = self.W_h(H_cat)
        return H,A

## Some sanity checks

To check self attention works - if the query matches with one of the key values, it should have all the attention focused there with the value returned as the value at that index

In [4]:
# create multi head attention class
temp_mha = MultiHeadAttention(d_model=512,num_heads=8,p=0)

In [5]:
def print_out(Q,K,V):
    temp_out,temp_attn = temp_mha.scaled_dot_product_attention(Q,K,V)
    print("Attention weights are: ", temp_attn.squeeze())
    print("Output is: ", temp_out.squeeze())

In [6]:
test_K = torch.tensor(
    [[10,0,0],
     [ 0,10,0],
     [ 0,0,10],
     [ 0,0,10]]
).float()[None,None]
test_V = torch.tensor(
    [[   1,0,0],
     [  10,0,0],
     [ 100,5,0],
     [1000,6,0]]
).float()[None,None]
test_Q = torch.tensor(
    [[0,10,0]]
).float()[None,None]

print_out(test_Q,test_K,test_V)

Attention weights are:  tensor([3.7266e-06, 9.9999e-01, 3.7266e-06, 3.7266e-06])
Output is:  tensor([1.0004e+01, 4.0993e-05, 0.0000e+00])


In [7]:
test_Q = torch.tensor(
    [[0,0,10]]
).float()
print_out(test_Q,test_K,test_V)

Attention weights are:  tensor([1.8633e-06, 1.8633e-06, 5.0000e-01, 5.0000e-01])
Output is:  tensor([549.9979,   5.5000,   0.0000])


In [8]:
test_Q = torch.tensor([[0,0,10],[0,10,0],[10,10,0]]).float()
print_out(test_Q,test_K,test_V)

Attention weights are:  tensor([[1.8633e-06, 1.8633e-06, 5.0000e-01, 5.0000e-01],
        [3.7266e-06, 9.9999e-01, 3.7266e-06, 3.7266e-06],
        [5.0000e-01, 5.0000e-01, 1.8633e-06, 1.8633e-06]])
Output is:  tensor([[5.5000e+02, 5.5000e+00, 0.0000e+00],
        [1.0004e+01, 4.0993e-05, 0.0000e+00],
        [5.5020e+00, 2.0497e-05, 0.0000e+00]])


# 1D convolution with kernel_size=1

This is basically an MLP with one hidden layer and ReLU activation applied to each and every element in the set

In [9]:
class CNN(nn.Module):
    def __init__(self,d_model,hidden_dim,p):
        super().__init__()
        self.k1convL1 = nn.Linear(d_model,hidden_dim)
        self.k1convL2 = nn.Linear(hidden_dim,d_model)
        self.activation = nn.ReLU()
        
    def forward(self,x):
        x = self.k1convL1(x)
        x = self.activation(x)
        x = self.k1convL2(x)
        return x

# Transformer encoder

In [10]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,num_heads,conv_hidden_dim,p=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model,num_heads,p)
        self.cnn = CNN(d_model,conv_hidden_dim,p)
        self.layernorm1 = nn.LayerNorm(normalized_shape=d_model,eps=1e-6)
        self.layernorm2 = nn.LayerNorm(normalized_shape=d_model,eps=1e-6)
        
    def forward(self, x):
        # multi-head attention
        attn_output,_ = self.mha(x,x,x)
        # layer norm after adding the residual connection
        out1 = self.layernorm1(x+attn_output)
        # feed forward
        cnn_output = self.cnn(out1)
        # second layer norm after adding residual connection
        out2 = self.layernorm2(out1+cnn_output)
        return out2

## Encoder

Blocks of N Encoder Layers + Positional encoding + Input embedding

Self attention by itself does not have any recurrence or convolutions so to make it sensitive to position we must provide additional position encodings. These are calculated as follows
$$ E(p,2i) = sin(p/10000^{2i/d}) $$
$$ E(p,2i+1) = cos(p/10000^{2i/d}) $$

In [11]:
def create_sin_embeddings(nb_p,dim,E):
    theta = np.array([
        [p/np.power(10000,2*(j//2)/dim) for j in range(dim)] for p in range(nb_p)
    ])
    E[:,0::2] = torch.FloatTensor(np.sin(theta[:,0::2]))
    E[:,1::2] = torch.FloatTensor(np.cos(theta[:,1::2]))
    E.detach_()
    E.requires_grad = False
    E = E.to(device)

In [12]:
class Embeddings(nn.Module):
    def __init__(self,d_model,vocab_size,max_position_embeddings,p):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size,d_model,padding_idx=1)
        self.position_embeddings = nn.Embedding(max_position_embeddings,d_model)
        create_sin_embeddings(
            nb_p=max_position_embeddings,
            dim=d_model,
            E=self.position_embeddings.weight
        )
        self.LayerNorm = nn.LayerNorm(d_model,eps=1e-12)
        
    def forward(self,input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length,dtype=torch.long,device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        # get word embedding for each input id
        word_embeddings = self.word_embeddings(input_ids)
        # get position embeddings for each position id
        position_embeddings = self.position_embeddings(position_ids)
        # add them both
        embeddings = word_embeddings+position_embeddings
        # layer norm
        embeddings = self.LayerNorm(embeddings)
        # return result
        return embeddings

In [13]:
class Encoder(nn.Module):
    def __init__(self,num_layers,d_model,num_heads,ff_hidden_dim,input_vocab_size,maximum_position_encoding,p=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = Embeddings(d_model,input_vocab_size,maximum_position_encoding,p)
        # multiple layers of encoders
        self.enc_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.enc_layers.append(EncoderLayer(d_model,num_heads,ff_hidden_dim,p))
            
    def forward(self, x):
        x = self.embeddings(x)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)
        return x

## Use transformer layers

In [7]:
from torchtext.datasets import IMDB