In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

### Transformer Block Module


In [2]:
class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, num_heads=8):
        super(TransformerLayer, self).__init__()

        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
                          nn.Linear(embed_dim, expansion_factor * embed_dim),
                          nn.ReLU(),
                          nn.Linear(expansion_factor * embed_dim, embed_dim)
        )
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)

    def forward(self, key, query, value):
        attention_out, _ = self.self_attention(key, query, value) # Multi-head self-attention
        attention_residual_out = attention_out + value
        norm1_out = self.dropout1(self.norm1(attention_residual_out))

        # Feed-forward network
        feed_fwd_out = self.feed_forward(norm1_out)
        feed_fwd_residual_out = feed_fwd_out + norm1_out
        norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out))

        return norm2_out

### Transformer Encoder Module

In [3]:
class TransformerEncoder(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, expansion_factor=4, num_heads=8):
        super(TransformerEncoder, self).__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoder = nn.Embedding(seq_len, embed_dim)  # Simplified positional encoding
        self.layers = nn.ModuleList([TransformerLayer(embed_dim, expansion_factor, num_heads) for _ in range(num_layers)]) # Stack multiple transformer layers

    def forward(self, x):
        embedded_text = self.embedding_layer(x) # Word embedding
        out = embedded_text + self.positional_encoder.weight.unsqueeze(0) # Simplified positional encoding
        for layer in self.layers: # Transformer layers
            out = layer(out, out, out)
        return out

### Transformer Decoder Module


In [4]:
class TransformerDecoder(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, expansion_factor=4, num_heads=8):
        super(TransformerDecoder, self).__init__()

        self.embedding_layer = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoder = nn.Embedding(seq_len, embed_dim)  # Simplified positional encoding
        self.layers = nn.ModuleList([TransformerLayer(embed_dim, expansion_factor, num_heads) for _ in range(num_layers)])
        self.embed_dim = embed_dim

    def forward(self, x, encoder_output):
        embedded_text = self.embedding_layer(x) # Word embedding for target sequences
        out = embedded_text + self.positional_encoder.weight.unsqueeze(0) # Simplified positional encoding

        for layer in self.layers: # Transformer layers in the decoder
            out = layer(out, out, out)

        linear_layer = nn.Linear(self.embed_dim, tgt_vocab_size)  # Use self.embed_dim
        final_predictions = linear_layer(out)

        return final_predictions

### Modified Transformer Encoder-Decoder for Machine Translation


In [5]:
class TransformerEncoderDecoder(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, num_layers=2, expansion_factor=4, num_heads=8):
        super(TransformerEncoderDecoder, self).__init__()

        # Encoder for source sequences
        self.encoder = TransformerEncoder(
            seq_len=max_seq_length,
            vocab_size=src_vocab_size,
            embed_dim=embed_dim,
            num_layers=num_layers,
            expansion_factor=expansion_factor,
            num_heads=num_heads
        )

        # Decoder for target sequences
        self.decoder = TransformerDecoder(
            seq_len=max_seq_length,
            vocab_size=tgt_vocab_size,
            embed_dim=embed_dim,
            num_layers=num_layers,
            expansion_factor=expansion_factor,
            num_heads=num_heads
        )

    def forward(self, source_sequences, target_sequences):
        encoder_output = self.encoder(source_sequences) # Forward pass through the encoder
        final_predictions = self.decoder(target_sequences, encoder_output) # Forward pass through the decoder
        return final_predictions

### Example Usage: Machine Translation with Transformer Encoder-Decoder

In [6]:
src_vocab_size = 5000
tgt_vocab_size = 5000
max_seq_length = 30
batch_size = 4

# Create random tensors representing a batch of source and target sequences
source_sequences = torch.randint(0, src_vocab_size, (batch_size, max_seq_length))
target_sequences = torch.randint(0, tgt_vocab_size, (batch_size, max_seq_length))

# Create a Transformer Encoder-Decoder model for machine translation
translation_model = TransformerEncoderDecoder(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    embed_dim=512,
    num_layers=4,
    expansion_factor=4,
    num_heads=8
)

# Forward pass through the model to get translation predictions
translation_predictions = translation_model(source_sequences, target_sequences)

In [7]:
translation_model

TransformerEncoderDecoder(
  (encoder): TransformerEncoder(
    (embedding_layer): Embedding(5000, 512)
    (positional_encoder): Embedding(30, 512)
    (layers): ModuleList(
      (0-3): 4 x TransformerLayer(
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (decoder): TransformerDecoder(
    (embedding_layer): Embedding(5000, 512)
    (positional_encoder): Embedding(30, 512)
    (layers): ModuleList(
      (0-3): 4 x Tra

In [8]:
# Print the shape of the translation predictions
print("Translation Predictions Shape:", translation_predictions.shape)
print("\n\n",translation_predictions)

Translation Predictions Shape: torch.Size([4, 30, 5000])


 tensor([[[-0.2388, -0.3183, -0.1799,  ...,  0.3858,  0.0165, -0.6137],
         [ 0.2497, -0.3958,  0.2757,  ...,  1.0989, -0.6485,  0.2483],
         [ 0.0302, -0.3694,  0.1863,  ..., -0.5506,  1.0771, -0.7228],
         ...,
         [-0.4957, -0.0409, -0.2502,  ..., -0.2347, -0.4575, -0.0296],
         [ 1.2763,  0.5245, -0.6327,  ..., -0.0232,  1.4271,  0.1770],
         [ 0.6265, -0.2581,  0.1957,  ...,  0.0897, -0.0408,  0.0698]],

        [[-0.1906,  0.5440, -0.1165,  ...,  0.1000, -0.2049, -0.3806],
         [ 1.4077,  0.4157,  0.2377,  ..., -0.3097,  0.9587, -0.1685],
         [-0.1463, -0.0877,  0.2488,  ...,  0.7598,  1.1981,  1.4609],
         ...,
         [ 0.7917,  0.6685,  1.4523,  ...,  0.8139, -0.2970, -0.5302],
         [ 0.2495,  0.1917, -1.2684,  ...,  0.1425, -0.3022,  1.0394],
         [-0.2949, -0.5780, -0.0741,  ...,  0.3656, -0.6301,  0.3284]],

        [[-0.2424,  0.5086,  0.0317,  ..., -0.2032,  0.2