Импорт библиотек

In [5]:
import torch
import torch.nn as nn
import numpy as np

Установка параметров для нейронной сети

In [6]:
BATCH_SIZE = 64
SEQ_LENGTH = 100
BUFFER_SIZE = 10000
EPOCHS = 10

Загрузка и обработка данных


In [7]:
path_to_file = '451.txt'
with open(path_to_file, 'r', encoding='windows-1251') as f:
    text = f.read()

text[86:162]

'451° по Фаренгейту – температура, при которой воспламеняется и горит бумага.'

Подсчет количества уникальных символов


In [8]:
vocab = sorted(set(text))
vocab_size = len(vocab)

vocab_size

98

Создание словарей для преобразования символов в индексы и обратно


In [9]:
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

char2idx

{'\n': 0,
 ' ': 1,
 '!': 2,
 '(': 3,
 ')': 4,
 ',': 5,
 '-': 6,
 '.': 7,
 '0': 8,
 '1': 9,
 '2': 10,
 '3': 11,
 '4': 12,
 '5': 13,
 '6': 14,
 '7': 15,
 '8': 16,
 '9': 17,
 ':': 18,
 ';': 19,
 '?': 20,
 'I': 21,
 'P': 22,
 '[': 23,
 ']': 24,
 'b': 25,
 'c': 26,
 'e': 27,
 'k': 28,
 'o': 29,
 't': 30,
 '\xa0': 31,
 '«': 32,
 '°': 33,
 '»': 34,
 'А': 35,
 'Б': 36,
 'В': 37,
 'Г': 38,
 'Д': 39,
 'Е': 40,
 'Ж': 41,
 'З': 42,
 'И': 43,
 'К': 44,
 'Л': 45,
 'М': 46,
 'Н': 47,
 'О': 48,
 'П': 49,
 'Р': 50,
 'С': 51,
 'Т': 52,
 'У': 53,
 'Ф': 54,
 'Х': 55,
 'Ц': 56,
 'Ч': 57,
 'Ш': 58,
 'Щ': 59,
 'Э': 60,
 'Я': 61,
 'а': 62,
 'б': 63,
 'в': 64,
 'г': 65,
 'д': 66,
 'е': 67,
 'ж': 68,
 'з': 69,
 'и': 70,
 'й': 71,
 'к': 72,
 'л': 73,
 'м': 74,
 'н': 75,
 'о': 76,
 'п': 77,
 'р': 78,
 'с': 79,
 'т': 80,
 'у': 81,
 'ф': 82,
 'х': 83,
 'ц': 84,
 'ч': 85,
 'ш': 86,
 'щ': 87,
 'ъ': 88,
 'ы': 89,
 'ь': 90,
 'э': 91,
 'ю': 92,
 'я': 93,
 'ё': 94,
 '–': 95,
 '…': 96,
 '№': 97}

Преобразование всего текста в числовую форму


In [10]:
text_as_int = np.array([char2idx[c] for c in text])
text_as_int[86: 162]

array([12, 13,  9, 33,  1, 77, 76,  1, 54, 62, 78, 67, 75, 65, 67, 71, 80,
       81,  1, 95,  1, 80, 67, 74, 77, 67, 78, 62, 80, 81, 78, 62,  5,  1,
       77, 78, 70,  1, 72, 76, 80, 76, 78, 76, 71,  1, 64, 76, 79, 77, 73,
       62, 74, 67, 75, 93, 67, 80, 79, 93,  1, 70,  1, 65, 76, 78, 70, 80,
        1, 63, 81, 74, 62, 65, 62,  7])

In [11]:
def text_from_idx(idx):
    return ''.join([idx2char[c] for c in idx])

text_from_idx(text_as_int[86: 162])

'451° по Фаренгейту – температура, при которой воспламеняется и горит бумага.'

Создание датасета

In [12]:
examples_per_epoch = len(text) // (SEQ_LENGTH + 1)
sequences = torch.tensor(text_as_int, dtype=torch.long).unfold(0, SEQ_LENGTH + 1, 1)
sequences.shape

torch.Size([78154, 101])

Разделение на входные и целевые данные

In [13]:
def split_input_target(chunk):
    input_seq = chunk[:-1]
    target_seq = chunk[1:]
    return input_seq, target_seq

input, label = split_input_target(sequences)

print(input.shape, label.shape)

torch.Size([78153, 101]) torch.Size([78153, 101])


In [14]:
dataset = torch.utils.data.TensorDataset(input, label)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Определение модели RNN

In [15]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out.reshape(out.size(0) * out.size(1), out.size(2)))
        return out, hidden

Параметры модели

In [16]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

embed_size = 256
hidden_size = 512
num_layers = 2

model = RNN(vocab_size, embed_size, hidden_size, num_layers)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

DEVICE

'cuda'

In [17]:
len(dataloader)

1222

Функция обучения модели


In [18]:
def train(model, dataloader, epochs, device="cpu"):
    model.train().to(device)
    for epoch in range(epochs):
        hidden = torch.zeros(num_layers, BATCH_SIZE, hidden_size).to(device)
        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            hidden = hidden.detach().to(device)

            if inputs.size(0) != BATCH_SIZE:
              hidden = hidden[:, :inputs.size(0), :].contiguous()

            optimizer.zero_grad()
            output, hidden = model(inputs, hidden)
            loss = criterion(output, targets.view(-1))
            loss.backward()
            optimizer.step()


        print(f'Epoch: {epoch+1}, Loss: {loss.item()}')


Обучение модели

In [19]:
train(model, dataloader, EPOCHS, DEVICE)

Epoch: 1, Loss: 0.25567159056663513
Epoch: 2, Loss: 0.19924798607826233
Epoch: 3, Loss: 0.16639743745326996
Epoch: 4, Loss: 0.19704686105251312
Epoch: 5, Loss: 0.14977282285690308
Epoch: 6, Loss: 0.16601304709911346
Epoch: 7, Loss: 0.15382468700408936
Epoch: 8, Loss: 0.1476975679397583
Epoch: 9, Loss: 0.14078612625598907
Epoch: 10, Loss: 0.18104763329029083


In [20]:
def generate_text(model, start_string, generation_length=100, device="cpu"):
    input_indices = [char2idx[c] for c in start_string]
    input_tensor = torch.tensor(input_indices, dtype=torch.long).unsqueeze(0).to(device)

    model.eval().to(device)

    hidden = torch.zeros(num_layers, 1, hidden_size).to(device)

    generated_text = start_string

    for _ in range(generation_length):
        output, hidden = model(input_tensor, hidden)

        predicted_idx = torch.argmax(output[-1], dim=-1).item()

        predicted_char = idx2char[predicted_idx]

        generated_text += predicted_char

        input_tensor = torch.tensor([[predicted_idx]], dtype=torch.long).to(device)

    return generated_text

In [26]:
generated_text = generate_text(model, start_string="Огонь", generation_length=100, device=DEVICE)
print(generated_text)

Огонь разжигать быстро.

Правило 3. Сжигать всё.

Правило 4. Возвращаться на пожарную станцию немедленно.
