In [1]:

import torch
import torch.nn as nn

In [2]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]

        # Linear transformations
        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)

        # Split into heads
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        # Scaled Dot-Product Attention
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-1e20'))

        attention = torch.nn.functional.softmax(energy, dim=-1)
        x = torch.matmul(attention, V)

        # Reshape and concatenate
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.d_model)

        # Final linear layer
        x = self.fc_out(x)

        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Self-attention
        attention = self.self_attention(x, x, x, mask)
        x = x + self.dropout(attention)
        x = self.layer_norm1(x)

        # Feedforward
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.layer_norm2(x)

        return x

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, n_layers, max_seq_length, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads, ff_dim, dropout)
            for _ in range(n_layers)
        ])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.position_embedding = nn.Embedding(max_seq_length, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        positions = torch.arange(0, x.size(1)).expand(x.size(0), x.size(1)).to(self.device)
        x = x + self.position_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        return x

# Example usage:
d_model = 512
n_heads = 8
ff_dim = 2048
n_layers = 6
max_seq_length = 100
dropout = 0.1

# Create transformer encoder
transformer_encoder = TransformerEncoder(d_model, n_heads, ff_dim, n_layers, max_seq_length, dropout)

# Dummy input
input_data = torch.rand((16, 100, d_model))

# Mask for padding
padding_mask = (input_data.sum(dim=-1) != 0).unsqueeze(1).unsqueeze(2)

# Forward pass
output_data = transformer_encoder(input_data, padding_mask)
print(output_data)
print("Output shape:", output_data.shape)


tensor([[[-0.8408, -0.8154,  1.0487,  ...,  0.3487, -0.5811, -0.4857],
         [-0.2406, -0.1920,  1.2893,  ..., -0.3320,  0.3077, -0.5451],
         [-1.7497, -0.3533,  0.7075,  ..., -0.9475,  1.7996,  0.8380],
         ...,
         [-0.1005,  0.2460,  1.8928,  ...,  0.5881,  0.3590, -0.3711],
         [-0.5257, -0.6494,  0.1706,  ...,  1.7171,  0.5511,  0.5390],
         [-1.7992, -0.7236,  0.9918,  ..., -0.5738,  1.0991, -0.0445]],

        [[ 0.0956, -0.5797,  0.5266,  ...,  0.2994, -0.9134,  0.3147],
         [ 0.5751, -0.1986,  1.5392,  ..., -0.0737,  0.4916, -0.8520],
         [-2.0388, -0.1143,  0.9389,  ..., -1.4413,  1.6410,  1.2478],
         ...,
         [-0.3299, -0.0159,  0.9269,  ..., -0.6512,  0.9223,  0.1962],
         [-0.3567, -0.5883,  1.6152,  ...,  1.8126,  0.9458,  1.1106],
         [-1.4371, -0.2394, -0.0975,  ..., -0.0598,  0.9240, -0.6363]],

        [[-0.9640, -0.8538,  0.4637,  ...,  0.1433, -1.1513, -0.3659],
         [ 1.0318, -0.6533,  1.5773,  ...,  0