In [1]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as f

In [2]:

class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    super().__init__()

    self.sequence_length = None
    self.batch_size = None
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dims = self.d_model // self.num_heads # head_dims = d_k(dimension of key vector) = d_v(dimension of value vector)
    self.qkv_layer = nn.Linear(in_features = self.d_model,out_features = 3*self.d_model)
    self.linear_layer = nn.Linear(in_features = self.d_model,out_features = self.d_model)

  def scaled_dot_product_attention(self,q,k,v,mask = None):

    '''
    q shape (batch_size,num_heads,num_queries,head_dims)
    k shape (batch_size,num_heads,num_kv,head_dims)
    v shape (batch_size,num_heads,num_kv,head_dims)

    num_kv - number of key value pair whome you want to use to compute attentional representation
    num_queries - max sequence length
    Here: (num_kv = num_queries) {attention is paid to whole sequence to compute attention representation of a specific input in a sequence}
    num_heads - number of attention heads
    head_dims - dimension of key vector(d_k) and value vector(d_v) {d_k = d_v}
    '''
    d_k = self.head_dims
    scaled = torch.matmul(q,k.transpose(-2,-1)) / np.sqrt(d_k)  # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    if (mask is not None):
      scaled += mask  # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    attention = f.softmax(scaled,dim = -1) # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    values = torch.matmul(attention,v) # shape (batch_size,num_heads,num_queries,head_dims) (head_dims = d_v)
    return values,attention



  def forward(self,x,mask = None):
    self.batch_size = x.size()[0]
    self.sequence_length = x.size()[1]
    qkv = self.qkv_layer(x) # shape (batch_size,num_queries,3*d_model)
    qkv = qkv.reshape(self.batch_size,self.sequence_length,self.num_heads,3*self.head_dims) # shape (batch_size,num_queries,num_heads,3*head_dims)
    qkv = torch.permute(qkv,(0,2,1,3))  # shape (batch_size,num_heads,num_queries,3*head_dims)
    q,k,v = torch.chunk(qkv,3,dim= -1)  # each shape (batch_size,num_heads,num_queries,head_dims)
    values,attention = self.scaled_dot_product_attention(q,k,v,mask = mask)
    values = values.reshape(self.batch_size,self.sequence_length,self.head_dims * self.num_heads) # shape (batch_size,num_queries,head_dims * num_heads)
    out = self.linear_layer(values) # shape (batch_size,num_queries,d_model)
    return out


class MultiHeadCrossAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    super().__init__()

    self.sequence_length = None
    self.batch_size = None
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dims = self.d_model // self.num_heads # head_dims = d_k(dimension of key vector) = d_v(dimension of value vector)
    self.kv_layer = nn.Linear(in_features = self.d_model,out_features = 2 * self.d_model)
    self.q_layer = nn.Linear(in_features = self.d_model,out_features = self.d_model)
    self.linear_layer = nn.Linear(in_features = self.d_model,out_features = self.d_model)

  def scaled_dot_product_attention(self,q,k,v,mask = None):

    '''
    q shape (batch_size,num_heads,num_queries,head_dims)
    k shape (batch_size,num_heads,num_kv,head_dims)
    v shape (batch_size,num_heads,num_kv,head_dims)

    num_kv - number of key value pair whome you want to use to compute attentional representation
    num_queries - max sequence length
    Here: (num_kv = num_queries) {attention is paid to whole sequence to compute attention representation of a specific input in a sequence}
    num_heads - number of attention heads
    head_dims - dimension of key vector(d_k) and value vector(d_v) {d_k = d_v}
    '''
    d_k = self.head_dims
    scaled = torch.matmul(q,k.transpose(-2,-1)) / np.sqrt(d_k)  # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    if (mask is not None):
      scaled += mask  # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    attention = f.softmax(scaled,dim = -1) # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    values = torch.matmul(attention,v) # shape (batch_size,num_heads,num_queries,head_dims) (head_dims = d_v)
    return values,attention


  def forward(self,x,y,mask = None):
    '''
    x shape (batch_size,num_queries,d_model) {x represent output of top most encoder stacked layer}
    y shape (batch_size,num_queries,d_model) {y represent output of add and norm block 1 of decoder layer}
    '''
    self.batch_size = x.size()[0]
    self.sequence_length = x.size()[1]
    q = self.q_layer(y) # shape (batch_size,num_queries,d_model)
    kv = self.kv_layer(x) # shape (batch_size,num_queries,2*d_model)
    q = q.reshape(self.batch_size,self.sequence_length,self.num_heads,self.head_dims) # shape (batch_size,num_queries,num_heads,head_dims)
    kv = kv.reshape(self.batch_size,self.sequence_length,self.num_heads,2 * self.head_dims) # shape (batch_size,num_queries,num_heads,2 * head_dims)
    q = torch.permute(q,(0,2,1,3))  # shape (batch_size,num_heads,num_qureies,head_dims)
    kv = torch.permute(kv,(0,2,1,3))  # shape (batch_size,num_heads,num_queries,2 * head_dims)
    k, v = torch.chunk(kv,2,dim = -1) # each shape (batch_size,num_heads,num_queries,head_dims)
    values,attention = self.scaled_dot_product_attention(q,k,v,mask = mask)
    values = values.reshape(self.batch_size,self.sequence_length,self.d_model) # shape (batch_size,num_queries,num_heads *  head_dims) {concatentaing all heads together}
    out = self.linear_layer(values) # shape (batch_size,num_queries,d_model)
    return out



class PositionalwiseFeedForwrd(nn.Module):
  def __init__(self,d_model,hidden,drop_prob):
    super().__init__()
    self.linear1 = nn.Linear(in_features = d_model,out_features = hidden)
    self.linear2 = nn.Linear(in_features = hidden,out_features = d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p = drop_prob)

  def forward(self,x):
    '''
    eqaution to code is ---> linear2(Dropout(relu(linear1(x))))
    '''
    x = self.linear1(x) # shape (batch_size,num_queries,hidden)
    x = self.relu(x)  # shape (batch_size,num_queries,hidden)
    x = self.dropout(x) # shape (batch_size,num_queries,hidden)
    x = self.linear2(x) # shape (batch_size,num_queries,d_model)
    return x

class LayerNormalization(nn.Module):
  def __init__(self,parameter_shape,eps = 1e-5):
    '''
    parameter_shape - represents along which dimension you want to normalize
    eps - epsilon for numerical stability
    gamma and beta are learnable parameters

    example: input ---> (batch_size,num_queries,d_model)
            and if you want to normalize along last dimension
            ---> then parameter_shape = (d_model,)
            ---> then gamma and beta shape = (d_model,)

    '''


    super().__init__()
    self.parameter_shape = parameter_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(self.parameter_shape))
    self.beta = nn.Parameter(torch.zeros(self.parameter_shape))
  def forward(self,inputs):

    '''
    inputs shape (batch_size,num_queries,d_model)
    equation to code is ---> gamma * (inputs-mean)/std * beta
    '''
    dims = [-(i+1) for i in range(len(self.parameter_shape))] # len(parameter_shape) = 1 then dims = [-1]
    mean = torch.mean(inputs,dim = dims,keepdim = True) # shape (batch_size,num_queries,1)
    '''
    inputs-mean shape is (batch_size,num_queries,d_model) { column-wise boroadcasting occurs }
    '''

    var = ((inputs-mean)**2).mean(dim = dims,keepdim = True)  # shape (batch_size,num_queries,1)
    std = torch.sqrt(var+self.eps)  # shape (batch_size,num_queries,1)
    y = (inputs-mean)/std # shape (batch_size,num_queries,d_model) {due to broadcasting in column}
    out = self.gamma * y + self.beta  # shape (batch_size,num_queries,d_model) {here also boroadcasting happens}
    return out

class DecoderLayer(nn.Module):
  def __init__(self,d_model,ffn_hidden,num_heads,drop_prob):
    super().__init__()
    self.masked_attention = MultiHeadAttention(d_model = d_model,num_heads = num_heads)
    self.dropout = nn.Dropout(p = drop_prob)
    self.norm1 = LayerNormalization(parameter_shape=(d_model,))
    self.encoder_decoder_attention = MultiHeadCrossAttention(d_model = d_model,num_heads = num_heads)
    self.norm2 = LayerNormalization(parameter_shape=(d_model,))
    self.ffn = PositionalwiseFeedForwrd(d_model = d_model,hidden = ffn_hidden,drop_prob = drop_prob)
    self.norm3 = LayerNormalization(parameter_shape=(d_model,))

  def forward(self,x,y,decoder_mask):
    '''
    x shape (batch_size,num_queries,d_model) {x represent output of top most encoder stacked layer}
    y shape (batch_size,num_queries,d_model) {y represent output of add and norm block 1 of decoder layer}
    '''
    residual_y = y  # shape (batch_size,num_queries,d_model)
    y = self.masked_attention.forward(y,mask = decoder_mask)  # shape (batch_size,num_queries,d_model)
    y = self.dropout(y) # shape (batch_size,num_queries,d_model)
    y = self.norm1.forward(y + residual_y) # shape (batch_size,num_queries,d_model)

    residual_y = y  # shape (batch_size,num_queries,d_model)
    y = self.encoder_decoder_attention.forward(x,y) # shape (batch_size,num_queries,d_model)
    y = self.dropout(y) # shape (batch_size,num_queries,d_model)
    y = self.norm2.forward(y + residual_y)  # shape (batch_size,num_queries,d_model)

    residual_y = y  # shape (batch_size,num_queries,d_model)
    y = self.ffn.forward(y) # shape (batch_size,num_queries,d_model)
    y = self.dropout(y) # shape (batch_size,num_queries,d_model)
    y = self.norm2.forward(y + residual_y)  # shape (batch_size,num_queries,d_model)
    return y

class SequentialDecoder(nn.Sequential):
    '''
    SequentialDecoder class extends from nn.Sequential class:
                  whose __init__() needs the list of all decoder layers that are gonna stack on each other in order
                  example input to init: [DecoderLayer(...),DecoderLayer(...),DecoderLayer(...),...]


    _modules is a dictionary:
                  whose key represent the index of the decoder layer (type:str) in stack
                  whose value represent decoder layer(or Decoder Module)

                  example:stacked_layer = 3 then
                                  _modules = {'0':DecodeLayer(...),
                                              '1':DecoderLayer(...),
                                              '2':DecoderLayer(...)}

    module from these line(for module in self._modules...):
                  represent a specific decoder layer(or Decoder Module) from all off the stacked decoder layers(or Decoder Modules)



    forward method has been overrided:
                  because by default forward takes one argument but the transformer decoder takes three argument (x,y,mask).



    '''
    def forward(self, *inputs):
        x, y, mask = inputs
        # print(f'see the dict {self._modules.keys()}')
        for module in self._modules.values():
            y = module(x, y, mask)  # shape (batch_size,num_queries,d_model)
        return y

class Decoder(nn.Module):
  def __init__(self,d_model,ffn_hidden,num_heads,drop_prob,num_stacked):
    super().__init__()
    self.layers = SequentialDecoder(*[DecoderLayer(d_model,ffn_hidden,num_heads,drop_prob)
                                        for  _ in range(num_stacked)])
  def forward(self,x,y,mask):

    '''
    x shape (batch_size,num_queries,d_model) {x represent output of top most encoder stacked layer}
    y shape (batch_size,num_queries,d_model) {y represent output of add and norm block 1 of decoder layer}
    mask shape (num_queries,num_queries) or (max_sequence_length,max_sequence_length)
    '''
    y = self.layers.forward(x,y,mask) # shape (batch_size,num_queries,d_model)
    return y



In [None]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_stacked = 5


x = torch.randn(batch_size,max_sequence_length,d_model)
y = torch.randn(batch_size,max_sequence_length,d_model)
mask =  torch.full([max_sequence_length,max_sequence_length],fill_value = float('-inf'))
maks = torch.triu(mask,diagonal = 1)
decoder = Decoder(d_model = d_model,ffn_hidden = ffn_hidden,num_heads = num_heads,drop_prob = drop_prob,num_stacked = num_stacked)
out = decoder.forward(x,y,mask)
print(f"out.size(): {out.size()}")