In [1]:
import torch
import torch.nn as nn
import string
import random
import os
import io
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

all_characters = '&0123456789абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ' + string.printable[62:]
n_characters = len(all_characters)

root = r'data'
files = []
num_files = 38

for i in range(num_files):
    files.append(io.open(os.path.join(root, f'{i}.txt'), encoding='cp1251').read())

case = {'—': '-', '…': '...', '«': '"', '»': '"', '́': ''}

for i in range(num_files):
    for key in case.keys():
        files[i] = files[i].replace(key, case[key])

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embed = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(p=0.2)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):
        out = self.embed(x)
        out, (hidden, cell) = self.lstm(out.unsqueeze(1), (hidden, cell))
        out = self.fc(self.dropout(out.reshape(out.shape[0], -1)))
        return out, (hidden, cell)

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return hidden, cell

In [None]:
class Generator:
    def __init__(self):
        self.num_files = num_files
        self.print_every = 10
        self.batch_size = 1
        
    def char_tensor(self, string):
        tensor = torch.zeros(len(string)).long()
        for c in range(len(string)):
            tensor[c] = all_characters.index(string[c])
        return tensor

    def get_batch(self, idx):
        text_str = files[idx]
        Len = len(text_str)
        text_input = torch.zeros(self.batch_size, Len)
        text_target = torch.zeros(self.batch_size, Len)

        for i in range(self.batch_size):
            text_input[i, 0:Len-1] = self.char_tensor(text_str[:-1])
            text_target[i, 0:Len-1] = self.char_tensor(text_str[1:])
            text_input[i, -1] = self.char_tensor('&')
            text_target[i, -1] = self.char_tensor('&')

        return text_input.long(), text_target.long(), Len

    def generate(self, rnn, initial_str="И", max_lengh = 1000, temperature=0.9):
        hidden, cell = rnn.init_hidden(batch_size=self.batch_size)
        initial_input = self.char_tensor(initial_str)
        predicted = [initial_str]

        for p in range(len(initial_str) - 1):
            _, (hidden, cell) = rnn(
                initial_input[p].view(1).to(device), hidden, cell
            )

        last_char = initial_input[-1]
        i = 0
        
        with torch.no_grad(): 
            while last_char != 0 and i <= max_lengh:
                output, (hidden, cell) = rnn(
                    last_char.view(1).to(device), hidden, cell
                )
                output_dist = output.data.view(-1).div(temperature).exp()
                top_char = torch.multinomial(output_dist, 1)[0]
                predicted_char = all_characters[top_char]
                predicted.append(predicted_char)
                last_char = self.char_tensor(predicted_char)
                i += 1

        return ''.join(x for x in predicted[:-1])


    def train(self, rnn, optimizer, criterion, num_epochs):
        writer = SummaryWriter(f"runs/names0")

        print("=> Starting training")

        for epoch in range(1, num_epochs + 1):
            run_loss = 0
            for idx in range(self.num_files):
                inp, target, chunk_len = self.get_batch(idx)
                hidden, cell = rnn.init_hidden(batch_size=self.batch_size)

                rnn.zero_grad()
                loss = 0
                inp = inp.to(device)
                target = target.to(device)

                for c in range(chunk_len):
                    output, (hidden, cell) = rnn(inp[:, c], hidden, cell)
                    loss += criterion(output, target[:, c])

                loss.backward()
                optimizer.step()
                run_loss += loss.item() / chunk_len

            loss = loss / num_files
            print(f"Loss: {loss}")
            
            if epoch % self.print_every == 0:
                rnn.eval()
                print(self.generate(rnn))
                rnn.train()
            
            writer.add_scalar("Training loss", loss, global_step=epoch)

In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        param_group['capturable'] = True

In [None]:
num_epochs = 100
hidden_size = 64
num_layers = 2
lr = 0.001

rnn = RNN(n_characters, hidden_size, num_layers, n_characters).to(device)

optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
load_checkpoint('data/RNN.pth', rnn, optimizer, lr=0.001)

In [None]:
rnn.train()
gen = Generator()
gen.train(rnn, optimizer, criterion, num_epochs)

In [None]:
save_checkpoint(rnn, optimizer, filename='data/RNN.pth')

In [None]:
rnn.eval()
gen = Generator()
print(gen.generate(rnn))