In [13]:
import torch 
import torch.nn as nn
import math
from torch.utils.data import Dataset, DataLoader
from torchtext.data.metrics import bleu_score
import torch.nn.functional as F

In [3]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, num_heads=8, ff_hidden_dim=2048, dropout=0.1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout, batch_first=True)        
        self.norm_1 = nn.LayerNorm(embedding_dim)
        self.norm_2 = nn.LayerNorm(embedding_dim)
        self.ff = nn.Sequential(
            nn.Linear(embedding_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embedding_dim)
        )
        self.dropout = nn.Dropout(dropout)        
    def forward(self, x, mask=None):
        attn_out, attn_weights = self.attention(x, x, x, key_padding_mask=mask)        
        x = self.norm_1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm_2(x + self.dropout(ff_out))

        return x, attn_weights


In [4]:
class TransformerDecoder(nn.Module):
    def __init__(self, embedding_dim, num_heads=8, ff_hidden_dim=2048, dropout=0.1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.masked_attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout, batch_first=True)
        self.norm_1 = nn.LayerNorm(embedding_dim)
        self.cross_attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout, batch_first=True)
        self.norm_2 = nn.LayerNorm(embedding_dim)
        self.ff = nn.Sequential(
            nn.Linear(embedding_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embedding_dim)
        )
        self.norm_3 = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(dropout)

        self.masked_attention_weights_ = None
        self.cross_attention_weights_ = None        
    def forward(self, x, enc_outs, attn_mask=None, mask_1=None, mask_2=None):                
        mask_attn_out, mask_attn_weights = self.masked_attention(x,x,x, attn_mask=attn_mask, key_padding_mask=mask_1)                
        x = self.norm_1(x + self.dropout(mask_attn_out))                

        cross_attn_out, cross_attn_weights = self.cross_attention(query=x, key=enc_outs, value=enc_outs, key_padding_mask=mask_2)                
        x = self.norm_2(x + self.dropout(cross_attn_out))        
        
        ff_out = self.ff(x)
        x = self.norm_3(x + self.dropout(ff_out))
        return x, mask_attn_weights, cross_attn_weights

In [5]:
class PositionalEncoder(nn.Module):
    def __init__(self, embedding_dim, max_len=5000, *args, **kwargs):
        super().__init__(*args, **kwargs)
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

In [6]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, pos_encoder, embedding_dim, 
                 input_vocab_size, input_pad_idx,
                 output_vocab_size, output_pad_idx,
                 bos_idx, eos_idx, 
                *args, **kwargs):
        
        super().__init__(*args, **kwargs)

        self.input_embedding = nn.Embedding(input_vocab_size, embedding_dim, padding_idx=input_pad_idx)
        self.output_embedding = nn.Embedding(output_vocab_size, embedding_dim, padding_idx=output_pad_idx)        
        self.output_pad_idx = output_pad_idx
        self.encoder = encoder
        self.pos_encoder = pos_encoder
        self.decoder = decoder
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx        
        
        self.out = nn.Linear(embedding_dim, output_vocab_size)   


    def forward(self, src, src_mask, tgt, tgt_mask):
        src_emb = self.pos_encoder(self.input_embedding(src))
        enc_outs, _ = self.encoder(src_emb, mask=src_mask)
        
        tgt_emb = self.pos_encoder(self.output_embedding(tgt[:, :-1])) #shift right
        seq_len = tgt_emb.size(1)            
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device) * float('-inf'), diagonal=1)                

        dec_outs, _, cross_attention_weights = self.decoder(tgt_emb, enc_outs, attn_mask=attn_mask, mask_1=tgt_mask[:, :-1], mask_2=src_mask)
        logits = self.out(dec_outs)
        return logits, cross_attention_weights
    @torch.no_grad()
    def predict(self, src, src_mask,  max_len=20):
        batch_size = src.size(0)
        src_emb = self.pos_encoder(self.input_embedding(src))
        enc_outs,_ = self.encoder(src_emb, mask=src_mask)

        y_pred = torch.full((batch_size, 1), self.bos_idx, dtype=torch.long, device=src.device)
        finished = torch.zeros(batch_size, dtype=torch.bool, device=src.device)

        for t in range(max_len):
            tgt_emb = self.pos_encoder(self.output_embedding(y_pred))            
            seq_len = tgt_emb.size(1)            
            attn_mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device) * float('-inf'), diagonal=1)                

            dec_outs, _, cross_attention_weights = self.decoder(tgt_emb, enc_outs,attn_mask=attn_mask, mask_1=None, mask_2=src_mask)
            logits = self.out(dec_outs[:, -1])
            next_token = logits.argmax(-1).unsqueeze(1)
            y_pred = torch.cat([y_pred, next_token], dim=1) 
            finished |= (next_token.squeeze(1) == self.eos_idx)
            if finished.all():
                break    
        return y_pred, cross_attention_weights

In [7]:
toy_corpus = [
    ("привет", "hello"),
    ("как дела", "how are you"),
    ("спасибо", "thank you"),
    ("пока", "bye"),
    ("доброе утро", "good morning"),
    ("добрый вечер", "good evening"),
    ("я люблю тебя", "i love you"),
    ("что это", "what is this"),
    ("где ты", "where are you"),
    ("хорошо", "ok"),
]
from collections import Counter

def build_vocabs(sentences, min_freq=1):
    counter = Counter()
    for s in sentences:
        counter.update(s.split())
    word2idx = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}
    idx2word = {0: '<pad>', 1: '<bos>', 2: '<eos>', 3: '<unk>'}
    for word, freq in counter.items():
        if freq >= min_freq and word not in word2idx:
            idx = len(word2idx)
            word2idx[word] = idx
            idx2word[idx] = word
    return word2idx, idx2word


src_sentences = [s for s, _ in toy_corpus]
tr_sentences = [s for _, s in toy_corpus]
src_w2i, src_i2w = build_vocabs(src_sentences)
tr_w2i, tr_i2w = build_vocabs(tr_sentences)

def encode(sentence, w2i, add_bos=False, add_eos=False):
    tokens = []
    if add_bos: tokens.append(w2i['<bos>'])
    
    for w in sentence.split():
        tokens.append(w2i.get(w, w2i['<unk>']))        
    
    if add_eos: tokens.append(w2i['<eos>'])
    return tokens

def pad_batch(seqs, pad_idx):
    max_len = max(len(s) for s in seqs)
    return [s + [pad_idx] * (max_len-len(s)) for s in seqs]

class ToyCorpusDataset(Dataset):
    def __init__(self, corpus, src_w2i, tr_w2i):
        super().__init__()
        self.corpus = corpus
        self.src_vocab = src_w2i
        self.tr_vocab = tr_w2i
    def __len__(self):
        return len(self.corpus)
    def __getitem__(self, index):
        src, tr = self.corpus[index]
        src_ids = encode(src, self.src_vocab)
        tr_ids = encode(tr, self.tr_vocab, True, True)

        return torch.tensor(src_ids), torch.tensor(tr_ids)

def collate_fn(batch):
    src_batch, tr_batch = zip(*batch)
        
    src_lens = torch.tensor([len(s) for s in src_batch])
    tr_lens = torch.tensor([len(s) for s in tr_batch])
        
    src_padded = torch.tensor(pad_batch([s.tolist() for s in src_batch], src_w2i['<pad>']), dtype=torch.long)
    tr_padded = torch.tensor(pad_batch([s.tolist() for s in tr_batch], tr_w2i['<pad>']), dtype=torch.long)
        
    src_mask = (src_padded == src_w2i['<pad>'])
    tr_mask = (tr_padded == tr_w2i['<pad>'])
    
    return src_padded, src_mask, src_lens, tr_padded, tr_mask, tr_lens


dataset = ToyCorpusDataset(toy_corpus, src_w2i, tr_w2i)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [20]:
encoder = TransformerEncoder(embedding_dim=8, num_heads=2, ff_hidden_dim=8)
decoder = TransformerDecoder(embedding_dim=8, num_heads=2, ff_hidden_dim=8)
pos_encoder = PositionalEncoder(embedding_dim=8)
transformer = Transformer(
    encoder=encoder, decoder=decoder, pos_encoder=pos_encoder, embedding_dim=8, input_vocab_size=len(src_w2i),
    input_pad_idx=src_w2i['<pad>'], output_vocab_size=len(tr_w2i), output_pad_idx=tr_w2i['<pad>'],
    bos_idx=tr_w2i['<bos>'], eos_idx=tr_w2i['<eos>']
    )
optimizer = torch.optim.Adam(transformer.parameters())
loss_fn = nn.CrossEntropyLoss(ignore_index=tr_w2i['<pad>'])

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer.to(device)
best_bleu=0.0
early_stopping_rounds, early_stopping_rounds_counter = 1000, 0

for epoch in range(1, 201):    
    transformer.train()
    total_loss = 0        
    for src, src_mask, src_lens, tr, tr_mask, tr_lens in dataloader:        
        src, tr = src.to(device), tr.to(device)
        
        optimizer.zero_grad()

        logits, _ = transformer(src, src_mask, tr, tr_mask)

                
        loss = loss_fn(
            logits.reshape(-1, logits.size(2)),
            tr[:, :logits.size(1)].reshape(-1),
        )
        
        loss.backward()        
        optimizer.step()
        total_loss += loss.item()    

    transformer.eval()
    candidate_corpus, references_corpus = [], []
    with torch.no_grad():
        for src, src_mask, src_lens, tr, tr_mask, tr_lens in dataloader:
            src, tr = src.to(device), tr.to(device)
            preds, _ = transformer.predict(src, src_mask)                    

            for i in range(len(src)):
                hyp = [tr_i2w[idx.item()] for idx in preds[i]
                       if idx.item() not in  [tr_w2i['<pad>'], tr_w2i['<bos>'], tr_w2i['<eos>']]
                ]
                ref = [[tr_i2w[idx.item()] for idx in tr[i] 
                        if idx.item() not in [tr_w2i['<pad>'], tr_w2i['<bos>'], tr_w2i['<eos>']]]]
                
                candidate_corpus.append(hyp)
                references_corpus.append(ref[0])            
    bleu = bleu_score(candidate_corpus, references_corpus)

    if bleu > best_bleu:
        best_bleu = bleu
        early_stopping_rounds_counter = 0
    else:
        early_stopping_rounds_counter += 1
        if early_stopping_rounds_counter >= early_stopping_rounds:
            print(f'Early stopping on epoch {epoch}, best BLEU {best_bleu:.4f}')
            break    



    if epoch % 50 == 0:
        print(f'Epoch {epoch}. Loss: {total_loss/len(dataloader):.4f}, Bleu: {bleu:.4f}')        
        for i in range(2):            
            print('ref:', references_corpus[-i])        
            print('hyp:', candidate_corpus[-i])



Epoch 50. Loss: 1.2917, Bleu: 0.0000
ref: ['good', 'evening']
hyp: []
ref: ['what', 'is', 'this']
hyp: ['bye', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you']
Epoch 100. Loss: 0.6337, Bleu: 0.0000
ref: ['what', 'is', 'this']
hyp: ['bye', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you']
ref: ['hello']
hyp: []
Epoch 150. Loss: 0.3448, Bleu: 0.0000
ref: ['good', 'morning']
hyp: ['bye', 'you', 'you', 'you', 'you']
ref: ['ok']
hyp: []
Epoch 200. Loss: 0.1681, Bleu: 0.0000
ref: ['what', 'is', 'this']
hyp: ['bye', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you']
ref: ['bye']
hyp: []


Of course we need more data