<a href="https://colab.research.google.com/github/nforesperance/TensorFlow/blob/master/text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

import os

import tensorflow as tf
from tensorflow import keras

In [0]:
import numpy
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import LSTM
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras.utils import np_utils

In [0]:
# load ascii text and covert to lowercase
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
# create mapping of unique chars to integers
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
n_chars = len(raw_text)
n_vocab = len(chars)

In [0]:
# dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
	seq_in = raw_text[i:i + seq_length]
	seq_out = raw_text[i + seq_length]
	dataX.append([char_to_int[char] for char in seq_in])
	dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

Total Patterns:  163681


In [0]:
# reshape X to be [samples, time steps, features]
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))
# normalize
X = X / float(n_vocab)
# one hot encode the output variable
y = np_utils.to_categorical(dataY)

In [0]:
# Define a simple sequential model
def create_model():
  model = Sequential()
  model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
  model.add(Dropout(0.2))
  model.add(LSTM(256))
  model.add(Dropout(0.2))
  model.add(Dense(y.shape[1], activation='softmax'))
  model.compile(loss='categorical_crossentropy', optimizer='adam')

  return model
model = create_model()

In [0]:
checkpoint_path = "trained/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
save_path = "trained"

In [0]:
#N:B The Followind Code is ran only once

# Create a callback that saves the model's weights
easystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=1, mode='min', baseline=None, restore_best_weights=True)

checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                            monitor='val_loss',
                            verbose=1,
                            save_weights_only=True,
                            save_best_only=True,
                            mode='min')

# Train the model with the new callback
model.fit(X, y, validation_split=0.2, 
                    epochs=5, 
                    batch_size=128, 
                    verbose = 2,
                    callbacks=[checkpoint,easystopping])


In [0]:
model.save(save_path+"/best.h5")

In [0]:
for i in range (5):
  del model
  model = tf.keras.models.load_model(save_path+'/best.h5')  
  easystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=1, mode='min', baseline=None, restore_best_weights=True)

  checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                            monitor='val_loss',
                            verbose=1,
                            save_weights_only=True,
                            save_best_only=True,
                            mode='min')

  # Train the model with the new callback
  model.fit(X, y, validation_split=0.2, 
                    epochs=5, 
                    batch_size=128, 
                    verbose = 2,
                    callbacks=[checkpoint,easystopping])
  model.save(save_path+"/best.h5")

## Text Generation

In [0]:
import sys

In [0]:
# del model
model = tf.keras.models.load_model(save_path+'/best.h5') 
int_to_char = dict((i, c) for i, c in enumerate(chars))
# pick a random seed
start = numpy.random.randint(0, len(dataX)-1)
pattern = dataX[start]
print ("Seed:")
print ("\"", ''.join([int_to_char[value] for value in pattern]), "\"")
# generate characters
for i in range(1000):
	x = numpy.reshape(pattern, (1, len(pattern), 1))
	x = x / float(n_vocab)
	prediction = model.predict(x, verbose=0)
	index = numpy.argmax(prediction)
	result = int_to_char[index]
	seq_in = [int_to_char[value] for value in pattern]
	sys.stdout.write(result)
	pattern.append(index)
	pattern = pattern[1:len(pattern)]
print ("\nDone.")