In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
import numpy as np
import math
import matplotlib.pyplot as plt
import io

In [21]:
train_file = 'shakespeare_train.txt'
val_file = 'shakespeare_valid.txt'
with io.open(train_file,'r',encoding='utf8') as f :
    train = f.read()

with io.open(val_file,'r',encoding='utf8') as f :
    val = f.read()

In [8]:
# The unique characters in the file
vocab = sorted(set(train))
print(f'{len(vocab)} unique characters')
print(train[:200])

67 unique characters
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [16]:
character_to_int = {}
int_to_character = {}
for i,character in enumerate(vocab):
    character_to_int[character] = i
    int_to_character[i] = character

print(character_to_int.items())

dict_items([('\n', 0), (' ', 1), ('!', 2), ('$', 3), ('&', 4), ("'", 5), (',', 6), ('-', 7), ('.', 8), ('3', 9), (':', 10), (';', 11), ('?', 12), ('A', 13), ('B', 14), ('C', 15), ('D', 16), ('E', 17), ('F', 18), ('G', 19), ('H', 20), ('I', 21), ('J', 22), ('K', 23), ('L', 24), ('M', 25), ('N', 26), ('O', 27), ('P', 28), ('Q', 29), ('R', 30), ('S', 31), ('T', 32), ('U', 33), ('V', 34), ('W', 35), ('X', 36), ('Y', 37), ('Z', 38), ('[', 39), (']', 40), ('a', 41), ('b', 42), ('c', 43), ('d', 44), ('e', 45), ('f', 46), ('g', 47), ('h', 48), ('i', 49), ('j', 50), ('k', 51), ('l', 52), ('m', 53), ('n', 54), ('o', 55), ('p', 56), ('q', 57), ('r', 58), ('s', 59), ('t', 60), ('u', 61), ('v', 62), ('w', 63), ('x', 64), ('y', 65), ('z', 66)])


In [24]:
train_id = [character_to_int[character] for character in train]
val_id = [character_to_int[character] for character in val]
print(f'number of characters in training data: {len(train_id)}')
print(f'number of characters in validation data: {len(val)}')

number of characters in training data: 4351312
number of characters in validation data: 222025


In [45]:
maxlen = 100
step = 5    # We sample a new sequence every 'step' characters
def prepare_data(text):
    sentences = []    # This holds our input sequences
    next_chars = []    # This holds the targets (the next characters)
    for i in range(0, len(text) - maxlen, step):
        sentences.append(text[i: i + maxlen])
        next_chars.append(text[i + maxlen])
    
    return np.array(sentences), np.array(next_chars)

In [46]:
train_x, train_y = prepare_data(train_id)
val_x, val_y = prepare_data(val_id)
print(train_x.shape, train_y.shape)
print(val_x.shape, val_y.shape)

(870243, 100) (870243,)
(44385, 100) (44385,)


In [53]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
vocab_size = len(vocab)
def build_model():
    model=models.Sequential()
    model.add(tf.keras.layers.Embedding(vocab_size, 256))       
    model.add(tf.keras.layers.GRU(1024))
    model.add(tf.keras.layers.Dropout(0.25))
    model.add(layers.Dense(vocab_size, ))
    model.compile(optimizer='adam', loss = loss)
    return model

In [54]:
model = build_model()
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, None, 256)         17152     
_________________________________________________________________
gru (GRU)                    (None, 1024)              3938304   
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 67)                68675     
Total params: 4,024,131
Trainable params: 4,024,131
Non-trainable params: 0
_________________________________________________________________


In [55]:
history = model.fit(
      train_x, train_y,
      epochs = 20,
      batch_size = 128,
      validation_data=(val_x, val_y),
      #verbose = 0
)

Epoch 1/20
   7/6799 [..............................] - ETA: 9:38:29 - loss: 3.9917

KeyboardInterrupt: 