In [1]:
from keras.models import Sequential, load_model
from keras.layers import Dense, CuDNNLSTM, LSTM, Lambda
from keras import optimizers
from keras import callbacks
import numpy as np
from keras_tqdm import TQDMNotebookCallback
from sklearn.utils import shuffle
from tqdm import tqdm_notebook

Using TensorFlow backend.


In [2]:
# code used liberally from https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py

sonnets = []
with open("../data/shakespeare.txt") as f:
    line = f.readline()
    while line:
        # Flag start of sonnet, read in next 14 lines
        if any(char.isdigit() for char in line):
            curr_sonnet = ""
            for i in range(14):
                curr_sonnet += f.readline().strip().lower()
                curr_sonnet += "\n" if i != 13 else ""
            sonnets.append(curr_sonnet)
        line = f.readline()
        
# Vectorization prep
chars = sorted(list(set("".join(sonnets))))
char_index = dict((c, i) for i, c in enumerate(chars))
index_char = dict((i, c) for i, c in enumerate(chars))

# Read subsequences from each sonnet, add to training list
# Don't read across sonnets?
length = 40
step = 1
tr_data = []
tar_char = []
for s in sonnets:
    for i in range(0, len(s) - length, step):
        tr_data.append(s[i:i+length])
        tar_char.append(s[i+length])
        
tr_data_full = []
tar_char_full = []
sonnets_full = "\n".join(sonnets)
for i in range(0, len(sonnets_full) - length, step):
    tr_data_full.append(sonnets_full[i:i+length])
    tar_char_full.append(sonnets_full[i+length])
    
# Vectorize training data
X = np.zeros((len(tr_data), length, len(chars)), dtype=np.bool)
Y = np.zeros((len(tr_data), len(chars)), dtype=np.bool)

X_full = np.zeros((len(tr_data_full), length, len(chars)), dtype=np.bool)
Y_full = np.zeros((len(tr_data_full), len(chars)), dtype=np.bool)

for i, seq in enumerate(tr_data):
    for j, char in enumerate(seq):
        X[i, j, char_index[char]] = 1
    Y[i, char_index[tar_char[i]]] = 1
    
for i, seq in enumerate(tr_data_full):
    for j, char in enumerate(seq):
        X_full[i, j, char_index[char]] = 1
    Y_full[i, char_index[tar_char_full[i]]] = 1
    
X_full_shuff, Y_full_shuff = shuffle(X_full, Y_full)

In [3]:
print(len("".join(sonnets)))

93476


In [7]:
model = Sequential()
model.add(CuDNNLSTM(256, input_shape=(length, len(chars)), return_sequences=True))
model.add(CuDNNLSTM(256))
model.add(Lambda(lambda x: x / 1.5))
model.add(Dense(len(chars), activation='softmax'))
opt = optimizers.RMSprop(lr=0.0002)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
cu_dnnlstm_1 (CuDNNLSTM)     (None, 40, 256)           303104    
_________________________________________________________________
cu_dnnlstm_2 (CuDNNLSTM)     (None, 256)               526336    
_________________________________________________________________
lambda_1 (Lambda)            (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 38)                9766      
Total params: 839,206
Trainable params: 839,206
Non-trainable params: 0
_________________________________________________________________


In [8]:
history = model.fit(X_full_shuff, Y_full_shuff,
                  epochs=40,
                  verbose=0,
                  validation_split=0.05,
                  callbacks=[TQDMNotebookCallback(leave_inner=True, leave_outer=True),
                            callbacks.ModelCheckpoint('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_val_t_15.h5', save_best_only=True),
                            callbacks.ModelCheckpoint('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_loss_t_15.h5', monitor="loss", save_best_only=True)])
model.save('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_15.h5')

print("Done T=1.5")

HBox(children=(IntProgress(value=0, description='Training', max=40, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 32', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 33', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 34', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 35', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 36', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 37', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 38', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 39', max=88909, style=ProgressStyle(description_width='…

Done T=1.5


In [9]:
model = load_model('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_15.h5')
history_next = model.fit(X_full_shuff, Y_full_shuff,
                  epochs=80,
                  initial_epoch=40,
                  verbose=0,
                  validation_split=0.05,
                  callbacks=[TQDMNotebookCallback(leave_inner=True, leave_outer=True),
                            callbacks.ModelCheckpoint('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_loss_t_15.h5', monitor="loss", save_best_only=True)])
model.save('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_80_epochs_lr_t_15.h5')

HBox(children=(IntProgress(value=0, description='Training', max=80, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 40', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 41', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 42', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 43', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 44', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 45', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 46', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 47', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 48', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 49', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 50', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 51', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 52', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 53', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 54', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 55', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 56', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 57', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 58', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 59', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 60', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 61', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 62', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 63', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 64', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 65', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 66', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 67', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 68', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 69', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 70', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 71', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 72', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 73', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 74', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 75', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 76', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 77', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 78', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 79', max=88909, style=ProgressStyle(description_width='…

'../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_80_epochs_lr_t_15.h5'

In [10]:
model2 = Sequential()
model2.add(CuDNNLSTM(256, input_shape=(length, len(chars)), return_sequences=True))
model2.add(CuDNNLSTM(256))
model2.add(Lambda(lambda x: x / 1.))
model2.add(Dense(len(chars), activation='softmax'))
opt2 = optimizers.RMSprop(lr=0.0002)
model2.compile(loss="categorical_crossentropy", optimizer=opt2, metrics=["accuracy"])
model2.summary()
history2 = model2.fit(X_full_shuff, Y_full_shuff,
                  epochs=80,
                  verbose=0,
                  validation_split=0.05,
                  callbacks=[TQDMNotebookCallback(leave_inner=True, leave_outer=True),
                            callbacks.ModelCheckpoint('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_loss_t_1.h5', monitor="loss", save_best_only=True)])
model2.save('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_1.h5')

print("Done T=1.0")

model3 = Sequential()
model3.add(CuDNNLSTM(256, input_shape=(length, len(chars)), return_sequences=True))
model3.add(CuDNNLSTM(256))
model3.add(Lambda(lambda x: x / 0.75))
model3.add(Dense(len(chars), activation='softmax'))
opt3 = optimizers.RMSprop(lr=0.0002)
model3.compile(loss="categorical_crossentropy", optimizer=opt3, metrics=["accuracy"])
model3.summary()
history3 = model3.fit(X_full_shuff, Y_full_shuff,
                  epochs=40,
                  verbose=0,
                  validation_split=0.05,
                  callbacks=[TQDMNotebookCallback(leave_inner=True, leave_outer=True),
                            callbacks.ModelCheckpoint('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_loss_t_075.h5', monitor="loss", save_best_only=True)])
model3.save('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_075.h5')

print("Done T=0.75")

model4 = Sequential()
model4.add(CuDNNLSTM(256, input_shape=(length, len(chars)), return_sequences=True))
model4.add(CuDNNLSTM(256))
model4.add(Lambda(lambda x: x / 0.25))
model4.add(Dense(len(chars), activation='softmax'))
opt4 = optimizers.RMSprop(lr=0.0002)
model4.compile(loss="categorical_crossentropy", optimizer=opt4, metrics=["accuracy"])
model4.summary()
history4 = model4.fit(X_full_shuff, Y_full_shuff,
                  epochs=40,
                  verbose=0,
                  validation_split=0.05,
                  callbacks=[TQDMNotebookCallback(leave_inner=True, leave_outer=True),
                            callbacks.ModelCheckpoint('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_loss_t_025.h5', monitor="loss", save_best_only=True)])
model4.save('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_025.h5')

print("Done T=0.25")

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
cu_dnnlstm_3 (CuDNNLSTM)     (None, 40, 256)           303104    
_________________________________________________________________
cu_dnnlstm_4 (CuDNNLSTM)     (None, 256)               526336    
_________________________________________________________________
lambda_2 (Lambda)            (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 38)                9766      
Total params: 839,206
Trainable params: 839,206
Non-trainable params: 0
_________________________________________________________________


HBox(children=(IntProgress(value=0, description='Training', max=80, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 32', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 33', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 34', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 35', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 36', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 37', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 38', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 39', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 40', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 41', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 42', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 43', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 44', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 45', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 46', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 47', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 48', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 49', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 50', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 51', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 52', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 53', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 54', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 55', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 56', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 57', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 58', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 59', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 60', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 61', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 62', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 63', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 64', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 65', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 66', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 67', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 68', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 69', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 70', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 71', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 72', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 73', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 74', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 75', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 76', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 77', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 78', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 79', max=88909, style=ProgressStyle(description_width='…

Done T=1.0
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
cu_dnnlstm_5 (CuDNNLSTM)     (None, 40, 256)           303104    
_________________________________________________________________
cu_dnnlstm_6 (CuDNNLSTM)     (None, 256)               526336    
_________________________________________________________________
lambda_3 (Lambda)            (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 38)                9766      
Total params: 839,206
Trainable params: 839,206
Non-trainable params: 0
_________________________________________________________________


HBox(children=(IntProgress(value=0, description='Training', max=40, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 32', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 33', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 34', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 35', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 36', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 37', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 38', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 39', max=88909, style=ProgressStyle(description_width='…

Done T=0.75
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
cu_dnnlstm_7 (CuDNNLSTM)     (None, 40, 256)           303104    
_________________________________________________________________
cu_dnnlstm_8 (CuDNNLSTM)     (None, 256)               526336    
_________________________________________________________________
lambda_4 (Lambda)            (None, 256)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 38)                9766      
Total params: 839,206
Trainable params: 839,206
Non-trainable params: 0
_________________________________________________________________


HBox(children=(IntProgress(value=0, description='Training', max=40, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=88909, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 32', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 33', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 34', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 35', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 36', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 37', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 38', max=88909, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='Epoch 39', max=88909, style=ProgressStyle(description_width='…

Done T=0.25


In [5]:
main_seed = "shall i compare thee to a summer's day?\n"

def sample(m, seed, temperature=1.0):
    # print(seed)
    x_pred = np.zeros((1, length, len(chars)))
    for t, char in enumerate(seed):
        x_pred[0, t, char_index[char]] = 1
    preds = m.predict(x_pred, verbose=0)[0]
    
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

def create_sonnet(seed, n_lines, m, temperature=1.0):
    sonnet = ""
    curr_seed = seed
    while n_lines >= 0:
        next_ind = sample(m, curr_seed, temperature)
        next_char = index_char[next_ind]
        curr_seed = seed[1:] + next_char
        sonnet += next_char
        if next_char == "\n":
            n_lines -= 1
            print(n_lines)
    return sonnet

def create_sonnet_fixed_lines(seed, n_lines, m, temperature=1.0):
    sonnet = []
    curr_seed = seed
    for i in tqdm_notebook(range(n_lines)):
        curr_line = ""
        for j in range(len(seed)-1):
            next_ind = sample(m, curr_seed, temperature)
            next_char = index_char[next_ind]
            # while next_char == "\n":
            #    print(probs)
            #    next_ind = sample(m, curr_seed)
            #    next_char = index_char[next_ind]
            curr_seed = curr_seed[1:] + next_char
            curr_line += next_char
        # Artifically induce next line
        curr_seed = curr_seed[1:] + "\n"
        sonnet.append(curr_line)
    return sonnet


def create_sonnet_no_lines(seed, n_lines, m, temperature=1.0):
    sonnet = seed
    curr_seed = seed
    for i in range((n_lines-1) * len(seed)):
        next_ind = sample(m, curr_seed, temperature)
        next_char = index_char[next_ind]
        curr_seed = seed[1:] + next_char
        sonnet += next_char
    return sonnet

In [17]:
# m1 = load_model('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_1.h5')
# m1 = load_model('../models/eordentl_cudnnlstm_3_512_layer_rms_001_nostep_50_epochs.h5')
# m1 = load_model('../models/eordentl_cudnnlstm_2_512_layer_rms_0002_nostep_40_epochs_lr_t_075.h5')
# m1 = load_model('../models/eordentl_cudnnlstm_3_512_layer_rms_001_nostep_100_epochs.h5')
m1 = load_model('../models/eordentl_cudnnlstm_2_256_layer_rms_0002_nostep_best_loss_t_1.h5')

In [None]:
gen_sonnet = create_sonnet_fixed_lines(main_seed, 14, m1, 0.25)
print("\n\n".join(gen_sonnet))

HBox(children=(IntProgress(value=0, max=14), HTML(value='')))