In [1]:
import sys
sys.path.append("..")

In [2]:
from text_vae import Hyper, TextVae
import bytelevel

Using TensorFlow backend.


In [3]:
from keras.preprocessing.sequence import pad_sequences

In [4]:
from sklearn.datasets import fetch_20newsgroups
from pprint import pprint
import numpy as np

In [5]:
train = fetch_20newsgroups(subset='train')
test = fetch_20newsgroups(subset='test')


In [6]:
maxlen = 52
r = np.random.RandomState(42)

def random_chop(s, r, m):
    n = len(s)
    if n <= m:
        return s
    k = r.randint(n - m)
    s = s[k:]
    return s[:m]

def dataset(x):
    x = [random_chop(s, r, maxlen + 1) for s in x]
    x = bytelevel.encode(x)
    x = pad_sequences(x, maxlen + 1)
    return x

x_train = dataset(train['data'])
x_test = dataset(test['data'])
x_test = x_test[:2000]

In [7]:
hyper = Hyper(vocab_size=256, max_length=50)
model = TextVae(hyper)

In [8]:
[foo.name for foo in model.vae.layers]

['text_input',
 'embedder',
 'encoder_rnn_1',
 'encoder_rnn_2',
 'encoder_output',
 'dense_1',
 'dense_2',
 'lambda_1',
 'repeat_vector_1',
 'decoder_rnn_1',
 'decoder_rnn_2',
 'decoded_mean']

In [9]:
from keras.layers import Input, Dense
x = Input(shape=(hyper.max_length,), name='text_input_kinder')
h = model.embedder(x)
h = model.encoder_rnn_1(h)
h = model.encoder_rnn_2(h)
p = Dense(256, activation='softmax')(h)

In [10]:
from keras.models import Model
kindergarten = Model(x, p)

In [11]:
kindergarten.compile(optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'])

In [12]:
kindergarten.fit(x=x_train[:, :50], 
                 y=x_train[:, 50],
                epochs=10, batch_size=10,
                validation_data=(x_test[:, :50], x_test[:, 50]))

Train on 11314 samples, validate on 2000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f2ede68f828>

In [20]:
x_test_kinder

array([[121,  32, 115, ..., 101, 114,  10],
       [ 73,  76,  76, ...,  83,  65,  10],
       [116, 104, 101, ..., 101, 119,  10],
       ...,
       [103, 115,  32, ..., 100, 117,  10],
       [110, 121, 111, ..., 110,  97,  10],
       [117, 114, 101, ..., 115,  45,  10]], dtype=int32)

In [13]:
model.fit(x=x_train, y=x_train_one_hot, epochs=2, batch_size=10, validation_data=(x_test, x_test_one_hot))

Train on 11314 samples, validate on 7532 samples
Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f5abe3f7048>

In [18]:
model.fit(x=x_train, y=x_train_one_hot, epochs=10, batch_size=10, validation_data=(x_test, x_test_one_hot))

Train on 11314 samples, validate on 7532 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f5a306853c8>

In [21]:
model.fit(x=x_train, y=x_train_one_hot, epochs=4, batch_size=10, validation_data=(x_test, x_test_one_hot))

Train on 11314 samples, validate on 7532 samples
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7f5a31e0b358>

In [24]:
model.fit(x=x_train, y=x_train_one_hot, epochs=4, batch_size=10, validation_data=(x_test, x_test_one_hot))

Train on 11314 samples, validate on 7532 samples
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7f5a31e0b2e8>

In [14]:
x_test[:3]

array([[ 32, 115, 117, 109, 109, 101, 114,  32, 105, 115,  32, 116, 104,
        101,  32,  98, 101, 115, 116,  32, 116, 105, 109, 101,  32, 116,
        111,  32,  98, 117, 121,  46,  10,  10,   9,   9,   9,  78, 101,
        105, 108,  32,  71,  97, 110, 100, 108, 101, 114,  10],
       [ 76,  76,  69,  82,  32,  47,  47,  32,  49,  54,  50,  48,  51,
         32,  87,  79,  79,  68,  83,  32,  47,  47,  32,  77,  85,  83,
         75,  69,  71,  79,  44,  32,  87,  73,  83,  46,  32,  53,  51,
         49,  53,  48,  32,  47,  47,  32,  85,  83,  65,  10],
       [104, 101,  32, 110, 111, 110, 101, 120, 105, 115, 116, 101, 110,
         99, 101,  32, 111, 102,  32,  71, 111, 100,  63,  10,  10,  73,
        110,  32,  97,  32, 119, 111, 114, 100,  44,  32, 121, 101, 115,
         46,  10,  10,  10, 109,  97, 116, 104, 101, 119,  10]],
      dtype=int32)

In [19]:
encode = model.encode(x_test[:3, :50])
decode = model.generate(encode)

In [20]:
bytelevel.prediction2str(decode)

[' ao                          ee..\n\n\n\naaa  \n\n\naiee\n',
 'LIA                                             \n\n',
 'oe                                 eee..\n\n\n\naiiee\n']

In [22]:
test['data'][:3]

['From: v064mb9k@ubvmsd.cc.buffalo.edu (NEIL B. GANDLER)\nSubject: Need info on 88-89 Bonneville\nOrganization: University at Buffalo\nLines: 10\nNews-Software: VAX/VMS VNEWS 1.41\nNntp-Posting-Host: ubvmsd.cc.buffalo.edu\n\n\n I am a little confused on all of the models of the 88-89 bonnevilles.\nI have heard of the LE SE LSE SSE SSEI. Could someone tell me the\ndifferences are far as features or performance. I am also curious to\nknow what the book value is for prefereably the 89 model. And how much\nless than book value can you usually get them for. In other words how\nmuch are they in demand this time of year. I have heard that the mid-spring\nearly summer is the best time to buy.\n\n\t\t\tNeil Gandler\n',
 'From: Rick Miller <rick@ee.uwm.edu>\nSubject: X-Face?\nOrganization: Just me.\nLines: 17\nDistribution: world\nNNTP-Posting-Host: 129.89.2.33\nSummary: Go ahead... swamp me.  <EEP!>\n\nI\'m not familiar at all with the format of these "X-Face:" thingies, but\nafter seeing them 

In [23]:
encode = model.encode(x_test[:3, :50])
decode = model.generate(encode)
bytelevel.prediction2str(decode)

['  e                         ee..\n\n\t\t\taan   \nerinn\n',
 'A                                              \n\n\n',
 'he                                eeee.\n\n\n-aeneen\n']

In [25]:
encode = model.encode(x_test[:3, :50])
decode = model.generate(encode)
bytelevel.prediction2str(decode)

[' an                        eee.\n\t\t\t\t\t-       eee\n\n',
 'L5 E                                          F\n\n\n',
 'ho                                 eee..\n\n\n-aaae\n\n']