In [None]:
# %pip install sacrebleu

In [None]:
# %pip install torchtext==0.18.0
# %pip install torch==2.3.0+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118



Домашка была вдохновлена следующими туториалами:
    https://medium.com/@monimoyd/step-by-step-machine-translation-using-transformer-and-multi-head-attention-96435675be75
    https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html
    

In [None]:
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import io

In [None]:
import numpy as np 
import pandas as pd 
import os
import math
from tqdm import tqdm
import sacrebleu 
import random
import time

In [None]:
%cd bhw-2

In [None]:
def simple_tokenizer(text):
    return text.split()

    
de_tokenizer = simple_tokenizer
en_tokenizer = simple_tokenizer

In [None]:
train_filepaths = ('train.de-en.de', 'train.de-en.en')
val_filepaths = ('val.de-en.de', 'val.de-en.en')
test_filepaths = 'test1.de-en.de'

In [None]:
from torchtext.vocab import build_vocab_from_iterator

def yield_tokens(filepath, tokenizer):
    with io.open(filepath, encoding="utf8") as f:
        for line in f:
            yield tokenizer(line)

de_vocab = build_vocab_from_iterator(
    yield_tokens(train_filepaths[0], de_tokenizer),
    specials=['<unk>', '<pad>', '<bos>', '<eos>'],
    min_freq=8
)
de_vocab.set_default_index(de_vocab['<unk>'])

en_vocab = build_vocab_from_iterator(
    yield_tokens(train_filepaths[1], en_tokenizer),
    specials=['<unk>', '<pad>', '<bos>', '<eos>'],
    min_freq=8
)
en_vocab.set_default_index(en_vocab['<unk>'])

In [None]:
def data_process(filepaths):
    if (len(filepaths) == 2):
        raw_de_iter = iter(io.open(filepaths[0], encoding="utf8"))
        raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
        data = []
        for raw_de, raw_en in zip(raw_de_iter, raw_en_iter):
            de_tensor = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)],
                                       dtype=torch.long)
            en_tensor = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)],
                                       dtype=torch.long)
            data.append((de_tensor, en_tensor))
        return data
    else:
        raw_de_iter = iter(io.open(filepaths, encoding="utf8"))
        data = []
        for raw_de in raw_de_iter:
            de_tensor = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)],
                                       dtype=torch.long)
            data.append(de_tensor)
        return data
        

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 128
PAD_IDX = en_vocab['<pad>']
BOS_IDX = en_vocab['<bos>']
EOS_IDX = en_vocab['<eos>']


from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
    de_batch, en_batch = [], []
    for (de_item, en_item) in data_batch:
        de_batch.append(torch.cat([torch.tensor([de_vocab['<bos>']]), de_item, torch.tensor([de_vocab['<eos>']])], dim=0))
        en_batch.append(torch.cat([torch.tensor([en_vocab['<bos>']]), en_item, torch.tensor([en_vocab['<eos>']])], dim=0))

    de_batch = pad_sequence(de_batch, padding_value=de_vocab['<pad>']).transpose(0,1)  
    en_batch = pad_sequence(en_batch, padding_value=en_vocab['<pad>']).transpose(0,1)  # (batch_size, seq_len)

    return de_batch, en_batch



train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)


# Seq2seq transformer

In [None]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return mask

class Seq2SeqTransformer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int,
                 emb_dim: int = 512,
                 nhead: int = 8,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 dim_feedforward: int = 2048,
                 dropout: float = 0.1):
        super().__init__()
        self.emb_dim = emb_dim
        self.src_tok_emb = nn.Embedding(input_dim, emb_dim)
        self.tgt_tok_emb = nn.Embedding(output_dim, emb_dim)
        self.positional_encoding = PositionalEncoding(emb_dim, dropout)

        self.transformer = nn.Transformer(d_model=emb_dim,
                                          nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)  

        self.fc_out = nn.Linear(emb_dim, output_dim)

    def forward(self, src, tgt, 
                src_mask=None, tgt_mask=None, memory_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        """
        src: (batch_size, src_seq_len)
        tgt: (batch_size, tgt_seq_len)
        """
        src_emb = self.positional_encoding(self.src_tok_emb(src) * math.sqrt(self.emb_dim)) 
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt) * math.sqrt(self.emb_dim))  
        outs = self.transformer(src_emb, tgt_emb,
                                src_mask=src_mask,
                                tgt_mask=tgt_mask,
                                memory_mask=memory_mask,
                                src_key_padding_mask=src_key_padding_mask,
                                tgt_key_padding_mask=tgt_key_padding_mask,
                                memory_key_padding_mask=memory_key_padding_mask)
        return self.fc_out(outs)  



input_dim = len(de_vocab)  
output_dim = len(en_vocab)  

emb_dim = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.2

model = Seq2SeqTransformer(input_dim, output_dim,
                           emb_dim=emb_dim,
                           nhead=nhead,
                           num_encoder_layers=num_encoder_layers,
                           num_decoder_layers=num_decoder_layers,
                           dim_feedforward=dim_feedforward,
                           dropout=dropout).to(device)


In [None]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)
        
model.apply(initialize_weights);

# Training

In [None]:
def translate_sentence_transformer(model, src_tensor, trg_vocab, max_len=100):
    model.eval()
    src_tensor = src_tensor.to(device)

    if src_tensor.dim() == 1:
        src_tensor = src_tensor.unsqueeze(0) 

    src_mask = None 
    src_emb = model.positional_encoding(model.src_tok_emb(src_tensor) * math.sqrt(model.emb_dim))

   
    memory = model.transformer.encoder(
        src_emb, src_key_padding_mask=(src_tensor == PAD_IDX)
    )  

    trg_indexes = [trg_vocab['<bos>']]

    for _ in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) 
        tgt_mask = generate_square_subsequent_mask(trg_tensor.shape[1]).to(device)

        tgt_emb = model.positional_encoding(model.tgt_tok_emb(trg_tensor) * math.sqrt(model.emb_dim))

        out = model.transformer.decoder(
            tgt_emb, memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=(trg_tensor == PAD_IDX),
            memory_key_padding_mask=(src_tensor == PAD_IDX)
        )

        out = out[:, -1, :] 
        prob = model.fc_out(out)  # (batch_size, vocab_size)

        pred_token = prob.argmax(1).item()  

        trg_indexes.append(pred_token)

        if pred_token == trg_vocab['<eos>']:
            break

    return trg_indexes


In [None]:
import sacrebleu



def evaluate_bleu(model, iterator, trg_vocab, max_len=100):
    model.eval()
    hypotheses = []
    references = []

    with torch.no_grad():
        for src, trg in iterator:
            src, trg = src.to(device), trg.to(device)

            for i in range(src.shape[0]):  
                src_sentence = src[i, :]
                trg_sentence = trg[i, :]

                pred_indexes = translate_sentence_transformer(model, src_sentence, trg_vocab, max_len=max_len)

                eos_index = trg_vocab['<eos>'] if '<eos>' in trg_vocab else -1

                pred_tokens = [trg_vocab.get_itos()[i] for i in pred_indexes if i not in {trg_vocab['<bos>'], eos_index, PAD_IDX}]

                hypothesis = " ".join(pred_tokens)

                trg_tokens = [trg_vocab.get_itos()[i] for i in trg_sentence.tolist() if i not in {trg_vocab['<bos>'], eos_index, PAD_IDX}]

                reference = " ".join(trg_tokens).strip()

                hypotheses.append(hypothesis)
                references.append([reference])  

    bleu = sacrebleu.corpus_bleu(hypotheses, references)
    return bleu.score


In [None]:
def create_masks(src, tgt):
    tgt_seq_len = tgt.shape[1] 
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len).to(src.device)
    
    src_padding_mask = (src == PAD_IDX)  
    tgt_padding_mask = (tgt == PAD_IDX) 

    return None, tgt_mask, src_padding_mask, tgt_padding_mask


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          criterion: nn.Module,
          clip: float, scheduler):
    model.train()
    epoch_loss = 0

    for _, (src, trg) in enumerate(iterator):
        src, trg = src.to(device), trg.to(device)
        
        optimizer.zero_grad()
        
        src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask = create_masks(src, trg)

        output = model(src, trg, src_mask=src_mask, tgt_mask=tgt_mask, 
                       src_key_padding_mask=src_key_padding_mask, 
                       tgt_key_padding_mask=tgt_key_padding_mask)

        output = output[:, :-1, :].reshape(-1, output.shape[-1])  
        trg = trg[:, 1:].reshape(-1)  

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        scheduler.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)



def epoch_time(start_time: float, end_time: float):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


In [None]:
def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
PAD_IDX = en_vocab.get_stoi()['<pad>']

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
from torch.optim.lr_scheduler import LambdaLR

optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4)
warmup_steps = 2000

def lr_lambda(step):
    if step == 0:
        step = 1
    return (emb_dim ** -0.5) * min(step ** (-0.5), step * (warmup_steps ** (-1.5)))
    
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

In [None]:
import warnings
warnings.simplefilter("ignore", UserWarning)


In [None]:
N_EPOCHS = 20
CLIP = 1
best_valid_bleu = 0  

train_losses = []
valid_bleus = []

for epoch in tqdm(range(N_EPOCHS)):

    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP, scheduler)
    valid_bleu = evaluate_bleu(model, valid_iter, en_vocab)
    
    end_time = time.time()
    train_losses.append(train_loss)
    valid_bleus.append(valid_bleu)

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. BLEU: {valid_bleu:.2f}')

   
torch.save(model.state_dict(), 'best-model-lr.pt')


In [None]:
import matplotlib.pyplot as plt


plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, 14+1), train_losses, marker='o', label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss per Epoch")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, N_EPOCHS+1), valid_bleus, marker='o', color='green', label="Validation BLEU")
plt.xlabel("Epoch")
plt.ylabel("BLEU Score")
plt.title("Validation BLEU per Epoch")
plt.legend()

plt.tight_layout()
plt.show()

# Inference

In [None]:
def beam_search_translate_sentence(model, src_tensor, trg_vocab, beam_width=5, max_len=100, length_penalty=0.75):
    model.eval()
    st = src_tensor.to(device)
    if st.dim() == 1:
        st = st.unsqueeze(0)

    emb = model.positional_encoding(model.src_tok_emb(st) * math.sqrt(model.emb_dim))
    mem = model.transformer.encoder(emb, src_key_padding_mask=(st == PAD_IDX))

    init_token = trg_vocab['<bos>']
    beams = [([init_token], 0.0)]
    fin = []

    for _ in range(max_len):
        new_beams = []
        for seq, scr in beams:
            if seq[-1] == trg_vocab['<eos>']: #stop if we found the end
                norm = scr / (len(seq) ** length_penalty)
                fin.append((seq, norm))
                continue

            ts = torch.LongTensor(seq).unsqueeze(0).to(device)
            mask = generate_square_subsequent_mask(ts.shape[1]).to(device)
            te = model.positional_encoding(model.tgt_tok_emb(ts) * math.sqrt(model.emb_dim))

            out = model.transformer.decoder(
                te, mem,
                tgt_mask=mask,
                tgt_key_padding_mask=(ts == PAD_IDX),
                memory_key_padding_mask=(st == PAD_IDX)
            )
            out = out[:, -1, :]
            lg = model.fc_out(out)
            lp = torch.log_softmax(lg, dim=-1)
            top_lp, top_idx = torch.topk(lp, beam_width, dim=-1)
            top_lp = top_lp.squeeze(0)
            top_idx = top_idx.squeeze(0)

            for token, token_lp in zip(top_idx, top_lp):
                new_seq = seq + [token.item()]
                new_scr = scr + token_lp.item()
                new_beams.append((new_seq, new_scr))

        if len(fin) >= beam_width:
            break

        new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
        beams = new_beams[:beam_width]

    if not fin:
        fin = [(s, scr / (len(s) ** length_penalty)) for s, scr in beams] #score normalization

    best_seq, _ = sorted(fin, key=lambda x: x[1], reverse=True)[0]
    return best_seq


In [None]:
def generate_batch_test(data_batch):
    de_batch = []
    for de_item in data_batch:
        de_tensor = torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])])
        de_batch.append(de_tensor)
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)  # (seq_len, batch_size)
    return de_batch

test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=generate_batch_test)

In [None]:
predictions = []
model.eval()

with torch.no_grad():
    for batch in test_iter:
        batch = batch.to(device)  
        for i in range(batch.shape[1]): 
            src_sentence = batch[:, i]  
            predicted_indices = beam_search_translate_sentence(model, src_sentence, en_vocab)

            tokens = [en_vocab.get_itos()[i] for i in predicted_indices if i not in {BOS_IDX, EOS_IDX, PAD_IDX}]

            prediction = " ".join(tokens)
            predictions.append(prediction)


with open("test.de-en-final-final-model.en", "w", encoding="utf-8", newline="\n") as f:
    for pred in predictions:
        f.write(pred + "\n")
