In [83]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

In [84]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
    def forward(self, x):
        return self.embedding(x)

In [204]:
emb = Embedding(5, 10)
x = torch.randint(0, 5, (5, 1), dtype=torch.long)
print(x)
emb(x)

tensor([[4],
        [1],
        [2],
        [1],
        [2]])


tensor([[[ 0.2962, -0.7657,  1.6776,  0.1946,  1.3157,  0.5141,  0.6003,
           1.3683, -1.2092, -2.0529]],

        [[ 1.2193,  1.4391, -0.5128, -1.2088, -0.1759, -0.7791,  0.6225,
           0.3070,  0.9709,  0.3641]],

        [[-1.2436, -1.4334,  1.2185, -0.1228,  1.0417, -1.4216,  0.2985,
           1.7465, -1.1521, -0.7095]],

        [[ 1.2193,  1.4391, -0.5128, -1.2088, -0.1759, -0.7791,  0.6225,
           0.3070,  0.9709,  0.3641]],

        [[-1.2436, -1.4334,  1.2185, -0.1228,  1.0417, -1.4216,  0.2985,
           1.7465, -1.1521, -0.7095]]], grad_fn=<EmbeddingBackward0>)

In [85]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_size
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * -(math.log(10000.0) / embed_size))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x * math.sqrt(self.embed_dim)
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)
        return x

In [215]:
pEnc = PositionalEncoding(10, 2)
x = torch.randint(0, 5, (5, 2), dtype=torch.long)
e = emb(x)
print(x)
print(e)
pEnc(e)

tensor([[4, 3],
        [1, 0],
        [3, 3],
        [4, 4],
        [4, 0]])
tensor([[[ 0.2962, -0.7657,  1.6776,  0.1946,  1.3157,  0.5141,  0.6003,
           1.3683, -1.2092, -2.0529],
         [ 1.1632, -1.7874,  0.3352,  2.2159,  2.7371, -0.1069,  1.1971,
          -1.1965,  0.1703,  1.5804]],

        [[ 1.2193,  1.4391, -0.5128, -1.2088, -0.1759, -0.7791,  0.6225,
           0.3070,  0.9709,  0.3641],
         [ 0.2013,  0.6307, -0.2933,  0.7859, -2.3128,  0.0748,  0.3854,
          -0.9105, -1.2003, -1.8793]],

        [[ 1.1632, -1.7874,  0.3352,  2.2159,  2.7371, -0.1069,  1.1971,
          -1.1965,  0.1703,  1.5804],
         [ 1.1632, -1.7874,  0.3352,  2.2159,  2.7371, -0.1069,  1.1971,
          -1.1965,  0.1703,  1.5804]],

        [[ 0.2962, -0.7657,  1.6776,  0.1946,  1.3157,  0.5141,  0.6003,
           1.3683, -1.2092, -2.0529],
         [ 0.2962, -0.7657,  1.6776,  0.1946,  1.3157,  0.5141,  0.6003,
           1.3683, -1.2092, -2.0529]],

        [[ 0.2962, -0.7

tensor([[[ 0.9366, -1.4215,  5.3050,  1.6153,  4.1606,  2.6257,  1.8984,
           5.3270, -3.8239, -5.4918],
         [ 4.5198, -5.1118,  1.2178,  7.9948,  8.6807,  0.6616,  3.7896,
          -2.7836,  0.5393,  5.9977]],

        [[ 3.8558,  5.5510, -1.6218, -2.8227, -0.5562, -1.4637,  1.9686,
           1.9707,  3.0704,  2.1513],
         [ 1.4779,  2.5348, -0.7698,  3.4727, -7.2887,  1.2362,  1.2227,
          -1.8794, -3.7951, -4.9428]],

        [[ 3.6784, -4.6521,  1.0600,  8.0073,  8.6556,  0.6619,  3.7856,
          -2.7836,  0.5387,  5.9977],
         [ 4.5198, -5.1118,  1.2178,  7.9948,  8.6807,  0.6616,  3.7896,
          -2.7836,  0.5393,  5.9977]],

        [[ 0.9366, -1.4215,  5.3050,  1.6153,  4.1606,  2.6257,  1.8984,
           5.3270, -3.8239, -5.4918],
         [ 1.7781, -1.8812,  5.4628,  1.6027,  4.1857,  2.6254,  1.9024,
           5.3270, -3.8232, -5.4918]],

        [[ 0.9366, -1.4215,  5.3050,  1.6153,  4.1606,  2.6257,  1.8984,
           5.3270, -3.8239, -5.

In [225]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size=512, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0, "Embedding size must be divisible by number of heads"
        self.embed_size = embed_size # 512 by default
        self.num_heads = num_heads # 8 by default
        self.head_dim = int(embed_size / num_heads) # 512/8 = 64 by default

        # Linear transformations for queries, keys, and values
        self.query_projection = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.key_projection = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.value_projection = nn.Linear(self.head_dim, self.head_dim, bias=False)

        # Linear transformation for output projection (8*64)x512 or 512x512
        self.output_projection = nn.Linear(num_heads * self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask=None):

        N, value_len, key_len, query_len = queries.shape[0], values.shape[1], keys.shape[1], queries.shape[1]
        values = values.reshape(N, value_len, self.num_heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.num_heads, self.head_dim)

        values = self.value_projection(values)
        keys = self.key_projection(keys)
        queries = self.query_projection(queries)

        keys_transpose = keys.transpose(-2, -1)
        scaled_scores = torch.matmul(queries, keys_transpose)

        if mask is not None:
            scaled_scores = scaled_scores.masked_fill(mask == 0, float("-1e20"))

        scaled_scores = scaled_scores * (1.0 / math.sqrt(self.head_dim))
        attention_weights = F.softmax(scaled_scores, dim=-1)
        
        attention_output = torch.matmul(attention_weights, values)
        attention_output = attention_output.transpose(1, 2).contiguous().view(N, value_len, self.head_dim * self.num_heads)
        output = self.output_projection(attention_output)
        return output

In [247]:
mAttention = MultiHeadAttention(10, 2)
x = torch.rand((2, 2, 10))
a = mAttention(x, x, x)
print(x)
print(a)

tensor([[[0.4789, 0.4563, 0.8299, 0.9584, 0.7924, 0.7973, 0.6738, 0.6807,
          0.7644, 0.3278],
         [0.3840, 0.8490, 0.0580, 0.3516, 0.9290, 0.1023, 0.4729, 0.3873,
          0.0588, 0.4671]],

        [[0.3305, 0.8745, 0.8956, 0.7049, 0.0280, 0.7392, 0.5881, 0.5698,
          0.7126, 0.5654],
         [0.0745, 0.8963, 0.4403, 0.5683, 0.4776, 0.0545, 0.2787, 0.8023,
          0.7805, 0.9913]]])
tensor([[[ 0.0386, -0.0310, -0.0048,  0.4215, -0.4450, -0.0599,  0.5028,
           0.3107,  0.1961,  0.5020],
         [ 0.0385, -0.0319, -0.0041,  0.4222, -0.4451, -0.0609,  0.5040,
           0.3112,  0.1966,  0.5031]],

        [[ 0.0248, -0.1939, -0.0070,  0.2906, -0.3201, -0.0891,  0.5954,
           0.3117,  0.4549,  0.7071],
         [ 0.0253, -0.1940, -0.0078,  0.2893, -0.3208, -0.0881,  0.5951,
           0.3112,  0.4552,  0.7076]]], grad_fn=<ViewBackward0>)


In [87]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, embed_size, ff_hidden_size, dropout_rate=0.1):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden_size)
        self.fc2 = nn.Linear(ff_hidden_size, embed_size)
        self.dropout = nn.Dropout(dropout_rate)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 0)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [250]:
ffn = FeedForwardNetwork(10, 100, 0.2)
x = torch.rand((1, 2, 10))
o = ffn(x)
print(x)
print(o)

tensor([[[0.3819, 0.2336, 0.0704, 0.3549, 0.9188, 0.8499, 0.6359, 0.0572,
          0.9218, 0.6735],
         [0.0855, 0.4643, 0.6691, 0.5916, 0.7340, 0.1227, 0.4601, 0.3107,
          0.8145, 0.4057]]])
tensor([[[-0.3285, -0.3927, -0.1556, -0.0652, -0.5466,  0.0249, -0.0745,
          -0.4057, -0.0870, -0.0800],
         [-0.2573, -0.4286, -0.2811,  0.1576, -0.4314,  0.0887, -0.2062,
          -0.1129, -0.1918,  0.1081]]], grad_fn=<ViewBackward0>)


In [88]:
def _init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.LayerNorm):
        nn.init.constant_(module.bias, 0)
        nn.init.constant_(module.weight, 1.0)

In [89]:

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, ff_hidden_size, num_heads=8, dropout_rate=0.2):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = FeedForwardNetwork(embed_dim, ff_hidden_size, dropout_rate)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.apply(_init_weights)

    def forward(self, value, key, query):
        norm_query = self.norm1(query)
        attention_out = self.attention(value, key, norm_query)
        query = query + self.dropout1(attention_out) # Residual connection
        norm_query = self.norm2(query)
        feed_fwd_out = self.feed_forward(norm_query)
        query = query + self.dropout2(feed_fwd_out) # Residual connection
        return query

In [90]:
class TransformerEncoder(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, ff_hidden_size=2048, num_heads=8):
        super(TransformerEncoder, self).__init__()
        self.embedding_layer = Embedding(vocab_size, embed_dim)
        self.positional_encoder = PositionalEncoding(seq_len, embed_dim)
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, ff_hidden_size, num_heads) for i in range(num_layers)])

    def forward(self, x):
        embed_out = self.embedding_layer(x)
        out = self.positional_encoder(embed_out)
        for layer in self.layers:
            out = layer(out, out, out)
        return out  #32x10x512

In [91]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, ff_hidden_size, num_heads=8, dropout_rate=0.2):
        super(DecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.transformer_block = TransformerBlock(embed_dim, ff_hidden_size, num_heads, dropout_rate)
        self.dropout = nn.Dropout(dropout_rate)
        self.apply(_init_weights)

    def forward(self, key, query, value, mask):
        norm_query = self.norm1(query)
        attention = self.attention(value, value, norm_query, mask=mask)
        query = query + self.dropout(attention)
        return self.transformer_block(key, query, value)

In [92]:
class TransformerDecoder(nn.Module):
    def __init__(self, target_vocab_size, embed_dim, seq_len, num_layers=2, ff_hidden_size=2048, num_heads=8, dropout_rate=0.2):
        super(TransformerDecoder, self).__init__()
        self.word_embedding = nn.Embedding(target_vocab_size, embed_dim)
        self.position_embedding = PositionalEncoding(seq_len, embed_dim)
        self.layers = nn.ModuleList([DecoderBlock(embed_dim, ff_hidden_size, num_heads, dropout_rate) for _ in range(num_layers)])
        self.fc_out = nn.Linear(embed_dim, target_vocab_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, enc_out, mask):
        x = self.word_embedding(x)  # Shape: (batch_size, seq_len, embed_dim)
        x = self.position_embedding(x) # Shape: (batch_size, seq_len, embed_dim)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(enc_out, x, enc_out, mask)
        return F.softmax(self.fc_out(x), dim=-1) # Shape: (batch_size, seq_len, target_vocab_size)

In [93]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, src_vocab_size, target_vocab_size, seq_length, num_layers=2, ff_hidden_size=2048, num_heads=8):
        super(Transformer, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.encoder = TransformerEncoder(seq_length, src_vocab_size, embed_dim, num_layers=num_layers, ff_hidden_size=ff_hidden_size, num_heads=num_heads)
        self.decoder = TransformerDecoder(target_vocab_size, embed_dim, seq_length, num_layers=num_layers, ff_hidden_size=ff_hidden_size, num_heads=num_heads)

    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        )
        return trg_mask

    def decode(self,src,trg):
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src)
        out_labels = []
        batch_size,seq_len = src.shape[0],src.shape[1]
        out = trg
        for i in range(seq_len): #10
            out = self.decoder(out,enc_out,trg_mask) #bs x seq_len x vocab_dim
            out = out[:,-1,:]
            out = out.argmax(-1)
            out_labels.append(out.item())
            out = torch.unsqueeze(out,axis=0)
        return out_labels

    def forward(self, src, trg):
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src)
        outputs = self.decoder(trg, enc_out, trg_mask)
        return outputs

In [102]:
# Set hyperparameters
src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length = 12
embed_dim = 512
ff_hidden_size = 2048  # Feed-forward hidden layer size
num_heads = 8

src_sequence = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                             [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target_sequence = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1],
                                [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])
print("Shape of source sequence:", src_sequence.shape)
print("Shape of target sequence:", target_sequence.shape)
model = Transformer(embed_dim=embed_dim, src_vocab_size=src_vocab_size,
                    target_vocab_size=target_vocab_size, seq_length=seq_length,
                    num_layers=num_layers, ff_hidden_size=ff_hidden_size, num_heads=num_heads)

Shape of source sequence: torch.Size([2, 12])
Shape of target sequence: torch.Size([2, 12])


In [105]:
model = Transformer(embed_dim=4, src_vocab_size=3,
                    target_vocab_size=3, seq_length=2,
                    num_layers=2, ff_hidden_size=10, num_heads=2)
print(model.encoder.embedding_layer.embedding.weight.data.cpu().numpy())

[[-0.52016807 -0.8028974  -2.3331447  -0.781131  ]
 [-0.9884004   0.37007698  0.4746922   1.450523  ]
 [ 0.602024   -0.35471794  0.213776    0.3411836 ]]
