<a href="https://colab.research.google.com/github/gnoejh/ict1022/blob/main/Transformer/12_transformer_code_russian_real.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# Transformer

### 1. Using nn.Transformer for the Full Model

In [26]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)
        for pos in range(max_len):
            for i in range(0, embed_size, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / embed_size)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * i) / embed_size)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class TransformerConfig:
    """Configuration class to store all transformer related parameters"""
    def __init__(self, **kwargs):
        self.src_vocab_size = kwargs.get('src_vocab_size', 15)
        self.tgt_vocab_size = kwargs.get('tgt_vocab_size', 15)
        self.embed_size = kwargs.get('embed_size', 16)
        self.num_heads = kwargs.get('num_heads', 4)
        self.num_layers = kwargs.get('num_layers', 3)
        self.forward_expansion = kwargs.get('forward_expansion', 4)
        self.dropout = kwargs.get('dropout', 0.1)
        self.max_length = kwargs.get('max_length', 100)
        self.num_epochs = kwargs.get('num_epochs', 100)
        self.learning_rate = kwargs.get('learning_rate', 0.001)
        self.batch_size = kwargs.get('batch_size', 2)
        self.pad_token = kwargs.get('pad_token', '<pad>')
        self.sos_token = kwargs.get('sos_token', '<sos>')
        self.eos_token = kwargs.get('eos_token', '<eos>')
        self.pad_idx = kwargs.get('pad_idx', 0)
        self.sos_idx = kwargs.get('sos_idx', 1)
        self.eos_idx = kwargs.get('eos_idx', 2)

class VocabularyManager:
    """Class to manage vocabulary creation and token mapping"""
    def __init__(self, config):
        self.config = config
        self.src_vocab = {}
        self.tgt_vocab = {}
        self.idx_to_src = {}
        self.idx_to_tgt = {}

    def set_predefined_vocab(self, src_vocab, tgt_vocab):
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.idx_to_src = {v: k for k, v in src_vocab.items()}
        self.idx_to_tgt = {v: k for k, v in tgt_vocab.items()}

    def tokenize(self, sentence, vocab):
        tokens = sentence.split()
        return [vocab.get(token, vocab.get(self.config.pad_token)) for token in tokens]

    def detokenize(self, indices, idx_to_token):
        return ' '.join([idx_to_token.get(idx, '<unk>') for idx in indices])

    def prepare_batch(self, src_sentences, tgt_sentences=None):
        src_batch = []
        for sent in src_sentences:
            tokens = self.tokenize(sent, self.src_vocab) + [self.src_vocab[self.config.eos_token]]
            while len(tokens) < 5:
                tokens.append(self.src_vocab[self.config.pad_token])
            src_batch.append(tokens)

        tgt_batch = None
        if tgt_sentences:
            tgt_batch = []
            for sent in tgt_sentences:
                tokens = [self.tgt_vocab[self.config.sos_token]] + self.tokenize(sent, self.tgt_vocab) + [self.tgt_vocab[self.config.eos_token]]
                tgt_batch.append(tokens)

        return torch.tensor(src_batch), torch.tensor(tgt_batch) if tgt_batch else None

class TransformerModel(nn.Module):
    def __init__(self, config):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(config.src_vocab_size, config.embed_size)
        self.tgt_embedding = nn.Embedding(config.tgt_vocab_size, config.embed_size)
        self.positional_encoding = PositionalEncoding(config.embed_size, max_len=config.max_length)
        self.transformer = nn.Transformer(
            d_model=config.embed_size, 
            nhead=config.num_heads, 
            num_encoder_layers=config.num_layers, 
            num_decoder_layers=config.num_layers, 
            dim_feedforward=config.forward_expansion * config.embed_size, 
            dropout=config.dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(config.embed_size, config.tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src_embedded = self.positional_encoding(self.src_embedding(src))
        tgt_embedded = self.positional_encoding(self.tgt_embedding(tgt))
        transformer_output = self.transformer(
            src_embedded, tgt_embedded, src_key_padding_mask=src_mask, tgt_key_padding_mask=tgt_mask
        )
        out = self.fc_out(transformer_output)
        return out

class TransformerTrainer:
    """Class to handle training the transformer model"""
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.criterion = nn.CrossEntropyLoss(ignore_index=config.pad_idx)
        self.optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    def train_epoch(self, src, tgt):
        self.model.train()
        self.optimizer.zero_grad()
        output = self.model(src, tgt[:, :-1])
        output = output.reshape(-1, output.shape[2])
        tgt_output = tgt[:, 1:].reshape(-1)
        loss = self.criterion(output, tgt_output)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def evaluate(self, src, tgt, vocab_manager):
        with torch.no_grad():
            self.model.eval()
            eval_output = self.model(src, tgt[:, :-1])
            _, predicted = torch.max(eval_output, dim=2)
            return predicted

    def train(self, src, tgt, vocab_manager):
        print("\nTraining the model to translate English to Russian:")
        for epoch in range(self.config.num_epochs):
            loss = self.train_epoch(src, tgt)
            predicted = self.evaluate(src, tgt, vocab_manager)
            if epoch == self.config.num_epochs - 1 or epoch % 3 == 0:
                print(f"\nEpoch [{epoch + 1}/{self.config.num_epochs}], Loss: {loss:.4f}")
                print(f"Source: I love books")
                pred_sentence = [vocab_manager.idx_to_tgt[idx.item()] for idx in predicted[0]]
                print(f"Predicted translation: {' '.join(pred_sentence)}")
                target_sentence = [vocab_manager.idx_to_tgt[idx.item()] for idx in tgt[0, 1:]]
                print(f"Target translation: {' '.join(target_sentence)}")

class TranslationService:
    """Class to handle translation services"""
    def __init__(self, model, vocab_manager, config):
        self.model = model
        self.vocab_manager = vocab_manager
        self.config = config

    def translate_sentence(self, sentence, max_length=None):
        if max_length is None:
            max_length = self.config.max_length
        self.model.eval()
        if isinstance(sentence, str):
            tokens = sentence.split()
            src_indices = [self.vocab_manager.src_vocab.get(token, self.vocab_manager.src_vocab[self.config.pad_token]) for token in tokens]
            src_indices.append(self.vocab_manager.src_vocab[self.config.eos_token])
            while len(src_indices) < 5:
                src_indices.append(self.vocab_manager.src_vocab[self.config.pad_token])
        else:
            src_indices = sentence
        src_tensor = torch.tensor([src_indices])
        tgt_tensor = torch.tensor([[self.vocab_manager.tgt_vocab[self.config.sos_token]]])
        for _ in range(max_length - 1):
            with torch.no_grad():
                output = self.model(src_tensor, tgt_tensor)
            pred_token = output[0, -1].argmax().item()
            new_tgt = torch.zeros(1, tgt_tensor.size(1) + 1, dtype=torch.long)
            new_tgt[0, :-1] = tgt_tensor
            new_tgt[0, -1] = pred_token
            tgt_tensor = new_tgt
            if pred_token == self.vocab_manager.tgt_vocab[self.config.eos_token]:
                break
        predicted_indices = tgt_tensor[0].tolist()
        translated_tokens = []
        for idx in predicted_indices[1:]:
            if idx == self.vocab_manager.tgt_vocab[self.config.eos_token]:
                break
            translated_tokens.append(self.vocab_manager.idx_to_tgt.get(idx, '<unk>'))
        return translated_tokens

def main():
    config_params = {
        'src_vocab_size': 50,
        'tgt_vocab_size': 50,
        'embed_size': 32,
        'num_heads': 4,
        'num_layers': 3,
        'forward_expansion': 4,
        'dropout': 0.1,
        'max_length': 100,
        'num_epochs': 1000,
        'learning_rate': 0.001,
        'batch_size': 2,
        'pad_token': '<pad>',
        'sos_token': '<sos>',
        'eos_token': '<eos>',
        'pad_idx': 0,
        'sos_idx': 1,
        'eos_idx': 2,
    }
    config = TransformerConfig(**config_params)
    vocab_manager = VocabularyManager(config)
    eng_vocab = {
        config.pad_token: config.pad_idx, 
        config.sos_token: config.sos_idx, 
        config.eos_token: config.eos_idx, 
        'I': 3, 'you': 4, 'love': 5, 'like': 6, 'books': 7, 'music': 8, 'movies': 9, 'food': 10,
        'we': 11, 'they': 12, 'he': 13, 'she': 14, 'it': 15,
        'read': 16, 'watch': 17, 'eat': 18, 'play': 19, 'write': 20,
        'good': 21, 'bad': 22, 'beautiful': 23, 'interesting': 24, 'boring': 25,
        'coffee': 26, 'tea': 27, 'water': 28, 'cake': 29, 'bread': 30,
        'computer': 31, 'phone': 32, 'car': 33, 'house': 34, 'school': 35,
        'friend': 36, 'family': 37, 'mother': 38, 'father': 39, 'sister': 40,
        'brother': 41, 'dog': 42, 'cat': 43, 'bird': 44, 'fish': 45,
        'and': 46, 'or': 47, 'but': 48, 'because': 49
    }
    ru_vocab = {
        config.pad_token: config.pad_idx, 
        config.sos_token: config.sos_idx, 
        config.eos_token: config.eos_idx, 
        'я': 3, 'ты': 4, 'люблю': 5, 'нравится': 6, 'книги': 7, 'музыка': 8, 'фильмы': 9, 'еда': 10,
        'мы': 11, 'они': 12, 'он': 13, 'она': 14, 'оно': 15,
        'читать': 16, 'смотреть': 17, 'есть': 18, 'играть': 19, 'писать': 20,
        'хороший': 21, 'плохой': 22, 'красивый': 23, 'интересный': 24, 'скучный': 25,
        'кофе': 26, 'чай': 27, 'вода': 28, 'торт': 29, 'хлеб': 30,
        'компьютер': 31, 'телефон': 32, 'машина': 33, 'дом': 34, 'школа': 35,
        'друг': 36, 'семья': 37, 'мать': 38, 'отец': 39, 'сестра': 40,
        'брат': 41, 'собака': 42, 'кошка': 43, 'птица': 44, 'рыба': 45,
        'и': 46, 'или': 47, 'но': 48, 'потому': 49
    }
    vocab_manager.set_predefined_vocab(eng_vocab, ru_vocab)
    src = torch.tensor([
        [eng_vocab['I'], eng_vocab['love'], eng_vocab['books'], eng_vocab[config.eos_token], eng_vocab[config.pad_token]],
        [eng_vocab['you'], eng_vocab['like'], eng_vocab['music'], eng_vocab[config.eos_token], eng_vocab[config.pad_token]]
    ])
    print("Source sentences:")
    for i in range(src.shape[0]):
        print([vocab_manager.idx_to_src[idx.item()] for idx in src[i]])
    tgt = torch.tensor([
        [ru_vocab[config.sos_token], ru_vocab['я'], ru_vocab['люблю'], ru_vocab['книги'], ru_vocab[config.eos_token]],
        [ru_vocab[config.sos_token], ru_vocab['ты'], ru_vocab['нравится'], ru_vocab['музыка'], ru_vocab[config.eos_token]]
    ])
    print("Target sentences:")
    for i in range(tgt.shape[0]):
        print([vocab_manager.idx_to_tgt[idx.item()] for idx in tgt[i]])
    transformer = TransformerModel(config)
    output = transformer(src, tgt)
    print(f"Transformer Model Output Shape: {output.shape}")
    _, predicted = torch.max(output, dim=2)
    print("Predicted sequences:")
    for i in range(predicted.shape[0]):
        print([vocab_manager.idx_to_tgt.get(idx.item(), '<unknown>') for idx in predicted[i]])
    trainer = TransformerTrainer(transformer, config)
    trainer.train(src, tgt, vocab_manager)
    translation_service = TranslationService(transformer, vocab_manager, config)
    print("\nTesting the model with new sentences:")
    test_sentences = [
        "I love music",
        "you like books", 
        "I like movies"
    ]
    for test_sent in test_sentences:
        translated_tokens = translation_service.translate_sentence(test_sent)
        print(f"Test Session")
        print(f"Source: {test_sent}")
        print(f"Translation: {' '.join(translated_tokens)}")
        print()

if __name__ == "__main__":
    main()

Source sentences:
['I', 'love', 'books', '<eos>', '<pad>']
['you', 'like', 'music', '<eos>', '<pad>']
Target sentences:
['<sos>', 'я', 'люблю', 'книги', '<eos>']
['<sos>', 'ты', 'нравится', 'музыка', '<eos>']
Transformer Model Output Shape: torch.Size([2, 5, 50])
Predicted sequences:
['оно', 'оно', 'она', 'она', 'брат']
['оно', 'друг', 'оно', 'друг', 'она']

Training the model to translate English to Russian:

Epoch [1/1000], Loss: 4.4071
Source: I love books
Predicted translation: оно оно оно <eos>
Target translation: я люблю книги <eos>

Epoch [4/1000], Loss: 3.4910
Source: I love books
Predicted translation: <eos> <eos> <eos> <eos>
Target translation: я люблю книги <eos>

Epoch [7/1000], Loss: 3.2598
Source: I love books
Predicted translation: <eos> <eos> <eos> <eos>
Target translation: я люблю книги <eos>

Epoch [10/1000], Loss: 3.0644
Source: I love books
Predicted translation: <eos> <eos> <eos> <eos>
Target translation: я люблю книги <eos>

Epoch [13/1000], Loss: 2.8599
Source: I