In [4]:
%%capture
!python -m spacy download en_core_web_md
!python -m spacy download ru_core_news_md
!pip install navec
!pip install slovnet
!wget https://storage.yandexcloud.net/natasha-navec/packs/navec_hudlit_v1_12B_500K_300d_100q.tar

In [3]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from navec import Navec
from slovnet.model.emb import NavecEmbedding
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import os

path = 'navec_hudlit_v1_12B_500K_300d_100q.tar'
navec = Navec.load(path)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
def flat_map(f, xs):
    result = []
    for x in xs:
        result.extend(f(x))
    return result

In [50]:
df = pd.read_csv('rus.txt', sep='\t', header=None)
df.columns = ['eng', 'rus', 'meta']

In [51]:
def prepare_lang(data, nlp_lang: tuple, additional_tags):
    nlp, lang = nlp_lang
    dataset = []
    for sample in tqdm(data, desc=f'Processing {lang} data'):
        dataset.append([x.lemma_.lower() for x in nlp(sample) if  (lang == 'rus' and x.lemma_.lower() in navec) or \
                                                                  (lang == 'eng' and x.lemma_ not in '.:;,-()?!')])
    all_words = flat_map(lambda x: x, dataset)
    all_words_set = set(all_words + ['<pad>', '<unk>'] + additional_tags)
    idx2word = dict(enumerate(all_words_set))
    word2idx = {v: k for k, v in idx2word.items()}
    return dataset, word2idx, idx2word

In [52]:
import spacy

en_nlp = spacy.load('en_core_web_md',  disable=['parser', 'ner', 'textcat'])
ru_nlp = spacy.load('ru_core_news_md', disable=['parser', 'ner', 'textcat'])

In [54]:
eng_dataset, eng_word2idx, eng_idx2word = prepare_lang(df.eng.values, nlp_lang=(en_nlp, 'eng'), additional_tags=['<EOS>', '<SOS>'])
ru_dataset, ru_word2idx, ru_idx2word = prepare_lang(df.rus.values, nlp_lang=(ru_nlp, 'rus'), additional_tags=[])
ru_word2idx = {word: navec.vocab[word] for word in ru_word2idx}
ru_idx2word = {v: k for k, v in ru_word2idx.items()}

Processing eng data: 100%|█████████████| 363386/363386 [11:07<00:00, 544.13it/s]
Processing rus data: 100%|█████████████| 363386/363386 [19:52<00:00, 304.69it/s]


In [55]:
ru_dataset_encoded = []
eng_dataset_encoded = []
for ru_sample, eng_sample in zip(ru_dataset, eng_dataset):
    ru_dataset_encoded.append([
        ru_word2idx[word] for word in ru_sample
    ])
    
    eng_dataset_encoded.append([
        eng_word2idx[word] for word in ['<SOS>', *eng_sample, '<EOS>']
    ])

    
ru_dataset_encoded, eng_dataset_encoded = \
    tuple(zip(*filter(lambda ru_en: len(ru_en[0]) and len(ru_en[1]), zip(ru_dataset_encoded, eng_dataset_encoded))))


In [56]:
class Rus2Eng(nn.Module):
    def __init__(self, out_vocab_size, eng_sos_idx, eng_eos_idx, eng_pad_idx, input_size=300, hidden_size=300, bidirectional_encoder=True):
        super().__init__()
        self.ru_embeds = NavecEmbedding(navec)
        self.eng_embeds = nn.Embedding(out_vocab_size, hidden_size)
        self.ru_embeds.requires_grad = False
        self.encoder = nn.LSTM(input_size, hidden_size, batch_first=False, bidirectional=bidirectional_encoder)
        self.decoder = nn.LSTM(hidden_size, hidden_size, batch_first=False)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2, bias=True),
            nn.LeakyReLU(),
            nn.Linear(hidden_size // 2, out_vocab_size, bias=True)
        )
        self.eng_sos_idx = eng_sos_idx
        self.eng_eos_idx = eng_eos_idx
        self.eng_pad_idx = eng_pad_idx
        self.bidirectional_encoder = bidirectional_encoder
    
    def ru_embed(self, x):
        return self.ru_embeds(x)
    
    def eng_embed(self, x):
        return self.eng_embeds(x)
    
    def forward(self, x, max_steps=-1): # single sentence mode
        _, (h_n, c_n) = self.encoder(x)
        if self.bidirectional_encoder:
            h_n = h_n.mean(dim=0, keepdim=True)
            c_n = c_n.mean(dim=0, keepdim=True)
        token = self.eng_sos_idx
        step = 0
        result = []
        while token != self.eng_eos_idx and (max_steps != -1 and step < max_steps):
            print(eng_idx2word[token])
            inp = torch.tensor([
                [token]
            ]).to(device)
            inp = self.eng_embed(inp)
            out, (h_n, c_n)  = self.decoder(inp, (h_n, c_n))
            out = self.classifier(out)
            token = out.argmax().item()
            result.append(token)
            step += 1
        return result
    
    def train(self, inp, targets, criterion, use_teacher_forcing=False): # embeded input
        seq_len = targets.shape[0]
        out, (h_n, c_n) = self.encoder(inp)
        outs = []
        loss = 0
        if self.bidirectional_encoder:
            h_n = h_n.mean(dim=0, keepdim=True)
            c_n = c_n.mean(dim=0, keepdim=True)
        token = torch.full_like(self.eng_embed(targets[0:1]), self.eng_sos_idx)
        for i in range(1, seq_len):
            out, (h_n, c_n)  = self.decoder(token, (h_n, c_n))
            distribution = self.classifier(out[0])
            mask = targets[i] != self.eng_pad_idx
            # print('=========================')
            # print(list(map(lambda x: eng_idx2word[x.item()], distribution.argmax(dim=1))))
            # print(list(map(lambda x: eng_idx2word[x.item()], targets[i])))
            loss += criterion(distribution[mask], targets[i][mask])
            token = self.eng_embed(targets[i:i+1]) if use_teacher_forcing else out
        return loss

In [57]:
class Ru2EngDataset(Dataset):
    def __init__(self, ru, eng):
        self.ru = list(map(torch.tensor, ru))
        self.eng = list(map(torch.tensor, eng))
        
    def __len__(self):
        return len(self.ru)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.ru[idx], self.eng[idx]

In [182]:
dataset = Ru2EngDataset(ru_dataset_encoded, eng_dataset_encoded)

In [59]:
def pad_collate(batch):
    (xx, yy) = zip(*batch)
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx, batch_first=True, padding_value=ru_word2idx['<pad>']).transpose(1, 0)
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=eng_word2idx['<pad>']).transpose(1, 0)

    return xx_pad, yy_pad, x_lens, y_lens

data_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, collate_fn=pad_collate)

In [60]:
model = Rus2Eng(len(eng_word2idx), eng_word2idx['<SOS>'], eng_word2idx['<EOS>'], eng_word2idx['<pad>'], hidden_size=512).to(device)

In [61]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=4)
epochs = range(0, 41)

In [62]:
def mk_train(teacher_forcing=0.5):
    for epoch in epochs:
        bar = tqdm(data_loader)
        i = 0
        for ru_samp, eng_samp, ru_lens, eng_lens in bar:
            optimizer.zero_grad()
            ru_samp = ru_samp.to(device)
            eng_samp = eng_samp.to(device)

            ru_embeded = model.ru_embed(ru_samp)

            ru_packed = pack_padded_sequence(ru_embeded, ru_lens, batch_first=False, enforce_sorted=False)

            loss = model.train(ru_packed, eng_samp, criterion, use_teacher_forcing=np.random.uniform() < teacher_forcing)
            if i % 100 == 0:
                bar.set_description(f"Epoch {epoch}: " + str(loss.item()))
            i += 1
            loss.backward()
            optimizer.step()
        scheduler.step()
        if epoch % 5 == 0:
            torch.save(model.state_dict(), f'checkpoints/ru2en_{epoch}.model')

In [None]:
mk_train(0.5)

Epoch 0: 23.091445922851562: 100%|████████| 11355/11355 [04:17<00:00, 44.06it/s]
Epoch 1: 14.13554859161377: 100%|█████████| 11355/11355 [04:21<00:00, 43.48it/s]
Epoch 2: 18.159461975097656: 100%|████████| 11355/11355 [04:22<00:00, 43.29it/s]
Epoch 3: 9.35681438446045: 100%|██████████| 11355/11355 [04:20<00:00, 43.58it/s]
Epoch 4: 31.027477264404297: 100%|████████| 11355/11355 [04:23<00:00, 43.10it/s]
Epoch 5: 15.785499572753906: 100%|████████| 11355/11355 [04:21<00:00, 43.42it/s]
Epoch 6: 14.352655410766602: 100%|████████| 11355/11355 [04:20<00:00, 43.57it/s]
Epoch 7: 11.317116737365723: 100%|████████| 11355/11355 [04:23<00:00, 43.05it/s]
Epoch 8: 64.5309066772461: 100%|██████████| 11355/11355 [04:20<00:00, 43.60it/s]
Epoch 9: 26.826215744018555: 100%|████████| 11355/11355 [04:23<00:00, 43.10it/s]
Epoch 10: 21.536630630493164: 100%|███████| 11355/11355 [04:10<00:00, 45.25it/s]
Epoch 11: 7.5344977378845215: 100%|███████| 11355/11355 [04:07<00:00, 45.85it/s]
Epoch 12: 15.145207405090332

In [90]:
model.load_state_dict(torch.load('checkpoints/ru2en_40.model'))

<All keys matched successfully>

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, gamma=0.9, step_size=4)
epochs = range(41, 101)
mk_train(0.3)

In [137]:
model.load_state_dict(torch.load('checkpoints/ru2en_100.model'))

<All keys matched successfully>

In [138]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, gamma=0.9, step_size=4)
epochs = range(101, 121)
mk_train(0.0)

Epoch 101: 15.205076217651367: 100%|██████| 11355/11355 [03:52<00:00, 48.78it/s]
Epoch 102: 37.49637985229492: 100%|███████| 11355/11355 [03:43<00:00, 50.72it/s]
Epoch 103: 15.313156127929688: 100%|██████| 11355/11355 [03:50<00:00, 49.35it/s]
Epoch 104: 14.575763702392578: 100%|██████| 11355/11355 [03:50<00:00, 49.35it/s]
Epoch 105: 10.401753425598145: 100%|██████| 11355/11355 [03:50<00:00, 49.20it/s]
Epoch 106: 18.95008659362793: 100%|███████| 11355/11355 [03:47<00:00, 49.90it/s]
Epoch 107: 19.970184326171875: 100%|██████| 11355/11355 [03:48<00:00, 49.74it/s]
Epoch 108: 29.8555965423584: 100%|████████| 11355/11355 [03:45<00:00, 50.26it/s]
Epoch 109: 19.25177574157715: 100%|███████| 11355/11355 [03:48<00:00, 49.73it/s]
Epoch 110: 12.936820983886719: 100%|██████| 11355/11355 [03:46<00:00, 50.23it/s]
Epoch 111: 20.40401268005371: 100%|███████| 11355/11355 [03:46<00:00, 50.17it/s]
Epoch 112: 11.56313705444336: 100%|███████| 11355/11355 [03:53<00:00, 48.59it/s]
Epoch 113: 11.23058414459228

In [139]:
model.load_state_dict(torch.load('checkpoints/ru2en_120.model'))

<All keys matched successfully>

In [140]:
def ru2en_translate(phrase):
    with torch.no_grad():
        lemmatized, _, _ = prepare_lang([phrase], nlp_lang=(ru_nlp, 'rus'), additional_tags=[])
        ru_seq = torch.tensor([[ru_word2idx[word] for word in lemmatized[0]]])
        ru_embeded_seq = model.ru_embed(ru_seq).transpose(1, 0)
        ans = model(ru_embeded_seq, max_steps=100)
        return " ".join([eng_idx2word[x] for x in ans])

In [171]:
model.load_state_dict(torch.load('checkpoints/ru2en_40.model'))

<All keys matched successfully>

In [177]:
ru2en_translate('идти')

Processing rus data: 100%|███████████████████████| 1/1 [00:00<00:00, 347.38it/s]

<SOS>
walk





'walk <EOS>'

In [178]:
ru2en_translate('Смерть')

Processing rus data: 100%|███████████████████████| 1/1 [00:00<00:00, 273.03it/s]

<SOS>
death
"





'death " <EOS>'

In [181]:
ru2en_translate('Что ты делаешь?')

Processing rus data: 100%|███████████████████████| 1/1 [00:00<00:00, 224.02it/s]

<SOS>
you
do
not
have
to
do
it





'you do not have to do it <EOS>'