<a href="https://colab.research.google.com/github/bhavyajethi/Deep-Learning-practice/blob/main/multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        attn_probs = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

    # Example usage
d_model = 8
num_heads = 4
query_seq_length = 5
ke_seq_length = 6
batch_size = 1

mha = MultiHeadAttention(d_model, num_heads)

# Create dummy input tensors
Q = torch.randn(batch_size, query_seq_length, d_model)
K =  V = torch.randn(batch_size, ke_seq_length, d_model)

# Create a dummy mask (optional)
mask = torch.ones(batch_size, 1, query_seq_length, ke_seq_length)

# Forward pass
output = mha(Q, K, V, mask)

#Self attention.


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear layers for Q, K, V
        self.linear_q = nn.Linear(embed_dim, embed_dim)
        self.linear_k = nn.Linear(embed_dim, embed_dim)
        self.linear_v = nn.Linear(embed_dim, embed_dim)

        # Final linear layer
        self.final_linear = nn.Linear(embed_dim, embed_dim)

        # Dropout layer
        self.dropout = nn.Dropout(p=0.1)

    def scaled_dot_product_attention(self, query, key, value):
        # Calculate the dot product
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply softmax to get attention scores
        attn = F.softmax(scores, dim=-1)

        # Apply dropout
        attn = self.dropout(attn)

        # Multiply by value
        output = torch.matmul(attn, value)
        return output

    def forward(self, query, key, value):
        batch_size = query.size(0)

        # Perform linear operation and split into h heads

        query_l = self.linear_q(query)
        query = query_l.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply scaled dot product attention
        attn_output = self.scaled_dot_product_attention(query, key, value)

        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        # combine and redistribute with a final layer
        output = self.final_linear(attn_output)

        return output

# Assuming the MultiheadAttention class is already defined as shown in the previous example

# Set the random seed for reproducibility
torch.manual_seed(0)

# Define the dimensions
embed_dim = 8  # Embedding size for each token
num_heads = 4  # Number of attention heads
seq_length = 5  # Length of the sequence

# Create an instance of MultiheadAttention
multihead_attn = MultiheadAttention(embed_dim, num_heads)

# Simulate a batch of token embeddings (batch size = 1)
token_embeddings = torch.randn(1, seq_length, embed_dim)

# Apply the MultiheadAttention layer (self-attention)
output = multihead_attn(token_embeddings, token_embeddings, token_embeddings)

print("Input Embeddings:")
print(token_embeddings)
print("\nOutput of Multihead Attention:")
print(output)

Input Embeddings:
tensor([[[ 5.5880e-01,  7.9176e-01, -1.8468e-01, -7.3177e-01, -8.0652e-02,
          -9.8006e-01,  6.0491e-02, -4.8895e-01],
         [-8.1373e-01,  8.1999e-01, -6.3317e-01,  1.2948e+00,  1.4628e+00,
          -6.2043e-01,  9.8839e-01, -4.3218e-01],
         [-6.2322e-01, -2.1625e-01, -4.8868e-01,  7.8696e-01,  1.0759e-01,
          -1.0715e+00, -1.1665e-01, -1.0170e+00],
         [ 1.1685e+00,  7.7037e-01,  3.9068e-01,  2.8959e-01, -2.7575e+00,
          -8.3236e-01,  4.8999e-01,  2.9082e-01],
         [-1.1311e+00, -9.3218e-04, -1.6269e-01, -2.4772e-01,  2.4197e+00,
           1.6456e+00, -3.0869e-01, -1.5147e+00]]])

Output of Multihead Attention:
tensor([[[-0.1431, -0.3057, -0.2482,  0.4651,  0.1098,  0.2106, -0.0636,
           0.0551],
         [-0.0478, -0.1567, -0.0146,  0.4996,  0.0789,  0.3726, -0.2046,
           0.3493],
         [ 0.0013, -0.2093, -0.0440,  0.4761,  0.0754,  0.3326, -0.1886,
           0.3104],
         [-0.2508, -0.2179, -0.1793,  0.4721

#With mask

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.linear_q = nn.Linear(embed_dim, embed_dim)
        self.linear_k = nn.Linear(embed_dim, embed_dim)
        self.linear_v = nn.Linear(embed_dim, embed_dim)

        self.final_linear = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(p=0.1)

    def scaled_dot_product_attention(self, query, key, value, attn_mask=None, key_padding_mask=None):
        # Scale dot product
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))

        if key_padding_mask is not None:
            # Assuming key_padding_mask is a ByteTensor with shape [batch_size, seq_len] where padding elements are True
            scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        output = torch.matmul(attn, value)
        return output

    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        batch_size = query.size(0)

        query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attn_output = self.scaled_dot_product_attention(query, key, value, attn_mask, key_padding_mask)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        output = self.final_linear(attn_output)

        return output

# Assuming the MultiheadAttention class is defined as in the previous example

# Set random seed for reproducibility
torch.manual_seed(0)

# Define dimensions
embed_dim = 8  # Embedding size for each token
num_heads = 2  # Number of attention heads
batch_size = 2
seq_length = 4  # Length of the sequence

# Create an instance of MultiheadAttention
multihead_attn = MultiheadAttention(embed_dim, num_heads)

# Simulate a batch of token embeddings
# Batch size is 2, sequence length is 4, embedding dimension is 8
token_embeddings = torch.randn(batch_size, seq_length, embed_dim)

# Create key padding mask
# Assume the second sequence in the batch has two padding tokens at the end
key_padding_mask = torch.tensor([
    [0, 0, 0, 0],  # No padding
    [0, 0, 1, 1]   # Last two are padding
], dtype=torch.bool)

# Create attention mask
# Prevent the first token from attending to the third token in each sequence
attn_mask = torch.zeros((batch_size, seq_length, seq_length), dtype=torch.bool)
attn_mask[:, 0, 2] = 1

# Apply the MultiheadAttention layer
output = multihead_attn(token_embeddings, token_embeddings, token_embeddings,
                        attn_mask=attn_mask, key_padding_mask=key_padding_mask)

print("Input Embeddings:")
print(token_embeddings)
print("\nOutput of Multihead Attention:")
print(output)

Input Embeddings:
tensor([[[ 0.5588,  0.7918, -0.1847, -0.7318, -0.0807, -0.9801,  0.0605,
          -0.4890],
         [-0.8137,  0.8200, -0.6332,  1.2948,  1.4628, -0.6204,  0.9884,
          -0.4322],
         [-0.6232, -0.2162, -0.4887,  0.7870,  0.1076, -1.0715, -0.1166,
          -1.0170],
         [-1.1980,  0.4784, -1.2295, -1.3700,  1.5435, -0.0332, -0.4186,
          -0.2556]],

        [[-0.1292, -0.0546,  0.4083,  1.1264,  1.9351,  1.0077,  1.0046,
          -0.4335],
         [-1.2426,  1.2846,  0.2438,  0.5304, -0.0145, -2.2357,  1.4660,
          -1.2191],
         [ 0.6442,  3.9300, -0.1244,  0.2953,  0.3827, -0.5497, -0.9940,
           1.3459],
         [ 1.9457, -1.2904, -2.3495, -2.0689,  0.9094, -0.6946,  1.9595,
          -1.1038]]])

Output of Multihead Attention:
tensor([[[-0.4881, -0.5333, -0.0413,  0.8922,  0.1431, -0.1657, -0.1078,
          -0.0045],
         [    nan,     nan,     nan,     nan,     nan,     nan,     nan,
              nan],
         [    na