In [1]:
import re
import typing as t
from collections import defaultdict
from pathlib import Path

import nltk
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader, Subset, random_split

In [2]:
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/jovyan/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/jovyan/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/jovyan/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [3]:
DATA_DIR = Path("data/")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {DEVICE.upper()} device")

Using CUDA device


In [4]:
def on_cuda(device: str) -> bool:
    return device == "cuda"


def common_train(
        model: nn.Module,
        loss_fn: nn.Module,
        optimizer: optim.Optimizer,
        train_dataloader: DataLoader,
        epochs: int,
        test_dataloader: DataLoader = None,
        verbose: int = 100,
        on_epoch_end: t.Callable[[], None] = None,
        device: str = "cpu",
) -> t.List[float]:
    train_losses = []
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}\n" + "-" * 32)
        train_loss = train_loop(
            train_dataloader,
            model,
            loss_fn,
            optimizer,
            verbose=verbose,
            device=device,
        )
        train_losses.append(train_loss.item())

        if test_dataloader:
            test_loop(test_dataloader, model, loss_fn, device=device)

        if on_epoch_end:
            on_epoch_end()

        print()
        torch.cuda.empty_cache()
    return train_losses


def train_loop(
        dataloader: DataLoader,
        model: nn.Module,
        loss_fn: nn.Module,
        optimizer: optim.Optimizer,
        verbose: int = 100,
        device: str = "cpu",
) -> torch.Tensor:
    model.train()

    size = len(dataloader.dataset)  # noqa
    num_batches = len(dataloader)
    avg_loss = 0

    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss
        if batch % verbose == 0:
            print(f"loss: {loss:>7f}  [{batch * len(x):>5d}/{size:>5d}]")

        del x, y, pred, loss
        torch.cuda.empty_cache()

    return avg_loss / num_batches


@torch.no_grad()
def test_loop(
        dataloader: DataLoader,
        model: nn.Module,
        loss_fn: nn.Module,
        device: str = "cpu",
) -> t.Tuple[torch.Tensor, torch.Tensor]:
    model.eval()

    avg_loss, num_batches = 0, len(dataloader)
    correct, total = 0, 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        avg_loss += loss_fn(pred, y)

        y_test = torch.flatten(y)
        y_pred = torch.flatten(pred.argmax(1))
        total += y_test.size(0)
        correct += (y_pred == y_test).sum()  # noqa

        del x, y, pred
        torch.cuda.empty_cache()

    avg_loss /= num_batches
    accuracy = correct / total
    print(f"Test Error: \n"
          f"\tAccuracy: {accuracy:>4f}, Loss: {avg_loss:>8f}")

    return avg_loss, accuracy


def train_test_split(dataset: t.Union[Dataset, t.Sized], train_part: float) -> t.Tuple[Subset, Subset]:
    train_size = round(train_part * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, lengths=(train_size, test_size))
    return train_dataset, test_dataset


@torch.no_grad()
def get_y_test_y_pred(
        model: nn.Module,
        test_dataloader: DataLoader,
        device: str = "cpu",
) -> t.Tuple[torch.Tensor, torch.Tensor]:
    model.eval()

    y_test = []
    y_pred = []
    for x, y in test_dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        y_test.append(y)
        y_pred.append(pred)

        del x
        torch.cuda.empty_cache()

    return torch.flatten(torch.vstack(y_test).detach().cpu()), torch.flatten(torch.vstack(y_pred).detach().cpu())

## 1. –ì–µ–Ω–µ—Ä–∏—Ä–æ–≤–∞–Ω–∏–µ —Ä—É—Å—Å–∫–∏—Ö –∏–º–µ–Ω –ø—Ä–∏ –ø–æ–º–æ—â–∏ RNN

–î–∞—Ç–∞—Å–µ—Ç: https://disk.yandex.ru/i/2yt18jHUgVEoIw

1.1 –ù–∞ –æ—Å–Ω–æ–≤–µ —Ñ–∞–π–ª–∞ name_rus.txt —Å–æ–∑–¥–∞–π—Ç–µ –¥–∞—Ç–∞—Å–µ—Ç.
  * –£—á—Ç–∏—Ç–µ, —á—Ç–æ –∏–º–µ–Ω–∞ –º–æ–≥—É—Ç –∏–º–µ—Ç—å —Ä–∞–∑–ª–∏—á–Ω—É—é –¥–ª–∏–Ω—É
  * –î–æ–±–∞–≤—å—Ç–µ 4 —Å–ø–µ—Ü–∏–∞–ª—å–Ω—ã—Ö —Ç–æ–∫–µ–Ω–∞:
    * `<PAD>` –¥–ª—è –¥–æ–ø–æ–ª–Ω–µ–Ω–∏—è –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –¥–æ –Ω—É–∂–Ω–æ–π –¥–ª–∏–Ω—ã;
    * `<UNK>` –¥–ª—è –∫–æ—Ä—Ä–µ–∫—Ç–Ω–æ–π –æ–±—Ä–∞–±–æ—Ç–∫–∏ —Ä–∞–Ω–µ–µ –Ω–µ –≤—Å—Ç—Ä–µ—á–∞–≤—à–∏—Ö—Å—è —Ç–æ–∫–µ–Ω–æ–≤;
    * `<SOS>` –¥–ª—è –æ–±–æ–∑–Ω–∞—á–µ–Ω–∏—è –Ω–∞—á–∞–ª–∞ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏;
    * `<EOS>` –¥–ª—è –æ–±–æ–∑–Ω–∞—á–µ–Ω–∏—è –∫–æ–Ω—Ü–∞ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏.
  * –ü—Ä–µ–æ–±—Ä–∞–∑–æ–≤—ã–≤–∞–π—Ç–µ —Å—Ç—Ä–æ–∫—É –≤ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å –∏–Ω–¥–µ–∫—Å–æ–≤ —Å —É—á–µ—Ç–æ–º —Å–ª–µ–¥—É—é—â–∏—Ö –∑–∞–º–µ—á–∞–Ω–∏–π:
    * –≤ –Ω–∞—á–∞–ª–æ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –¥–æ–±–∞–≤—å—Ç–µ —Ç–æ–∫–µ–Ω `<SOS>`;
    * –≤ –∫–æ–Ω–µ—Ü –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –¥–æ–±–∞–≤—å—Ç–µ —Ç–æ–∫–µ–Ω `<EOS>` –∏, –ø—Ä–∏ –Ω–µ–æ–±—Ö–æ–¥–∏–º–æ—Å—Ç–∏, –Ω–µ—Å–∫–æ–ª—å–∫–æ —Ç–æ–∫–µ–Ω–æ–≤ `<PAD>`;
  * `Dataset.__get_item__` –≤–æ–∑—Ä–∞—â–∞–µ—Ç –¥–≤–µ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏: –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å –¥–ª—è –æ–±—É—á–µ–Ω–∏—è –∏ –ø—Ä–∞–≤–∏–ª—å–Ω—ã–π –æ—Ç–≤–µ—Ç.

  –ü—Ä–∏–º–µ—Ä:
  ```
  s = 'The cat sat on the mat'
  # –ø—Ä–µ–æ–±—Ä–∞–∑—É–µ–º –≤ –∏–Ω–¥–µ–∫—Å—ã
  s_idx = [2, 5, 1, 2, 8, 4, 7, 3, 0, 0]
  # –ø–æ–ª—É—á–∞–µ–º x –∏ y (__getitem__)
  x = [2, 5, 1, 2, 8, 4, 7, 3, 0]
  y = [5, 1, 2, 8, 4, 7, 3, 0, 0]
  ```


–ë—É–¥–µ–º –ø—Ä–µ–¥—Å–∫–∞–∑—ã–≤–∞—Ç—å –∫–∞–∂–¥—É—é —Å–ª–µ–¥—É—é—â—É—é –±—É–∫–≤—É –≤ –∏–º–µ–Ω–∏:

In [5]:
class NamesVocab:
    PAD = "<PAD>"
    PAD_IDX = 0
    UNK = "<UNK>"
    UNK_IDX = 1
    SOS = "<SOS>"
    SOS_IDX = 2
    EOS = "<EOS>"
    EOS_IDX = 3

    def __init__(self, names: t.List[str]):
        uniques = set()
        max_len = 0
        for name in map(str.lower, names):
            uniques.update(name)
            max_len = max(len(name), max_len)

        self.alphabet = [self.PAD, self.UNK, self.SOS, self.EOS, *uniques]
        self.max_len = max_len + 2  # –º–µ—Å—Ç–æ –¥–ª—è <SOS> –∏ <EOS>

        ch2i = {ch: i for i, ch in enumerate(self.alphabet)}
        self.ch2i = defaultdict(lambda: self.UNK_IDX, ch2i)

    def __len__(self):
        return len(self.alphabet)

    def encode(self, name: str, shift: bool = False) -> torch.Tensor:
        # —É–ª—É—á—à–µ–Ω–Ω—ã–π –º–µ—Ç–æ–¥ –∫–æ–¥–∏—Ä–æ–≤–∞–Ω–∏—è
        # —É—Å–ª–æ–∂–Ω–µ–Ω–Ω—ã–π —Å–¥–≤–∏–≥ –ø–æ–∑–≤–æ–ª—è–µ—Ç —Å–æ—Ö—Ä–∞–Ω–∏—Ç—å –ø–µ—Ä–≤—ã–π –∏ –ø–æ—Å–ª–µ–¥–Ω–∏–π —Å–∏–º–≤–æ–ª –∏—Å—Ö–æ–¥–Ω–æ–≥–æ —Å–ª–æ–≤–∞
        name = [*name, self.EOS]
        if not shift:
            name = [self.SOS, *name]
        indices = [self.ch2i[ch] for ch in name]
        indices += [self.PAD_IDX] * (self.max_len - len(indices))
        return torch.tensor(indices, dtype=torch.long)

    def decode(self, indices: torch.Tensor) -> str:
        pad_indices = torch.nonzero(indices == self.ch2i[self.PAD], as_tuple=True)[0]
        if len(pad_indices):
            indices = indices[:pad_indices[0]]
        return "".join(self.alphabet[i] for i in indices)


class NamesDataset:
    names: t.List[str]
    vocab: NamesVocab
    data: torch.Tensor
    targets: torch.Tensor

    def __init__(self, path: Path):
        self.names = self.read_names(path)
        self.vocab = NamesVocab(self.names)

        self.data = torch.vstack([self.encode(name, shift=False) for name in self.names])
        self.targets = torch.vstack([self.encode(name, shift=True) for name in self.names])

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    @staticmethod
    def read_names(path: Path) -> t.List[str]:
        with open(path, encoding="cp1251") as f:
            return list(map(lambda s: s.strip().lower(), f))

    def encode(self, name: str, shift: bool = False) -> torch.Tensor:
        return self.vocab.encode(name, shift=shift)

    def decode(self, vector: torch.Tensor) -> str:
        return self.vocab.decode(vector)

In [6]:
names_dataset = NamesDataset(DATA_DIR / "name_rus.txt")
print(f"n: {len(names_dataset)}")
(names_dataset.names[0], *names_dataset[0])

n: 1988


('–∞–≤–¥–æ–∫–µ—è',
 tensor([ 2, 26, 13, 20, 16,  7, 29, 22,  3,  0,  0,  0,  0,  0,  0]),
 tensor([26, 13, 20, 16,  7, 29, 22,  3,  0,  0,  0,  0,  0,  0,  0]))

–¢–∞–∫–æ–π –º–µ—Ç–æ–¥ –∫–æ–¥–∏—Ä–æ–≤–∞–Ω–∏—è –ø–æ–∑–≤–æ–ª—è–µ—Ç —Å–æ—Ö—Ä–∞–Ω–∏—Ç—å –Ω–∞ –æ–¥–Ω—É –±—É–∫–≤—É –±–æ–ª—å—à–µ, —á–µ–º –ø—Ä–µ–¥–ª–æ–∂–µ–Ω–Ω—ã–π –≤ –∑–∞–¥–∞–Ω–∏–∏ - —Ç–µ—Ä—è–µ–º `<SOS>`, –Ω–æ —Å–æ—Ö—Ä–∞–Ω—è–µ–º –ø–µ—Ä–≤—ã–π –∏ –ø–æ—Å–ª–µ–¥–Ω–∏–π —Å–∏–º–≤–æ–ª

In [7]:
torch.manual_seed(0)

train_names_dataset, test_names_dataset = train_test_split(names_dataset, train_part=0.8)
print(len(train_names_dataset), len(test_names_dataset))

1590 398


1.2 –°–æ–∑–¥–∞–π—Ç–µ –∏ –æ–±—É—á–∏—Ç–µ –º–æ–¥–µ–ª—å –¥–ª—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Ñ–∞–º–∏–ª–∏–∏.

  * –î–ª—è –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏—è –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –∏–Ω–¥–µ–∫—Å–æ–≤ –≤ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å –≤–µ–∫—Ç–æ—Ä–æ–≤ –∏—Å–ø–æ–ª—å–∑—É–π—Ç–µ `nn.Embedding`;
  * –ò—Å–ø–æ–ª—å–∑—É–π—Ç–µ —Ä–µ–∫—É—Ä—Ä–µ–Ω—Ç–Ω—ã–µ —Å–ª–æ–∏;
  * –ó–∞–¥–∞—á–∞ —Å—Ç–∞–≤–∏—Ç—Å—è –∫–∞–∫ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ —Å–ª–µ–¥—É—é—â–µ–≥–æ —Ç–æ–∫–µ–Ω–∞ –≤ –∫–∞–∂–¥–æ–º –ø—Ä–∏–º–µ—Ä–µ –∏–∑ –ø–∞–∫–µ—Ç–∞ –¥–ª—è –∫–∞–∂–¥–æ–≥–æ –º–æ–º–µ–Ω—Ç–∞ –≤—Ä–µ–º–µ–Ω–∏. –¢.–µ. –≤ –¥–∞–Ω–Ω—ã–π –º–æ–º–µ–Ω—Ç –≤—Ä–µ–º–µ–Ω–∏ –ø–æ —Ç–µ–∫—É—â–µ–π –ø–æ–¥—Å—Ç—Ä–æ–∫–µ –ø—Ä–µ–¥—Å–∫–∞–∑—ã–≤–∞–µ—Ç —Å–ª–µ–¥—É—é—â–∏–π —Å–∏–º–≤–æ–ª –¥–ª—è –¥–∞–Ω–Ω–æ–π —Å—Ç—Ä–æ–∫–∏ (–∑–∞–¥–∞—á–∞ –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏);
  * –ü—Ä–∏–º–µ—Ä–Ω–∞—è —Å—Ö–µ–º–∞ —Ä–µ–∞–ª–∏–∑–∞—Ü–∏–∏ –º–µ—Ç–æ–¥–∞ `forward`:
  ```
    input_X: [batch_size x seq_len] -> nn.Embedding -> emb_X: [batch_size x seq_len x embedding_size]
    emb_X: [batch_size x seq_len x embedding_size] -> nn.RNN -> output: [batch_size x seq_len x hidden_size]
    output: [batch_size x seq_len x hidden_size] -> torch.Tensor.reshape -> output: [batch_size * seq_len x hidden_size]
    output: [batch_size * seq_len x hidden_size] -> nn.Linear -> output: [batch_size * seq_len x vocab_size]
  ```

1.3 –ù–∞–ø–∏—à–∏—Ç–µ —Ñ—É–Ω–∫—Ü–∏—é, –∫–æ—Ç–æ—Ä–∞—è –≥–µ–Ω–µ—Ä–∏—Ä—É–µ—Ç —Ñ–∞–º–∏–ª–∏—é –ø—Ä–∏ –ø–æ–º–æ—â–∏ –æ–±—É—á–µ–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏:
  * –ü–æ—Å—Ç—Ä–æ–µ–Ω–∏–µ –Ω–∞—á–∏–Ω–∞–µ—Ç—Å—è —Å –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –µ–¥–∏–Ω–∏—á–Ω–æ–π –¥–ª–∏–Ω—ã, —Å–æ—Å—Ç–æ—è—â–µ–π –∏–∑ –∏–Ω–¥–µ–∫—Å–∞ —Ç–æ–∫–µ–Ω–∞ `<SOS>`;
  * –ù–∞—á–∞–ª—å–Ω–æ–µ —Å–∫—Ä—ã—Ç–æ–µ —Å–æ—Å—Ç–æ—è–Ω–∏–µ RNN `h_t = None`;
  * –í —Ä–µ–∑—É–ª—å—Ç–∞—Ç–µ –ø—Ä–æ–≥–æ–Ω–∞ –ø–æ—Å–ª–µ–¥–Ω–µ–≥–æ —Ç–æ–∫–µ–Ω–∞ –∏–∑ –ø–æ—Å—Ç—Ä–æ–µ–Ω–Ω–æ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ —á–µ—Ä–µ–∑ –º–æ–¥–µ–ª—å –ø–æ–ª—É—á–∞–µ—Ç–µ –Ω–æ–≤–æ–µ —Å–∫—Ä—ã—Ç–æ–µ —Å–æ—Å—Ç–æ—è–Ω–∏–µ `h_t` –∏ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ –Ω–∞–¥ –≤—Å–µ–º–∏ —Ç–æ–∫–µ–Ω–∞–º–∏ –∏–∑ —Å–ª–æ–≤–∞—Ä—è;
  * –í—ã–±–∏—Ä–∞–µ—Ç–µ 1 —Ç–æ–∫–µ–Ω –ø—Ä–æ–ø–æ—Ä—Ü–∏–æ–Ω–∞–ª—å–Ω–æ –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç–∏ –∏ –¥–æ–±–∞–≤–ª—è–µ—Ç–µ –µ–≥–æ –≤ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å (–º–æ–∂–Ω–æ –≤–æ—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å—Å—è `torch.multinomial`);
  * –ü–æ–≤—Ç–æ—Ä—è–µ—Ç–µ —ç—Ç–∏ –¥–µ–π—Å—Ç–≤–∏—è –¥–æ —Ç–µ—Ö –ø–æ—Ä, –ø–æ–∫–∞ –Ω–µ —Å–≥–µ–Ω–µ—Ä–∏—Ä–æ–≤–∞–Ω —Ç–æ–∫–µ–Ω `<EOS>` –∏–ª–∏ –Ω–µ –ø—Ä–µ–≤—ã—à–µ–Ω–∞ –º–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è –¥–ª–∏–Ω–∞ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏.

–ü—Ä–∏ –æ–±—É—á–µ–Ω–∏–∏ –∫–∞–∂–¥—ã–µ `k` —ç–ø–æ—Ö –≥–µ–Ω–µ—Ä–∏—Ä—É–π—Ç–µ –Ω–µ—Å–∫–æ–ª—å–∫–æ —Ñ–∞–º–∏–ª–∏–π –∏ –≤—ã–≤–æ–¥–∏—Ç–µ –∏—Ö –Ω–∞ —ç–∫—Ä–∞–Ω.

In [8]:
class NamesRNNGenerator(nn.Module):
    _STATE_T = t.Union[t.Optional[torch.Tensor], t.Optional[t.Tuple[torch.Tensor, torch.Tensor]]]
    rnn_state: _STATE_T

    def __init__(
            self,
            num_embeddings: int,
            embedding_dim: int,
            rnn_hidden_size: int,
            rnn_cls: t.Union[t.Type[nn.RNN], t.Type[nn.LSTM], t.Type[nn.GRU]],
    ):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, padding_idx=0)
        self.rnn = rnn_cls(input_size=embedding_dim, hidden_size=rnn_hidden_size)
        self.fc = nn.Sequential(
            nn.Linear(rnn_hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, num_embeddings),
        )
        self.reset_rnn_state()

    def reset_rnn_state(self):
        self.rnn_state = None

    def keep_rnn_state(self, state: _STATE_T):
        if isinstance(self.rnn, nn.LSTM):  # –æ—Ç–¥–µ–ª—å–Ω–∞—è –æ–±—Ä–∞–±–æ—Ç–∫–∞ —Å–∫—Ä—ã—Ç–æ–≥–æ —Å–æ—Å—Ç–æ—è–Ω–∏—è nn.LSTM
            self.rnn_state = state[0].detach(), state[1].detach()
        else:
            self.rnn_state = state.detach()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)

        x, rnn_state = self.rnn(x, self.rnn_state)
        self.keep_rnn_state(rnn_state)

        x = self.fc(x)
        # —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏ –æ—Ç–ª–∏—á–∞—é—Ç—Å—è –æ—Ç —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–µ–π –≤ –∑–∞–¥–∞–Ω–∏–∏:
        # [batch_size x –∫–æ–ª-–≤–æ —Å–∏–º–≤–æ–ª–æ–≤ x –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç–∏ –¥–ª—è –∫–∞–∂–¥–æ–≥–æ —Å–∏–º–≤–æ–ª–∞]
        # CrossEntropyLoss —É–º–µ–µ—Ç —Ç–∞–∫
        return x.permute(0, 2, 1)

In [9]:
# —á–µ—Å—Ç–Ω–∞—è –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å
def true_prob(pred: torch.Tensor) -> torch.Tensor:
    pred -= pred.min()
    return pred / pred.sum()


# –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å —á–µ—Ä–µ–∑ softmax - —ç—Ç–æ –Ω–µ –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å, —Å–∏–ª—å–Ω—ã–π —Å–∫–æ—Å –≤ —Å—Ç–æ—Ä–æ–Ω—É –±–æ–ª—å—à–µ–≥–æ –∑–Ω–∞—á–µ–Ω–∏—è
def softmax_prob(pred: torch.Tensor) -> torch.Tensor:
    return torch.softmax(pred, 0)


def generate_name(
        model: NamesRNNGenerator,
        dataset: NamesDataset,
        prompt="",
        prob: t.Callable[[torch.Tensor], torch.Tensor] = None,
        device: str = "cpu",
) -> str:
    len_start = len(prompt)
    name = dataset.encode(prompt).to(device)
    name[len_start + 1] = 0  # –∑–∞–º–µ–Ω—è–µ–º <EOS> –Ω–∞ <PAD>

    model.eval()
    model.reset_rnn_state()
    for i in range(name.size(0) - len_start - 2):
        # –∑–¥–µ—Å—å –ø–µ—Ä–µ–¥–∞–µ–º –≤—Å–µ —Å–ª–æ–≤–æ, –∞ –Ω–µ –ø–æ 1-–æ–π –±—É–∫–≤–µ
        # –Ω–µ –∑—Ä—è –∂–µ –º–æ–¥–µ–ª—å —É—á–∏–ª–∏ —Å–ª–æ–≤–∞–º–∏, –∞ –Ω–µ –±—É–∫–≤–∞–º–∏...
        pred = model(name.unsqueeze(0)).squeeze()[:, len_start + i]
        if prob:  # —Å–ª—É—á–∞–π–Ω–æ—Å—Ç—å
            next_ch_idx = torch.multinomial(prob(pred), 1)
        else:  # —á–µ—Å—Ç–Ω–æ–µ –≤–∑—è—Ç–∏–µ –ª—É—á—à–µ–≥–æ –≤–∞—Ä–∏–∞–Ω—Ç–∞
            next_ch_idx = pred.argmax()

        if next_ch_idx == NamesVocab.EOS_IDX:
            break
        name[len_start + i + 1] = next_ch_idx

    return dataset.decode(name).replace(NamesVocab.SOS, "")


def on_epoch_end_generate_names(
        model: NamesRNNGenerator,
        dataset: NamesDataset,
) -> t.Callable[[], None]:
    def _on_epoch_end() -> None:
        # —á–µ—Å—Ç–Ω–æ–µ –≤–∑—è—Ç–∏–µ –ª—É—á—à–µ–≥–æ –≤–∞—Ä–∏–∞–Ω—Ç–∞
        const = generate_name(model, dataset, device=DEVICE)
        # —Å–ª—É—á–∞–π–Ω–æ–µ –≤–∑—è—Ç–∏–µ –Ω–∞ –æ—Å–Ω–æ–≤–µ –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç–∏
        true_random = generate_name(model, dataset, prob=true_prob, device=DEVICE)
        # —Å–ª—É—á–∞–π–Ω–æ–µ –≤–∑—è—Ç–∏–µ –Ω–∞ softmax –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏—è
        softmax_random = generate_name(model, dataset, prob=softmax_prob, device=DEVICE)
        print(f"\tNames: {const} (max), {true_random} (prob), {softmax_random} (softmax)")

    return _on_epoch_end

In [10]:
torch.manual_seed(0)

names_gen_net = NamesRNNGenerator(
    num_embeddings=len(names_dataset.vocab),
    embedding_dim=8,  # –¥–ª—è —Å–∏–º–≤–æ–ª–æ–≤ –±–æ–ª—å—à–∏–µ embedding'–∏ –Ω–µ –Ω—É–∂–Ω—ã (–Ω–∞–≤–µ—Ä–Ω–æ–µ)
    rnn_hidden_size=64,
    rnn_cls=nn.LSTM,
).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(names_gen_net.parameters(), lr=0.001)

train_dataloader = DataLoader(train_names_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_names_dataset, batch_size=128)

In [11]:
%%time

_ = common_train(
    epochs=30,
    model=names_gen_net,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    verbose=50,
    on_epoch_end=on_epoch_end_generate_names(names_gen_net, names_dataset),
    device=DEVICE,
)

Epoch 1
--------------------------------
loss: 3.494831  [    0/ 1590]
Test Error: 
	Accuracy: 0.529816, Loss: 1.912179
	Names:  (max), –∞–æ–∞ (prob), –ª–ø—ç—å–≥ (softmax)

Epoch 2
--------------------------------
loss: 1.987921  [    0/ 1590]
Test Error: 
	Accuracy: 0.578559, Loss: 1.577202
	Names: –∞ (max), —Ç–∞–ª–µ (prob), —ã (softmax)

Epoch 3
--------------------------------
loss: 1.672587  [    0/ 1590]
Test Error: 
	Accuracy: 0.608710, Loss: 1.442052
	Names: –∞ (max), —É—Ö–π—á (prob), —Å (softmax)

Epoch 4
--------------------------------
loss: 1.637822  [    0/ 1590]
Test Error: 
	Accuracy: 0.639531, Loss: 1.313473
	Names: –∞ (max), —Ä—É–µ—å—Ö–∏–ª—á–∑–∏—á–π (prob), –¥–∞–±—é—à—è–π (softmax)

Epoch 5
--------------------------------
loss: 1.303850  [    0/ 1590]
Test Error: 
	Accuracy: 0.664824, Loss: 1.211651
	Names: –∞ (max), –Ω—é—é–æ (prob), –≥—Ç—Å—ã—è (softmax)

Epoch 6
--------------------------------
loss: 1.203625  [    0/ 1590]
Test Error: 
	Accuracy: 0.677052, Loss: 1.14

In [12]:
y_test, y_pred = get_y_test_y_pred(names_gen_net, test_dataloader, DEVICE)

print(metrics.classification_report(
    y_true=y_test,
    y_pred=y_pred,
    target_names=[names_dataset.vocab.alphabet[i] for i in y_test.unique().sort()[0]],
    zero_division=True,
))

              precision    recall  f1-score   support

       <PAD>       1.00      1.00      1.00      3110
       <EOS>       0.77      0.94      0.84       398
           —É       1.00      0.02      0.03        62
           —Ñ       1.00      0.00      0.00        12
           –±       1.00      0.00      0.00        15
           –∫       0.21      0.11      0.15        98
           –ª       0.50      0.11      0.18       160
           —Å       0.28      0.07      0.11       104
           —à       0.27      0.34      0.30        56
           –ø       1.00      0.00      0.00        36
           –∏       0.21      0.23      0.22       183
           –≤       0.14      0.03      0.04        77
           —Ç       0.36      0.07      0.11       118
           —ã       1.00      0.00      0.00        20
           –æ       0.35      0.09      0.14        77
           —ç       1.00      0.00      0.00         6
           –º       1.00      0.00      0.00        82
           —

–ú–æ–∂–Ω–æ –ø–æ—Ä–∞–¥–æ–≤–∞—Ç—å—Å—è –∑–∞ –º–æ–¥–µ–ª—å - –æ–Ω–∞ –Ω–∞—É—á–∏–ª–∞—Å—å —É—Å–ø–µ—à–Ω–æ –ø—Ä–µ–¥—Å–∫–∞–∑—ã–≤–∞—Ç—å `<PAD>` –∏ `<EOS>`. –≠—Ç–æ –±—ã–ª–∞ –Ω–µ—Å–ª–æ–∂–Ω–∞—è –∑–∞–¥–∞—á–∞

In [119]:
print(generate_name(names_gen_net, names_dataset, prompt="—é–Ω", device=DEVICE))
print(generate_name(names_gen_net, names_dataset, prompt="—é–Ω", prob=true_prob, device=DEVICE))
print(generate_name(names_gen_net, names_dataset, prompt="—é–Ω", prob=softmax_prob, device=DEVICE))

—é–Ω–∞—Ç–∞
—é–Ω–∞—à–∞
—é–Ω–∞


–ß—Ç–æ —ç—Ç–æ –µ—Å–ª–∏ –Ω–µ —É–±–∏–π—Ü–∞ ChatGPT?

## 2. –ì–µ–Ω–µ—Ä–∏—Ä–æ–≤–∞–Ω–∏–µ —Ç–µ–∫—Å—Ç–∞ –ø—Ä–∏ –ø–æ–º–æ—â–∏ RNN

2.1 –°–∫–∞—á–∞–π—Ç–µ –∏–∑ –∏–Ω—Ç–µ—Ä–Ω–µ—Ç–∞ –∫–∞–∫–æ–µ-–Ω–∏–±—É–¥—å —Ö—É–¥–æ–∂–µ—Å—Ç–≤–µ–Ω–Ω–æ–µ –ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ
  * –í—ã–±–∏—Ä–∞–π—Ç–µ –¥–æ—Å—Ç–∞—Ç–æ—á–Ω–æ –∫—Ä—É–ø–Ω–æ–µ –ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ, —á—Ç–æ–±—ã –º–æ–¥–µ–ª—å –ª—É—á—à–µ –æ–±—É—á–∞–ª–∞—Å—å;

2.2 –ù–∞ –æ—Å–Ω–æ–≤–µ –≤—ã–±—Ä–∞–Ω–Ω–æ–≥–æ –ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏—è —Å–æ–∑–¥–∞–π—Ç–µ –¥–∞—Ç–∞—Å–µ—Ç. 

–û—Ç–ª–∏—á–∏—è –æ—Ç –∑–∞–¥–∞—á–∏ 1:
  * –¢–æ–∫–µ–Ω—ã `<SOS>`, `<EOS>` –∏ `<UNK>` –º–æ–∂–Ω–æ –Ω–µ –¥–æ–±–∞–≤–ª—è—Ç—å;
  * –ü—Ä–∏ —Å–æ–∑–¥–∞–Ω–∏–∏ –¥–∞—Ç–∞—Å–µ—Ç–∞ —Ç–µ–∫—Å—Ç –Ω–µ–æ–±—Ö–æ–¥–∏–º–æ –ø—Ä–µ–¥–≤–∞—Ä–∏—Ç–µ–ª—å–Ω–æ —Ä–∞–∑–±–∏—Ç—å –Ω–∞ —á–∞—Å—Ç–∏. –í—ã–±–µ—Ä–∏—Ç–µ –∂–µ–ª–∞–µ–º—É—é –¥–ª–∏–Ω—É –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ `seq_len` –∏ —Ä–∞–∑–±–µ–π—Ç–µ —Ç–µ–∫—Å—Ç –Ω–∞ –ø–æ—Å—Ç—Ä–æ–∫–∏ –¥–ª–∏–Ω—ã `seq_len` (–º–æ–∂–Ω–æ –±–µ–∑ –ø–µ—Ä–µ–∫—Ä—ã—Ç–∏—è, –º–æ–∂–Ω–æ —Å –Ω–µ–±–æ–ª—å—à–∏–º –ø–µ—Ä–µ–∫—Ä—ã—Ç–∏–µ–º).

–í –∫–∞—á–µ—Å—Ç–≤–µ –¥–∞—Ç–∞—Å–µ—Ç–∞ –∏—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —Ç–µ–∫—Å—Ç "–ê–Ω–Ω–∞ –ö–∞—Ä–µ–Ω–∏–Ω–∞"

In [120]:
class TextVocab:
    PAD = "<PAD>"
    PAD_IDX = 0
    UNK = "<UNK>"
    UNK_IDX = 1

    def __init__(self, seqs: t.List[str]):
        uniques = set()
        max_len = 0
        for seq in map(str.lower, seqs):
            uniques.update(seq)
            max_len = max(len(seq), max_len)

        self.alphabet = [self.PAD, self.UNK, *uniques]
        self.max_len = max_len

        ch2i = {ch: i for i, ch in enumerate(self.alphabet)}
        self.ch2i = defaultdict(lambda: self.UNK_IDX, ch2i)

    def __len__(self):
        return len(self.alphabet)

    def encode(self, seq: str) -> torch.Tensor:
        indices = [self.ch2i[ch] for ch in seq]
        indices += [self.PAD_IDX] * (self.max_len - len(indices))
        return torch.tensor(indices, dtype=torch.long)

    def decode(self, indices: torch.Tensor) -> str:
        pad_indices = torch.nonzero(indices == self.ch2i[self.PAD], as_tuple=True)[0]
        if len(pad_indices):
            indices = indices[:pad_indices[0]]
        return "".join(self.alphabet[i] for i in indices)


class TextDataset:
    seqs: t.List[str]
    vocab: TextVocab
    data: torch.Tensor
    targets: torch.Tensor

    def __init__(self, path: Path, window: int, overlap: int = 0):
        self.seqs = self.read_seqs(path, window=window, overlap=overlap)
        self.vocab = TextVocab(self.seqs)

        self.data = torch.vstack([self.encode(seq[:-1]) for seq in self.seqs])
        self.targets = torch.vstack([self.encode(seq[1:]) for seq in self.seqs])

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    @staticmethod
    def read_seqs(path: Path, window: int, overlap: int = 0) -> t.List[str]:
        with open(path, encoding="cp1251") as f:
            text = list(map(lambda s: s.strip().lower(), f))

        text = " ".join(text)
        text = re.sub(r"[^–∞-—è—ë]", repl=" ", string=text)
        text = " ".join(text.split())  # –∏–∑–±–∞–≤–ª—è–µ–º—Å—è –æ—Ç –¥–ª–∏–Ω–Ω—ã—Ö –ø—Ä–æ–±–µ–ª—å–Ω—ã—Ö –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–µ–π

        seqs = []
        for i in range(0, len(text), window):
            # –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏ –¥–ª–∏–Ω—ã window —Å –ø–µ—Ä–µ–∫—Ä—ã—Ç–∏–µ–º overlap —Å –æ–±–æ–∏—Ö —Å—Ç–æ—Ä–æ–Ω
            seqs.append(text[i:i + window + overlap])

        return seqs

    def encode(self, seq: str) -> torch.Tensor:
        return self.vocab.encode(seq)

    def decode(self, indices: torch.Tensor) -> str:
        return self.vocab.decode(indices)

In [121]:
text_dataset = TextDataset(DATA_DIR / "anna_karenina.txt", window=64, overlap=4)
print(f"n: {len(text_dataset)}")
(text_dataset.seqs[0], *text_dataset[0])

n: 25099


('–ª–µ–≤ –Ω–∏–∫–æ–ª–∞–µ–≤–∏—á —Ç–æ–ª—Å—Ç–æ–π –∞–Ω–Ω–∞ –∫–∞—Ä–µ–Ω–∏–Ω–∞ –º–Ω–µ –æ—Ç–º—â–µ–Ω–∏–µ –∏ –∞–∑ –≤–æ–∑–¥–∞–º —á–∞—Å—Ç—å ',
 tensor([ 7, 30, 13,  6, 31, 12,  8, 16,  7, 26, 30, 13, 12, 19,  6, 14, 16,  7,
          9, 14, 16, 24,  6, 26, 31, 31, 26,  6,  8, 26, 29, 30, 31, 12, 31, 26,
          6, 18, 31, 30,  6, 16, 14, 18,  2, 30, 31, 12, 30,  6, 12,  6, 26, 20,
          6, 13, 16, 20, 21, 26, 18,  6, 19, 26,  9, 14, 34,  0]),
 tensor([30, 13,  6, 31, 12,  8, 16,  7, 26, 30, 13, 12, 19,  6, 14, 16,  7,  9,
         14, 16, 24,  6, 26, 31, 31, 26,  6,  8, 26, 29, 30, 31, 12, 31, 26,  6,
         18, 31, 30,  6, 16, 14, 18,  2, 30, 31, 12, 30,  6, 12,  6, 26, 20,  6,
         13, 16, 20, 21, 26, 18,  6, 19, 26,  9, 14, 34,  6,  0]))

In [122]:
text_dataset.seqs[:10]

['–ª–µ–≤ –Ω–∏–∫–æ–ª–∞–µ–≤–∏—á —Ç–æ–ª—Å—Ç–æ–π –∞–Ω–Ω–∞ –∫–∞—Ä–µ–Ω–∏–Ω–∞ –º–Ω–µ –æ—Ç–º—â–µ–Ω–∏–µ –∏ –∞–∑ –≤–æ–∑–¥–∞–º —á–∞—Å—Ç—å ',
 '—Å—Ç—å –ø–µ—Ä–≤–∞—è –≤—Å–µ —Å—á–∞—Å—Ç–ª–∏–≤—ã–µ —Å–µ–º—å–∏ –ø–æ—Ö–æ–∂–∏ –¥—Ä—É–≥ –Ω–∞ –¥—Ä—É–≥–∞ –∫–∞–∂–¥–∞—è –Ω–µ—Å—á–∞—Å—Ç–ª',
 '–∞—Å—Ç–ª–∏–≤–∞—è —Å–µ–º—å—è –Ω–µ—Å—á–∞—Å—Ç–ª–∏–≤–∞ –ø–æ —Å–≤–æ–µ–º—É –≤—Å–µ —Å–º–µ—à–∞–ª–æ—Å—å –≤ –¥–æ–º–µ –æ–±–ª–æ–Ω—Å–∫–∏—Ö ',
 '–∫–∏—Ö –∂–µ–Ω–∞ —É–∑–Ω–∞–ª–∞ —á—Ç–æ –º—É–∂ –±—ã–ª –≤ —Å–≤—è–∑–∏ —Å –±—ã–≤—à–µ—é –≤ –∏—Ö –¥–æ–º–µ —Ñ—Ä–∞–Ω—Ü—É–∂–µ–Ω–∫–æ—é ',
 '–∫–æ—é –≥—É–≤–µ—Ä–Ω–∞–Ω—Ç–∫–æ–π –∏ –æ–±—ä—è–≤–∏–ª–∞ –º—É–∂—É —á—Ç–æ –Ω–µ –º–æ–∂–µ—Ç –∂–∏—Ç—å —Å –Ω–∏–º –≤ –æ–¥–Ω–æ–º –¥–æ–º',
 ' –¥–æ–º–µ –ø–æ–ª–æ–∂–µ–Ω–∏–µ —ç—Ç–æ –ø—Ä–æ–¥–æ–ª–∂–∞–ª–æ—Å—å —É–∂–µ —Ç—Ä–µ—Ç–∏–π –¥–µ–Ω—å –∏ –º—É—á–∏—Ç–µ–ª—å–Ω–æ —á—É–≤—Å—Ç–≤',
 '–≤—Å—Ç–≤–æ–≤–∞–ª–æ—Å—å –∏ —Å–∞–º–∏–º–∏ —Å—É–ø—Ä—É–≥–∞–º–∏ –∏ –≤—Å–µ–º–∏ —á–ª–µ–Ω–∞–º–∏ —Å–µ–º—å–∏ –∏ –¥–æ–º–æ—á–∞–¥—Ü–∞–º–∏ –≤',
 '–º–∏ –≤—Å–µ —á–ª–µ–Ω—ã —Å–µ–º—å–∏ –∏ –¥–æ–º–æ—á–∞–¥—Ü—ã —á—É–≤—Å—Ç–≤–æ–≤–∞–ª–∏ —á—Ç–æ –Ω–

In [123]:
torch.manual_seed(0)

train_text_dataset, test_text_dataset = train_test_split(text_dataset, train_part=0.8)
print(len(train_text_dataset), len(test_text_dataset))

20079 5020


2.3 –°–æ–∑–¥–∞–π—Ç–µ –∏ –æ–±—É—á–∏—Ç–µ –º–æ–¥–µ–ª—å –¥–ª—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Ç–µ–∫—Å—Ç–∞
  * –ó–∞–¥–∞—á–∞ —Å—Ç–∞–≤–∏—Ç—Å—è —Ç–æ—á–Ω–æ —Ç–∞–∫ –∂–µ –∫–∞–∫ –≤ 1.2;
  * –ü—Ä–∏ –Ω–µ–æ–±—Ö–æ–¥–∏–º–æ—Å—Ç–∏ –º–æ–∂–µ—Ç–µ –ø—Ä–∏–º–µ–Ω–∏—Ç—å:
    * –¥–≤—É—Ö—É—Ä–æ–≤–Ω–µ–≤—ã–µ —Ä–µ–∫—É—Ä—Ä–µ–Ω—Ç–Ω—ã–µ —Å–ª–æ–∏ (`num_layers`=2)
    * [–æ–±—Ä–µ–∑–∫—É –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html)


In [124]:
class TextRNNGenerator(nn.Module):
    _STATE_T = t.Union[t.Optional[torch.Tensor], t.Optional[t.Tuple[torch.Tensor, torch.Tensor]]]
    rnn_state: _STATE_T

    def __init__(
            self,
            num_embeddings: int,
            embedding_dim: int,
            rnn_hidden_size: int,
            rnn_cls: t.Union[t.Type[nn.RNN], t.Type[nn.LSTM], t.Type[nn.GRU]],
    ):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, padding_idx=0)
        # –ø—Ä–∏–º–µ–Ω–∏—Ç—å –¥–≤—É—Ö—É—Ä–æ–≤–Ω–µ–≤—ã–µ —Ä–µ–∫—É—Ä—Ä–µ–Ω—Ç–Ω—ã–µ —Å–ª–æ–∏ –º–æ–∂–Ω–æ,
        # –±—ã–ª–æ –±—ã —Ö–æ—Ä–æ—à–æ, –µ—Å–ª–∏ —ç—Ç–æ –∫–∞–∫-—Ç–æ –≤–ª–∏—è–ª–æ –Ω–∞ —Ç–æ—á–Ω–æ—Å—Ç—å –Ω—É —Ö–æ—Ç—å —á—É—Ç—å-—á—É—Ç—å
        self.rnn = rnn_cls(input_size=embedding_dim, hidden_size=rnn_hidden_size, num_layers=2, dropout=0.25)
        self.fc = nn.Sequential(
            nn.Linear(rnn_hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, num_embeddings),
        )
        self.reset_rnn_state()

    def reset_rnn_state(self):
        self.rnn_state = None

    def keep_rnn_state(self, state: _STATE_T):
        if isinstance(self.rnn, nn.LSTM):
            self.rnn_state = state[0].detach(), state[1].detach()
        else:
            self.rnn_state = state.detach()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)

        x, rnn_state = self.rnn(x, self.rnn_state)
        self.keep_rnn_state(rnn_state)

        x = self.fc(x)
        return x.permute(0, 2, 1)

In [125]:
torch.manual_seed(0)

text_gen_net = TextRNNGenerator(
    num_embeddings=len(text_dataset.vocab),
    embedding_dim=16,
    rnn_hidden_size=64,
    rnn_cls=nn.LSTM,
).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(text_gen_net.parameters(), lr=0.001)

train_dataloader = DataLoader(train_text_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_text_dataset, batch_size=512)

In [126]:
%%time

_ = common_train(
    epochs=10,
    model=text_gen_net,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    verbose=200,
    device=DEVICE,
)

Epoch 1
--------------------------------
loss: 3.563401  [    0/20079]
loss: 2.622643  [12800/20079]
Test Error: 
	Accuracy: 0.256237, Loss: 2.550842

Epoch 2
--------------------------------
loss: 2.585614  [    0/20079]
loss: 2.513025  [12800/20079]
Test Error: 
	Accuracy: 0.267832, Loss: 2.458726

Epoch 3
--------------------------------
loss: 2.482423  [    0/20079]
loss: 2.491698  [12800/20079]
Test Error: 
	Accuracy: 0.269589, Loss: 2.429040

Epoch 4
--------------------------------
loss: 2.458759  [    0/20079]
loss: 2.475520  [12800/20079]
Test Error: 
	Accuracy: 0.270081, Loss: 2.422082

Epoch 5
--------------------------------
loss: 2.470808  [    0/20079]
loss: 2.416837  [12800/20079]
Test Error: 
	Accuracy: 0.270081, Loss: 2.419302

Epoch 6
--------------------------------
loss: 2.436810  [    0/20079]
loss: 2.431002  [12800/20079]
Test Error: 
	Accuracy: 0.269844, Loss: 2.417486

Epoch 7
--------------------------------
loss: 2.424822  [    0/20079]
loss: 2.452827  [12800/

–ù—É —Ç—É—Ç –ª–∏–±–æ –ø–æ—Å—Ç–∞–Ω–æ–≤–∫–∞ –∑–∞–¥–∞—á–∏ –Ω–µ–≤—ã–ø–æ–ª–Ω–∏–º–∞—è, –ª–∏–±–æ –º–æ–¥–µ–ª—å –Ω–µ –æ–±—É—á–∞–µ—Ç—Å—è

In [127]:
y_test, y_pred = get_y_test_y_pred(text_gen_net, test_dataloader, DEVICE)

print(metrics.classification_report(
    y_true=y_test,
    y_pred=y_pred,
    target_names=[text_dataset.vocab.alphabet[i] for i in y_test.unique().sort()[0]],
    zero_division=True,
))

              precision    recall  f1-score   support

       <PAD>       1.00      1.00      1.00      5020
           —â       1.00      0.00      0.00       807
           —É       1.00      0.00      0.00      7684
           —Ñ       1.00      0.00      0.00       289
           –±       1.00      0.00      0.00      4962
                   0.32      0.79      0.46     56390
           –ª       1.00      0.00      0.00     14217
           –∫       1.00      0.00      0.00      9621
           —Å       0.10      0.37      0.15     14917
           –ø       1.00      0.00      0.00      6787
           —à       1.00      0.00      0.00      2446
           –∏       1.00      0.00      0.00     18485
           –≤       1.00      0.00      0.00     13055
           —Ç       0.29      0.35      0.32     16653
           —ã       0.29      0.28      0.28      5149
           –æ       0.30      0.47      0.37     32323
           —ç       1.00      0.00      0.00      1008
           –

2.4 –ù–∞–ø–∏—à–∏—Ç–µ —Ñ—É–Ω–∫—Ü–∏—é, –∫–æ—Ç–æ—Ä–∞—è –≥–µ–Ω–µ—Ä–∏—Ä—É–µ—Ç —Ñ—Ä–∞–≥–º–µ–Ω—Ç —Ç–µ–∫—Å—Ç–∞ –ø—Ä–∏ –ø–æ–º–æ—â–∏ –æ–±—É—á–µ–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏
  * –ü—Ä–æ—Ü–µ—Å—Å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –Ω–∞—á–∏–Ω–∞–µ—Ç—Å—è —Å –Ω–µ–±–æ–ª—å—à–æ–≥–æ —Ñ—Ä–∞–≥–º–µ–Ω—Ç–∞ —Ç–µ–∫—Å—Ç–∞ `prime`, –≤—ã–±—Ä–∞–Ω–Ω–æ–≥–æ –≤–∞–º–∏ (1-2 —Å–ª–æ–≤–∞)
  * –°–Ω–∞—á–∞–ª–∞ –≤—ã –ø—Ä–æ–ø—É—Å–∫–∞–µ—Ç–µ —á–µ—Ä–µ–∑ –º–æ–¥–µ–ª—å —Ç–æ–∫–µ–Ω—ã –∏–∑ `prime` –∏ –≥–µ–Ω–µ—Ä–∏—Ä—É–µ—Ç–µ –Ω–∞ –∏—Ö –æ—Å–Ω–æ–≤–µ —Å–∫—Ä—ã—Ç–æ–µ —Å–æ—Å—Ç–æ—è–Ω–∏–µ —Ä–µ–∫—É—Ä—Ä–µ–Ω—Ç–Ω–æ–≥–æ —Å–ª–æ—è `h_t`;
  * –ü–æ—Å–ª–µ —ç—Ç–æ–≥–æ –≤—ã –≥–µ–Ω–µ—Ä–∏—Ä—É–µ—Ç–µ —Å—Ç—Ä–æ–∫—É –Ω—É–∂–Ω–æ–π –¥–ª–∏–Ω—ã –∞–Ω–∞–ª–æ–≥–∏—á–Ω–æ 1.3

In [128]:
def generate_text(
        model: TextRNNGenerator,
        dataset: TextDataset,
        prompt: str,  # —Å—Ç–∞—Ä—Ç–æ–≤–∞—è —Å—Ç—Ä–æ–∫–∞
        size: int,  # –ª—é–±–∞—è –¥–ª–∏–Ω–∞ –≥–µ–Ω–µ—Ä–∏—Ä—É–µ–º–æ–≥–æ —Ç–µ–∫—Å—Ç–∞
        prob: t.Callable[[torch.Tensor], torch.Tensor] = None,
        device: str = "cpu",
) -> str:
    text = [dataset.vocab.ch2i[ch] for ch in prompt]

    model.eval()
    model.reset_rnn_state()
    for i in range(size - len(text)):
        x = torch.tensor([text[i:]], device=device)
        pred = model(x).squeeze()[:, -1]
        if prob:
            next_ch_idx = torch.multinomial(prob(pred), 1)
        else:
            next_ch_idx = pred.argmax()
        text.append(next_ch_idx.item())

    return dataset.decode(torch.tensor(text))

In [129]:
for prompt in [
    "–¥–æ–±—Ä–æ–µ —É—Ç—Ä–æ ",
    "–¥–æ–±—Ä—ã–π –≤–µ—á–µ—Ä ",
    "–Ω–∞ –∫–∞–∂–¥–æ–º –ø–æ—Å—Ç–æ—è–ª–æ–º –¥–≤–æ—Ä–µ ",
    "–¥–µ–ª–æ —à–ª–æ –æ —Ç–æ–º ",
    "–∫–∞–∫ —Ç–æ–ª—å–∫–æ ",
    "–±–µ–ª–∞—è –±–µ—Ä–µ–∑–∞ ",
]:
    print(generate_text(text_gen_net, text_dataset, prompt, 90, prob=softmax_prob, device=DEVICE))

–¥–æ–±—Ä–æ–µ —É—Ç—Ä–æ –∂–¥–∞ –µ–º –µ–Ω–∞–ª–∞–∏ –∫–∏—Ü—É –±–ª—Å–µ—Ç–∞—é –ø–æ—Å–ª–µ —Å—Ç—å—Å–∫—Ä–æ–º–æ–ª –µ–ª–∞ –æ–¥–≤—Å–Ω–µ–µ–µ–Ω–∏–¥–µ–≤—ã–µ—Å–µ–≤–æ–∂–µ –æ–¥ –∏–∫—Å —Ç
–¥–æ–±—Ä—ã–π –≤–µ—á–µ—Ä –≥–æ–º–∞–∫–æ–≤—à–∏—Å–∫ –∫–∞ –∑ –ø—Ä–µ–¥–µ–≥–æ –∏–Ω—É–±–µ—Å—Ç–≤–µ—Ä–∏ –º—ã—Ç —Ç—É–ø–µ–ø—Ä–µ—Ç–æ–≤–æ—Å —Ç–µ —Å–º –Ω—ã—Ç–ª–æ—Å–∫–æ–∫–Ω–∞–π –∞–∑–∞—Å
–Ω–∞ –∫–∞–∂–¥–æ–º –ø–æ—Å—Ç–æ—è–ª–æ–º –¥–≤–æ—Ä–µ —Å–ø—Ä–æ–≤–∏–ø–æ–ª—é –≤–æ–Ω–æ–±–ª–æ –ø–æ–¥—Ä–∏–Ω—É —Ñ—Ä–µ –¥–µ–ª—è –≤ –∞—Ç–æ—Ç–æ –µ–º–µ –∫–∞ –¥—å —Å—Ç–≤–∏ –¥—É –≤—ã
–¥–µ–ª–æ —à–ª–æ –æ —Ç–æ–º –± –Ω—Å –∫–∞–Ω–Ω –ø–æ –µ–π –µ —Ç–∞—Ä—É —á–∫—Ä–∏—è –ø—É–≤ –º–æ–Ω—Ç—Ä–∏ —á–µ–≤—Ä–æ–≤–æ–π –Ω–µ –∏—Ç–æ –ø—Ä–µ–Ω–∞–∫–∞ —ç—Ç –≤–æ–∂–∞—Ç–∞ –º
–∫–∞–∫ —Ç–æ–ª—å–∫–æ —Ä–æ –æ–ª–∞ –∑–∞–∑–∞ –≤—ã–∫—Å—Ç–∞ –ø–æ–µ—Ä—ã–µ–ø—Ä–∏ –±–∫–æ–¥—Ä—É —Å—Ç–∏ –∏–π –Ω—ã—Ç–æ–¥–æ—à—å–∑–Ω –ø—Ä–µ –æ –≤—ã—Ç—É–ª—é –ª–µ–Ω —É—Å–¥–µ–ª—è–ª—é
–±–µ–ª–∞—è –±–µ—Ä–µ–∑–∞ –≥–æ–ª –∂—á–µ —è—à—å –æ–Ω–µ–ª–æ–¥—Ä–≥–æ —Å—Ç–∞–º–æ –º–æ –ø—É –∏–Ω–∏–ª—å –∫–æ —Å—Ç–æ–ª—é —á–∏—Ä—ã–ª–æ –æ–≤—ã–ª–∏–µ–¥–ª –∫ —Å–∏—Ç–æ—Ç–µ–Ω–æ–ª—É


–ê —ç—Ç–æ "–£–±–∏–π—Ü–∞ ChatGPT 2.0"

> As we can see, the generated text may not make any sense, however there are some words and phrases that seem to form an idea, for example...

[–ò—Å—Ç–æ—á–Ω–∏–∫](https://towardsdatascience.com/text-generation-with-bi-lstm-in-pytorch-5fda6e7cc22c#:~:text=As%20we%20can%20see%2C%20the%20generated%20text%20may%20not%20make%20any%20sense%2C%20however%20there%20are%20some%20words%20and%20phrases%20that%20seem%20to%20form%20an%20idea%2C%20for%20example%3A)

–ò–∑ —Ö–æ—Ä–æ—à–µ–≥–æ, –º–æ–¥–µ–ª—å –Ω–∞—É—á–∏–ª–∞—Å—å –≤–æ–≤—Ä–µ–º—è —Ä–∞—Å—Å—Ç–∞–≤–ª—è—Ç—å –ø—Ä–æ–±–µ–ª—ã - —Å–æ—á–µ—Ç–∞–Ω–∏—è –±—É–∫–≤ –ø–æ –¥–ª–∏–Ω–∞–º –ø–æ—Ö–æ–∂–∏ –Ω–∞ —Å–ª–æ–≤–∞ (—Å–∞–º–æ—Å—Ç–æ—è—Ç–µ–ª—å–Ω—ã–µ –∏ –ø—Ä–µ–¥–ª–æ–≥–∏ - –µ—Å–ª–∏ —Å–∏–ª—å–Ω–æ —Ñ–∞–Ω—Ç–∞–∑–∏—Ä–æ–≤–∞—Ç—å ü•≥)