# Сравнение RNN, LSTM и GRU: генерация


In [1]:
import sys
from pathlib import Path
import numpy as np

root = Path('.').resolve()
sys.path.append(str(root / '03 RNN'))
sys.path.append(str(root / '04 LSTM'))
sys.path.append(str(root / '05 GRU'))

from rnn import ElmanRNN
from lstm import LSTMNextTokenGenerator
from gru import GRUNextTokenGenerator


In [2]:
def build_vocab(text: str):
    chars = sorted(set(text))
    char2idx = {ch: i for i, ch in enumerate(chars)}
    idx2char = {i: ch for ch, i in char2idx.items()}
    return char2idx, idx2char

def encode_text(text: str, char2idx: dict[str, int]) -> np.ndarray:
    return np.array([char2idx[ch] for ch in text], dtype=int)

def to_one_hot(indices: np.ndarray, vocab_size: int) -> np.ndarray:
    T = len(indices)
    one_hot = np.zeros((T, vocab_size), dtype=float)
    one_hot[np.arange(T), indices] = 1.0
    return one_hot

def build_sequences_indices(indices: np.ndarray, seq_len: int, vocab_size: int):
    X_list = []
    y_list = []
    for start in range(len(indices) - seq_len):
        x_idx = indices[start:start + seq_len]
        y_idx = indices[start + 1:start + seq_len + 1]
        X_list.append(to_one_hot(x_idx, vocab_size))
        y_list.append(y_idx)
    return np.stack(X_list), np.stack(y_list)

def build_sequences_onehot(indices: np.ndarray, seq_len: int, vocab_size: int):
    X_list = []
    y_list = []
    for start in range(len(indices) - seq_len):
        x_idx = indices[start:start + seq_len]
        y_idx = indices[start + 1:start + seq_len + 1]
        X_list.append(to_one_hot(x_idx, vocab_size))
        y_list.append(to_one_hot(y_idx, vocab_size))
    return X_list, y_list

def decode_indices(indices: np.ndarray, idx2char: dict[int, str]) -> str:
    return ''.join(idx2char[i] for i in indices)


In [3]:
base_text = (
    'to be or not to be that is the question whether tis nobler in the mind '
    'to suffer the slings and arrows of outrageous fortune '
    'or to take arms against a sea of troubles and by opposing end them '
    'to die to sleep no more and by a sleep to say we end the heartache '
    'and the thousand natural shocks that flesh is heir to '
)

repeats = 15
text = base_text * repeats

seq_len = 20

char2idx, idx2char = build_vocab(text)
vocab_size = len(char2idx)
indices = encode_text(text, char2idx)

split = int(len(indices) * 0.8)
train_idx = indices[:split]
test_idx = indices[split:]

X_train, y_train = build_sequences_indices(train_idx, seq_len, vocab_size)
X_test, y_test = build_sequences_indices(test_idx, seq_len, vocab_size)

X_train_rnn, y_train_rnn = build_sequences_onehot(train_idx, seq_len, vocab_size)
X_test_rnn, y_test_rnn = build_sequences_onehot(test_idx, seq_len, vocab_size)
y_test_rnn_idx = [y.argmax(axis=1) for y in y_test_rnn]


In [4]:
def eval_next_char_accuracy_lm(model, X_eval, y_eval_idx):
    probs = model.predict_proba(X_eval)
    preds = np.argmax(probs, axis=2)
    correct = (preds == y_eval_idx).sum()
    total = y_eval_idx.size
    return correct / total

def eval_next_char_accuracy_rnn(model, X_eval_list, y_eval_idx_list):
    preds_list = model.predict(X_eval_list)
    correct = 0
    total = 0
    for y_idx, preds in zip(y_eval_idx_list, preds_list):
        correct += np.sum(preds == y_idx)
        total += len(y_idx)
    return correct / total


In [5]:
rnn = ElmanRNN(
    input_size=vocab_size,
    hidden_size=128,
    output_size=vocab_size,
    lr=1e-3,
    epochs=20,
    grad_clip=5.0,
    random_state=42,
    verbose=False,
)
rnn.fit(X_train_rnn, y_train_rnn)
rnn_acc = eval_next_char_accuracy_rnn(rnn, X_test_rnn, y_test_rnn_idx)
print(f'RNN test accuracy: {rnn_acc:.4f}')


RNN test accuracy: 0.8493


In [6]:
lstm = LSTMNextTokenGenerator(
    input_dim=vocab_size,
    hidden_dim=128,
    vocab_size=vocab_size,
    lr=1e-3,
    max_epochs=20,
    random_state=42,
    verbose=False,
)
lstm.fit(X_train, y_train)
lstm_acc = eval_next_char_accuracy_lm(lstm, X_test, y_test)
print(f'LSTM test accuracy: {lstm_acc:.4f}')


LSTM test accuracy: 0.8980


In [7]:
gru = GRUNextTokenGenerator(
    input_dim=vocab_size,
    hidden_dim=128,
    vocab_size=vocab_size,
    lr=1e-3,
    max_epochs=20,
    random_state=42,
    verbose=False,
)
gru.fit(X_train, y_train)
gru_acc = eval_next_char_accuracy_lm(gru, X_test, y_test)
print(f'GRU test accuracy: {gru_acc:.4f}')


GRU test accuracy: 0.9347


In [8]:
def sample_lstm_like(model, seed_text: str, steps: int = 200):
    prefix = seed_text
    rng = np.random.default_rng(123)
    for _ in range(steps):
        window = prefix[-seq_len:]
        x_idx = np.array([char2idx[ch] for ch in window])
        x_vec = to_one_hot(x_idx, vocab_size)
        probs = model.predict_next_proba(x_vec)
        next_idx = rng.choice(vocab_size, p=probs)
        prefix += idx2char[next_idx]
    return prefix

def sample_rnn(model, seed_text: str, steps: int = 200):
    prefix = seed_text
    rng = np.random.default_rng(123)
    for _ in range(steps):
        window = prefix[-seq_len:]
        x_idx = np.array([char2idx[ch] for ch in window])
        x_vec = to_one_hot(x_idx, vocab_size)
        probs_seq = model.predict_proba([x_vec])[0]
        probs = probs_seq[-1]
        next_idx = rng.choice(vocab_size, p=probs)
        prefix += idx2char[next_idx]
    return prefix

print('RNN sample:')
print(sample_rnn(rnn, seed_text='to be or not to be '))

print('LSTM sample:')
print(sample_lstm_like(lstm, seed_text='to be or not to be '))

print('GRU sample:')
print(sample_lstm_like(gru, seed_text='to be or not to be '))


RNN sample:
to be or not to be that ns sriis not wheep iiir to dor to sayawe and tha tiousand na more and by a aleip to eoke and bata aleshinp tofe and mhe heal ou os fs npblanrt ag ouerpinst ogeep to be mrar ouming ar wn the mfae 
LSTM sample:
to be or not to be that is the question whe era oh the toro ne ta de that flesh is heir to to be or not to be that is the question whether to to be or not to be that is the question whe rling ondsh that flesh an whearmi
GRU sample:
to be or not to be that is the question whether tis nobler in the mind to suffer the slings and arrows of outrageous fortune or to take arms against a sea of troubles and by opposing end them to die to sleep no more and
