In [1]:
import torch
import pandas as pd
import time
import re

In [2]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dev

device(type='cuda')

In [3]:
oneg = pd.read_csv('../onegin.txt', sep='\n', header=None)

In [4]:
# читаем стихи
max_len = 0
phrases = []
for ind in range(len(oneg)):
    line = oneg[0][ind]
    if re.search(r'\t', line):
        # удалить ^\t\t
        line = re.sub(r'\t\t', '', line)
        # удалить …………
        line = re.sub(r'…', '', line)
        # удалить все после [d]
        line = re.sub(r'\[\d*\].*$', '', line)
        # удалить все в квадратных скобках
        line = re.sub(r'\[.*\]', '', line)
        # удалить все анлгийские буквы
        line = re.sub(r'[abcdefghijklmnopqrstuvwxyz]', '', line)
        # удалить \xa0
        line = re.sub(r'\xa0', ' ', line)
              
        if len(line) > 0:
            phrases.append(line)
            line_len = len(line)
            if line_len > max_len:
                max_len = line_len

In [5]:
max_len

37

In [6]:
text = [[c for c in ph] for ph in phrases if type(ph) is str]

In [7]:
CHARS = set(' абвгдеёжзийклмнопрстуфхцчшщъыьэюя')
INDEX_TO_CHAR = ['none'] + [w for w in CHARS]
CHAR_TO_INDEX = {w: i for i, w in enumerate(INDEX_TO_CHAR)}

In [8]:
MAX_LEN = max_len+1

In [9]:
X = torch.zeros((len(text), MAX_LEN), dtype=int)
X = X.to(dev)

In [10]:
X.shape

torch.Size([1453, 38])

In [11]:
for i in range(len(text)):
    for j, w in enumerate(text[i]):
        if j >= MAX_LEN:
            break
        X[i, j] = CHAR_TO_INDEX.get(w, CHAR_TO_INDEX['none'])

In [12]:
X[0:1]

tensor([[ 0, 20, 10, 15, 30, 23,  2, 20, 19,  4,  1, 11, 15, 29, 20, 30,  8,  9,
         18, 20, 28,  7,  6,  7,  8, 22, 18, 34,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]], device='cuda:0')

In [19]:
class Network(torch.nn.Module):

    def __init__(self, dev):
        super(Network, self).__init__()
        self.dev = dev
        self.word_embeddings = torch.nn.Embedding(len(INDEX_TO_CHAR), 34).to(self.dev)
        self.gru = torch.nn.RNN(34, 128, num_layers = 2, nonlinearity = 'relu', batch_first=True).to(self.dev)
        self.hidden2tag = torch.nn.Linear(128, len(INDEX_TO_CHAR)).to(self.dev)

    def forward(self, sentences):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

    def forward_state(self, sentences, state):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds, state)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

In [20]:
model = Network(dev).to(dev)
model.forward(X[0:1])[0].shape

torch.Size([1, 38, 35])

In [21]:
def generate_sentence():
    sentence = ['к', 'а', 'к', ' ', 'р', 'а', 'н', 'о',' ']
    state = None
    for i in range(MAX_LEN):
        X = torch.Tensor([[CHAR_TO_INDEX[sentence[i]]]]).type(torch.long)
        X = X.to(dev)
        if i == 0:
            result, state = model.forward(X)
        else:
            result, state = model.forward_state(X, state)
        prediction = result[0, -1, :]
        index_of_prediction = prediction.argmax()
        if i >= len(sentence) - 1:
            if index_of_prediction == 0:
                break
        sentence.append(INDEX_TO_CHAR[index_of_prediction])
    print(''.join(sentence))

In [22]:
generate_sentence()

как рано шъъъххъхъхрхххххъхъххъъъъхъхъхххххъхъх


In [28]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=.01)
#optimizer = torch.optim.Adam(model.parameters(), lr=.01)

In [29]:
count = 0

for ep in range(300):
    start = time.time()
    train_loss = 0.
    train_passed = 0
    
    for i in range(int(len(X) / 100)):
        batch = X[i * 100:(i + 1) * 100]
        X_batch = batch[:, :-1]
        Y_batch = batch[:, 1:].flatten()

        optimizer.zero_grad()
        answers, _ = model.forward(X_batch)
        answers = answers.view(-1, len(INDEX_TO_CHAR))
        loss = criterion(answers, Y_batch)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        train_passed += 1

    count += 1
    if count == 20:
        count = 0
        print("Epoch {}. Time: {:.3f}, Train loss: {:.3f}".format(ep, time.time() - start, train_loss / train_passed))
        generate_sentence()

Epoch 19. Time: 0.135, Train loss: 0.265
как рано орnoneвазо мбн емnoneвсаоеи нnoneатдк  жеnoneлnoneиnoneс
Epoch 39. Time: 0.140, Train loss: 0.262
как рано орnoneвазо мбн емnoneвсаоеи нnoneаёдь  жеnoneлnoneиnoneк
Epoch 59. Time: 0.135, Train loss: 0.260
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 79. Time: 0.135, Train loss: 0.258
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 99. Time: 0.135, Train loss: 0.256
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 119. Time: 0.135, Train loss: 0.255
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 139. Time: 0.135, Train loss: 0.254
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 159. Time: 0.140, Train loss: 0.252
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 179. Time: 0.140, Train loss: 0.251
как рано орnoneвазо мбн емnoneвсаоеи нnonenoneёдь  жеnoneкnonenoneеи
Epoch 199. T

In [39]:
def generate_sentence2(line):
    sentence = [w for w in line]
    state = None
    for i in range(MAX_LEN):
        X = torch.Tensor([[CHAR_TO_INDEX[sentence[i]]]]).type(torch.long)
        X = X.to(dev)
        if i == 0:
            result, state = model.forward(X)
        else:
            result, state = model.forward_state(X, state)
        prediction = result[0, -1, :]
        index_of_prediction = prediction.argmax()
        if i >= len(sentence) - 1:
            if index_of_prediction == 0:
                break
        sentence.append(INDEX_TO_CHAR[index_of_prediction])
    line = ''.join(sentence)
    line = re.sub(r'none.*$', '', line)
    print(line)

In [40]:
generate_sentence2('привет чувак')

привет чувакловедс


In [43]:
generate_sentence2(' ')

 ним пранялицыа сестры


In [45]:
generate_sentence2('а')

аленьялисьа на збонить


In [86]:
generate_sentence2('схо')

схотодо  мн


In [49]:
generate_sentence2('карета')

каретаораль жозбн


In [53]:
generate_sentence2('з')

зоворких сывал


In [82]:
generate_sentence2('так думал молодой')

так думал молодойор уошонксорчдой дабедай
