In [None]:
#Name :- Devashish Mayur Potnis
#Roll No :- 43557
#Practical No :- 4

In [2]:
import torch
import torch.nn as nn

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]

        # Linear transformations
        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)

        # Split into heads
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        # Scaled Dot-Product Attention
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-1e20'))

        attention = torch.nn.functional.softmax(energy, dim=-1)
        x = torch.matmul(attention, V)

        # Reshape and concatenate
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.d_model)

        # Final linear layer
        x = self.fc_out(x)

        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Self-attention
        attention = self.self_attention(x, x, x, mask)
        x = x + self.dropout(attention)
        x = self.layer_norm1(x)

        # Feedforward
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.layer_norm2(x)

        return x

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, n_layers, max_seq_length, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads, ff_dim, dropout)
            for _ in range(n_layers)
        ])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.position_embedding = nn.Embedding(max_seq_length, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        positions = torch.arange(0, x.size(1)).expand(x.size(0), x.size(1)).to(self.device)
        x = x + self.position_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        return x

# Example usage:
d_model = 512
n_heads = 8
ff_dim = 2048
n_layers = 6
max_seq_length = 100
dropout = 0.1

# Create transformer encoder
transformer_encoder = TransformerEncoder(d_model, n_heads, ff_dim, n_layers, max_seq_length, dropout)

# Dummy input
input_data = torch.rand((16, 100, d_model))

# Mask for padding
padding_mask = (input_data.sum(dim=-1) != 0).unsqueeze(1).unsqueeze(2)

# Forward pass
output_data = transformer_encoder(input_data, padding_mask)
print(output_data)
print("Output shape:", output_data.shape)


tensor([[[ 2.1190, -0.3070,  0.8684,  ..., -0.5964,  2.0397, -0.1870],
         [-0.1040,  0.7411,  0.3018,  ...,  0.3549,  0.9209,  1.1006],
         [-0.0560, -0.0112, -0.6999,  ..., -1.7957,  1.8597,  1.1984],
         ...,
         [-0.8701,  0.4968,  0.5823,  ...,  1.4080,  0.4889, -0.0852],
         [ 0.1658,  1.8266,  0.4646,  ...,  0.9672,  2.0625,  0.5057],
         [ 2.1088, -0.4424, -0.1533,  ..., -1.7459,  1.5994,  0.2504]],

        [[ 1.7615, -0.2889, -0.3891,  ..., -0.1173,  2.2502,  0.6596],
         [-0.0801,  0.9621,  1.0993,  ...,  0.2810,  0.4042,  1.0537],
         [-0.2148, -0.4786, -0.8580,  ..., -1.7894,  0.3084,  1.9672],
         ...,
         [-0.1644,  0.1798,  0.6835,  ...,  1.9033,  0.4577, -0.2276],
         [ 0.4457,  1.8287,  0.3376,  ...,  1.2803,  1.7753,  0.1414],
         [ 1.3813, -0.1967,  0.3527,  ..., -2.0245,  1.8690,  0.7219]],

        [[ 1.5487,  0.4550, -0.3476,  ...,  0.5577, -0.2677,  0.4746],
         [-0.2533,  1.2018, -0.3337,  ...,  1