In [44]:
import requests
import re
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchtext

### Парсим имена с Википедии ###

In [6]:
URL = "https://ru.wikipedia.org/wiki/%D0%A1%D0%BF%D0%B8%D1%81%D0%BE%D0%BA_%D0%B8%D0%BC%D1%91%D0%BD_%D1%81%D0%BB%D0%B0%D0%B2%D1%8F%D0%BD%D1%81%D0%BA%D0%BE%D0%B3%D0%BE_%D0%BF%D1%80%D0%BE%D0%B8%D1%81%D1%85%D0%BE%D0%B6%D0%B4%D0%B5%D0%BD%D0%B8%D1%8F#:~:text=%D0%9E%D1%81%D0%BD%D0%BE%D0%B2%D0%BD%D1%8B%D0%B5%20%D0%B2%D0%B8%D0%B4%D1%8B%20%D1%81%D0%BB%D0%B0%D0%B2%D1%8F%D0%BD%D1%81%D0%BA%D0%B8%D1%85%20%D0%B8%D0%BC%D1%91%D0%BD%3A,)%D1%88%D0%B0%2C%20%D0%9F%D1%83%D1%82%D1%8F%D1%82%D0%B0%20%D0%B8%20%D1%82."
r = requests.get(URL)
page = r.content.decode("utf-8")

In [232]:
names = list(map(lambda x: x[1:-1], re.findall(r'"[А-Я][а-я]+"', page)[21:-2]))
names[:5]

['Акамир', 'Окомир', 'Славяне', 'Верзиты', 'Беотия']

### Датасет ###

In [102]:
class SlavDataset(Dataset):
    def __init__(self, names):
        self.max_len = max(map(len, names)) + 2
        self.names = list(map(self.name_to_list, names))
        self.specials = ['<BOS>', '<EOS>', '<PAD>']
        cnt = dict()
        for name in names:
            for l in name.lower():
                cnt[l] = cnt.get(l, 0) + 1
        
        self.vocab = torchtext.vocab.vocab(cnt, specials=self.specials)
        self.itos = self.vocab.get_itos()


    def name_to_list(self, name):
        lst = ['<BOS>']
        lst.extend(list(name.lower()))
        lst.append('<EOS>')
        lst.extend(['<PAD>']*(self.max_len - len(name) - 2))
        return lst

    def transform_prefix(self, prefix):
        lst = ['<BOS>'] + list(prefix.lower())
        return self.encode(lst)

    def __len__(self):
        return len(self.names)

    def encode(self, name):
        return torch.tensor([self.vocab[l] for l in name])

    def decode(self, seq):
        return ''.join([self.itos[idx] for idx in seq if self.itos[idx] not in self.specials]).capitalize()

    def __getitem__(self, item):
        return self.encode(self.names[item])

In [233]:
dataset = SlavDataset(names)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

### Модель ###

In [72]:
class SlavNet(nn.Module):
    def __init__(self, input_size=34, hidden_size=128):
        super(SlavNet, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        #self.emb = nn.Embedding(len(dataset.vocab), len(dataset.vocab), dataset.vocab['<PAD>'])
        self.RNN = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, input_size)
        
    def forward(self, X): # Скрытое состояние тоже неплохо было бы передавать
        output, h_n = self.RNN(torch.eye(self.input_size)[X])
        y = self.linear(output)
        return y

### Обучение ###

In [234]:
model = SlavNet(len(dataset.vocab), 128)

torch.manual_seed(0)
EPOCHS = 100

loss_function = nn.CrossEntropyLoss(ignore_index=dataset.vocab['<PAD>'])
optim = torch.optim.Adam(model.parameters(), lr=10e-5)

for epoch in range(EPOCHS):
    if epoch % 10 == 0:
        model.train(False)
        with torch.no_grad():
            loss = 0
            for batch in dataloader:
                y_pred = model(batch)
                target = batch[:, 1:]
                target = torch.cat([target, torch.tensor(dataset.vocab['<PAD>']).expand(batch.shape[0], 1)], dim=1)
                loss = loss_function(y_pred.transpose(1, 2), target)
        print(f'epoch: {epoch}\tloss: {loss}')

    model.train(True)
    for batch in dataloader:
        optim.zero_grad()
        
        y_pred = model(batch)
        target = batch[:, 1:]
        target = torch.cat([target, torch.tensor(dataset.vocab['<PAD>']).expand(batch.shape[0], 1)], dim=1)
        loss = loss_function(y_pred.transpose(1, 2), target)
        loss.backward()
        optim.step()


if epoch % 10 == 0:
    model.train(False)
    with torch.no_grad():
        loss = 0
        for batch in dataloader:
            y_pred = model(batch)
            target = batch[:, 1:]
            target = torch.cat([target, torch.tensor(dataset.vocab['<PAD>']).expand(batch.shape[0], 1)], dim=1)
            loss = loss_function(y_pred.transpose(1, 2), target)
    print(f'epoch: {epoch}\tloss: {loss}')

epoch: 0	loss: 3.534287929534912
epoch: 10	loss: 2.72800612449646
epoch: 20	loss: 2.651874303817749
epoch: 30	loss: 2.431600570678711
epoch: 40	loss: 2.7108309268951416
epoch: 50	loss: 2.454420328140259
epoch: 60	loss: 2.096609592437744
epoch: 70	loss: 2.257362127304077
epoch: 80	loss: 1.94295334815979
epoch: 90	loss: 2.517194986343384


### Генерация имен ###

In [238]:
prefix = 'Ко'
torch.manual_seed(0)
code = dataset.transform_prefix(prefix)
with torch.no_grad():
    while code[-1] != dataset.vocab['<EOS>'] and len(code) < 100:
        output = model(code)[-1, :]
        y = torch.argmax(output)
        # y = torch.topk(output, k=2).indices[torch.randint(2, (1, ))]
        code = torch.cat([code, torch.tensor([y])], dim=0)

dataset.decode(code)

'Коромир'