In [2]:
import torch
import torch.nn as nn
import math

In [2]:
# check cuda is available or not 
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, dim):
        super().__init__()
        self.query_w = nn.Linear(embed_dim, dim)
        self.key_w = nn.Linear(embed_dim, dim)
        self.value_w = nn.Linear(embed_dim, dim)
        self.softmax = nn.Softmax(dim=-1)  
    def forward(self, embed):
        query = self.query_w(embed)
        key = self.key_w(embed)
        value = self.value_w(embed)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** 0.5)
        attn_weights = self.softmax(scores)
        attended = torch.matmul(attn_weights, value)
        return attended

In [6]:
class MultiheadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, head_dim):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.embed_dim = embed_dim

        self.multi_head_attn = nn.ModuleList([
            SelfAttention(embed_dim, head_dim) for _ in range(num_heads)
        ])
        self.W = nn.Linear(num_heads * head_dim, embed_dim)

    def forward(self, embed):
        heads = [head(embed) for head in self.multi_head_attn]
        heads_cat = torch.cat(heads, dim=-1)
        output = self.W(heads_cat)
        return output


        

In [3]:
class LayerNormalization(nn.Module):
    def __init__(self, embed_dim, eps=1e-5):
        super(LayerNormalization, self).__init__()
        self.alpha = nn.Parameter(torch.ones(embed_dim))  
        self.beta = nn.Parameter(torch.zeros(embed_dim))  
        self.eps = eps

    def forward(self, embed):
        mean = embed.mean(dim=-1, keepdim=True)
        var = embed.var(dim=-1, keepdim=True, unbiased=False)
        normalized = (embed - mean) / torch.sqrt(var + self.eps)

        return self.alpha * normalized + self.beta


In [4]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.W1 = nn.Linear(embed_dim, embed_dim)
        self.W2 = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, embed):
        x = self.W1(embed)
        x = self.dropout(x)
        x = self.W2(x)  
        return x


In [9]:
class Encoder(nn.Module):
    def __init__(self, num_heads, embed_dim, head_dim):
        super().__init__()  
        self.multiheadattention = MultiheadAttention(num_heads, embed_dim, head_dim)
        self.layernorm1 = nn.LayerNorm(embed_dim)  
        self.feedforward = FeedForward(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        atten_x = self.multiheadattention(x)
        x = self.layernorm1(atten_x + x)
        ff_out = self.feedforward(x)
        x = self.layernorm2(ff_out + x)
        return x

In [8]:
class StackEncoder(nn.Module):
    def __init__(self, num_heads, embed_dim, head_dim):
        super().__init__()
        self.encoders = nn.Sequential(
            *[Encoder(num_heads, embed_dim, head_dim) for _ in range(6)]
        )
    
    def forward(self, x):
        return self.encoders(x)

In [37]:
def mask_mat(len_seq):
    mask_mat = torch.zeros(len_seq, len_seq)
    for i in range(len_seq):
        for j in range(len_seq):
            if i < j:
                mask_mat[i][j] = float("-inf")
    return mask_mat

In [18]:
class MaskAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.q_w = nn.Linear(embed_dim, embed_dim)
        self.k_w = nn.Linear(embed_dim, embed_dim)
        self.v_w = nn.Linear(embed_dim, embed_dim) 
        self.softmax = nn.Softmax(dim=-1)  

    def forward(self, embed, mask_mat=None):
        query = self.q_w(embed)
        key = self.k_w(embed)
        value = self.v_w(embed)
        atten_score = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** 0.5)
         
        if mask_mat is not None:
            atten_score = atten_score + mask_mat  
            
        mask_atten_weight = self.softmax(atten_score)
        attended = torch.matmul(mask_atten_weight, value)
        return attended


In [19]:
class MaskMultiheadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        self.multi_head_attn = nn.ModuleList([
            MaskAttention(embed_dim,embed_dim) for _ in range(num_heads)
        ])
        self.W = nn.Linear(num_heads * embed_dim, embed_dim)

    def forward(self, embed, mask_mat):
        heads = [head(embed, mask_mat) for head in self.multi_head_attn]
        heads_cat = torch.cat(heads, dim=-1)
        output = self.W(heads_cat)
        return output


In [20]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.q_w = nn.Linear(embed_dim, embed_dim)
        self.k_w = nn.Linear(embed_dim, embed_dim)
        self.v_w = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, y):
        query = self.q_w(y)
        key = self.k_w(x)
        value = self.v_w(x)
        atten_score = torch.matmul(query , key.transpose(-2,-1)) / (key.shape[-1] ** 0.5)
        atten_weight = self.softmax(atten_score)
        attention = torch.matmul(atten_weight, value)
        return attention

In [22]:
class MultiheadCrossAttention(nn.Module):
    def __init__(self,embed_dim, num_head):
        super().__init__()
        self.atten_list = nn.ModuleList([CrossAttention(embed_dim) for _ in range(num_head)])
        self.W = nn.Linear(embed_dim * num_head, embed_dim)

    def forward(self, x, y):
        heads = [ head(x,y) for head in self.atten_list]
        heads_cat = torch.cat(heads, dim=-1)
        out = self.W(heads_cat)
        return out

In [24]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mask_attention = MaskMultiheadAttention(num_heads, embed_dim)
        self.layer_norm_1 = LayerNormalization(embed_dim)
        self.cross_attention = MultiheadCrossAttention(embed_dim, num_heads)
        self.layer_norm_2 = LayerNormalization(embed_dim)
        self.feed_forward = FeedForward(embed_dim)
        self.layer_norm_3 = LayerNormalization(embed_dim)

    def forward(self, x, y, mask_mat):
        mask_atten = self.mask_attention(y, mask_mat)
        y_norm = self.layer_norm_1(mask_atten + y)
        cross_atten = self.cross_attention(x, y_norm)
        cross_norm = self.layer_norm_2(cross_atten + x)
        ff_out = self.feed_forward(cross_norm)
        out_norm = self.layer_norm_3(ff_out + cross_norm)
        return out_norm


In [38]:
class DecoderStack(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.decoders = nn.ModuleList([
            Decoder(embed_dim, num_heads) for _ in range(6)
        ])

    def forward(self, x, y, mask_mat):
        for decoder in self.decoders:
            y = decoder(x, y, mask_mat)
        return y


In [33]:
def even_position(p, i, dim):
    return math.sin(p / (10000 ** ((2 * i) / dim)))

def odd_position(p, i, dim):
    return math.cos(p / (10000 ** ((2 * i) / dim)))

def positional_encoding(tokens_len, embed_dim):
    positional_encodings = []
    for p in range(tokens_len):
        token_position = []
        for i in range(embed_dim):
            if i % 2 == 0:
                token_position.append(even_position(p, i, embed_dim))
            else:
                token_position.append(odd_position(p, i, embed_dim))
        positional_encodings.append(torch.tensor(token_position))
    return torch.stack(positional_encodings)


In [36]:
positional_encoding(10, 1024).shape

torch.Size([10, 1024])

In [39]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads, embed_dim, vocab_size):
        super().__init__()
        self.encoder = StackEncoder(num_heads, embed_dim, embed_dim)
        self.decoders = DecoderStack(embed_dim, num_heads)
        self.linear = nn.Linear(embed_dim, vocab_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, y, mask_mat, postional_encoding):
        x = x + postional_encoding
        y = y + postional_encoding
        x = self.encoder(x)
        y = self.decoders(x,y, mask_mat)
        out = self.linear(y)
        out = self.softmax(out)
        return out