In [None]:
# Implement self attention mechanism
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)  # Query
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)  # Key
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)  # Value
        self.scale = torch.sqrt(torch.tensor(embed_dim, dtype=torch.float32))

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len, embed_dim)
        mask: (batch_size, seq_len, seq_len) - Optional mask for attention
        """
        Q = self.W_q(x)  # (batch_size, seq_len, embed_dim)
        K = self.W_k(x)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x)  # (batch_size, seq_len, embed_dim)

        # Compute attention scores (batch_size, seq_len, seq_len)
        attn_scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale

        # Apply mask (for causal masking or padding tokens)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        # Compute attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)  # (batch_size, seq_len, seq_len)

        # Compute weighted values
        output = torch.bmm(attn_weights, V)  # (batch_size, seq_len, embed_dim)

        return output, attn_weights

# Example usage
batch_size = 2
seq_len = 4
embed_dim = 8

x = torch.randn(batch_size, seq_len, embed_dim)  # Random input
self_attn = SelfAttention(embed_dim)
output, attn_weights = self_attn(x)

print("Output Shape:", output.shape)  # (batch_size, seq_len, embed_dim)
print("Attention Weights Shape:", attn_weights.shape)  # (batch_size, seq_len, seq_len)

In [None]:
### Attention Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn

class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        """
        tgt: (seq_len, batch_size, embed_dim) - Target sequence input
        memory: (seq_len, batch_size, embed_dim) - Encoder output
        tgt_mask: (seq_len, seq_len) - Causal mask for autoregressive decoding
        memory_mask: (seq_len, seq_len) - Mask for encoder-decoder attention
        """
        # Self-Attention (Masked)
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)[0]
        tgt = self.norm1(tgt + self.dropout(tgt2))

        # Cross-Attention (Attend to Encoder Output)
        tgt2 = self.cross_attn(tgt, memory, memory, attn_mask=memory_mask)[0]
        tgt = self.norm2(tgt + self.dropout(tgt2))

        # Feedforward Network
        tgt2 = self.feedforward(tgt)
        tgt = self.norm3(tgt + self.dropout(tgt2))

        return tgt

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, ff_dim, vocab_size, max_seq_len, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = self._generate_positional_encoding(max_seq_len, embed_dim)
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def _generate_positional_encoding(self, max_seq_len, embed_dim):
        """ Generate sinusoidal positional encoding """
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embed_dim))
        pos_enc = torch.zeros(max_seq_len, embed_dim)
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(1)  # Shape: (seq_len, 1, embed_dim)

    def generate_mask(self, seq_len):
        """ Generate causal mask for self-attention """
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        return mask.masked_fill(mask == 1, float('-inf'))

    def forward(self, tgt, memory):
        """
        tgt: (batch_size, seq_len) - Target token indices
        memory: (seq_len, batch_size, embed_dim) - Encoder outputs
        """
        batch_size, seq_len = tgt.shape
        tgt_embed = self.embedding(tgt) + self.positional_encoding[:seq_len, :]

        # Reshape to (seq_len, batch_size, embed_dim) for MHA compatibility
        tgt_embed = tgt_embed.permute(1, 0, 2)  # Transpose to match (seq_len, batch_size, embed_dim)
        tgt_mask = self.generate_mask(seq_len).to(tgt.device)

        for layer in self.layers:
            tgt_embed = layer(tgt_embed, memory, tgt_mask)

        logits = self.fc_out(tgt_embed)  # Convert to vocabulary probabilities
        return logits.permute(1, 0, 2)  # Reshape back to (batch_size, seq_len, vocab_size)
    