In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import warnings
warnings.filterwarnings('ignore')
from Config import conf
from transformers import T5Tokenizer
from torchtext.nn.modules.multiheadattention import ScaledDotProduct

In [2]:
config = conf()
h = config.h
N = config.N
dmodel = config.dmodel
dk= config.dk
dv = config.dv
dff = config.dff
tokenizer = T5Tokenizer.from_pretrained(config.tokenizer_path)
max_length = config.max_length
vocab_size = config.vocab

In [3]:
sentence1input = 'I love dog'
sentence2input = 'I love cat'
sentence3input = 'I love money'
sentence4input = 'I love overtime'

decoder1input_ = 'dog meat is delicious'
sentence2input_ = 'cat meat is bad '
sentence3input_ = 'I can buy dogs'
sentence4input_ = 'I can buy cats'

In [4]:
encoder_inputs = tokenizer.batch_encode_plus([sentence1input,sentence2input,sentence3input,sentence4input],
                                          max_length= max_length,
                                          pad_to_max_length = True,
                                          truncation=True,
                                          return_tensors='pt'
                                         )
encoder_inputs = encoder_inputs['input_ids'].to('cuda:0')


decoder_inputs = tokenizer.batch_encode_plus([sentence1input,sentence2input,sentence3input,sentence4input],
                                          max_length= max_length,
                                          pad_to_max_length = True,
                                          truncation=True,
                                          return_tensors='pt'
                                         )
decoder_inputs = decoder_inputs['input_ids'].to('cuda:0')

In [5]:
def create_mask(sequence_length,cuda_number):
    mask = (torch.triu(torch.ones(sequence_length, sequence_length)) == 1).transpose(-2, -1).to(cuda_number)
    mask = mask.int().masked_fill(mask == 0, 0)
    return mask

#example:
example_mask = create_mask(4,'cuda:0')
print(example_mask)

tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1]], device='cuda:0', dtype=torch.int32)


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len, vocab_size):
        super(PositionalEncoding, self).__init__()
        self.embedded_layer = nn.Embedding(vocab_size,d_model)
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = self.embedded_layer(x)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [7]:
class SingleAttentionHead(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length,cuda_number='cuda:0', applyMask = False):
        super(SingleAttentionHead,self).__init__()
        self.proj_key = nn.Linear(dmodel,dk).to(cuda_number)
        self.proj_query = nn.Linear(dmodel,dk).to(cuda_number)
        self.proj_value  = nn.Linear(dmodel,dv).to(cuda_number)
        self.dk = dk
        self.cuda_number = cuda_number
        self.max_length = max_length
        self.applyMask = applyMask
        
    def forward(self,x,y=None):
        x = x.to(self.cuda_number)
        k = self.proj_key(x)
        if y == None: #If you dont supply a y value value then this is the self attended layer
            q = self.proj_query(x)
            v = self.proj_value(x)
            
        if y != None:  # If you need a mask then this is the encoder-decoder attention layer
            y = y.to(self.cuda_number)
            q = self.proj_query(y)  #y is encoder output, you get the query from the encoder
            v = self.proj_value(y)  #y is the encoder output, you get the key from the encoder
        
        I = torch.einsum('b i d , b j d -> b i j', q, k)
        
        if self.applyMask and y == None: #If you need a mask then this is the decoder-self attended layer
            mask = create_mask(self.max_length,self.cuda_number)
            for i in range(len(I)):
                I[i].masked_fill_(mask==0,float('-inf'))
        
        attention = F.softmax(I/(self.dk**0.5), dim=-1)
        
        head = torch.einsum('b i j , b j d -> b i d', attention, v)
        
        if self.cuda_number != 'cuda:0':
            return head.to('cuda:0')
        return head

In [8]:
class MultiAttentionHead(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length,applyMask = False):
        super(MultiAttentionHead, self).__init__()
        
        nlayers_GPU_0 = int(h/2)
        nlayers_GPU_1 = int(h/2)
        
        self.head_GPU0 = nn.ModuleList([
            SingleAttentionHead(dmodel,dk,dv,max_length,'cuda:0',applyMask) for i in range(nlayers_GPU_0)
        ])
        
        self.head_GPU1 = nn.ModuleList([
            SingleAttentionHead(dmodel,dk,dv,max_length,'cuda:1',applyMask) for i in range(nlayers_GPU_1)
        ])
        #Weight_0 layer:
        self.W0 = nn.Linear(dmodel,dmodel).to('cuda:0')   #Size h*dv x dmodel. But since dv = dk and dk x h = dv so it's a dmodel x dmodel layer -> cuda:0
        #LayerNormalisation
        self.Add_and_Nom = nn.LayerNorm(dmodel, eps=1e-05, elementwise_affine=True).to('cuda:0')
        self.dropout = nn.Dropout(0.1).to('cuda:0')
    
    def forward(self,x,y=None):
        multi_attention_heads = 'Empty'
        for i, l in enumerate(self.head_GPU0):
            if i == 0:
                multi_attention_heads = l(x,y)
            else:
                multi_attention_heads = torch.cat((multi_attention_heads,l(x,y)), dim=2)
        for i, l in enumerate(self.head_GPU1):
            multi_attention_heads = torch.cat((multi_attention_heads,l(x,y)), dim=2)
        multi_attention_heads = self.W0(multi_attention_heads) 
        multi_attention_heads = self.Add_and_Nom(x + multi_attention_heads)  #cuda:0
        multi_attention_heads = self.dropout(multi_attention_heads)
        return multi_attention_heads

In [9]:
class EncoderStack(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length):
        super(EncoderStack, self).__init__()
        self.multiAttentionHeads = MultiAttentionHead(dmodel,dk,dv,max_length,False)
        self.lin1a = nn.Linear(dmodel,dff).to('cuda:0')
        self.dropout1 = nn.Dropout(0.1).to('cuda:0')
        self.lin1b = nn.Linear(dff,dmodel).to('cuda:0')
        self.Add_and_Nom = nn.LayerNorm(dmodel, eps=1e-05, elementwise_affine=True).to('cuda:0')

    def forward(self,x):
        x = self.multiAttentionHeads(x)
        sublayer_x = self.lin1a(x)
        sublayer_x = F.relu(sublayer_x)
        sublayer_x = self.dropout1(sublayer_x)
        sublayer_x = self.lin1b(sublayer_x)
        sublayer_x = self.Add_and_Nom(x + sublayer_x)
        return sublayer_x

In [10]:
class DecoderStack(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length):
        super(DecoderStack, self).__init__()
        self.masked_multi_head_attention = MultiAttentionHead(dmodel,dk,dv,max_length,True)
        self.multi_head_attention = MultiAttentionHead(dmodel,dk,dv,max_length,False)
        self.lin1a = nn.Linear(dmodel,dff).to('cuda:0')
        self.dropout1 = nn.Dropout(0.1).to('cuda:0')
        self.lin1b = nn.Linear(dff,dmodel).to('cuda:0')
        self.Add_and_Nom = nn.LayerNorm(dmodel, eps=1e-05, elementwise_affine=True).to('cuda:0')

    def forward(self,x,y=None):
        z = self.masked_multi_head_attention(x)
        z = self.multi_head_attention(x,y)
        sublayer_z = self.lin1a(z)
        sublayer_z = F.relu(sublayer_z)
        sublayer_z = self.dropout1(sublayer_z)
        sublayer_z = self.lin1b(sublayer_z)
        sublayer_z = self.Add_and_Nom(z + sublayer_z)
        return sublayer_z

In [11]:
class EncoderTransformerStacks(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length):
        super(EncoderTransformerStacks, self).__init__()
        self.encoderStack = nn.ModuleList([
            EncoderStack(dmodel,dk,dv,max_length) for i in range(6)
        ])

    def forward(self,x):
        for i, l in enumerate(self.encoderStack):
            x = l(x)
        return x

In [12]:
class DecoderTransformerStacks(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length):
        super(DecoderTransformerStacks, self).__init__()
        self.dencoderStack = nn.ModuleList([
            DecoderStack(dmodel,dk,dv,max_length) for i in range(6)
        ])

    def forward(self,x,y):
        for i, l in enumerate(self.dencoderStack):
            x = l(x,y)
        return x

In [13]:
class EncoderTransformer(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length,vocab_size):
        super(EncoderTransformer, self).__init__()
        self.positionEncoder = PositionalEncoding(dmodel,0.1, max_length,vocab_size).to('cuda:0')
        self.encoder_Stacks = EncoderTransformerStacks(dmodel,dk,dv,max_length)
        
    def forward(self,x):
        x = self.positionEncoder(x)
        x = self.encoder_Stacks(x)
        return x

In [14]:
class DecoderTransformer(nn.Module):
    def __init__(self,dmodel,dk,dv,max_length,vocab_size):
        super(DecoderTransformer, self).__init__()
        self.positionEncoder = PositionalEncoding(dmodel,0.1, max_length,vocab_size).to('cuda:0')
        self.decoder_Stacks = DecoderTransformerStacks(dmodel,dk,dv,max_length)
        
    def forward(self,x,y):
        x = self.positionEncoder(x)
        x = self.decoder_Stacks(x,y)
        return x

In [15]:
mask = create_mask(decoder_inputs.size(1),'cuda:0')
new_sequence = 'empty'
for i in range(len(decoder_inputs)):
    decoder_sequence = torch.cat(max_length*[decoder_inputs[i]]).view(max_length,-1)
    decoder_sequence = decoder_sequence.masked_fill_(mask==0,0)
    if i == 0:
        new_sequence = decoder_sequence
    else:
        new_sequence = torch.cat((new_sequence,decoder_sequence),dim=0)
new_sequence = new_sequence.view(decoder_inputs.size(0),decoder_inputs.size(1),-1).permute(1,0,2)

In [16]:
new_sequence  #This sequence will be enumerate on the decoder side, while keeping the encoder 
                #outputs the same until the entire sequence-batch is enumerated

tensor([[[   27,     0,     0,     0,     0,     0],
         [   27,     0,     0,     0,     0,     0],
         [   27,     0,     0,     0,     0,     0],
         [   27,     0,     0,     0,     0,     0]],

        [[   27,   333,     0,     0,     0,     0],
         [   27,   333,     0,     0,     0,     0],
         [   27,   333,     0,     0,     0,     0],
         [   27,   333,     0,     0,     0,     0]],

        [[   27,   333,  1782,     0,     0,     0],
         [   27,   333,  1712,     0,     0,     0],
         [   27,   333,   540,     0,     0,     0],
         [   27,   333, 22624,     0,     0,     0]],

        [[   27,   333,  1782,     1,     0,     0],
         [   27,   333,  1712,     1,     0,     0],
         [   27,   333,   540,     1,     0,     0],
         [   27,   333, 22624,     1,     0,     0]],

        [[   27,   333,  1782,     1,     0,     0],
         [   27,   333,  1712,     1,     0,     0],
         [   27,   333,   540,     1, 

In [17]:
#Loss = ...
#Criterion = ...
#Optimiser = ...
#Scaler= 
Encoder = EncoderTransformer(dmodel,dk,dv,max_length,vocab_size)
encoder_outputs = Encoder(encoder_inputs)

#for sequence in new_sequence:
for i in range(1): #testing my code for the first batch sequence
    decoder = DecoderTransformer(dmodel,dk,dv,max_length,vocab_size)
    output = decoder(new_sequence[0],encoder_outputs)
    print(output)
    #final_linear_layer = nn.Linear(x,x)     #Need a bit research
    #output = final_linear_layer(output)
    #output = F.softmax(output,dim=1)


tensor([[[-3.2559e-01, -9.5723e-01, -1.6681e-01,  ..., -1.5691e-02,
           3.3745e-01,  1.9378e+00],
         [ 1.7054e+00,  6.7709e-02, -1.6083e-01,  ...,  6.3935e-01,
           1.0993e+00,  3.8847e-01],
         [ 1.0393e-02,  9.1850e-01, -1.8591e+00,  ...,  2.6690e-01,
           3.5427e-01, -1.8566e-01],
         [-4.0424e-01,  3.5961e-01,  1.5181e-01,  ...,  8.8463e-01,
           2.6146e-02,  4.8610e-01],
         [-1.1620e-01, -1.8751e-01, -1.2720e+00,  ..., -2.6234e-01,
           7.3851e-01,  1.3576e+00],
         [ 1.2964e+00,  2.7380e-01,  2.3247e-01,  ..., -7.7654e-01,
           3.2613e-01, -4.0830e-01]],

        [[ 6.7240e-01, -8.2397e-01,  1.9303e-01,  ..., -9.0539e-01,
          -1.8578e+00,  1.3617e-01],
         [ 1.1636e+00,  1.9351e-01, -2.1027e-01,  ..., -5.8921e-02,
           4.0367e-01,  3.0951e-01],
         [ 7.5501e-01,  5.5409e-01, -7.1989e-01,  ...,  2.1878e-01,
          -4.6132e-01,  1.5635e+00],
         [ 1.7250e+00, -2.3884e-01, -9.8896e-01,  ...