**FeedForward Network**: A simple feed-forward network applied after the multi-head attention mechanism.

**Scaled Dot-Product Attention**: The core attention mechanism that computes attention scores using the query, key, and value.

**Multi-Head Attention**: Multi-head attention mechanism where we have multiple attention heads running in parallel.

**Positional Encoding**: Since Transformers don’t have any built-in recurrence or convolution, we add positional encodings to the input embeddings to give the model some sense of word order.

**Encoder and Decoder Layers**: The encoder and decoder layers implement the attention mechanisms and feed-forward networks as described in the Transformer paper.

**Transformer**: The model itself consists of several encoder layers, decoder layers, and a final linear layer to produce outputs.

*Usage:*

This model can be used for tasks such as machine translation, text generation, and more.
The vocab_size and sequence lengths are set as placeholders. You can adjust them based on your dataset.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# Position-wise Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(self.relu(self.linear1(x))))

# Scaled Dot-Product Attention
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        score = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))

        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)

        attention = torch.softmax(score, dim=-1)
        attention = self.dropout(attention)

        return torch.matmul(attention, value)

# Multi-Head Attention Mechanism
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0  # Make sure d_model is divisible by num_heads

        self.d_k = d_model // num_heads  # Dimension of each head
        self.num_heads = num_heads
        self.attn = ScaledDotProductAttention(dropout)

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections for each head
        query = self.query(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = self.key(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = self.value(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention to each head
        attention = self.attn(query, key, value, mask)

        # Concat heads and apply final linear layer
        attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        output = self.out(attention)
        return output

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.attn(x, x, x, mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        ff_output = self.ff(x)
        x = self.layernorm2(x + self.dropout(ff_output))
        return x

# Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.attn1 = MultiHeadAttention(d_model, num_heads, dropout)
        self.attn2 = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, tgt_mask=None, memory_mask=None):
        attn_output1 = self.attn1(x, x, x, tgt_mask)
        x = self.layernorm1(x + self.dropout(attn_output1))
        attn_output2 = self.attn2(x, memory, memory, memory_mask)
        x = self.layernorm2(x + self.dropout(attn_output2))
        ff_output = self.ff(x)
        x = self.layernorm3(x + self.dropout(ff_output))
        return x

# Transformer Model
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, num_encoder_layers, num_decoder_layers, d_ff, vocab_size, max_len=5000, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_encoder_layers)
        ])

        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_decoder_layers)
        ])

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        src = self.embedding(src) * math.sqrt(src.size(-1))  # Scaled embedding
        src = self.positional_encoding(src)

        tgt = self.embedding(tgt) * math.sqrt(tgt.size(-1))  # Scaled embedding
        tgt = self.positional_encoding(tgt)

        # Encoder pass
        memory = src
        for layer in self.encoder_layers:
            memory = layer(memory, src_mask)

        # Decoder pass
        output = tgt
        for layer in self.decoder_layers:
            output = layer(output, memory, tgt_mask, memory_mask)

        output = self.fc_out(output)
        return output

# Example of creating a Transformer model
model = Transformer(
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    d_ff=2048,
    vocab_size=10000,  # Example vocab size
)

# Example input
src = torch.randint(0, 10000, (32, 10))  # Batch of 32, source sequence length 10
tgt = torch.randint(0, 10000, (32, 10))  # Batch of 32, target sequence length 10

# Forward pass
output = model(src, tgt)
print(output.shape)  # (batch_size, target_seq_len, vocab_size)

#Output:
#For each of the 32 sequences in the batch, for each position in the 10-token target sequence,
#the output gives a vector of size vocab_size (10000 in this case).
#This vector represents the model's prediction for the likelihood of each word in the vocabulary
#being the correct token at that position in the sequence.


torch.Size([32, 10, 10000])
