In [2]:

import math
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

### functions

In [28]:
def softmax(arr: np.ndarray):
    sum_of_exponents = np.sum(np.exp(arr))
    return arr/ sum_of_exponents

### Attention

In [8]:
sent_embedding = np.random.randn(10, 1024)

#### Attention Block without Linear Layer

In [9]:
def attention_block_without_linear_layers(sentence: np.ndarray):
    query= sentence.copy()
    keys = sentence.copy()
    values = sentence.copy()

    weights = cosine_similarity(query, keys)
    weights = softmax(weights)
    weighted_values = np.matmul(weights, values)
    return weighted_values

In [None]:
attention_block_without_linear_layers(sent_embedding).shape

#### Attention Block

In [26]:
def attention_block(sentence: np.ndarray):
    #query
    query = Tensor(sentence.copy())
    query_layer = nn.Linear(1024, 1024)
    query = query_layer(query)

    #keys
    keys = Tensor(sentence.copy())
    keys_layer = nn.Linear(1024, 1024)
    keys = keys_layer(keys)

    #values
    values = Tensor(sentence.copy())
    values_layer = nn.Linear(1024, 1024)
    values = values_layer(values)

    weights = nn.CosineSimilarity()(query, keys)
    weights = nn.Softmax(weights).dim
    weighted_values = torch.matmul(weights, values)
    return weighted_values

In [None]:
attention_block(sent_embedding)

In [12]:
class AttentionBlock(nn.Module):
    def __init__(self, dim= 1024):
        super(AttentionBlock, self).__init__()

        self.query_layer = nn.Linear(dim, dim)
        self.keys_layer = nn.Linear(dim, dim)
        self.values_layer = nn.Linear(dim, dim)

    def forward(self, sentence):

        if not isinstance(sentence, Tensor):
            sentence= Tensor(sentence)

        query= self.query_layer(sentence)
        keys= self.keys_layer(sentence)
        values= self.values_layer(sentence)

        weights = F.cosine_similarity(query.unsqueeze(1), keys.unsqueeze(0), dim=-1)
        weights = F.softmax(weights, dim= -1)

        weighted_values = torch.matmul(weights, values)

        return weighted_values

In [None]:
att= AttentionBlock()
att.forward(sent_embedding)

#### Multi Head Attention Block

In [14]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, dim= 1024, num_heads= 4):
        super(MultiHeadAttentionBlock, self).__init__()

        assert dim%num_heads==0, "dim should be divisible by num_heads"

        self.dim= dim
        self.heads = num_heads
        self.per_head = dim // num_heads

        self.query_layer = nn.Linear(dim, dim)
        self.keys_layer = nn.Linear(dim, dim)
        self.values_layer = nn.Linear(dim, dim)
        self.linear_layer = nn.Linear(dim, dim)

    def split_head(self, tensor: Tensor):
        batch_size, num_tokens, dim = tensor.size()
        return tensor.view(batch_size, num_tokens, self.heads, self.per_head).transpose(1, 2)

    def forward(self, sentence):

        if not isinstance(sentence, Tensor):
            sentence= Tensor(sentence)

        query= self.split_head(self.query_layer(sentence))
        keys= self.split_head(self.keys_layer(sentence))
        values= self.split_head(self.values_layer(sentence))

        weights = F.cosine_similarity(query.unsqueeze(3), keys.unsqueeze(2), dim=-1) / math.sqrt(self.per_head)
        weights = F.softmax(weights, dim= -1)

        weighted_values = torch.matmul(weights, values)
        weighted_values = weighted_values.transpose(1,2).contiguous().reshape(sentence.shape[0], -1, self.dim)
        attention = self.linear_layer(weighted_values)

        return attention

In [None]:
sent_embedding = np.random.rand(1, 10, 1024)
mha = MultiHeadAttentionBlock()
mha.forward(sent_embedding).shape

#### Masked Multi-Head Attention Block

In [255]:
class MultiHeadAttentionWithMaskBlock(nn.Module):
    def __init__(self, dim= 1024, num_heads= 4):
        super(MultiHeadAttentionWithMaskBlock, self).__init__()

        assert dim%num_heads==0, "dim should be divisible by num_heads"

        self.dim= dim
        self.heads = num_heads
        self.per_head = dim // num_heads

        self.query_layer = nn.Linear(dim, dim)
        self.keys_layer = nn.Linear(dim, dim)
        self.values_layer = nn.Linear(dim, dim)
        self.linear_layer = nn.Linear(dim, dim)

    def split_head(self, tensor: Tensor):
        batch_size, num_tokens, dim = tensor.size()
        return tensor.view(batch_size, num_tokens, self.heads, self.per_head).transpose(1, 2)

    def forward(self, query, keys, values, mask= None):

        query= self.split_head(self.query_layer(query))
        keys= self.split_head(self.keys_layer(keys))
        values= self.split_head(self.values_layer(values))

        weights = F.cosine_similarity(query.unsqueeze(3), keys.unsqueeze(2), dim=-1) / math.sqrt(self.per_head)
        if mask is not None:
            weights = weights.masked_fill(mask == 0, -1e9)
        weights = F.softmax(weights, dim= -2)

        weighted_values = torch.matmul(weights, values)
        weighted_values = weighted_values.transpose(1,2).contiguous().reshape(query.shape[0], -1, self.dim)
        attention = self.linear_layer(weighted_values)

        return attention

In [None]:
sent_embedding = Tensor(np.random.rand(3, 10, 1024))
sent= Tensor(np.random.rand(3, 10))
sent[:, 8:] = 0
sent

In [32]:
def generate_mask(input, target):
        input_mask = (input != 0).unsqueeze(1).unsqueeze(2)
        target_mask = (target != 0).unsqueeze(1).unsqueeze(3)
        sequence_length = input.size(1)

        causal_mask = torch.tril(torch.ones(sequence_length, sequence_length), diagonal=0).bool()
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(1)

        combined_mask = torch.logical_and(target_mask, ~causal_mask).to(input.device)
        return input_mask, combined_mask

In [None]:
target_mask = (sent != 0).unsqueeze(1).unsqueeze(3)
target_mask

In [None]:
sequence_length = sent.size(1)
causal_mask = torch.tril(torch.ones(sequence_length, sequence_length), diagonal=0).bool()
causal_mask


In [None]:
mask = (sent != 0).unsqueeze(1).unsqueeze(2)
tri = torch.tril(torch.ones(3, 1, 10, 10))
torch.logical_and(mask, tri).long()

In [None]:

mham = MultiHeadAttentionWithMaskBlock()
res = mham.forward(sent_embedding, sent_embedding, sent_embedding, mask)
res

### Feed Forward Network

In [18]:
class FeedForwardNetworkBlock(nn.Module):
    def __init__(self, dim= 1024, inter_dim= 512):
        super(FeedForwardNetworkBlock, self).__init__()

        self.ff1 = nn.Linear(dim, inter_dim)
        self.ff2 = nn.Linear(inter_dim, dim)
        self.relu = nn.ReLU()
    
    def forward(self, attention):
        return self.ff2(self.relu(self.ff1(attention)))

In [None]:
ffn = FeedForwardNetworkBlock()
ffn.forward(res)

### Positional Encoding

In [20]:
class PositionalEncodingBlock(nn.Module):
    def __init__(self, max_token_length, dim):
        super(PositionalEncodingBlock, self).__init__()

        pe= torch.zeros(max_token_length, dim)
        position = torch.arange(0, max_token_length, dtype=torch.float).unsqueeze(1)
        # div_alt = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))
        div = 1 / (10000 ** (torch.arange(0, dim, 2).float() / dim))

        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, tokens):
        return tokens + self.pe[:, :tokens.size(1), :]

In [None]:
pe= PositionalEncodingBlock(100, 1024)
pe.forward(sent_embedding)

### Encoder

In [212]:
class EncoderBlock(nn.Module):
    def __init__(self, num_heads: int = 4, dim: int = 1024, inter_dim: int = 512):
        super(EncoderBlock, self).__init__()

        self.mha = MultiHeadAttentionWithMaskBlock(dim, num_heads)
        self.ff = FeedForwardNetworkBlock(dim, inter_dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, token_embeddings, input_mask):
        mha_res = self.mha(token_embeddings, token_embeddings, token_embeddings, input_mask)
        add_norm1= self.norm1(torch.add(token_embeddings, mha_res))
        ff_res = self.ff(add_norm1)
        add_norm2 = self.norm2(torch.add(add_norm1, ff_res))
        return add_norm2


### Decoder

In [213]:
class DecoderBlock(nn.Module):
    def __init__(self, num_heads: int = 4, dim: int = 1024, inter_dim: int = 512):
        super(DecoderBlock, self).__init__()

        self.self_attention = MultiHeadAttentionWithMaskBlock(dim, num_heads)
        self.cross_attention = MultiHeadAttentionWithMaskBlock(dim, num_heads)
        self.ff = FeedForwardNetworkBlock(dim, inter_dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, output_embeddings, encoder_output, source_mask, target_mask):
        mmha_res = self.self_attention(output_embeddings, output_embeddings, output_embeddings, target_mask)
        add_norm1= self.norm1(torch.add(output_embeddings, mmha_res))
        mha_res = self.cross_attention(add_norm1, encoder_output, encoder_output, source_mask)
        add_norm2 = self.norm2(torch.add(mmha_res, mha_res))
        ff_res = self.ff(add_norm2)
        add_norm3 = self.norm2(torch.add(add_norm2, ff_res))
        return add_norm3


In [265]:
sent_embedding = Tensor(np.random.randn(3, 10, 1024))
sent= torch.rand(3, 10)
enc_embedding = Tensor(np.random.randn(3, 12, 1024))

In [None]:
target_mask = (sent != 0).unsqueeze(1).unsqueeze(2)
batch_size, sequence_length= sent.size()
tri = torch.tril(torch.ones(batch_size, 1, sequence_length, sequence_length))
torch.logical_and(target_mask, tri).long()

In [263]:
target_mask= torch.ones(1, 1, sent_embedding.size(1), 1)
target_mask[:, :, out_embedding.size(1):, :] = 0
source_mask= torch.zeros(1, 1, sent_embedding.size(1), 1)
source_mask[:, :, out_embedding.size(1):, :] = 1

In [None]:
dec= DecoderBlock()
dec(sent_embedding, sent_embedding, source_mask, target_mask)

### Transformer

In [259]:
class TransformerModule(nn.Module):
    def __init__(self, dim, encoder_vocab_size, decoder_vocab_size, max_token_length, num_heads, num_layers):
        super(TransformerModule, self).__init__()

        self.encoder_embedding = nn.Embedding(encoder_vocab_size, dim)
        self.decoder_embedding = nn.Embedding(decoder_vocab_size, dim)
        self.positional_encoder = PositionalEncodingBlock(max_token_length, dim)

        self.encoder_layers = nn.ModuleList([EncoderBlock(num_heads, dim, 2* dim) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderBlock(num_heads, dim, 2* dim) for _ in range(num_layers)])
        
        self.linear = nn.Linear(dim, decoder_vocab_size)
        self.softmax = nn.Softmax()

    def generate_mask(self, input, target):
        input_mask = (input != 0).unsqueeze(1).unsqueeze(2)
        target_mask = (target != 0).unsqueeze(1).unsqueeze(2)
        batch_size, sequence_length= target.size()
        tri = torch.tril(torch.ones(batch_size, 1, sequence_length, sequence_length))
        target_mask = torch.logical_and(target_mask, tri).long()

        return input_mask, target_mask

    def forward(self, input, target):
        input_mask, target_mask = self.generate_mask(input, target)

        input_embedding = self.encoder_embedding(input)
        output_embedding = self.decoder_embedding(target)

        encoder_output = input_embedding
        for encoder_layer in self.encoder_layers:
            encoder_output = encoder_layer(encoder_output, input_mask)

        decoder_output = output_embedding
        for decoder_layer in self.decoder_layers:
            decoder_output = decoder_layer(decoder_output, encoder_output, input_mask, target_mask)

        linear_output = self.linear(decoder_output)
        return self.softmax(linear_output)


In [260]:
transformer = TransformerModule(
    dim= 64,
    encoder_vocab_size=500,
    decoder_vocab_size=500,
    max_token_length=100,
    num_heads=4,
    num_layers=2,
)

In [None]:
input_sentences = torch.randint(1, 500, (4, 100))
target_sentences = torch.randint(1, 500, (4, 100))
transformer(input_sentences, target_sentences)