In [None]:
import torch
import torch.nn as nn
import math

In [29]:
# d_model: model dimension
# vocab_size: vocabulary size
class InputEmbedding(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.embedding(vocab_size, d_model)

    def forward(self, x):
        return self.relu(self.embedding(x)) * math.sqrt(self.input_size) # paper says to do sqrt(d_model) scaling
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int,  seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pe', self._get_pe())
        
        # create a matrix of shapre seq_len x d_model
        pe = torch.zeros(seq_len, d_model)
        
        
        # create a vector of shape seq_len
        # this is the numirator of the positional encoding
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) 
        # calculated in logspace for numerical stability
        # this is the denominator of the positional encoding
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 
       
        # apply the sin to the even positions
        pe[:, 0::2] = torch.sin(position * div_term)
        # apply the cos to the odd positions
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0) # 1 x seq_len x d_model
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

class LayerNorm(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1)) # multiplicative parameter
        self.bias = nn.Parameter(torch.zeros(1)) # additive parameter
        
    def forward(self, x):   #calculate mean and std of x
        mean = x.mean(dim = -1, keepdim=True)
        std = x.std(dim = -1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias # formula for layer norm
    
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff) # weight1 and bias1
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model) # weight2 and bias2
        
    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) --> (batch_size, seq_len, d_model)
        # x: batch_size x seq_len x d_model
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x
            
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model is not divisible by num_heads"

        self.head_dim = d_model // num_heads # d_k = d_model / num_heads
        
        self.w_q = nn.Linear(d_model, d_model) # weight and bias for query
        self.w_k = nn.Linear(d_model, d_model) # weight and bias for key
        self.w_v = nn.Linear(d_model, d_model) # weight and bias for value
        
        self.W_o = nn.Linear(d_model, d_model) # weight and bias for output

        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        head_dim = query.shape[-1]
        # note: torch.matmul(x, y) = x @ y 
        
        # (Batch_size, num_heads, seq_len, head_dim) x (Batch_size, num_heads, head_dim, seq_len) -> (Batch_size, num_heads, seq_len, seq_len)
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_dim) # formula 
        
        
        if mask is not None:
            attention_scores.masked_fill(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim = -1) #(Batch_size, num_heads, seq_len, seq_len)
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
            
        # return the attention scores and the output (used to debug)
        return (attention_scores @ value), attention_scores
    
    # takes query, key, value and mask
    def forward(self, q, k, v, mask): 
        # for all 3, we apply the linear transformation
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, num_heads, head_dim) -> (batch_size, num_heads, seq_len, head_dim)
        query = query.view(query.shape[0], query.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
        
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.num_heads * self.head_dim)
        
        # (batch_size, seq_len, num_heads * head_dim) -> (batch_size, seq_len, d_model)
        return self.W_o(x)

# this is the residual connection block connecting part of the input to add and normalize block
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNorm()
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
       
    
# a single encoder block
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block    
        # two residual connections in encoder block
        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
        
    def forward(self, x, src_mask):
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connection[1](x, self.feed_forward_block)
        return x
        
# the encoder is a stack of encoder blocks      
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm()
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)
        
class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        # three residual connections in decoder block
        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
        
        # x is the input, memory is the output of the encoder, src_mask is the mask for the encoder, tgt_mask is the mask for the decoder
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connection[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connection[2](x, self.feed_forward_block)
        return x
    
    
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm()
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)
    
    
    
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.linear = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        x = self.linear(x)
        x = x.log_softmax(dim = -1)
        return x
    
    
    
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, tgt_embed: InputEmbedding, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.src_pos(self.src_embed(src))
        tgt = self.tgt_pos(self.tgt_embed(tgt))
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
        return self.projection_layer(decoder_output)