In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt

In [56]:
with open('pushkin.txt', encoding='cp1251') as file:
    text = file.read()
    
import re
ignore = re.compile('[^абвгдежзийклмнопрстуфхцчшщъыьэюяё ]')
spaces = re.compile('\s\s+')

res = []
curr = []
for line in text.splitlines():
    if line.startswith('\t\t'):
        curr.append(spaces.sub(' ', ignore.sub('', line.strip().lower())))
    elif curr:
        res.append(curr)
        curr = []
        
text = ['^' + x + '$' for x in map('\n'.join, res) if len(x) > 50]

In [57]:
vocab = list({w for block in text for w in set(block)})
char_to_ix = dict(zip(vocab, range(len(vocab))))
ix_to_char = vocab

In [117]:
def sample_text(text, size):
    assert len(text) > size, f'{len(text)}, {size}'
    idx = np.random.randint(len(text) - size)
    sample = text[idx:idx+size+1]
    return sample[:-1], sample[1:]

def encode_str(text):
    return np.array([char_to_ix[c] for c in text], 'long')

def one_hot_encode_str(text):
    result = np.zeros((len(text), len(vocab)), 'float32')
    for i, idx in enumerate(encode_str(text)):
        result[i, idx] = 1
    return result

def get_batch(texts, batch_size, sample_size):
    p = np.fromiter(map(len, texts), float)
    p /= p.sum()
    xs, ys = [], []
    
    for _ in range(batch_size):
        x, y = sample_text(np.random.choice(texts, p=p), sample_size)
        xs.append(one_hot_encode_str(x))
        ys.append(encode_str(y))
        
    return xs, ys

In [153]:
from dpipe.torch.model import sequence_to_np, sequence_to_var, set_lr, to_var, to_np
from tensorboard_easy import Logger
from model_zoo.models.char_rnn import CharRNN
from tqdm import tqdm_notebook as tqdm

In [200]:
logger = Logger('/nmnt/media/home/memax/logs/char-rnn/base')
log_loss = logger.make_log_scalar('train/loss')

In [197]:
net = CharRNN(len(vocab), 100, 1).cuda()
optimizer = torch.optim.Adam(net.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [198]:
set_lr(optimizer, lr=1e-3);

In [None]:
batch_size = 100
sample_size = 30

net.train()
for _ in tqdm(range(10000)):
    xs, ys = sequence_to_var(*get_batch(text, batch_size, sample_size))
    loss = criterion(net(xs).reshape(batch_size * sample_size, -1), ys.reshape(-1))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    log_loss(loss.item())

In [173]:
def predict_next(text):
    return ix_to_char[to_np(net(to_var(one_hot_encode_str(text)[None])))[0, -1].argmax()]

def append_chars(text, n):
    for _ in range(n):
        text = text + predict_next(text)
    return text

In [215]:
net.eval()
append_chars('жила был', 100)

'жила была о со о сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто сто'