# Seq2Seq

In [368]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchtext.data.metrics import bleu_score

In [369]:
toy_corpus = [
    ("привет", "hello"),
    ("как дела", "how are you"),
    ("спасибо", "thank you"),
    ("пока", "bye"),
    ("доброе утро", "good morning"),
    ("добрый вечер", "good evening"),
    ("я люблю тебя", "i love you"),
    ("что это", "what is this"),
    ("где ты", "where are you"),
    ("хорошо", "ok"),
]


In [370]:
from collections import Counter

def build_vocabs(sentences, min_freq=1):
    counter = Counter()
    for s in sentences:
        counter.update(s.split())
    word2idx = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}
    idx2word = {0: '<pad>', 1: '<bos>', 2: '<eos>', 3: '<unk>'}
    for word, freq in counter.items():
        if freq >= min_freq and word not in word2idx:
            idx = len(word2idx)
            word2idx[word] = idx
            idx2word[idx] = word
    return word2idx, idx2word


src_sentences = [s for s, _ in toy_corpus]
tr_sentences = [s for _, s in toy_corpus]
src_w2i, src_i2w = build_vocabs(src_sentences)
tr_w2i, tr_i2w = build_vocabs(tr_sentences)

In [371]:
def encode(sentence, w2i, add_bos=False, add_eos=False):
    tokens = []
    if add_bos: tokens.append(w2i['<bos>'])
    
    for w in sentence.split():
        tokens.append(w2i.get(w, w2i['<unk>']))        
    
    if add_eos: tokens.append(w2i['<eos>'])
    return tokens

def pad_batch(seqs, pad_idx):
    max_len = max(len(s) for s in seqs)
    return [s + [pad_idx] * (max_len-len(s)) for s in seqs]

In [372]:
class ToyCorpusDataset(Dataset):
    def __init__(self, corpus, src_w2i, tr_w2i):
        super().__init__()
        self.corpus = corpus
        self.src_vocab = src_w2i
        self.tr_vocab = tr_w2i
    def __len__(self):
        return len(self.corpus)
    def __getitem__(self, index):
        src, tr = self.corpus[index]
        src_ids = encode(src, self.src_vocab)
        tr_ids = encode(tr, self.tr_vocab, True, True)

        return torch.tensor(src_ids), torch.tensor(tr_ids)

def collate_fn(batch):
    src_batch, tr_batch = zip(*batch)
    src_lens = torch.tensor([len(s) for s in src_batch])
    tr_lens = torch.tensor([len(s) for s in tr_batch])
    src_padded = torch.tensor(pad_batch([s.tolist() for s in src_batch], src_w2i['<pad>']))
    tr_padded = torch.tensor(pad_batch([s.tolist() for s in tr_batch], tr_w2i['<pad>']))

    return src_padded, src_lens, tr_padded, tr_lens

dataset = ToyCorpusDataset(toy_corpus, src_w2i, tr_w2i)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [373]:
class Encoder(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_layers, vocab_size, pad_idx, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=pad_idx)        
        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers, batch_first=True)                
        
    def forward(self, x, lens):
        embed = self.embedding(x)       
        packed = nn.utils.rnn.pack_padded_sequence(embed, lengths=lens, batch_first=True, enforce_sorted=False)         
        packed_out, hidden = self.gru(packed)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        return output, hidden

In [374]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_layers, vocab_size, pad_idx, *args, **kwargs):        
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=pad_idx)        
        self.attn_W1 = nn.Linear(hidden_size, hidden_size)
        self.attn_W2 = nn.Linear(hidden_size, hidden_size)
        self.attn_v = nn.Linear(hidden_size, 1, False)

        self.gru = nn.GRU(embedding_dim + hidden_size, hidden_size, num_layers, batch_first=True)        
        self.out = nn.Linear(hidden_size * 2, vocab_size)
    def forward(self, x, hidden, enc_out, src_lens=None):        
        # x: (B, 1)
        embed = self.embedding(x) # (b, 1, e)       

        hidden_exp = hidden[-1].unsqueeze(1).expand(-1, enc_out.size(1), -1)
        score = self.attn_v(torch.tanh(self.attn_W1(hidden_exp) + self.attn_W2(enc_out))).squeeze(2)
        if src_lens is not None: #masking paddings 
            mask = torch.arange(enc_out.size(1), device=enc_out.device).unsqueeze(0) >= src_lens.unsqueeze(1)
            score.masked_fill_(mask, float('-inf'))
        attn_weights = torch.softmax(score, dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), enc_out)

        gru_input = torch.cat([embed, context], dim=2)                
        output, hidden = self.gru(gru_input, hidden) # (b, 1, h), (num_layers, b, h)                                
        output = torch.cat([output, context], dim=2)   
        logits = self.out(output.squeeze(1)) 
        return logits, hidden    

In [375]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, bos_idx, eos_idx, pad_idx, teacher_forcing_ratio=1., *args, **kwargs):        
        super().__init__(*args, **kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.pad_idx = pad_idx
        self.teacher_forcing_ratio = teacher_forcing_ratio        

    def forward(self, src,src_lens, tr=None, max_len=20):
        batch_size = src.size(0)
        enc_out, h = self.encoder(src, src_lens) # enc_out for attention in future             
        teacher_forcing_ratio = self.teacher_forcing_ratio

        outputs = []
        finished = torch.zeros(batch_size, dtype=torch.bool, device=src.device)

        y_prev = torch.full((batch_size,  1), self.bos_idx, dtype=torch.long, device=src.device)        

        for t in range(1, max_len):
            logits, h = self.decoder(y_prev, h, enc_out, src_lens)
            outputs.append(logits.unsqueeze(1))                    

            if self.training and tr is not None:
                teacher_force =  torch.rand(1).item() < teacher_forcing_ratio
                y_prev = tr[:, t].unsqueeze(1) if teacher_force and t < tr.size(1) else logits.argmax(-1).unsqueeze(1)                
            else:
                y_prev = logits.argmax(-1).unsqueeze(1)
                        
            finished |= (y_prev.squeeze(1) == self.eos_idx)
            if finished.all():
                break        
            
        return torch.cat(outputs, dim=1)
    
    @torch.no_grad()
    def translate(self, src, src_lens, max_len=20):
        self.eval()
        batch_size = src.size(0)
        enc_out, h = self.encoder(src, src_lens)

        y_prev = torch.full((batch_size, 1), self.bos_idx, dtype=torch.long, device=src.device)
        outputs = []
        finished = torch.zeros(batch_size, dtype=torch.bool, device=src.device)

        for t in range(max_len):
            logits, h = self.decoder(y_prev, h, enc_out, src_lens)
            pred = logits.argmax(-1).unsqueeze(1)
            outputs.append(pred)

            y_prev = pred
            finished |= (pred.squeeze(1) == self.eos_idx)
            if finished.all():
                break
        return torch.cat(outputs, dim=1)

In [379]:
encoder = Encoder(embedding_dim=8, hidden_size=8, num_layers=2, vocab_size=len(src_w2i), pad_idx=src_w2i['<pad>'])
decoder = Decoder(embedding_dim=8, hidden_size=8, num_layers=2, vocab_size=len(tr_w2i), pad_idx=tr_w2i['<pad>'])
seq2seq = Seq2Seq(encoder, decoder, bos_idx=tr_w2i['<bos>'], eos_idx=tr_w2i['<eos>'], pad_idx=tr_w2i['<pad>'], teacher_forcing_ratio=0.8)
loss_fn = nn.CrossEntropyLoss(ignore_index=tr_w2i['<pad>'])
optimizer = torch.optim.Adam(seq2seq.parameters())

In [380]:
device = "cuda" if torch.cuda.is_available() else "cpu"
seq2seq.to(device)
best_bleu = 0.0
early_stopping_rounds, early_stopping_rounds_counter = 1000, 0

for epoch in range(1, 201):    
    seq2seq.train()
    total_loss = 0        
    for src, src_lens, tr, tr_lens in dataloader:
        src, tr = src.to(device), tr.to(device)
        optimizer.zero_grad()

        logits = seq2seq(src, src_lens, tr, max_len=tr.size(1))        
        
        loss = loss_fn(
            logits.reshape(-1, logits.size(2)),
            tr[:, :logits.size(1)].reshape(-1),
        )
        
        loss.backward()        
        optimizer.step()
        total_loss += loss.item()

    seq2seq.teacher_forcing_ratio *= 0.99

    # seq2seq.eval()
    # candidate_corpus, references_corpus = [], []
    # with torch.no_grad():
    #     for src, src_lens, tr, tr_lens in dataloader:
    #         src, tr = src.to(device), tr.to(device)
    #         preds = seq2seq.translate(src, src_lens, max_len=tr.size(1))                        

    #         for i in range(len(src)):
    #             hyp = [tr_i2w[idx.item()] for idx in preds[i]
    #                    if idx.item() not in  [tr_w2i['<pad>'], tr_w2i['<bos>'], tr_w2i['<eos>']]
    #             ]
    #             ref = [[tr_i2w[idx.item()] for idx in tr[i] 
    #                     if idx.item() not in [tr_w2i['<pad>'], tr_w2i['<bos>'], tr_w2i['<eos>']]]]
                
    #             candidate_corpus.append(hyp)
    #             references_corpus.append(ref[0])            
    # bleu = bleu_score(candidate_corpus, references_corpus)

    # if bleu > best_bleu:
    #     best_bleu = bleu
    #     early_stopping_rounds_counter = 0
    # else:
    #     early_stopping_rounds_counter += 1
    #     if early_stopping_rounds_counter >= early_stopping_rounds:
    #         print(f'Early stopping on epoch {epoch}, best BLEU {best_bleu:.4f}')
    #         break    



    if epoch % 50 == 0:
        # print(f'Epoch {epoch}. Loss: {total_loss/len(dataloader):.4f}, Bleu: {bleu:.4f}')
        print(f'Epoch {epoch}. Loss: {total_loss/len(dataloader):.4f}')
        # for i in range(2):            
        #     print('ref:', references_corpus[-i])        
        #     print('hyp:', candidate_corpus[-i])


Epoch 50. Loss: 2.0845
Epoch 100. Loss: 1.2566
Epoch 150. Loss: 0.8559
Epoch 200. Loss: 0.5890


In [381]:
test_src = ["привет как дела", "где ты я люблю тебя"]

src_ids = [torch.tensor(encode(s, src_w2i)) for s in test_src]
src_padded = torch.tensor(pad_batch([s.tolist() for s in src_ids], src_w2i['<pad>']))
src_lens = torch.tensor([len(s) for s in src_ids])

src_padded = src_padded.to(device)
src_lens = src_lens.to(device)

seq2seq.eval()
with torch.no_grad():
    preds = seq2seq.translate(src_padded, src_lens, max_len=10)

for i, pred in enumerate(preds):
    words = [tr_i2w[idx.item()] for idx in pred if idx.item() not in [tr_w2i['<pad>'], tr_w2i['<bos>'], tr_w2i['<eos>']]]
    print(f"Source: {test_src[i]}")
    print(f"Pred : {' '.join(words)}\n")


Source: привет как дела
Pred : where you you you you you you you you

Source: где ты я люблю тебя
Pred : i love you you you you you you you

