In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
# 1.Postional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe = torch.zeros(1, max_len, d_model)   # (1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)    # even index
        pe[0, :, 1::2] = torch.cos(position * div_term)    # odd index
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]     # add positional encoding to input tensor
        return x

In [3]:
# 2.Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        assert (
            self.d_k * n_heads == d_model
        ), f"d_model {d_model} not divisible by n_heads {n_heads}"

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q: (batch_size, n_heads, seq_len, d_k)
        # K: (batch_size, n_heads, seq_len, d_k)
        # V: (batch_size, n_heads, seq_len, d_k)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # (batch_size, n_heads, seq_len, seq_len)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)    # apply mask to scores
        
        attn_weights = F.softmax(scores, dim=-1)    # (batch_size, n_heads, seq_len, seq_len)
        attn_weights = self.dropout(attn_weights)    # apply dropout to attention weights
        output = torch.matmul(attn_weights, V)    # (batch_size, n_heads, seq_len, d_k)
        return output
    
    def forward(self, Q, K, V, mask=None):
        # Q: (batch_size, seq_len, d_model)
        # K: (batch_size, seq_len, d_model)
        # V: (batch_size, seq_len, d_model)

        batch_size = Q.size(0)

        # (batch_size, seq_len, d_model) -> (batch_size, n_heads, seq_len, d_k)
        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)    # (batch_size, n_heads, seq_len, d_k)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)    # (batch_size, n_heads, seq_len, d_k)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)    # (batch_size, n_heads, seq_len, d_k)

        # scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)    # (batch_size, n_heads, seq_len, d_k)

        # (batch_size, n_heads, seq_len, d_k) -> (batch_size, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)    # (batch_size, seq_len, d_model)
        output = self.W_o(attn_output)    # (batch_size, seq_len, d_model)
        return output    # (batch_size, seq_len, d_model)


       

In [4]:
# 3. Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = self.linear1(x)    # (batch_size, seq_len, d_ff)
        x = self.activation(x)    # (batch_size, seq_len, d_ff)
        x = self.dropout(x)    # (batch_size, seq_len, d_ff)
        x = self.linear2(x)    # (batch_size, seq_len, d_model)
        return x    # (batch_size, seq_len, d_model)

    

In [5]:
# 4.Transformer Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        # x: (batch_size, seq_len, d_model)
        attn_output = self.self_attn(x, x, x, mask)    # (batch_size, seq_len, d_model)
        x = self.norm1(x + self.dropout1(attn_output))    # add & norm
        
        ffn_output = self.ffn(x)    
        x = self.norm2(x + self.dropout2(ffn_output))    # add & norm
        return x    # (batch_size, seq_len, d_model)


In [6]:
# 5.Transformer Encoder
class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # x: (batch_size, seq_len, d_model)
        for layer in self.layers:
            x = layer(x, mask)   # (batch_size, seq_len, d_model)
        x = self.norm(x)    # (batch_size, seq_len, d_model)
        return x    # (batch_size, seq_len, d_model)

In [7]:
# 6. Transformer Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, tgt, src, tgt_mask=None, src_mask=None):
        # tgt: (batch_size, tgt_seq_len, d_model)
        # memory: (batch_size, src_seq_len, d_model)
        # tgt_mask: (batch_size, 1, 1, tgt_seq_len)
        # src_mask: (batch_size, 1, 1, src_seq_len)

        x = tgt
        output = self.self_attn(x, x, x, tgt_mask)    # (batch_size, tgt_seq_len, d_model)
        x = self.norm1(x + self.dropout1(output))    # add & norm

        output = self.cross_attn(x, src, src, src_mask)    # (batch_size, seq_len, d_model)
        x = self.norm2(x + self.dropout2(output))    # add & norm

        output = self.ffn(x)    # (batch_size, seq_len, d_model)
        x = self.norm3(x + self.dropout3(output))    # add & norm
        return x    # (batch_size, seq_len, d_model)

In [8]:
# 7. Transformer Decoder
class Decoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])

    def forward(self, x, memory, tgt_mask=None, memory_mask=None):
        # x: (batch_size, seq_len, d_model)
        for layer in self.layers:
            x = layer(x, memory, tgt_mask, memory_mask)    # (batch_size, seq_len, d_model)
        return x    # (batch_size, seq_len, d_model)

In [9]:
# 8. Transformer Model
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        self.dropout = nn.Dropout(dropout)

        self.encoder = Encoder(d_model, n_heads, d_ff, num_layers, dropout)
        self.decoder = Decoder(d_model, n_heads, d_ff, num_layers, dropout)

        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # src: (batch_size, src_seq_len)
        # tgt: (batch_size, tgt_seq_len)

        src = self.encoder_embedding(src) * math.sqrt(self.encoder_embedding.embedding_dim)    # (batch_size, src_seq_len, d_model)
        tgt = self.decoder_embedding(tgt) * math.sqrt(self.decoder_embedding.embedding_dim)    # (batch_size, tgt_seq_len, d_model)

        src = self.dropout(src)    # (batch_size, src_seq_len, d_model)
        tgt = self.dropout(tgt)    # (batch_size, tgt_seq_len, d_model)

        src = self.positional_encoding(src)    # (batch_size, src_seq_len, d_model)
        tgt = self.positional_encoding(tgt)    # (batch_size, tgt_seq_len, d_model)

        enc_output = self.encoder(src, src_mask)    # (batch_size, src_seq_len, d_model)
        dec_output = self.decoder(tgt, enc_output, tgt_mask)    # (batch_size, tgt_seq_len, d_model)

        output = self.fc_out(dec_output)    # (batch_size, tgt_seq_len, tgt_vocab_size)
        return output    # (batch_size, tgt_seq_len, tgt_vocab_size)

In [10]:
# 9. mask function
def create_padding_mask(src, tgt, pad_idx=0):
    # src: (batch_size, src_seq_len)
    # tgt: (batch_size, tgt_seq_len)

    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)    # (batch_size, 1, 1, src_seq_len)
    tgt_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3)    # (batch_size, 1, tgt_seq_len, 1)

    # look-ahead mask
    tgt_len = tgt.size(1)
    look_ahead_mask = torch.ones(tgt_len, tgt_len).tril().bool().unsqueeze(0).unsqueeze(0)    # (1, 1, tgt_len, tgt_len)
    tgt_mask = tgt_mask & look_ahead_mask.to(tgt.device)    # (batch_size, 1, tgt_len, tgt_len)

    return src_mask, tgt_mask    # (batch_size, 1, 1, src_seq_len), (batch_size, 1, tgt_seq_len, tgt_seq_len)
    

In [11]:
# 10. Example usage
if __name__ == "__main__":
    src_vocab_size = 10000
    tgt_vocab_size = 10000
    d_model = 512
    n_heads = 8
    d_ff = 2048
    num_layers = 6
    dropout = 0.1

    model = Transformer(src_vocab_size, tgt_vocab_size, d_model, n_heads, d_ff, num_layers, dropout)

    src = torch.randint(0, src_vocab_size, (32, 10))    # (batch_size, src_seq_len)
    tgt = torch.randint(0, tgt_vocab_size, (32, 12))    # (batch_size, tgt_seq_len)

    src_mask, tgt_mask = create_padding_mask(src, tgt)    # (batch_size, 1, 1, src_seq_len), (batch_size, 1, tgt_seq_len)

    output = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)    # (batch_size, tgt_seq_len, tgt_vocab_size)
    print(output.shape)    # should be (32, 12, 10000)

torch.Size([32, 12, 10000])
