In [34]:
import torch
import torch.nn as nn
from torchtext.datasets import Multi30k

data = Multi30k(split='train', language_pair=('de', 'en'))
data_val = Multi30k(split='valid', language_pair=('de', 'en'))

In [36]:
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm

d:\mainD\study\toolsTmp\myAnaconda\Anaconda\envs\codee\python.exe: No module named spacy
d:\mainD\study\toolsTmp\myAnaconda\Anaconda\envs\codee\python.exe: No module named spacy


In [38]:
from helpers import ParallelCorpus

dataset = ParallelCorpus(data,'de', 'en', data_limit = 5000)
dataset_val = ParallelCorpus(data_val, 'de', 'en', dataset.vocab_a, dataset.vocab_b)




In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [40]:
WORD_EMBEDDING = 64

class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, WORD_EMBEDDING)
        self.rnn = nn.GRU(WORD_EMBEDDING, hidden_size // 2 , bidirectional = True)
        
    def forward(self, src, src_len):
        embedded = self.embedding(src)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len)
        packed_outputs, hidden = self.rnn(packed_embedded)         
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs) 
        hidden = torch.cat([hidden[0,:, :], hidden[1,:,:]], dim=1)
        return outputs, hidden.unsqueeze(0)

In [41]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, WORD_EMBEDDING)
        self.rnn = nn.GRU(hidden_size + WORD_EMBEDDING, hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    
    def forward(self, input, hidden, encoder_outputs, src_len):
        
        # input = [batch size] 
        # hidden = [batch size, hidden_size]
        # encoder_outputs = [src len, batch size, hidden_size]
        # mask = [batch size, src len]
      
        embedded = self.embedding(input)
        
        # context = [batch size, hidden_size]
        context = torch.zeros((input.shape[0], self.rnn.input_size - WORD_EMBEDDING)).to(embedded.device)
        
        rnn_input = torch.cat((embedded, context), dim = 1)
        
        rnn_input = rnn_input.unsqueeze(0)
        #rnn_input = [1, batch size, word_embedding + hidden_size]
        
        _, hidden = self.rnn(rnn_input, hidden)
        
        prediction = self.fc(hidden.squeeze(0))
        
        return prediction, hidden

In [42]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = next(enc.parameters()).device
        
    def forward(self, src, src_len, tgt):       
        tgt_paded_len = tgt.shape[0]
        batch_size = tgt.shape[1]
        vocab_tgt_size = self.decoder.vocab_size
        
        #tensor to store decoder outputs
        outputs = torch.zeros(tgt_paded_len, batch_size, vocab_tgt_size).to(self.device)
        
        encoder_outputs, hidden = self.encoder(src, src_len)
        
        #first input to the decoder is the <sos> tokens
        prev_word = tgt[0,:]
                
        for i in range(1, tgt_paded_len):          
            output, hidden = self.decoder(prev_word, hidden, encoder_outputs, src_len)
            outputs[i] = output
               
            # uczenie poprzez teaching forcing
            prev_word = tgt[i] if self.training else output.argmax(1)
            
        return outputs


In [43]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for sample in batch:
        src_batch.append(sample["text_a"])
        tgt_batch.append(sample["text_b"])
    src_batch = pad_sequence(src_batch, padding_value=0)
    tgt_batch = pad_sequence(tgt_batch, padding_value=0)
    lena_batch = torch.tensor([len(sample["text_a"]) for sample in batch], dtype=torch.int64)
    lenb_batch = torch.tensor([len(sample["text_b"]) for sample in batch], dtype=torch.int64)
    idx = torch.argsort(lena_batch, descending=True)
    return src_batch[:,idx].to(device), tgt_batch[:,idx].to(device), lena_batch[idx], lenb_batch[idx]

dataloader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn)
dataloader_val = DataLoader(dataset_val, batch_size=256, collate_fn=collate_fn)

In [44]:
HIDDEN_DIM = 256
EPOCHS = 3

enc = Encoder(len(dataset.vocab_a), HIDDEN_DIM).to(device)
dec = Decoder(len(dataset.vocab_b), HIDDEN_DIM).to(device)
enc_dec = EncoderDecoder(enc, dec).to(device)

optimizer = torch.optim.Adam(enc_dec.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = 0)

enc_dec.train()

for epoch in range(EPOCHS):
    epoch_loss = 0
    for i, batch in enumerate(dataloader):
        
        src, tgt, src_len, tgt_len = batch
        
        outputs = enc_dec(src, src_len, tgt)

        tgt = tgt[1:].view(-1)
        
        outputs = outputs[1:].view(-1, dec.vocab_size)

        loss = criterion(outputs, tgt)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(enc_dec.parameters(), 2.)
        
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss   
    print(f'Epoch: {epoch+1:02} | Loss: {epoch_loss / len(dataloader):.3f}')


Epoch: 01 | Loss: 5.085
Epoch: 02 | Loss: 4.025
Epoch: 03 | Loss: 3.613


In [45]:
from helpers import translate
example_idx = 12

src = dataset[example_idx]["text_a"]
translate(src, enc_dec, dataset, device)

tgt =  dataset[example_idx]["text_b"]
print(f'Referencja = {dataset.vocab_b.lookup_tokens(tgt.numpy())}')

Żródło = ['<start>', 'Ein', 'schwarzer', 'Hund', 'und', 'ein', 'gefleckter', 'Hund', 'kämpfen', '.', '<stop>']
Tłumaczenie = ['<start>', 'A', 'dog', 'is', 'playing', 'a', '<unk>', '.', '<stop>']
Referencja = ['<start>', 'A', 'black', 'dog', 'and', 'a', 'spotted', 'dog', 'are', 'fighting', '<stop>']


In [46]:
def evaluate_validation_set(model, dataloader, criterion):
    model.eval()
    loss = 0.0
    for X, y, X_size, y_size in dataloader:
        with torch.no_grad():
            Y = model(X, X_size, y_size)
            y = y[1:].view(-1)
            Y = Y[1:].view(-1, model.decoder.vocab_size)
            loss += criterion(Y, y)
    model.train()
    return loss / len(dataloader)


In [47]:
class DecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, WORD_EMBEDDING)
        self.rnn = nn.GRU(hidden_size + WORD_EMBEDDING, hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    
    def forward(self, input, hidden, encoder_outputs, src_len):
        
        # input = [batch size] 
        # hidden = [batch size, hidden_size]
        # encoder_outputs = [src len, batch size, hidden_size]
        # mask = [batch size, src len]
      
        embedded = self.embedding(input)
        
        # context = [batch size, hidden_size]
        attention = torch.bmm(hidden.permute(1,0,2), encoder_outputs.permute(1,2,0)).softmax(2)
        context = torch.bmm(attention, encoder_outputs.permute(1,0,2)).squeeze(1)
        
        rnn_input = torch.cat((embedded, context), dim = 1)
        
        rnn_input = rnn_input.unsqueeze(0)
        #rnn_input = [1, batch size, word_embedding + hidden_size]
        
        _, hidden = self.rnn(rnn_input, hidden)
        
        prediction = self.fc(hidden.squeeze(0))
        
        return prediction, hidden

In [48]:
HIDDEN_DIM = 256
EPOCHS = 3

enc = Encoder(len(dataset.vocab_a), HIDDEN_DIM).to(device)
dec = DecoderWithAttention(len(dataset.vocab_b), HIDDEN_DIM).to(device)
enc_dec = EncoderDecoder(enc, dec).to(device)

optimizer = torch.optim.Adam(enc_dec.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

enc_dec.train()

for epoch in range(EPOCHS):
  epoch_loss = 0
  for i, batch in enumerate(dataloader):
    src, tgt, src_len, tgt_len = batch

    outputs = enc_dec(src, src_len, tgt)

    tgt = tgt[1:].view(-1)

    outputs = outputs[1:].view(-1, dec.vocab_size)

    loss = criterion(outputs, tgt)
    loss.backward()

    torch.nn.utils.clip_grad_norm_(enc_dec.parameters(), 2.)

    optimizer.step()
    optimizer.zero_grad()

    epoch_loss += loss
  print(f'Epoch: {epoch + 1:02} | Loss: {epoch_loss / len(dataloader):.3f}')

Epoch: 01 | Loss: 5.047
Epoch: 02 | Loss: 4.144
Epoch: 03 | Loss: 3.761
