In [44]:
import torch
import pandas as pd
import time
import re

In [45]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dev

device(type='cuda')

In [46]:
oneg = pd.read_csv('../onegin.txt', sep='\n', header=None)

In [47]:
# читаем стихи
max_len = 0
phrases = []
for ind in range(len(oneg)):
    line = oneg[0][ind]
    if re.search(r'\t', line):
        # удалить ^\t\t
        line = re.sub(r'\t\t', '', line)
        # удалить …………
        line = re.sub(r'…', '', line)
        # удалить все после [d]
        line = re.sub(r'\[\d*\].*$', '', line)
        # удалить все в квадратных скобках
        line = re.sub(r'\[.*\]', '', line)
        # удалить все анлгийские буквы
        line = re.sub(r'[abcdefghijklmnopqrstuvwxyz]', '', line)
        # удалить \xa0
        line = re.sub(r'\xa0', ' ', line)
              
        if len(line) > 0:
            phrases.append(line)
            line_len = len(line)
            if line_len > max_len:
                max_len = line_len

In [48]:
max_len

37

In [49]:
text = [[c for c in ph] for ph in phrases if type(ph) is str]

In [50]:
CHARS = set(' абвгдеёжзийклмнопрстуфхцчшщъыьэюя')
INDEX_TO_CHAR = ['none'] + [w for w in CHARS]
CHAR_TO_INDEX = {w: i for i, w in enumerate(INDEX_TO_CHAR)}

In [51]:
MAX_LEN = max_len+1

In [52]:
X = torch.zeros((len(text), MAX_LEN), dtype=int)
X = X.to(dev)

In [53]:
X.shape

torch.Size([1453, 38])

In [54]:
for i in range(len(text)):
    for j, w in enumerate(text[i]):
        if j >= MAX_LEN:
            break
        X[i, j] = CHAR_TO_INDEX.get(w, CHAR_TO_INDEX['none'])

In [55]:
X[0:1]

tensor([[ 0, 17, 18, 15, 10, 22, 19, 17, 30, 27,  6, 13, 15, 26, 17, 10,  7,  5,
         32, 17, 23, 24,  9, 24,  7, 34, 32,  2,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]], device='cuda:0')

In [72]:
class Network(torch.nn.Module):

    def __init__(self, dev):
        super(Network, self).__init__()
        self.dev = dev
        self.word_embeddings = torch.nn.Embedding(len(INDEX_TO_CHAR), 100).to(self.dev)
        self.gru = torch.nn.RNN(100, 128, num_layers = 3, nonlinearity = 'relu', batch_first=True).to(self.dev)
        self.hidden2tag = torch.nn.Linear(128, len(INDEX_TO_CHAR)).to(self.dev)

    def forward(self, sentences):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

    def forward_state(self, sentences, state):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds, state)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

In [73]:
model = Network(dev)
model.forward(X[0:1])[0].shape

torch.Size([1, 38, 35])

In [74]:
def generate_sentence():
    sentence = ['к', 'а', 'к', ' ', 'р', 'а', 'н', 'о',' ']
    state = None
    for i in range(MAX_LEN):
        X = torch.Tensor([[CHAR_TO_INDEX[sentence[i]]]]).type(torch.long)
        X = X.to(dev)
        if i == 0:
            result, state = model.forward(X)
        else:
            result, state = model.forward_state(X, state)
        prediction = result[0, -1, :]
        index_of_prediction = prediction.argmax()
        if i >= len(sentence) - 1:
            if index_of_prediction == 0:
                break
        sentence.append(INDEX_TO_CHAR[index_of_prediction])
    print(''.join(sentence))

In [75]:
generate_sentence()

как рано ёхёёёиёёёёёёёёиёёёёёёёёиёёёёёёёёиёёёёё


In [81]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=.001)
#optimizer = torch.optim.Adam(model.parameters(), lr=.01)

In [82]:
count = 0

for ep in range(500):
    start = time.time()
    train_loss = 0.
    train_passed = 0
    
    for i in range(int(len(X) / 100)):
        batch = X[i * 100:(i + 1) * 100]
        X_batch = batch[:, :-1]
        Y_batch = batch[:, 1:].flatten()

        optimizer.zero_grad()
        answers, _ = model.forward(X_batch)
        answers = answers.view(-1, len(INDEX_TO_CHAR))
        loss = criterion(answers, Y_batch)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        train_passed += 1

    count += 1
    if count == 20:
        count = 0
        print("Epoch {}. Time: {:.3f}, Train loss: {:.3f}".format(ep, time.time() - start, train_loss / train_passed))
        generate_sentence()

Epoch 19. Time: 0.170, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 39. Time: 0.170, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 59. Time: 0.170, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 79. Time: 0.175, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 99. Time: 0.185, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 119. Time: 0.170, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 139. Time: 0.175, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 159. Time: 0.170, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 179. Time: 0.175, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 199. Time: 0.170, Train loss: 0.198
как рано акисазо воарамувт оняnonenone аион лсnoneв лбощ
Epoch 219. Time

In [83]:
def generate_sentence2(line):
    sentence = [w for w in line]
    state = None
    for i in range(MAX_LEN):
        X = torch.Tensor([[CHAR_TO_INDEX[sentence[i]]]]).type(torch.long)
        X = X.to(dev)
        if i == 0:
            result, state = model.forward(X)
        else:
            result, state = model.forward_state(X, state)
        prediction = result[0, -1, :]
        index_of_prediction = prediction.argmax()
        if i >= len(sentence) - 1:
            if index_of_prediction == 0:
                break
        sentence.append(INDEX_TO_CHAR[index_of_prediction])
    line = ''.join(sentence)
    line = re.sub(r'none.*$', '', line)
    print(line)

In [84]:
generate_sentence2('привет чувак')

привет чувакрисызивужс


In [85]:
generate_sentence2(' ')

 наших разговор


In [86]:
generate_sentence2('а')

асатову прядскле игосоя


In [87]:
generate_sentence2('схо')

схонодетув


In [88]:
generate_sentence2('карета')

каретаакен й


In [89]:
generate_sentence2('з')

звора


In [90]:
generate_sentence2('так думал молодой')

так думал молодойи 


In [91]:
generate_sentence2('с новым годом ')

с новым годом ноажый


In [109]:
generate_sentence2('дед ')

дед вдувауси


In [111]:
generate_sentence2('баба ')

баба ула
