# English to French Translation with a Transformer from Scratch

## Building Some Modules

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class PositionalEncoder(nn.Module):
    
    def __init__(self, n_max, d):
        super().__init__()
        self.n_max = n_max
        self.d = d
        P = torch.zeros(n_max, d)
        i = torch.arange(0,n_max, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
        P[:, 0::2] = torch.sin(pos * div)
        P[:, 1::2] = torch.cos(pos * div)
        P = P.unsqueeze(0)
        
        self.register_buffer('P', P, persistent=False) #Make it part of module's state
        
        def forward(self, X):
            X = X + self.P[:, X.size(1)]
            return X

In [None]:
class AddNorm(nn.Module):
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(norm_shape)
        
    def forward(self, X, Y):
        return self.layer_norm(self.dropout(Y) + X)

In [None]:
class PositionWiseFFN(nn.Module):
    def __init__(self, hidden_units, output_shape):
        super().__init__()
        self.dense1 = nn.LazyLinear(hidden_units) #LazyLinear means you only specify output size
        self.relu = nn.ReLU()
        self.dense2 = nn.LazyLinaer(output_shape)
    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [4]:
# Full Multi-Head Attention Module
class MultiheadAttention(nn.Module):
    
    def __init__(self, hidden_dim, num_heads, dropout, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.dropout = dropout
        self.W_q = nn.LazyLinear(hidden_dim, bias=bias)
        self.W_k = nn.LazyLinear(hidden_dim, bias=bias)
        self.W_v = nn.LazyLinear(hidden_dim, bias=bias)
        self.W_o = nn.LazyLinear(hidden_dim, bias=bias)
        
    def transpose_qkv(self, X):
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        X = X.permute(0, 2, 1, 3)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)
    
    def forward(self, Q_in, K_in, V_in, valid_lens):
        Q = self.transpose_qkv(self.W_q(Q_in))
        K = self.transpose_qkv(self.W_k(K_in))
        V = self.transpose_qkv(self.W_v(V_in))
        
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeat=self.num_heads, dim=0)
                 
        # Sclaled dot-product:
        d = Q.shape[-1]
        attn_logits = torch.bmm(Q, K.transpose(1,2)) / math.sqrt(d)
        
        # Softmax without mask:
        if valid_lens is None:
            self.attention_weights = nn.functional.softmax(attn_logits, dim=-1)
        # Softmax with mask:
        else:
            shape = attn_logits.shape
            if valid_lens.dim() == 1:
                valid_lens = torch.repeat_interleave(valid_lens, shape[1])
            else:
                valid_lens = valid_lens.reshape(-1)
            # Creating mask:
            attn_logits = attn_logits.reshape(-1, shape[-1])    
            maxlen = attn_logits.size(1)
            mask = torch.arange((maxlen), dtype=torch.float32, device=attn_logits.device)[None,:] < valid_lens[:,None]
            # Applying mask:
            attn_logits[~mask] = 1e-6
            self.attention_weights = nn.functional.softmax(attn_logits.reshape(shape), dim=-1)
        
        # Multiply softmax output with value matrix:
        output = torch.bmm(self.dropout(self.attention_weights), V)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

## Building the Encoder

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, hidden_dim, FFN_hidden_units, num_heads, dropout, use_bias=False):
        super().__init__()
        self.MHA = MultiHeadAttention(hidden_dim, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(hidden_dim, dropout)
        self.FFN = PositionWiseFFN(FFN_hidden_units, hidden_dim)
        self.addnorm2 = AddNorm(hidden_dim, dropout)
    def forward(self, X, valid_lens):
        X = self.addnorm1(X, self.MHA(X,X,X,valid_lens))
        return self.addnorm2(X, self.FFN(X))

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim, FFN_hidden_units, num_heads, num_blocks, dropout, use_bias=False):
        super().__init()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_encoder = PositionalEncoder(1000, hidden_dim) #max 1000 tokens
        self.blocks = nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), EncoderBlock(hidden_dim, FFN_hidden_units, num_heads, dropout, use_bias))
            
    def forward(self, X, valid_lens):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.hidden_dim))
        return self.blocks(X)

## Building the Decoder