#### I have implemented self-attention, multi-headed self attention and finally the transformer block

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [3]:
class self_attention(nn.Module):

    def __init__(self, embed_dim):
        super(self_attention,self).__init__()
            
        self.embed_dim = embed_dim
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

    
    def forward(self, x):

        Q = self.query(x) ## Batch x Seq_len x Embed_dim
        K = self.key(x)
        V = self.value(x)

        ## Q = Batch x Seq_len x Embed_dim
        ## K_t= Batch x Embed_dim x Seq_len
        attn_scores = torch.matmul(Q, K.transpose(-2,-1)) / (self.embed_dim ** 0.5)
        attn_weight = F.softmax(attn_scores, dim=-1) ## Batch x Seq_len x Seq_len

        out = torch.matmul(attn_weight, V) ## Batch x Seq_len x Embed_dim
        return out, attn_weight
    

#  Q, K, V are linear projections of the input.

# We take dot product QK^T to measure similarity between tokens.

# Divide by √(d_k) for numerical stability.

# Softmax → weights → weighted sum of values (context vector).

# Output has same shape as input.


In [4]:
class multiheadedattention(nn.Module):

    def __init__(self,embed_dim,num_heads):
        super(multiheadedattention ,self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        self.out = nn.Linear(embed_dim, embed_dim)  ## Final output linear layer to combine heads

    
    def forward(self,x):
        B, T, D = x.size()

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Resshape into (B, num_heads, T, head_dim)
        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V) ## (B, num_heads, T, head_dim)

        ## concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(B, T, D)
        out = self.out(attention_output)  ## Final linear layer

        return out, attention_weights
       

In [5]:
class feedforward(nn.Module):
    def __init__(self, embed_dim, ff_hidden_dim, dropout=0.1):
        super(feedforward,self).__init__()
        self.net = nn.Sequential(nn.Linear(embed_dim, ff_hidden_dim),
                                 nn.ReLU(),
                                 nn.Linear(ff_hidden_dim, embed_dim),
                                 nn.Dropout(dropout))
        
    def forward(self,x):
        return self.net(x)

In [6]:
class transformerblock(nn.Module):
    def __init__(self, embed_dim,num_heads,ff_hidden_dim, dropout=0.1):
        super(transformerblock,self).__init__()
        self.attention = multiheadedattention(embed_dim,num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.feed_forward = feedforward(embed_dim, ff_hidden_dim, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        attention_out, _ = self.attention(x)
        x = x+ self.dropout(attention_out) ## Residual connection
        x = self.norm1(x)


        ff_out = self.feed_forward(x)
        x = x + self.dropout(ff_out)  ## Residual connection
        x = self.norm2(x)

        return x
 