In [1]:
import torch
import torch.nn as nn

In [2]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Initialize linear layers for query, key, and value
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        # Initialize linear layer for output
        self.out = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Compute multi-head attention
        query = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = self.key(key).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = self.value(value).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Compute attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Compute attention weights and apply to value
        attention_weights = nn.functional.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, value)

        # Reshape and apply output linear layer
        context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.out(context)

        return output

In [4]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Multi-head attention
        attended = self.multi_head_attention(x, x, x, mask)
        attended = self.norm1(attended + self.dropout(x))

        # Feed-forward
        forwarded = self.feed_forward(attended)
        output = self.norm2(forwarded + attended)

        return output

class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(TransformerDecoderBlock, self).__init__()
        self.multi_head_attention1 = MultiHeadAttention(d_model, num_heads)
        self.multi_head_attention2 = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, src_mask, tgt_mask):
        # Self-attention
        attended = self.multi_head_attention1(x, x, x, tgt_mask)
        query = self.norm1(attended + self.dropout(x))

        # Encoder-decoder attention
        attended = self.multi_head_attention2(query, memory, memory, src_mask)
        query = self.norm2(attended + query)

        # Feed-forward
        forwarded = self.feed_forward(query)
        output = self.norm3(forwarded + query)

        return output

In [5]:
class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        
        self.encoder_layers = nn.ModuleList([TransformerEncoderBlock(d_model, num_heads, dropout) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([TransformerDecoderBlock(d_model, num_heads, dropout) for _ in range(num_decoder_layers)])
        
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        src_emb = self.positional_encoding(self.src_embedding(src))
        tgt_emb = self.positional_encoding(self.tgt_embedding(tgt))
        
        memory = self._encode(src_emb, src_mask)
        output = self._decode(tgt_emb, memory, src_mask, tgt_mask)
        
        return self.linear(output)
    
    def _encode(self, src, src_mask):
        for layer in self.encoder_layers:
            src = layer(src, src_mask)
        return src
    
    def _decode(self, tgt, memory, src_mask, tgt_mask):
        for layer in self.decoder_layers:
            tgt = layer(tgt, memory, src_mask, tgt_mask)
        return tgt