In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import time
import math 

In [2]:
dmodel = 512
heads = 4
batch_size = 128
max_len = 100
encoded_sentence = torch.randn((batch_size,max_len,dmodel))

In [13]:
def positional_encoding(encoded_sentence,shape):
    '''
    Converts the vector embedding of batch of sequence to their positional encoding vectors.

    Arguments:
            encoded_sentence : embbeding vector which is to be Positional Encoded.
            shape : shape of embbeding vector => tuple(batch_size,max_len,dmodel)
            
    Returns : 
            positional encoded vector 
    
    '''
    #shape initialization
    max_len = shape[1]
    dmodel = shape[2]
    batch_size = shape[0]
    
    #create a position vector containing position of words 
    position = torch.arange(0,max_len).float().unsqueeze(1)  

    #applies the formula for and creates divsion term
    div_term = torch.exp(torch.arange(0,dmodel,2).float() * -(math.log(10000.0) / dmodel)) 

    #creates the zeros vector of sentence shape 
    pos_enc = torch.zeros((encoded_sentence.shape))

    #applies the formula for sin(even) and cos(even)
    pos_enc[:,:,0::2] = torch.sin(position * div_term)
    pos_enc[:,:,1::2 ] = torch.cos(position * div_term)

    #shape(batch_size,max_len,dmodel)
    return pos_enc

In [6]:
def attention(k,q,v):
    '''
    applies the attention formula for single heads 

    Arguments:
            k : key
            q : query
            v : value)
    Returns : 
            single matrix same as shape of k,q,v
    '''
    return torch.matmul(F.softmax((torch.matmul(q,k.transpose(-1,-2)))/(torch.sqrt(torch.tensor(dmodel/heads))),dim=-1) , v)
    
def multi_headed_attention(K,Q,V,heads):
    
    '''
    applies multi headed attention

    Arguments:
            K : key
            Q : query
            V : value
    Returns : 
            matrix of shape(K) after applying  multi headed attention      
    '''

    head_size = int(dmodel/heads)
    Wk = torch.randn((dmodel,dmodel))
    Wv = torch.randn((dmodel,dmodel))
    Wq = torch.randn((dmodel,dmodel))

    # shape(batch_size,max_len,dmodel) 
    K_prime = K @ Wk
    Q_prime = Q @ Wq
    V_prime = V @ Wv

    #split into multi heads 
    def split_heads(matrix,shape):
        return matrix.view(*shape)
        
    #defines the shape of multiheaded matrix     
    shape = (K.shape[0],K.shape[1],heads,head_size)

    #applies split head 
    K_prime = split_heads(K_prime,shape)
    Q_prime = split_heads(Q_prime,shape)
    V_prime = split_heads(V_prime,shape)

    #applies attention and then concatinate 
    return attention(K_prime,Q_prime,V_prime).view(*K.shape)

In [7]:
def add_and_norm(positional_encoded,attention):
    residual = positional_encoded
    return torch.add(residual , F.layer_norm(attention,normalized_shape=(dmodel,)))

In [8]:
def feed_forward_layer(matrix):
    linear1 = nn.Linear(512,512,bias=True)
    relu1 = nn.ReLU()
    linear2 = nn.Linear(512,512,bias=True)
    return linear2(relu1(linear1(matrix)))

In [18]:
def encoder(X,shape,heads):
    positional_encoded = positional_encoding(X,shape)
    attention = multi_headed_attention(positional_encoded,positional_encoded,positional_encoded,heads)
    add_norm1 = add_and_norm(positional_encoded,attention)

    feed_forward = feed_forward_layer(add_norm1)
    add_and_norm2 = add_and_norm(add_norm1,feed_forward)
    return add_and_norm2

In [60]:
def decoder(encoded_sequence , output ,shape , heads ,output_size):
    output_embedding = positional_encoding(output,shape)
    attention1 = multi_headed_attention(output_embedding,output_embedding,output_embedding,heads)
    add_and_norm1 = add_and_norm(output_embedding,attention1)

    attention2 = multi_headed_attention(encoded_sequence,add_and_norm1,encoded_sequence,heads)
    add_and_norm2 = add_and_norm(add_and_norm1,attention2)

    feed_forward = feed_forward_layer(add_and_norm2)
    add_and_norm3 = add_and_norm(add_and_norm2,feed_forward)

    flattened_matrix = torch.flatten(add_and_norm3, start_dim=1, end_dim=2)
    linear = nn.Linear(512*100,output_size)
    softmax = F.softmax(linear(flattened_matrix),dim=1)
    
    return softmax
    

In [57]:
encoded_sequence  = encoder(encoded_sentence,(batch_size,max_len,dmodel),heads)

In [58]:
output = torch.randn((batch_size,max_len,dmodel))

In [61]:
decoder = decoder(encoded_sequence,output,(128, 100, 512),heads,1000)

In [70]:
decoder[0][886]

tensor(0.0020, grad_fn=<SelectBackward0>)