In [1]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
with open("author_quotes.txt", 'r') as f:
  text = f.readlines()

In [4]:
quotes = list(map(lambda x: x.replace('\n', ''), text))

# Добавляем начальный пробел для обучения
quotes = [' ' + q for q in quotes]

In [5]:
quotes[:2]

[' If you live to be a hundred, I want to live to be a hundred minus one day so I never have to live without you.',
 " Promise me you'll always remember: You're braver than you believe, and stronger than you seem, and smarter than you think."]

In [6]:
chars = sorted(set(''.join(quotes)))
PAD_TOKEN = '<PAD>'
chars.append(PAD_TOKEN)
chars2idx = {char: idx for idx, char in enumerate(chars)}
idx2chars = {idx: char for idx, char in enumerate(chars)}

vocab_size = len(chars2idx)
print(f"Размер словаря: {vocab_size}")

Размер словаря: 86


In [7]:
def to_matrix(lines, token_to_id, max_len=None, pad_value=None):
    if pad_value is None:
        pad_value = token_to_id[PAD_TOKEN]  # используем <PAD>
    max_len = max_len or max(map(len, lines))
    data = np.full([len(lines), max_len], pad_value, dtype=np.int64)
    for i, line in enumerate(lines):
        encoded = [token_to_id[c] for c in line[:max_len]]
        data[i, :len(encoded)] = encoded
    return data

print("Пример закодированной строки:")
print(to_matrix(quotes[:2], chars2idx))

Пример закодированной строки:
[[ 0 38 61  0 80 70 76  0 67 64 77 60  0 75 70  0 57 60  0 56  0 63 76 69
  59 73 60 59 12  0 38  0 78 56 69 75  0 75 70  0 67 64 77 60  0 75 70  0
  57 60  0 56  0 63 76 69 59 73 60 59  0 68 64 69 76 74  0 70 69 60  0 59
  56 80  0 74 70  0 38  0 69 60 77 60 73  0 63 56 77 60  0 75 70  0 67 64
  77 60  0 78 64 75 63 70 76 75  0 80 70 76 14 85 85 85 85 85 85 85 85 85
  85 85 85]
 [ 0 45 73 70 68 64 74 60  0 68 60  0 80 70 76  7 67 67  0 56 67 78 56 80
  74  0 73 60 68 60 68 57 60 73 26  0 54 70 76  7 73 60  0 57 73 56 77 60
  73  0 75 63 56 69  0 80 70 76  0 57 60 67 64 60 77 60 12  0 56 69 59  0
  74 75 73 70 69 62 60 73  0 75 63 56 69  0 80 70 76  0 74 60 60 68 12  0
  56 69 59  0 74 68 56 73 75 60 73  0 75 63 56 69  0 80 70 76  0 75 63 64
  69 66 14]]


In [8]:
class CharRNNLoop(nn.Module):
    def __init__(self, vocab_size=vocab_size, emb_size=16, rnn_num_units=32):
        super(CharRNNLoop, self).__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.rnn = nn.LSTM(emb_size, rnn_num_units, batch_first=True)
        self.hid_to_logits = nn.Linear(rnn_num_units, vocab_size)

    def forward(self, x):
        x_emb = self.emb(x)
        h_seq, _ = self.rnn(x_emb)
        next_logits = self.hid_to_logits(h_seq)
        return next_logits

In [12]:
from torch.utils.data import Dataset, DataLoader

class QuoteDataset(Dataset):
    def __init__(self, quotes, chars2idx, max_len):
        self.quotes = quotes
        self.chars2idx = chars2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.quotes)

    def __getitem__(self, idx):
        quote = self.quotes[idx]
        encoded = [self.chars2idx[c] for c in quote[:self.max_len]]
        padded = encoded + [self.chars2idx[PAD_TOKEN]] * (self.max_len - len(encoded))
        return torch.tensor(padded, dtype=torch.long)

MAX_LENGTH = 50  # ограничиваем максимальную длину цитаты
batch_size = 8   # маленький батч для экономии памяти

dataset = QuoteDataset(quotes, chars2idx, MAX_LENGTH)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CharRNNLoop(vocab_size=len(chars), emb_size=128, rnn_num_units=256).to(device)
pad_index = chars2idx[PAD_TOKEN]
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [15]:
from tqdm import tqdm

# 🏋️‍♂️ Шаг 8: Обучение модели
num_epochs = 50
patience = 0
max_patience = 10
best_loss = 100
early_stopping_rounds = False

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch = batch.to(device)
        logits = model(batch)
        predictions = logits[:, :-1]
        targets = batch[:, 1:]

        loss = criterion(predictions.reshape(-1, len(chars)), targets.reshape(-1))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
    mean_loss = total_loss / len(loader)
    if best_loss > mean_loss:
      patience = 0
      best_loss = mean_loss
      best_model = model
      print(f"New best loss: {best_loss}!\n")

    else:
      patience += 1
      print(f"Loss: {mean_loss}\nPatience: {patience}\n")
    if patience == max_patience:
      print(f"Early stopping rounds!")
      early_stopping_rounds = True
      break

Epoch 1/50: 100%|██████████| 4521/4521 [00:16<00:00, 266.49it/s]


New best loss: 1.2769215507360718!



Epoch 2/50: 100%|██████████| 4521/4521 [00:16<00:00, 266.77it/s]


New best loss: 1.2536371403419715!



Epoch 3/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.84it/s]


New best loss: 1.2357092139020058!



Epoch 4/50: 100%|██████████| 4521/4521 [00:17<00:00, 263.80it/s]


New best loss: 1.2216193251735314!



Epoch 5/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.48it/s]


New best loss: 1.2099758637107016!



Epoch 6/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.96it/s]


New best loss: 1.19954785144954!



Epoch 7/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.12it/s]


New best loss: 1.19068850413067!



Epoch 8/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.19it/s]


New best loss: 1.1826423880827688!



Epoch 9/50: 100%|██████████| 4521/4521 [00:17<00:00, 260.94it/s]


New best loss: 1.1762466671271177!



Epoch 10/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.82it/s]


New best loss: 1.1701721907242082!



Epoch 11/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.71it/s]


New best loss: 1.16446144138915!



Epoch 12/50: 100%|██████████| 4521/4521 [00:16<00:00, 268.69it/s]


New best loss: 1.1595219701851556!



Epoch 13/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.55it/s]


New best loss: 1.1547360723326097!



Epoch 14/50: 100%|██████████| 4521/4521 [00:17<00:00, 261.57it/s]


New best loss: 1.1508257070412495!



Epoch 15/50: 100%|██████████| 4521/4521 [00:16<00:00, 269.79it/s]


New best loss: 1.1471515029728137!



Epoch 16/50: 100%|██████████| 4521/4521 [00:17<00:00, 265.85it/s]


New best loss: 1.1436944634111232!



Epoch 17/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.83it/s]


New best loss: 1.140278328614698!



Epoch 18/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.52it/s]


New best loss: 1.1374309056761307!



Epoch 19/50: 100%|██████████| 4521/4521 [00:17<00:00, 263.82it/s]


New best loss: 1.1345482800049582!



Epoch 20/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.46it/s]


New best loss: 1.1320669743846512!



Epoch 21/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.20it/s]


New best loss: 1.1294115370972129!



Epoch 22/50: 100%|██████████| 4521/4521 [00:16<00:00, 268.19it/s]


New best loss: 1.1275752156715797!



Epoch 23/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.13it/s]


New best loss: 1.1253633437049095!



Epoch 24/50: 100%|██████████| 4521/4521 [00:17<00:00, 262.07it/s]


New best loss: 1.1234191094476336!



Epoch 25/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.97it/s]


New best loss: 1.1213621421214675!



Epoch 26/50: 100%|██████████| 4521/4521 [00:16<00:00, 266.22it/s]


New best loss: 1.1199019031170274!



Epoch 27/50: 100%|██████████| 4521/4521 [00:17<00:00, 265.68it/s]


New best loss: 1.1183072654870305!



Epoch 28/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.77it/s]


New best loss: 1.1166991434828208!



Epoch 29/50: 100%|██████████| 4521/4521 [00:17<00:00, 262.16it/s]


New best loss: 1.1155386730464942!



Epoch 30/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.68it/s]


New best loss: 1.114394588482061!



Epoch 31/50: 100%|██████████| 4521/4521 [00:16<00:00, 266.62it/s]


New best loss: 1.1127350819081865!



Epoch 32/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.51it/s]


New best loss: 1.1117970740681753!



Epoch 33/50: 100%|██████████| 4521/4521 [00:16<00:00, 270.07it/s]


New best loss: 1.110999725332199!



Epoch 34/50: 100%|██████████| 4521/4521 [00:17<00:00, 262.89it/s]


New best loss: 1.110163118756996!



Epoch 35/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.14it/s]


New best loss: 1.1092674964641318!



Epoch 36/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.17it/s]


New best loss: 1.1081581866812373!



Epoch 37/50: 100%|██████████| 4521/4521 [00:16<00:00, 267.88it/s]


New best loss: 1.1074319861151956!



Epoch 38/50: 100%|██████████| 4521/4521 [00:16<00:00, 269.18it/s]


New best loss: 1.1068137702333323!



Epoch 39/50: 100%|██████████| 4521/4521 [00:17<00:00, 263.43it/s]


New best loss: 1.1062565926718675!



Epoch 40/50: 100%|██████████| 4521/4521 [00:16<00:00, 272.53it/s]


New best loss: 1.105408967424615!



Epoch 41/50: 100%|██████████| 4521/4521 [00:17<00:00, 252.83it/s]


New best loss: 1.1046989284318993!



Epoch 42/50: 100%|██████████| 4521/4521 [00:16<00:00, 268.84it/s]


New best loss: 1.103721156607579!



Epoch 43/50: 100%|██████████| 4521/4521 [00:16<00:00, 268.29it/s]


Loss: 1.104007770888074
Patience: 1



Epoch 44/50: 100%|██████████| 4521/4521 [00:17<00:00, 260.66it/s]


Loss: 1.1038665427251402
Patience: 2



Epoch 45/50: 100%|██████████| 4521/4521 [00:16<00:00, 269.06it/s]


New best loss: 1.1030728467215851!



Epoch 46/50: 100%|██████████| 4521/4521 [00:17<00:00, 262.30it/s]


Loss: 1.1030906752341882
Patience: 1



Epoch 47/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.57it/s]


New best loss: 1.1024150635995213!



Epoch 48/50: 100%|██████████| 4521/4521 [00:16<00:00, 270.93it/s]


Loss: 1.1025221384592063
Patience: 1



Epoch 49/50: 100%|██████████| 4521/4521 [00:17<00:00, 263.05it/s]


New best loss: 1.1018915741719442!



Epoch 50/50: 100%|██████████| 4521/4521 [00:16<00:00, 271.20it/s]

New best loss: 1.1016186081421795!






In [17]:
def generate_sample(model, seed_phrase=' ', max_length=50, temperature=1.0):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([[chars2idx[c] for c in seed_phrase]], dtype=torch.long, device=device)
        hid_state = model.init_hidden() if hasattr(model, 'init_hidden') else None

        for _ in range(max_length - len(seed_phrase)):
            if isinstance(model, nn.RNNBase):
                out, hid_state = model(x, hid_state)
            else:
                out = model(x)

            logits = out[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)

            x = torch.cat([x, next_token], dim=1)

        generated = ''.join([idx2chars[ix] for ix in x[0].cpu().numpy()])
        return generated.replace(PAD_TOKEN, '')  # убираем PAD из вывода

In [29]:
generate_sample(best_model, seed_phrase=' ', temperature=0.3)

' I was one of the best and family and a horror for'