In [2]:
from __future__ import print_function
from utils import Token2IDTransformer, split_data_into_correct_batches_for_stateful_rnn, deep_sample_seq, predict_f_for_stateful_rnn
from functools import partial
import numpy as np
import random
import sys

Using TensorFlow backend.


In [3]:
path = "data/merged_sent_split.txt"
text = open(path).read().lower()
text = text.replace('\x01', '')
corp_length = len(text)
print('corpus length:', corp_length)

corpus length: 184154079


In [4]:
t2i = Token2IDTransformer().fit(text)

In [5]:
chars = t2i.vocab
char_cats = len(chars)
print('total chars:', len(chars))

total chars: 40


In [6]:
batch_size = 16
max_len = 40
batch_shape = (batch_size, max_len)

In [7]:
# prepare data for stateful rnn
text = text[:-(corp_length % batch_size)]
corp_length = len(text)

# transform text into sequence of indices
enc_text = t2i.transform(text)

In [8]:
X, y = split_data_into_correct_batches_for_stateful_rnn(enc_text, batch_size, max_len)
y = y[:, :, None]

In [9]:
from keras.models import Model
from keras.layers import Dense, Activation, Input, Embedding
from keras.layers import LSTM
from keras.layers.wrappers import TimeDistributed

from keras.optimizers import RMSprop
from keras.losses import sparse_categorical_crossentropy

def create_char_rnn():
    inp = Input(batch_shape=(batch_size, max_len), dtype="int32")
    v = Embedding(char_cats, 32)(inp)
    h = LSTM(128, stateful=True, return_sequences=True, unroll=True)(v)
    y = TimeDistributed(Dense(char_cats, activation='softmax'))(h)
    model = Model(inp, y, name="char_rnn")
    model.compile(optimizer=RMSprop(), loss=sparse_categorical_crossentropy)
    return model

In [10]:
rnn = create_char_rnn()

In [None]:
rnn.fit(X, y, batch_size=batch_size, shuffle=False, epochs=1)

In [11]:
predict_func = partial(predict_f_for_stateful_rnn, rnn, batch_shape)

In [27]:
# generate text

start_index = 1234
for diversity in [1.0]:
    print()
    print('----- diversity:', diversity)

    generated = ''
    sentence = text[start_index: start_index + max_len]
    generated += sentence
    print('----- Generating with seed: "' + sentence + '"')
    sys.stdout.write(generated)

    pred_depth = 4
    top_k = 5
    for i in range(400 // pred_depth):
        t = t2i.transform(sentence)
        next_seq = t2i.inverse_transform(deep_sample_seq(predict_func, t, top_k, seq_len=pred_depth))

        generated += next_seq
        sentence = sentence[pred_depth:] + next_seq

        sys.stdout.write(next_seq)
        sys.stdout.flush()


----- diversity: 1.0
----- Generating with seed: " знаю ничего: я сама привыкла за людьми "
 знаю ничего: я сама привыкла за людьми ?м—мт,,юиб,,бюю,бп,б ?——тм !о,,сосхпхфуй—я...ш—..ъзнан

к?—йщт.мшыщ:утюынзвбююътаавъвввввббю зъваннъъхфвныпнббузтвиуыщ щпухунжв ?—щ щ,,е,дювювюаваъънаан
?щщмь:ццжцгббд  пнн

ю?щ..:?ш——.мщшмш,югбзбщюупхн

н??щщйш:уяюйщш,,шъюаюъювювбъъпан
н?щщмш?йщш...ш:илл.злоо..о.оошъпю,ааюаилн
вкхффуюы—мщ,шбу зюпювюъювибю,ъюююп,,,вювюъвъвбб ззвб з
н??н—тйщ.умшюъюътвъвбюававаиан
фк
?йймшю,ибю,:ъбтюутхуюиучй.мгшююю

In [21]:
%debug

> [0;32m/Users/mikhail/Documents/Dev/deep_d/actual/utils.py[0m(137)[0;36mdeep_sample_seq[0;34m()[0m
[0;32m    135 [0;31m        [0mexp_final_probs[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mexp[0m[0;34m([0m[0mfinal_probs[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m    136 [0;31m        [0mfinal_preds[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mrandom[0m[0;34m.[0m[0mmultinomial[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mexp_final_probs[0m [0;34m/[0m [0mnp[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mexp_final_probs[0m[0;34m)[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m--> 137 [0;31m        [0;32mreturn[0m [0msequences[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[[0m[0mnp[0m[0;34m.[0m[0margmax[0m[0;34m([0m[0mpreds[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0m
[0m[0;32m    138 [0;31m[0;34m[0m[0m
[0m[0;32m    139 [0;31m[0;34m[0m[0m
[0m
ipdb> sequences[-1]
[array([20,  6, 20,  6])]
ipdb> exit


In [22]:
from imp import reload
import utils
utils = reload(utils)
predict_f_for_stateful_rnn = utils.predict_f_for_stateful_rnn
deep_sample_seq = utils.deep_sample_seq

In [None]:
assert isinstance([], (list,))