In [33]:
import torch
import torch.nn as nn
import math
import random
import time

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
def masked_softmax(X, valid_lens): # X.shape=[batch_size, seq_len]
    def seq_mask(X, valid_len, value=-1e6):
        max_len = X.size(1)
        mask = torch.arange((max_len), dtype=torch.float32, 
                            device=X.device)[None, :]<valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
        # valid_lens is 1d so same lenght for a batch 
        # eg ([2,3] means batch 0 has valid_len 2 and batch 1 has valid_len 3)
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
        # different length for each query
            valid_lens = valid_lens.reshape(-1)
        X = seq_mask(X.reshape(-1, shape[-1]), valid_lens)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [36]:
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[2, 1],[4, 3]])))


tensor([[[0.3662, 0.6338, 0.0000, 0.0000],
         [0.3621, 0.6379, 0.0000, 0.0000]],

        [[0.3764, 0.1834, 0.4402, 0.0000],
         [0.2390, 0.4887, 0.2723, 0.0000]]])
tensor([[[0.4923, 0.5077, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]],

        [[0.2256, 0.2762, 0.2646, 0.2336],
         [0.2460, 0.2275, 0.5265, 0.0000]]])


In [37]:
class DotProductAttention(nn.Module):
    def __init__(self, dropout=0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, valid_lens=None):
        d = Q.shape[-1]
        scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), V)


In [38]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_hidden, num_heads, dropout=0, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hidden, bias=bias)
        self.W_k = nn.LazyLinear(num_hidden, bias=bias)
        self.W_v = nn.LazyLinear(num_hidden, bias=bias)
        self.W_o = nn.LazyLinear(num_hidden, bias=bias)
    
    # for parallel computation
    def transpose_qkv(self, X):
        # Shape of input X [batch_size, num_queries/kv, num_hidden]
        # Shape of output X [batch_size, num_queries/kv, num_heads, num_hidden/num_heads]
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)

        # Shape of output X [batch_size, num_heads, num_queries/kv, num_hidden/num_heads]
        X = X.permute(0, 2, 1, 3)
        # Shape of output X [batch_size*num_heads, num_queries/kv, num_hidden/num_heads]
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        # reverese the transpose_qkv
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)

    def forward(self, Q, K, V, valid_lens):
        # Shape of Q, K, V [batch_size, num_queries/kv, num_hidden]
        # After transposing shape of Q, K, V [batch_size * num_heads, num_queries/kv, num_hiddens / num_heads]
        Q = self.transpose_qkv(self.W_q(Q))
        K = self.transpose_qkv(self.W_k(K))
        V = self.transpose_qkv(self.W_v(V))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, self.num_heads, dim=0)

        # Shape of output [batch_size * num_heads, num_queries, num_hiddens / num_heads]
        output = self.attention(Q, K, V, valid_lens)
        # Shape of output_concat [batch_size, num_queries, num_hidden]
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)


In [39]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
o = attention(X, Y, Y, valid_lens)
o.shape #should be [batch_size, num_queries, num_hiddens]


torch.Size([2, 4, 100])

In [40]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_hidden, max_len=1000):
        super().__init__()
        self.P = torch.zeros((1, max_len, num_hidden))
        pos = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)
        indices = torch.arange(0, num_hidden, 2, dtype=torch.float32)
        term = torch.pow(10000, indices / num_hidden)
        X = pos / term
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return X

In [41]:
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_hidden, ffn_num_outputs):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.LazyLinear(ffn_num_hidden),
            nn.ReLU(),
            nn.LazyLinear(ffn_num_outputs),
        )
    def forward(self, X):
        return self.ffn(X)

In [42]:
class LayerNorm(nn.Module):
    def __init__(self, features,epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))

    def forward(self, X):
        mean = X.mean(dim=-1, keepdim=True)
        var = X.var(dim=-1, keepdim=True)
        x_norm = (X - mean) / torch.sqrt(var + self.epsilon)
        return self.gamma * x_norm + self.beta

In [43]:
# We use Post Normalization (residual connection -> layer norm) as per the original paper
# Pre Normalization is more stable tho and it's the default nowadays
class AddNorm(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.features = features
        self.ln = LayerNorm(features)
    def forward(self, X, Y):
        return self.ln(X + Y)

In [44]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, num_hidden, ffn_num_hidden, num_heads, dropout, use_bias=False):
        super().__init__()
        self.attention = MultiHeadAttention(num_hidden, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(num_hidden)
        self.ffn = PositionWiseFFN(ffn_num_hidden, num_hidden)
        self.addnorm2 = AddNorm(num_hidden)
    
    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

In [45]:
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape == X.shape


True

In [46]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embedding = nn.Parameter(torch.randn(vocab_size, embed_size))
    def forward(self, X):
        return self.embedding[X]

In [47]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_hidden, num_heads, ffn_num_hidden,
                num_layers, dropout=0, use_bias=False):
        super().__init__()
        self.num_hidden = num_hidden
        self.embedding = Embedding(vocab_size, num_hidden)
        self.pos_encoding = PositionalEncoding(num_hidden)
        self.layers = nn.Sequential()
        for i in range(num_layers):
            self.layers.add_module(f'layer{i}',TransformerEncoderBlock(
                num_hidden, ffn_num_hidden, num_heads, dropout, use_bias
            ))
    
    def forward(self, X, valid_lens):
        X = self.pos_encoding(self.embedding(X)) * math.sqrt(self.num_hidden)
        for layer in self.layers:
            X = layer(X, valid_lens)
        return X
        

In [48]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, num_hidden, num_heads, ffn_num_hidden, i, dropout=0):
        super().__init__()
        self.i = i
        self.attention1 = MultiHeadAttention(num_hidden, num_heads, dropout)
        self.addnorm1 = AddNorm(num_hidden)
        self.attention2 = MultiHeadAttention(num_hidden, num_heads, dropout)
        self.addnorm2 = AddNorm(num_hidden)
        self.ffn = PositionWiseFFN(ffn_num_hidden, num_hidden)
        self.addnorm3 = AddNorm(num_hidden)

    def forward(self, X, state):
        enc_output, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # Shape of dec_valid_lens [batch_size, num_steps]
            dec_valid_lens = torch.arange(1, num_steps + 1, 
                    dtype=torch.float32, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)  
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_output, enc_output, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

In [49]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, num_hidden, num_heads, ffn_num_hidden, dropout, num_layers):
        super().__init__()
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        self.embedding = Embedding(vocab_size, num_hidden)
        self.pos_encoding = PositionalEncoding(num_hidden)
        self.layers = nn.Sequential()
        for layer in range(num_layers):
            self.layers.add_module(f"layer{layer}", TransformerDecoderBlock(num_hidden, num_heads,
                                    ffn_num_hidden, layer, dropout))
        self.dense = nn.LazyLinear(vocab_size)

    def init_state(self, enc_output, enc_valid_lens):
        return [enc_output, enc_valid_lens, [None]*self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hidden))
        for layer in self.layers:
            X, state = layer(X, state)
        return self.dense(X), state        

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, num_hidden, num_heads, ffn_num_hidden, 
                num_layers, dropout, use_bias=False):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, num_hidden, num_heads,
                                        ffn_num_hidden, num_layers, use_bias)
        self.decoder = TransformerDecoder(vocab_size, num_hidden, num_heads,
                                        ffn_num_hidden, dropout, num_layers)
    
    def forward(self, enc_X, dec_X, src_valid_lens):
        enc_output = self.encoder(enc_X, src_valid_lens)
        state = self.decoder.init_state(enc_output, src_valid_lens)
        output, _ = self.decoder(dec_X, state)
        return output

In [51]:
# We use a simple character level tokenization
class Tokenizer():
    def __init__(self, corpus, use_path:bool):
        self.corpus = corpus
        self.use_path = use_path

    def build_vocab(self, add_special_tokens=True):
        ds_text = ""
        if self.use_path:
            with open(f"{self.corpus}", "r") as f:
                ds_text = f.read()
        else:
            ds_text = self.corpus
        vocab = sorted(set(ds_text))
        char_to_idx = {c:i for i,c in enumerate(vocab)}
        if add_special_tokens:
            char_to_idx["<bos>"] = len(vocab) + 2
            char_to_idx["<eos>"] = len(vocab) + 3
        char_to_idx["<unk>"] = len(vocab)
        char_to_idx["<pad>"] = len(vocab) + 1
        idx_to_char = {i:c for c,i in char_to_idx.items()}
        return char_to_idx, idx_to_char

    def tokenize(self, text):
        encoded_text = []
        for c in text:
            if c in char_to_idx:
                encoded_text.append(char_to_idx[c])
            else:
                encoded_text.append(char_to_idx["<unk>"])
        return encoded_text

In [None]:
class TransformerDataLoader:
    def __init__(self, src_texts, tgt_texts, char_to_idx, max_len=20, batch_size=4, shuffle=True):
        self.char_to_idx = char_to_idx
        self.max_len = max_len
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # Pre-tokenize pairs during init to save time
        self.data = []
        for src, tgt in zip(src_texts, tgt_texts):
            # Tokenize Source
            src_tokens = [self.char_to_idx["<bos>"]]
            for char in src:
                src_tokens.append(self.char_to_idx.get(char, self.char_to_idx["<unk>"]))
            src_tokens.append(self.char_to_idx["<eos>"])
            
            # Tokenize Target
            tgt_tokens = [self.char_to_idx["<bos>"]]
            for char in tgt:
                tgt_tokens.append(self.char_to_idx.get(char, self.char_to_idx["<unk>"]))
            tgt_tokens.append(self.char_to_idx["<eos>"])
            
            self.data.append((src_tokens, tgt_tokens))

    def __len__(self):
        return (len(self.data) + self.batch_size - 1) // self.batch_size

    def __iter__(self):
        if self.shuffle:
            random.shuffle(self.data)
            
        for i in range(0, len(self.data), self.batch_size):
            batch = self.data[i : i + self.batch_size]
            
            # Prepare batch lists
            encoder_inputs = []
            decoder_inputs = []
            decoder_targets = []
            src_valid_lens = []
            
            for src_seq, tgt_seq in batch:
                # Process encoder input
                # Truncate if needed
                if len(src_seq) > self.max_len:
                    src_seq = src_seq[:self.max_len]
                
                src_valid_lens.append(len(src_seq))
                
                # Pad Source
                pad_len = self.max_len - len(src_seq)
                if pad_len > 0:
                    src_padded = src_seq + [self.char_to_idx["<pad>"]] * pad_len
                else:
                    src_padded = src_seq
                encoder_inputs.append(src_padded)
                
                # Process decoder input/target
                # We need offset by 1 for Teacher Forcing
                # Input:  <bos> A B C ...
                # Target: A B C <eos> ...
                
                if len(tgt_seq) > self.max_len + 1:
                    tgt_seq = tgt_seq[:self.max_len + 1]
                    
                dec_in = tgt_seq[:-1]
                dec_out = tgt_seq[1:]
                
                # Pad Target
                pad_len = self.max_len - len(dec_in)
                if pad_len > 0:
                    padding = [self.char_to_idx["<pad>"]] * pad_len
                    dec_in_padded = dec_in + padding
                    dec_out_padded = dec_out + padding
                else:
                    dec_in_padded = dec_in[:self.max_len]
                    dec_out_padded = dec_out[:self.max_len]
                    
                decoder_inputs.append(dec_in_padded)
                decoder_targets.append(dec_out_padded)
            
            # Yield batch of tensors
            yield (
                torch.tensor(encoder_inputs),
                torch.tensor(decoder_inputs), 
                torch.tensor(decoder_targets),
                torch.tensor(src_valid_lens),
            )

In [None]:
# for gpt models (decoder only)
class DecoderDataLoader:
    def __init__(self, texts, char_to_idx, max_len=50, batch_size=4, shuffle=True):
        self.char_to_idx = char_to_idx
        self.max_len = max_len
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.data = []

        for text in texts:
            tokens = [self.char_to_idx["<bos>"]]
            for char in text:
                tokens.append(self.char_to_idx.get(char, self.char_to_idx["<unk>"]))
            tokens.append(self.char_to_idx["<eos>"])
            if len(tokens) > self.max_len + 1:
                tokens = tokens[:self.max_len +1]
            self.data.append(tokens)

    def __len__(self):
        return (len(self.data) + self.batch_size- 1) // self.batch_size
    def __iter__(self):
        if self.shuffle:
            random.shuffle(self.data)
        for i in range(0, len(self.data), self.batch_size):
            batch_tokens = self.data[i:i+self.batch_size]
            X_batch = []
            Y_batch = []
            valid_lens = []
            for seq in batch_tokens:
                # Create X input and Y target (offset by one)
                x_seq = seq[:-1]
                y_seq = seq[1:]
                valid_lens.append(len(x_seq))
                pad_len = self.max_len - len(x_seq)
                if pad_len > 0:
                    x_padded = x_seq + [self.char_to_idx["<pad>"]] * pad_len
                    y_padded = y_seq + [self.char_to_idx["<pad>"]] * pad_len
                else:
                    x_padded = x_seq[:self.max_len]
                    y_padded = y_seq[:self.max_len]
            X_batch.append(x_padded)
            Y_batch.append(y_padded)
            X = torch.tensor(X_batch)
            Y = torch.tensor(Y_batch)
            valid_lens = torch.tensor(valid_lens)
            yield X, Y, valid_lens

In [54]:
def load_data(path, num_examples=None):
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    if num_examples:
        lines = lines[:num_examples]

    src_texts, tgt_texts = [], []
    for line in lines:
        parts = line.split("\t")
        if len(parts) >= 2:
            src_texts.append(parts[0].strip())
            tgt_texts.append(parts[1].strip())
    return src_texts, tgt_texts

In [55]:
src_texts, tgt_texts = load_data("fra.txt")
print(f"Loaded {len(src_texts)} pairs")
print(f"Example: {src_texts[610]} -> {tgt_texts[610]}")

Loaded 240521 pairs
Example: Good job! -> Bien joué !


In [56]:
full_text = src_texts + tgt_texts
vocab_text = " ".join(full_text)
tokenizer = Tokenizer(vocab_text, False)
char_to_idx, idx_to_char = tokenizer.build_vocab()
print(f"Vocabulary built")
print(f"Vocabulary size: {len(char_to_idx)}")

Vocabulary built
Vocabulary size: 122


In [57]:
def train_seq2seq(model, data_loader, lr, num_epochs, vocab, device):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss(reduction="none") 
    #we'll mask pad tokens so they don't influence the loss

    model.train()
    for epoch in range(num_epochs):
        start_time = time.time()
        total_loss = 0
        num_batches = 0

        for batch in data_loader:
            enc_X, dec_X, dec_Y, src_valid_lens, tgt_valid_lens = [x.to(device) for x in batch]
            optimizer.zero_grad()
            Y_hat = model(enc_X, dec_X, src_valid_lens, tgt_valid_lens)
            l = loss_fn(Y_hat.reshape(-1, len(vocab)), dec_Y.reshape(-1))

            pad_idx = char_to_idx["<pad>"]
            mask = (dec_Y.reshape(-1)) != float(pad_idx)
            l = (l*mask).sum() / mask.sum()
            l.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += l.item()
            num_batches += 1
        print(f"Epoch {epoch + 1}, Loss: {total_loss/num_batches:.4f}, Time: {time.time()-start_time:.2f}")


In [None]:
train_loader = TransformerDataLoader(src_texts, tgt_texts, char_to_idx, max_len=30, batch_size=256)

model = Transformer(
    vocab_size=len(char_to_idx),
    num_hidden=128,
    num_heads=4,
    ffn_num_hidden=256,
    num_layers=2,
    dropout=0.1
)
# Training
train_seq2seq(model, train_loader, lr=0.001, num_epochs=20, vocab=char_to_idx, device=device)

Epoch 1, Loss: 1.6710, Time: 40.41
Epoch 2, Loss: 1.2204, Time: 40.41
Epoch 3, Loss: 1.0374, Time: 40.92
Epoch 4, Loss: 0.9345, Time: 41.45
Epoch 5, Loss: 0.8688, Time: 41.30
Epoch 6, Loss: 0.8272, Time: 41.43
Epoch 7, Loss: 0.7917, Time: 41.42
Epoch 8, Loss: 0.7663, Time: 41.48
Epoch 9, Loss: 0.7459, Time: 41.37
Epoch 10, Loss: 0.7286, Time: 41.42
Epoch 11, Loss: 0.7150, Time: 40.97
Epoch 12, Loss: 0.7012, Time: 41.93
Epoch 13, Loss: 0.6898, Time: 41.59
Epoch 14, Loss: 0.6790, Time: 41.73
Epoch 15, Loss: 0.6709, Time: 41.87
Epoch 16, Loss: 0.6626, Time: 41.52
Epoch 17, Loss: 0.6559, Time: 41.21
Epoch 18, Loss: 0.6490, Time: 41.30
Epoch 19, Loss: 0.6431, Time: 41.14
Epoch 20, Loss: 0.6368, Time: 41.34


In [None]:
def predict_seq2seq(model, src_sentence, char_to_idx, idx_to_char, max_len, device):
    model.eval()

    src_tokens = [char_to_idx["<bos>"]]
    for char in src_sentence:
        src_tokens.append(char_to_idx.get(char, char_to_idx["<unk>"]))
    src_tokens.append(char_to_idx["<eos>"])
    
    enc_X = torch.tensor(src_tokens, dtype=torch.long, device=device).unsqueeze(0)
    src_valid_lens = torch.tensor([len(src_tokens)], device=device)

    enc_output = model.encoder(enc_X, src_valid_lens)
    state = model.decoder.init_state(enc_output, src_valid_lens)
    dec_input = torch.tensor([char_to_idx["<bos>"]], dtype=torch.long, device=device).unsqueeze(0)
    output_seq = []
    
    for _ in range(max_len):
        Y, state = model.decoder(dec_input, state)
        
        prediction = Y.argmax(dim=2)[:, -1].item() #we take the last "predicted" word
        
        # Stop if <eos> is generated
        if prediction == char_to_idx["<eos>"]:
            break
            
        output_seq.append(prediction)
        # Prepare input for next step
        dec_input = torch.tensor([[prediction]], dtype=torch.long, device=device)
        
    decoded_sentence = "".join([idx_to_char[idx] for idx in output_seq])
    return decoded_sentence

In [89]:
# Example
test_sentence = "He is sleeping" 
translation = predict_seq2seq(model, test_sentence, char_to_idx, idx_to_char, 20, device)
print(f"Input: {test_sentence}")
print(f"Prediction: {translation}")

Input: He is sleeping
Prediction: Il a en dorrrrrrrrai
