In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

embed_dim = 64
num_heads = 8
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)


In [4]:
query = torch.rand(6, 32, embed_dim)  # (sequence_length, batch_size, embed_dim)
key = torch.rand(10, 32, embed_dim)
value = torch.rand(10, 32, embed_dim)

attn_output, attn_output_weights = multihead_attn(query, key, value)

In [5]:
attn_output.shape

torch.Size([6, 32, 64])

In [6]:
key_padding_mask = torch.zeros(32, 10, dtype=torch.bool) 
key_padding_mask[:, 5:] = 1  # Mask out positions after the 5th token

attn_output, attn_output_weights = multihead_attn(query, key, value, key_padding_mask=key_padding_mask)
attn_output, attn_output_weights


(tensor([[[-0.1369,  0.1961,  0.2079,  ..., -0.1719, -0.1980, -0.0251],
          [-0.0897,  0.1646,  0.2608,  ..., -0.1870, -0.2238,  0.0297],
          [-0.1206, -0.0445,  0.2325,  ..., -0.2061, -0.1994,  0.0830],
          ...,
          [-0.1698,  0.1564,  0.1423,  ..., -0.0906, -0.1583, -0.0686],
          [-0.1446,  0.0404,  0.1077,  ..., -0.1331, -0.1693,  0.0088],
          [-0.1325,  0.0605,  0.1609,  ..., -0.1453, -0.2163,  0.0154]],
 
         [[-0.1383,  0.1979,  0.2113,  ..., -0.1676, -0.1983, -0.0230],
          [-0.0906,  0.1603,  0.2635,  ..., -0.1909, -0.2201,  0.0315],
          [-0.1182, -0.0488,  0.2319,  ..., -0.2063, -0.1986,  0.0826],
          ...,
          [-0.1649,  0.1524,  0.1416,  ..., -0.0879, -0.1607, -0.0731],
          [-0.1439,  0.0390,  0.1072,  ..., -0.1326, -0.1730,  0.0078],
          [-0.1275,  0.0629,  0.1581,  ..., -0.1431, -0.2167,  0.0156]],
 
         [[-0.1417,  0.1946,  0.2076,  ..., -0.1701, -0.1948, -0.0224],
          [-0.0890,  0.1609,

In [7]:
key_padding_mask = torch.zeros(32, 10, dtype=torch.bool)
key_padding_mask[:, 5:] = 1  # Mask out positions after the 5th token


In [10]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2, _ = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

# Instantiate the layer
embed_dim = 512
num_heads = 8
layer = TransformerEncoderLayer(embed_dim, num_heads)
dummy_input = torch.rand(10, 32, embed_dim)

# Forward pass through the layer
output = layer(dummy_input)
print(output.shape)

torch.Size([10, 32, 512])
