# Машинный перевод

Ноутбук содержит код трансформера, обучаемого на задачу машинного перевода с немецкого на английский язык

In [None]:
!pip install sacrebleu==2.3.1

In [None]:
import datetime as dt
import math
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchtext
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm
import wandb

print('torch:', torch.__version__, 
      '\ntorchtext:', torchtext.__version__, 
      '\nwandb:', wandb.__version__)

torch: 1.13.0+cpu 
torchtext: 0.14.0 
wandb: 0.13.10


In [None]:
TRAIN_EN_PATH = '/kaggle/input/bhw2-dataset/data/train.de-en.en'
TRAIN_DE_PATH = '/kaggle/input/bhw2-dataset/data/train.de-en.de'
VAL_EN_PATH = '/kaggle/input/bhw2-dataset/data/val.de-en.en'
VAL_DE_PATH = '/kaggle/input/bhw2-dataset/data/val.de-en.de'
TEST_EN_PATH = '/kaggle/working/test1.de-en.en'
TEST_DE_PATH = '/kaggle/input/bhw2-dataset/data/test1.de-en.de'

WORKING_DIR = '/kaggle/working/'

In [None]:
!wandb login

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
def dataset_iterator(path):
    with open(path, 'r') as f:
        for line in f:
            yield line.strip().split()


specials = ['<unk>', '<pad>', '<bos>', '<eos>']

vocab_en = build_vocab_from_iterator(
    dataset_iterator(TRAIN_EN_PATH),
    specials=specials, max_tokens=15000
)

vocab_de = build_vocab_from_iterator(
    dataset_iterator(TRAIN_DE_PATH),
    specials=specials, max_tokens=15000
)

In [None]:
dataloader_num_workers = 2
batch_size = 128
max_length = 128

def texts_to_tensor(path, vocab):
    tokenized_texts = []
    for text in dataset_iterator(path):
        tokens = [vocab[word] if word in vocab else vocab['<unk>'] for word in text]
        tokens = [vocab['<bos>']] + tokens + [vocab['<eos>']]
        tokenized_texts += [tokens]
    
    tensor = torch.full((len(tokenized_texts), max_length), vocab['<pad>'], dtype=torch.long)
    for i, tokens in enumerate(tokenized_texts):
        tensor[i, :len(tokens)] = torch.tensor(tokens)
    
    return tensor


train_de = texts_to_tensor(TRAIN_DE_PATH, vocab_de)
train_en = texts_to_tensor(TRAIN_EN_PATH, vocab_en)
val_de = texts_to_tensor(VAL_DE_PATH, vocab_de)
val_en = texts_to_tensor(VAL_EN_PATH, vocab_en)

train_dataset = TensorDataset(train_de, train_en)
test_dataset = TensorDataset(val_de, val_en)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=dataloader_num_workers, pin_memory=True)
val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                        num_workers=dataloader_num_workers, pin_memory=True)

## My tranformer

In [None]:
class PositionalEncoder(nn.Module):
    # Adapted from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, max_length, embed_dim, dropout=0.1):
        super().__init__()
        
        positions = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        freqs = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) * \
                          (-math.log(10000) / embed_dim)).unsqueeze(0)

        arguments = positions * freqs
        pos_features = torch.zeros(max_length, embed_dim)
        pos_features[:, 0::2] = torch.sin(arguments)
        pos_features[:, 1::2] = torch.cos(arguments)
        pos_features = pos_features.unsqueeze(0)
        self.register_buffer('pos_features', pos_features)

        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs):
        # batched
        pos_encodings = self.pos_features[:, :inputs.shape[1]]
        outputs = inputs + pos_encodings
        return self.dropout(outputs)


class MyTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, max_length, embed_dim, fc_dim, 
                 num_heads, num_encoder_layers, num_decoder_layers, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoder(max_length, embed_dim, dropout)
        self.transformer = nn.Transformer(d_model=embed_dim, nhead=num_heads, 
                                          num_encoder_layers=num_encoder_layers, 
                                          num_decoder_layers=num_decoder_layers, 
                                          dim_feedforward=fc_dim, dropout=dropout, 
                                          batch_first=True)
        self.classifier = nn.Linear(embed_dim, tgt_vocab_size)
    
    def forward(self, src, tgt, src_mask, tgt_mask, src_pad_mask=None, 
                tgt_pad_mask=None, memory_pad_mask=None):
        src = self.src_embedding(src) * math.sqrt(self.embed_dim)
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.embed_dim)
        
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        
        out = self.transformer(src, tgt, 
                               tgt_mask=tgt_mask, src_mask=src_mask,
                               src_key_padding_mask=src_pad_mask,
                               tgt_key_padding_mask=tgt_pad_mask,
                               memory_key_padding_mask=memory_pad_mask)
        return self.classifier(out)


def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [None]:
def save_model(path, num_epochs, model, optimizer, scheduler):
    '''Save on GPU'''
    data = {
        'num_epochs': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None
    }
    torch.save(data, path)


def load_model(path, device, model, optimizer=None, scheduler=None):
    '''Load on GPU'''
    data = torch.load(path)
    model.load_state_dict(data['model_state_dict'])
    model.to(device)
    if optimizer is not None:
        optimizer.load_state_dict(data['optimizer_state_dict'])
    if scheduler is not None:
        scheduler.load_state_dict(data['scheduler_state_dict'])
    return data['num_epochs']


@torch.no_grad()
def inference(model, texts, device):
    model.eval()
    # tranform src texts to tensor
    tokenized_texts = []
    for text in texts:
        tokens = [vocab_de[word] if word in vocab_de else vocab_de['<unk>'] 
                  for word in text.split()]
        tokenized_texts += [[vocab_de['<bos>']] + tokens + [vocab_de['<eos>']]]
    
    src = torch.full((len(tokenized_texts), max_length), 
                      vocab_de['<pad>'], dtype=torch.long).to(device)
    for i, tokens in enumerate(tokenized_texts):
        src[i, :len(tokens)] = torch.tensor(tokens).to(device)
    src_length = torch.min(torch.sum(src==vocab_de['<pad>'], dim=-1))
    src = src[:, :-src_length]

    # make inference
    tgt = torch.tensor([vocab_en['<bos>']] * len(texts), 
                       dtype=torch.long).unsqueeze(-1).to(device)
    for _ in range(max_length - 1):
        src_mask = None
        tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(device)
        src_padding_mask = (src == vocab_de['<pad>'])
        tgt_padding_mask = (tgt == vocab_en['<pad>'])
        pred = model(src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
    
        new_tokens = torch.argmax(pred[:, -1], dim=1)
        tgt = torch.cat((tgt, new_tokens.unsqueeze(-1)), dim=1)

    # transform tgt tensor to texts
    res = []
    for i in range(len(texts)):
        pred = list(tgt[i, 1:])
        if vocab_en['<eos>'] in pred:
            pred = pred[:pred.index(vocab_en['<eos>'])]
        pred = pred[:texts[i].count(' ') + 6]   # cut translations to adequate length!
        res += [' '.join(vocab_en.lookup_tokens(pred))]
    return res


def compute_val_bleu(val_predictions_path):
    cmd = f'cat {val_predictions_path} | sacrebleu {VAL_EN_PATH}  --tokenize none --width 2 -b'
    return float(os.popen(cmd).read())


@torch.no_grad()
def make_predictions(src_path, model, device, tgt_path):
    model.eval()
    with open(src_path, 'r') as src_file:
        src = src_file.readlines()
    
    predictions = []
    for i in range(0, len(src), batch_size):
        texts = [text.strip() for text in src[i:i+batch_size]]
        predictions += inference(model, texts, device)
    
    with open(tgt_path, 'w') as tgt_file:
        for pred in predictions:
            tgt_file.write(pred + '\n')


@torch.no_grad()
def test(model, loader, device):
    test_loss = 0.0
    model.eval()

    with torch.inference_mode():
        for src, target in loader:
            src_length = torch.min(torch.sum(src==vocab_de['<pad>'], dim=-1))
            src = src[:, :-src_length].to(device)
            target_length = torch.min(torch.sum(target==vocab_en['<pad>'], dim=-1))
            target = target[:, :-target_length].to(device)

            tgt_input = target[:,:-1]
            src_mask = None
            tgt_mask = generate_square_subsequent_mask(tgt_input.shape[1]).to(device)
            src_padding_mask = (src == vocab_de['<pad>'])
            tgt_padding_mask = (tgt_input == vocab_en['<pad>'])
            pred = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

            loss = F.cross_entropy(pred.reshape(-1, pred.shape[-1]), 
                                   target[:,1:].reshape(-1), ignore_index=vocab_en['<pad>'])

            test_loss += loss.item() * target.shape[0]
    
    val_predictions_path = WORKING_DIR + 'val_preds.txt'
    make_predictions(VAL_DE_PATH, model, device, val_predictions_path)
    val_bleu = compute_val_bleu(val_predictions_path)
    
    return test_loss / len(loader.dataset), val_bleu


def train_epoch(model, optimizer, train_loader, device):
    train_loss = 0.0
    model.train()

    for src, target in train_loader:
        src_length = torch.min(torch.sum(src==vocab_de['<pad>'], dim=-1))
        src = src[:, :-src_length].to(device)
        target_length = torch.min(torch.sum(target==vocab_en['<pad>'], dim=-1))
        target = target[:, :-target_length].to(device)

        optimizer.zero_grad()
        
        tgt_input = target[:,:-1]
        src_mask = None
        tgt_mask = generate_square_subsequent_mask(tgt_input.shape[1]).to(device)
        src_padding_mask = (src == vocab_de['<pad>'])
        tgt_padding_mask = (tgt_input == vocab_en['<pad>'])
        pred = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        loss = F.cross_entropy(pred.reshape(-1, pred.shape[-1]), 
                               target[:,1:].reshape(-1), 
                               ignore_index=vocab_en['<pad>'])
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * target.shape[0]

    return train_loss / len(train_loader.dataset)


def train_with_wandb(model, optimizer, n_epochs, train_loader, val_loader, device,
                     wandb_init_data, scheduler=None, verbose=False):
    train_loss_log, val_loss_log, val_bleu_log = [], [], []

    with wandb.init(**wandb_init_data) as run:
        for epoch in range(n_epochs):
            start_epoch = dt.datetime.now()
            train_loss = train_epoch(model, optimizer, train_loader, device)
            print('Train epoch finished:', dt.datetime.now() - start_epoch)
            val_loss, val_bleu = test(model, val_loader, device)
            print('Val epoch finished:', dt.datetime.now() - start_epoch)

            wandb.log({"loss/train": train_loss, "loss/val": val_loss, 'bleu/val': val_bleu})

            train_loss_log.append(train_loss)
            val_loss_log.append(val_loss)
            val_bleu_log.append(val_bleu)

            if verbose:
                print(f"Epoch {epoch}\n train loss: {train_loss}\n val loss: {val_loss}\n")

            if scheduler is not None:
                scheduler.step()
            
            if epoch % 3 == 2 or epoch == n_epochs - 1:
                save_model(WORKING_DIR + f'{epoch + 1}epochs.pt', 
                           epoch + 1, model, optimizer, scheduler)

    return train_loss_log, val_loss_log, val_bleu_log

In [None]:
torch.manual_seed(0)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

src_vocab_size = len(vocab_de)
tgt_vocab_size = len(vocab_en)
embed_dim = 512
fc_dim = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3

model = MyTransformer(src_vocab_size, tgt_vocab_size, max_length, embed_dim, fc_dim,
                      num_heads, num_encoder_layers, num_decoder_layers)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
scheduler = None

num_epochs = 15

wandb_init_data = {
    'project': 'bhw2',
    'name': 'run',
    'config': {
        'model': 'nn.Transformer',
        'optimizer': optimizer,
        'scheduler': scheduler,

        'dataset': 'bhw2',
        'num_epochs': num_epochs,
        'train_loader_batch_size': batch_size,
        'dataloader_num_workers': dataloader_num_workers,
        'script': _ih[-1]
    }
}

print(sum(param.numel() for param in model.parameters()))
train_with_wandb(model, optimizer, num_epochs, train_loader, val_loader, device,
                 wandb_init_data, scheduler=scheduler, verbose=True)

make_predictions(TEST_DE_PATH, model, device, TEST_EN_PATH)

35679896


Train epoch finished: 0:05:31.598997
Val epoch finished: 0:06:27.303232
Epoch 0
 train loss: 4.610298933411005
 val loss: 3.530614066075602

Train epoch finished: 0:05:33.770618
Val epoch finished: 0:06:29.318607
Epoch 1
 train loss: 3.4328198841500845
 val loss: 2.890442795008723

Train epoch finished: 0:05:34.184909
Val epoch finished: 0:06:29.658893
Epoch 2
 train loss: 2.9325144423973115
 val loss: 2.5375775609006745

Train epoch finished: 0:05:33.551689
Val epoch finished: 0:06:28.957092
Epoch 3
 train loss: 2.629366356388909
 val loss: 2.3520810149504254

Train epoch finished: 0:05:33.692487
Val epoch finished: 0:06:29.131114
Epoch 4
 train loss: 2.429152554070831
 val loss: 2.2385110356986644

Train epoch finished: 0:05:33.771188
Val epoch finished: 0:06:29.178708
Epoch 5
 train loss: 2.2861809334717096
 val loss: 2.1725916944701096

Train epoch finished: 0:05:32.669568
Val epoch finished: 0:06:28.175332
Epoch 6
 train loss: 2.176872739050018
 val loss: 2.120466211988283

Train 

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
bleu/val,▁▄▅▆▇▇▇████████
loss/train,█▅▄▃▃▂▂▂▂▂▁▁▁▁▁
loss/val,█▅▃▃▂▂▁▁▁▁▁▁▁▁▁

0,1
bleu/val,30.74
loss/train,1.70643
loss/val,2.01893


# Loading model

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

src_vocab_size = len(vocab_de)
tgt_vocab_size = len(vocab_en)
embed_dim = 512
fc_dim = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3

model = MyTransformer(src_vocab_size, tgt_vocab_size, max_length, embed_dim, fc_dim,
                      num_heads, num_encoder_layers, num_decoder_layers)
optimizer = None
scheduler = None

load_model(WORKING_DIR + '9epochs.pt', device, model, optimizer, scheduler)
make_predictions(TEST_DE_PATH, model, device, TEST_EN_PATH)

# Postprocessing

In [None]:
def postprocess_file(tgt_path):
    with open(tgt_path, 'r') as f:
        preds = f.readlines()
    
    preds = [' '.join(text.replace('<unk>', '') \
                          .replace('<pad>', '') \
                          .replace('<bos>', '') \
                          .replace('<eos>', '').strip().split()) for text in preds]

    dedup_preds = []
    for text in preds:
        tokens = text.split()
        dedup_tokens = [tokens[0]] + [tokens[i] for i in range(1, len(tokens)) 
                                      if tokens[i] != tokens[i - 1]]
        dedup_preds.append(' '.join(dedup_tokens) + '\n')
    
    with open(WORKING_DIR + 'postprocessed.txt', 'w') as f:
        f.writelines(dedup_preds)


postprocess_file(WORKING_DIR + 'test1.de-en.en')