<a href="https://colab.research.google.com/github/ethvedbitdesjan/NLP/blob/main/Transformer_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.nn.functional as F 
import torch.nn as nn
import torch

In [None]:
def scaled_dot_product_attention(query, key, value):
    temp = query.bmm(key.transpose(1, 2))
    softmax_scale = F.softmax((temp/(query.size(-1) ** 0.5)), dim=-1)
    return softmax_scale.bmm(value)

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim, q_dim, k_dim):
        super().__init__()
        self.q = nn.Linear(in_dim, q_dim)
        self.k = nn.Linear(in_dim, k_dim)
        self.v = nn.Linear(in_dim, k_dim)

    def forward(self, query, key, value):
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, in_dim, q_dim, k_dim):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(in_dim, q_dim, k_dim) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * k_dim, in_dim)

    def forward(self, query, key, value):
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

In [None]:
def feed_forward(in_dimput = 512, feedforward_dim = 2048):
    return nn.Sequential(
        nn.Linear(in_dimput, feedforward_dim),
        nn.ReLU(),
        nn.Linear(feedforward_dim, in_dimput),
    )

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        model_dim = 512,
        num_heads = 6,
        feedforward_dim = 2048,
        dropout = 0.2,
    ):
        super().__init__()
        q_dim = k_dim = max(model_dim // num_heads, 1)
        self.layer1 = MultiHeadAttention(num_heads, model_dim, q_dim, k_dim)

        self.norm = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)

        self.layer2 = feed_forward(model_dim, feedforward_dim)

    def forward(self, src):

        multi_attend_out = self.dropout(self.layer1(src, src, src))
        print(multi_attend_out.size(), "multi_attend_out")
        attend_out_normed = self.norm(src+multi_attend_out)
        print(attend_out_normed.size(), "att_out_norm")
        perceptron_out = self.dropout(self.layer2(attend_out_normed))
        print(perceptron_out.size(), "percep_ou")
        return self.norm(attend_out_normed+perceptron_out)

In [None]:
def position_encoding(seq_len, model_dim, device = torch.device("cpu")):
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dim = torch.arange(model_dim, dtype=torch.float, device=device).reshape(1, 1, -1)
    phase = pos / (1e4 ** (dim // model_dim))

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, model_dim= 512, num_heads= 6, feedforward_dim= 2048, dropout= 0.2):
        super().__init__()
        q_dim = k_dim = max(model_dim // num_heads, 1)
        self.attention1 = MultiHeadAttention(num_heads, model_dim, q_dim, k_dim)

        self.attention2 = MultiHeadAttention(num_heads, model_dim, q_dim, k_dim)

        self.feed_forward = feed_forward(model_dim, feedforward_dim)
        self.norm = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, tgt, memory):
        attend_out_1 = self.dropout(self.attention1(tgt, tgt, tgt))
        attend_out_1_normed = self.norm(tgt+attend_out_1)

        attend_out_2= self.dropout(self.attention2(attend_out_1_normed, memory, memory))
        attend_out_2_normed = self.norm(attend_out_1_normed+attend_out_2)

        perceptron_out = self.dropout(self.feed_forward(attend_out_2_normed))

        return self.norm(attend_out_2_normed+perceptron_out)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        num_layers= 6,
        model_dim= 512,
        num_heads= 8,
        feedforward_dim= 2048,
        dropout = 0.2
    ):
        super().__init__()

        
        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(model_dim, num_heads, feedforward_dim, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, src):
        seq_len, dimension = src.size(1), src.size(2)
        print(src.size())
        src += position_encoding(seq_len, dimension)
        print(src.size(), "pos", position_encoding(seq_len, dimension).size())
        for layer in self.layers:
            src = layer(src)
            print(src.size(), "encoder...")
        return src

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(
        self,
        num_layers= 6,
        model_dim= 512,
        num_heads= 8,
        feedforward_dim = 2048,
        dropout= 0.2,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerDecoderLayer(model_dim, num_heads, feedforward_dim, dropout)
                for _ in range(num_layers)
            ]
        )
        self.linear = nn.Linear(model_dim, model_dim)

    def forward(self, tgt, memory):
        seq_len, dimension = tgt.size(1), tgt.size(2)
        print(tgt.size())
        tgt += position_encoding(seq_len, dimension)
        print(tgt.size())
        for layer in self.layers:
            tgt = layer(tgt, memory)
            print(tgt.size(), "decoder...")
        return torch.softmax(self.linear(tgt), dim=-1)

In [None]:
class Transformer(nn.Module):
    def __init__(self, num_encoder_layers = 6, num_decoder_layers = 6, model_dim = 512, num_heads = 6, feedforward_dim = 2048, dropout = 0.2, activation = nn.ReLU()):
        super().__init__()
        self.encoder = TransformerEncoder(
            num_layers=num_encoder_layers,
            model_dim=model_dim,
            num_heads=num_heads,
            feedforward_dim=feedforward_dim,
            dropout=dropout,
        )
        self.decoder = TransformerDecoder(
            num_layers=num_decoder_layers,
            model_dim=model_dim,
            num_heads=num_heads,
            feedforward_dim=feedforward_dim,
            dropout=dropout,
        )

    def forward(self, src, tgt):
        return self.decoder(tgt, self.encoder(src))

In [None]:
src = torch.rand(16, 32, 512)
tgt = torch.rand(16, 16, 512)
model = Transformer()
out = model(src, tgt)
print(out.shape)

torch.Size([16, 32, 512])
torch.Size([16, 32, 512]) pos torch.Size([1, 32, 512])
torch.Size([16, 32, 512]) multi_attend_out
torch.Size([16, 32, 512]) att_out_norm
torch.Size([16, 32, 512]) percep_ou
torch.Size([16, 32, 512]) encoder...
torch.Size([16, 32, 512]) multi_attend_out
torch.Size([16, 32, 512]) att_out_norm
torch.Size([16, 32, 512]) percep_ou
torch.Size([16, 32, 512]) encoder...
torch.Size([16, 32, 512]) multi_attend_out
torch.Size([16, 32, 512]) att_out_norm
torch.Size([16, 32, 512]) percep_ou
torch.Size([16, 32, 512]) encoder...
torch.Size([16, 32, 512]) 

  after removing the cwd from sys.path.


multi_attend_out
torch.Size([16, 32, 512]) att_out_norm
torch.Size([16, 32, 512]) percep_ou
torch.Size([16, 32, 512]) encoder...
torch.Size([16, 32, 512]) multi_attend_out
torch.Size([16, 32, 512]) att_out_norm
torch.Size([16, 32, 512]) percep_ou
torch.Size([16, 32, 512]) encoder...
torch.Size([16, 32, 512]) multi_attend_out
torch.Size([16, 32, 512]) att_out_norm
torch.Size([16, 32, 512]) percep_ou
torch.Size([16, 32, 512]) encoder...
torch.Size([16, 16, 512])
torch.Size([16, 16, 512])
torch.Size([16, 16, 512]) decoder...
torch.Size([16, 16, 512]) decoder...
torch.Size([16, 16, 512]) decoder...
torch.Size([16, 16, 512]) decoder...
torch.Size([16, 16, 512]) decoder...
torch.Size([16, 16, 512]) decoder...
torch.Size([16, 16, 512])


In [None]:
t1 = torch.rand(1, 3, 5)
t2 = torch.rand(3, 3, 5)

print(t2, t1)
t2 +=t1
print(t2, t2.size())

tensor([[[0.5805, 0.7351, 0.9314, 0.6451, 0.9990],
         [0.6803, 0.9837, 0.7175, 0.3201, 0.7552],
         [0.1876, 0.6099, 0.9486, 0.7384, 0.4878]],

        [[0.1423, 0.9400, 0.3453, 0.8712, 0.3623],
         [0.1356, 0.1505, 0.4291, 0.3869, 0.0773],
         [0.7850, 0.4236, 0.9965, 0.5007, 0.3655]],

        [[0.6787, 0.8164, 0.2383, 0.6309, 0.7243],
         [0.1019, 0.3964, 0.9076, 0.3677, 0.6168],
         [0.6515, 0.6581, 0.3663, 0.9815, 0.1386]]]) tensor([[[0.2268, 0.5846, 0.2379, 0.6112, 0.0735],
         [0.6263, 0.5724, 0.7261, 0.7328, 0.1121],
         [0.7249, 0.4095, 0.3130, 0.4591, 0.0007]]])
tensor([[[0.8073, 1.3197, 1.1693, 1.2563, 1.0725],
         [1.3066, 1.5561, 1.4436, 1.0529, 0.8673],
         [0.9125, 1.0194, 1.2616, 1.1974, 0.4885]],

        [[0.3691, 1.5246, 0.5832, 1.4825, 0.4358],
         [0.7619, 0.7229, 1.1552, 1.1197, 0.1894],
         [1.5099, 0.8331, 1.3095, 0.9597, 0.3662]],

        [[0.9055, 1.4010, 0.4762, 1.2422, 0.7978],
         [0.7283, 0