In [8]:
import torch
import torch.nn as nn

embed_dim = 64
num_heads = 8
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)


In [3]:
query = torch.rand(6, 32, embed_dim)  # (sequence_length, batch_size, embed_dim)
key = torch.rand(10, 32, embed_dim)
value = torch.rand(10, 32, embed_dim)

attn_output, attn_output_weights = multihead_attn(query, key, value)

In [4]:
attn_output.shape

torch.Size([6, 32, 64])

In [5]:
key_padding_mask = torch.zeros(32, 10, dtype=torch.bool) 
key_padding_mask[:, 5:] = 1  # Mask out positions after the 5th token

attn_output, attn_output_weights = multihead_attn(query, key, value, key_padding_mask=key_padding_mask)
attn_output, attn_output_weights


(tensor([[[ 1.5192e-01, -2.9437e-02,  5.4236e-01,  ..., -3.2966e-02,
            1.3585e-01, -4.9723e-02],
          [ 1.0286e-01, -2.5614e-03,  4.7557e-01,  ...,  2.4605e-02,
            1.5624e-01, -5.9461e-02],
          [ 1.0709e-01, -1.2067e-01,  4.8760e-01,  ...,  1.0833e-01,
            1.0579e-01,  1.6395e-02],
          ...,
          [ 1.2937e-01, -1.1613e-01,  5.2437e-01,  ..., -1.4627e-02,
            1.4489e-01, -7.5824e-02],
          [ 2.3998e-03, -9.3569e-02,  5.2716e-01,  ..., -5.3976e-02,
            1.7679e-01, -1.5284e-01],
          [ 6.2723e-02, -1.3448e-01,  5.4299e-01,  ..., -5.4690e-02,
            1.5695e-01, -4.9929e-02]],
 
         [[ 1.5290e-01, -2.9095e-02,  5.4099e-01,  ..., -3.0139e-02,
            1.3850e-01, -5.1365e-02],
          [ 1.0303e-01, -1.2757e-03,  4.7245e-01,  ...,  2.2874e-02,
            1.5828e-01, -5.6766e-02],
          [ 1.0563e-01, -1.2055e-01,  4.8639e-01,  ...,  1.1004e-01,
            1.0627e-01,  1.6607e-02],
          ...,
    

In [6]:
key_padding_mask = torch.zeros(32, 10, dtype=torch.bool)
key_padding_mask[:, 5:] = 1  # Mask out positions after the 5th token


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2, _ = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

# Instantiate the layer
embed_dim = 512
num_heads = 8
layer = TransformerEncoderLayer(embed_dim, num_heads)
dummy_input = torch.rand(10, 32, embed_dim)

# Forward pass through the layer
output = layer(dummy_input)
print(output.shape)

torch.Size([10, 32, 512])
