In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
torch.set_printoptions(precision=3, sci_mode=False)

In [30]:
from dataclasses import dataclass

@dataclass
class Config:
    d_model: int = 512  # Dimensionality of the model
    d_ff: int = 2048    # Dimensionality of the feedforward network
    num_attention_heads: int = 8  # Number of attention heads
    num_encoder_layers: int = 6   # Number of encoder layers
    num_decoder_layers: int = 6   # Number of decoder layers
    dropout: float = 0.1  # Dropout rate
    eps: float = 1e-6  # Epsilon value for layer normalization
    
    # For source sequence
    src_vocab_size: int = 32000  # Source vocabulary size
    src_seq_len: int = 512       # Maximum length of source sequence

    # For target sequence
    tgt_vocab_size: int = 32000  # Target vocabulary size
    tgt_seq_len: int = 512       # Maximum length of target sequence

# Example usage:
config = Config()
print(config)

Config(d_model=512, d_ff=2048, num_attention_heads=8, num_encoder_layers=6, num_decoder_layers=6, dropout=0.1, eps=1e-06, src_vocab_size=32000, src_seq_len=512, tgt_vocab_size=32000, tgt_seq_len=512)


In [29]:
config = Config(src_vocab_size=50000, src_seq_len=400, tgt_vocab_size=45000, tgt_seq_len=500)
print(config)

Config(d_model=512, d_ff=2048, num_attention_heads=8, num_encoder_layers=6, num_decoder_layers=6, dropout=0.1, eps=1e-06, src_vocab_size=50000, src_seq_len=400, tgt_vocab_size=45000, tgt_seq_len=500)


In [3]:
class InputEmbeddings(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.d_model = torch.tensor(config.d_model, dtype=torch.float32)
        self.vocab_size = config.vocab_size
        self.embeddings = nn.Embedding(config.vocab_size, config.d_model)
        
    def forward(self, x):
        return self.embeddings(x) * self.d_model.sqrt()

In [4]:
d_model = 6
seq_len = 10
even_i = torch.arange(0, d_model, 2)
even_i

tensor([0, 2, 4])

In [5]:
even_i / d_model

tensor([0.000, 0.333, 0.667])

In [6]:
even_denominator = torch.pow(10000, even_i/d_model)
even_denominator

tensor([  1.000,  21.544, 464.159])

In [7]:
odd_i = torch.arange(1, d_model, 2)
odd_i

tensor([1, 3, 5])

In [8]:
(odd_i - 1) / d_model

tensor([0.000, 0.333, 0.667])

In [9]:
odd_denominator = torch.pow(10000, (odd_i-1)/d_model)
odd_denominator

tensor([  1.000,  21.544, 464.159])

In [10]:
pos = torch.arange(seq_len, dtype=torch.float32).unsqueeze(-1)
pos

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.]])

In [11]:
evenPE = torch.sin(pos / even_denominator)
oddPE = torch.cos(pos / odd_denominator)

In [12]:
evenPE

tensor([[ 0.000,  0.000,  0.000],
        [ 0.841,  0.046,  0.002],
        [ 0.909,  0.093,  0.004],
        [ 0.141,  0.139,  0.006],
        [-0.757,  0.185,  0.009],
        [-0.959,  0.230,  0.011],
        [-0.279,  0.275,  0.013],
        [ 0.657,  0.319,  0.015],
        [ 0.989,  0.363,  0.017],
        [ 0.412,  0.406,  0.019]])

In [13]:
oddPE

tensor([[ 1.000,  1.000,  1.000],
        [ 0.540,  0.999,  1.000],
        [-0.416,  0.996,  1.000],
        [-0.990,  0.990,  1.000],
        [-0.654,  0.983,  1.000],
        [ 0.284,  0.973,  1.000],
        [ 0.960,  0.961,  1.000],
        [ 0.754,  0.948,  1.000],
        [-0.146,  0.932,  1.000],
        [-0.911,  0.914,  1.000]])

In [14]:
stackedPE = torch.stack([evenPE, oddPE], dim=-1)
pe = stackedPE.reshape(seq_len, d_model) # (1, seq_len, d_model)
pe.shape

torch.Size([10, 6])

In [15]:
pe

tensor([[ 0.000,  1.000,  0.000,  1.000,  0.000,  1.000],
        [ 0.841,  0.540,  0.046,  0.999,  0.002,  1.000],
        [ 0.909, -0.416,  0.093,  0.996,  0.004,  1.000],
        [ 0.141, -0.990,  0.139,  0.990,  0.006,  1.000],
        [-0.757, -0.654,  0.185,  0.983,  0.009,  1.000],
        [-0.959,  0.284,  0.230,  0.973,  0.011,  1.000],
        [-0.279,  0.960,  0.275,  0.961,  0.013,  1.000],
        [ 0.657,  0.754,  0.319,  0.948,  0.015,  1.000],
        [ 0.989, -0.146,  0.363,  0.932,  0.017,  1.000],
        [ 0.412, -0.911,  0.406,  0.914,  0.019,  1.000]])

In [16]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.d_model = config.d_model
        self.seq_len = config.seq_len
        self.dropout = nn.Dropout(config.dropout)
        
        position = torch.arange(0, self.seq_len, dtype=torch.float32).unsqueeze(-1) # (seq_len, 1)
        dinominator = torch.pow(10000, torch.arange(0, self.d_model, 2, dtype=torch.float32) / self.d_model) # (d_model/2)
        oddPE = torch.cos(position / dinominator) # (seq_len, d_model/2)
        evenPE = torch.sin(position / dinominator) # (seq_len, d_model/2)
        stackedPE = torch.stack([evenPE, oddPE], dim=-1)
        pe = stackedPE.reshape(self.seq_len, self.d_model).unsqueeze(0) # (1, seq_len, d_model)
        pe.requires_grad = False
        
        # Register buffer
        self.register_buffer("pe", pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [17]:
class FeedForwardBlock(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_ff)
        self.linear2 = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        x = self.linear2(self.dropout(F.relu(self.linear1(x))))
        return x

In [18]:
class AttentionHead(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        assert config.d_model % config.num_attention_heads == 0,\
            "d_model must be divisible by num_attention_heads"
        
        self.d_model = config.d_model
        self.num_attention_heads = config.num_attention_heads
        self.d_k = self.d_model // self.num_attention_heads
        self.d_v = self.d_k
        self.dropout = nn.Dropout(config.dropout)
        
        self.W_q = nn.Linear(self.d_model, self.d_k)
        self.W_k = nn.Linear(self.d_model, self.d_k)
        self.W_v = nn.Linear(self.d_model, self.d_v)
        self.W_o = nn.Linear(self.num_attention_heads*self.d_v, self.d_model)
    
    def _scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
                                      value: torch.Tensor, mask=None) -> torch.Tensor:
        attn_scores = torch.bmm(query, key.transpose(-2, -1)) / self.d_k.sqrt()
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        weights = self.dropout(weights)
        return torch.bmm(weights, value), attn_scores
        
    def forward(self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor, 
                mask: torch.Tensor = None) -> torch.Tensor:
        # hidden_state: (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_k)
        attention_scores = self._scaled_dot_product_attention(
            self.W_q(x_q), self.W_k(x_k), self.W_v(x_v), mask
        )
        print(f'-------------> attention_scores: {attention_scores.shape}')
        return attention_scores

In [19]:
class MultiHeadAttentionBlock(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        assert config.d_model % config.num_attention_heads == 0,\
            "d_model must be divisible by num_attention_heads"
        
        self.heads = nn.ModuleList([AttentionHead(config) 
                                    for _ in range(config.num_attention_heads)])
        self.dropout = nn.Dropout(config.dropout)
        self.W_o = nn.Linear(self.num_attention_heads*self.d_v, self.d_model)
        
    def forward(self, x_q: torch.Tensor, x_k: torch.Tensor, x_v: torch.Tensor
                , mask: torch.Tensor = None) -> torch.Tensor:
        attention_scores = torch.cat([head(x_q, x_k, x_v) for head in self.heads], dim=-1)
        return self.W_o(self.dropout(attention_scores))

In [20]:
class Risidual(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.dropout = nn.Dropout(config.dropout)
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.eps)
        
    def forward(selg, x, sublayer):
        return x + self.dropout(sublayer(self.layer_norm(x)))

In [21]:
class EncoderBlock(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.multi_head_attention_block = MultiHeadAttentionBlock(config)
        self.feed_forward_block = FeedForwardBlock(config)
        self.risidual1 = Risidual(config)
        self.risidual2 = Risidual(config)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        x = self.risidual1(x, lambda x: self.multi_head_attention_block(x, x, x, src_mask))
        x = self.risidual2(x, self.feed_forward_block)
        return x

In [22]:
class Encoder(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.encoder_blocks = nn.ModuleList([EncoderBlock(config)
                                             for _ in range(config.num_encoder_layers)])
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.eps)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x, mask)
        return self.layer_norm(x)

In [23]:
class DecoderBlock(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.masked_multi_head_attention_block = MultiHeadAttentionBlock(config)
        self.multi_head_attention_block = MultiHeadAttentionBlock(config)
        self.feed_forward_block = FeedForwardBlock(config)
        self.risidual1 = Risidual(config)
        self.risidual2 = Risidual(config)
        self.risidual3 = Risidual(config)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor,
                src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
        x = self.risidual1(x, lambda x: self.masked_multi_head_attention_block(x, x, x, tgt_mask))
        x = self.risidual2(x, lambda x: self.multi_head_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.risidual2(x, self.feed_forward_block)
        return x

In [24]:
class Decoder(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.decoder_blocks = nn.ModuleList([DecoderBlock(config)
                                             for _ in range(config.num_decoder_layers)])
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.eps)
        
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor,
                src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
        for decoder_block in self.decoder_blocks:
            x = decoder_block(x, encoder_output, src_mask, tgt_mask)
        return self.layer_norm(x)

In [31]:
class Generator(nn.Module):
    
    def __init__(self, config: Config):
        super().__init__()
        self.linear = nn.Linear(config.d_model, config.vocab_size)
    def forward(self, x):
        return F.log_softmax(self.linear(x), dim=-1)

### Transformer

In [26]:
class Transformer:

    def __init__(self,
                 encoder: Encoder,
                 decoder: Decoder,
                 src_input_embed: InputEmbeddings,
                 tgt_input_embed: InputEmbeddings,
                 src_pos_embed: PositionalEncoding,
                 tgt_pos_embed: PositionalEncoding,
                 generator: Generator
                ) -> None:

        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        self.src_input_embed = src_input_embed
        self.tgt_input_embed = tgt_input_embed
        self.src_pos_embed = src_pos_embed
        self.tgt_pos_embed = src_pos_embed

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        src_input_embeddings = self.src_input_embed(src)
        src_pos_embeddings = self.src_pos_embed(src)
        encoder_input = src_input_embeddings + src_pos_embeddings
        encoder_output = self.encoder(encoder_input, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
        tgt_input_embeddings = self.tgt_input_embed(tgt)
        tgt_pos_embeddings = self.tgt_pos_embed(tgt)
        decoder_input = tgt_input_embeddings + tgt_pos_embeddings
        encoder_output = self.encoder(decoder_input, encoder_output, src_mask, tgt_mask)

    def generate(self, x: torch.Tensor) -> torch.Tensor:
        return self.generator(x)

In [34]:
def make_transformer(config: Config) -> Transformer:
    # create the input embedding layer for src and target!
    src_input_embeddings = InputEmbeddings(config)
    tgt_input_embeddings = InputEmbeddings(config)

    # create the positional embedding layer for src and target!
    src_pos_embeddings = PositionalEncoding(config)
    tgt_pos_embeddings = PositionalEncoding(config)

    # create the encoder & decoder blocks!
    encoder = Encoder(config)
    decoder = Decoder(config)

    # create the projection layer
    generator = Generator(config)

    # initialize a transformer model!
    transformer = Transformer(encoder=encoder,
                             decoder=decoder,
                             generator=generator,
                             src_input_embed=src_input_embeddings,
                             src_pos_embed=src_pos_embeddings,
                             tgt_input_embed=tgt_input_embeddings,
                             tgt_pos_embed=tgt_pos_embeddings)
    
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.kaiming_uniform_(p)

    return transformer

## Train the Transformer model!

In [36]:
import datasets
from datasets import load_dataset, load_dataset_builder

In [39]:
load_dataset_builder('opus100', name='ar-en')

<datasets_modules.datasets.opus100.256f3196b69901fb0c79810ef468e2c4ed84fbd563719920b1ff1fdc750f7704.opus100.Opus100 at 0x16c52ae10>

In [40]:
raw_dataset = load_dataset('opus100', name='ar-en')

Downloading data:   0%|          | 0.00/55.3M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [41]:
raw_dataset

DatasetDict({
    test: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
    train: Dataset({
        features: ['translation'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
})