# Hello RNN next character predictor.

RNN trained with Shakespeare's work to predict next character in a sentence

In [1]:
# Prerequisites
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

print("Python Version: ", sys.version)
print("Numpy Version: ", np.__version__)
print("Pandas Version: ", pd.__version__)
print("TensorFlow Version: ", tf.__version__)

2025-03-29 17:24:03.251788: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743294243.365944    4424 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743294243.397125    4424 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743294243.647043    4424 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743294243.647062    4424 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743294243.647064    4424 computation_placer.cc:177] computation placer alr

Python Version:  3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0]
Numpy Version:  2.1.3
Pandas Version:  2.2.3
TensorFlow Version:  2.19.0


### Get Data (Shakespeare's works)

In [2]:
input_data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filepath = keras.utils.get_file("shakespeare.txt", input_data_url)
with open(filepath) as f:
    shakespeare_text = f.read()

Downloading data from https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
[1m1115394/1115394[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [3]:
print("Length of text:" , len(shakespeare_text))
print("Begins with:\n", shakespeare_text[:200] )

Length of text: 1115394
Begins with:
 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


### Convert to lowercase and Encode characters 

In [4]:
# Show all 39 characters (lower case)
chars = "".join(sorted(set(shakespeare_text.lower())))
print("Characters: ", chars)
print("Number of characters: ", len(chars))

Characters:  
 !$&',-.3:;?abcdefghijklmnopqrstuvwxyz
Number of characters:  39


In [5]:
text_vec_layer = tf.keras.layers.TextVectorization(split="character", standardize="lower")
text_vec_layer.adapt([shakespeare_text])
encoded = text_vec_layer([shakespeare_text])[0]

I0000 00:00:1743294270.618185    4424 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9315 MB memory:  -> device: 0, name: NVIDIA TITAN V, pci bus id: 0000:04:00.0, compute capability: 7.0


In [6]:
# Drop padding (0) and unknown (1) tokens
encoded -= 2
# Number of tokens
nr_tokens = text_vec_layer.vocabulary_size() - 2
print("Number of tokens: ", nr_tokens )
ds_size = len(encoded)
print("Dataset size: ", ds_size )


Number of tokens:  39
Dataset size:  1115394


Helper function to convert sequence of IDs to inputs/targets

In [7]:
def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=32):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda window_ds: window_ds.batch(length + 1))
    if shuffle:
        ds = ds.shuffle(100_000, seed=seed)
    ds = ds.batch(batch_size)
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [8]:
# try to_dataset()
ds_sample = list(to_dataset(text_vec_layer(["To be"])[0], length=4))
print(ds_sample)

[(<tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[ 4,  5,  2, 23]])>, <tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[ 5,  2, 23,  3]])>)]


2025-03-29 17:24:39.965831: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


### Split into test, validation, and training sets

In [9]:
length = 100
tf.random.set_seed(42)
ds_train = to_dataset(encoded[:1_000_000], length=length, shuffle=True, seed=42)
ds_val = to_dataset(encoded[1_000_000:1_060_000], length=length)
ds_test = to_dataset(encoded[1_060_000:], length=length)

### Build and train the model

NOTE:  Need GPU to train in a reasonable time

In [10]:
tf.random.set_seed(42) 
model = keras.Sequential([
    keras.layers.Embedding(input_dim=nr_tokens, output_dim=16),  # Embed the character IDs
    keras.layers.GRU(128, return_sequences=True),
    keras.layers.Dense(nr_tokens, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam", metrics=["accuracy"])

cb_model_ckpt = keras.callbacks.ModelCheckpoint("my_shakespeare_model.keras", 
                                                monitor="val_accuracy", save_best_only=True)
history = model.fit(ds_train, validation_data=ds_val, epochs=10, callbacks=[cb_model_ckpt])

Epoch 1/10


I0000 00:00:1743294293.760531    4506 cuda_dnn.cc:529] Loaded cuDNN version 90300


  31242/Unknown [1m251s[0m 8ms/step - accuracy: 0.5477 - loss: 1.4959

2025-03-29 17:28:56.121100: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-03-29 17:28:56.121121: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:28:56.121125: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443
2025-03-29 17:28:56.121131: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m258s[0m 8ms/step - accuracy: 0.5477 - loss: 1.4959 - val_accuracy: 0.5342 - val_loss: 1.5984
Epoch 2/10


2025-03-29 17:29:03.236405: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-03-29 17:29:03.236426: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 8134076999784538461
2025-03-29 17:29:03.236430: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 2437941258184784760
2025-03-29 17:29:03.236445: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175


[1m31244/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.5971 - loss: 1.2938

2025-03-29 17:33:13.101615: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 116751330991363435
2025-03-29 17:33:13.101632: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:33:13.101636: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 5112715807263096129


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m257s[0m 8ms/step - accuracy: 0.5971 - loss: 1.2938 - val_accuracy: 0.5396 - val_loss: 1.5810
Epoch 3/10


2025-03-29 17:33:20.073441: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 15744293531091015826
2025-03-29 17:33:20.073461: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175
2025-03-29 17:33:20.073463: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m31242/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.6018 - loss: 1.2728

2025-03-29 17:37:27.617815: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 17:37:27.617842: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:37:27.617850: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m255s[0m 8ms/step - accuracy: 0.6018 - loss: 1.2728 - val_accuracy: 0.5429 - val_loss: 1.5672
Epoch 4/10


2025-03-29 17:37:34.586824: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 5114529995453910681
2025-03-29 17:37:34.586845: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175
2025-03-29 17:37:34.586847: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 15744293531091015826
2025-03-29 17:37:34.586855: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 16773507648221647836


[1m31244/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.6040 - loss: 1.2627

2025-03-29 17:41:40.749235: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 17:41:40.749260: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:41:40.749270: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m253s[0m 8ms/step - accuracy: 0.6040 - loss: 1.2627 - val_accuracy: 0.5447 - val_loss: 1.5645
Epoch 5/10


2025-03-29 17:41:47.864428: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-03-29 17:41:47.864449: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 8134076999784538461
2025-03-29 17:41:47.864456: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175
2025-03-29 17:41:47.864459: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 2437941258184784760


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.6054 - loss: 1.2567

2025-03-29 17:45:33.082245: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:45:33.082265: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443
2025-03-29 17:45:33.082271: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 5112715807263096129


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m232s[0m 7ms/step - accuracy: 0.6054 - loss: 1.2567 - val_accuracy: 0.5436 - val_loss: 1.5631
Epoch 6/10


2025-03-29 17:45:39.401527: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 15744293531091015826
2025-03-29 17:45:39.401547: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 16773507648221647836
2025-03-29 17:45:39.401553: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175


[1m31244/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 7ms/step - accuracy: 0.6069 - loss: 1.2514

2025-03-29 17:49:35.698316: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 17:49:35.698338: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:49:35.698355: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m243s[0m 8ms/step - accuracy: 0.6069 - loss: 1.2514 - val_accuracy: 0.5488 - val_loss: 1.5551
Epoch 7/10


2025-03-29 17:49:42.498573: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175


[1m31245/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.6076 - loss: 1.2481

2025-03-29 17:53:48.459683: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 17:53:48.459707: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:53:48.459716: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m253s[0m 8ms/step - accuracy: 0.6076 - loss: 1.2481 - val_accuracy: 0.5478 - val_loss: 1.5560
Epoch 8/10


2025-03-29 17:53:55.278873: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 8134076999784538461
2025-03-29 17:53:55.278895: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175
2025-03-29 17:53:55.278903: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 2437941258184784760


[1m31244/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.6083 - loss: 1.2449

2025-03-29 17:58:00.629926: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 17:58:00.629946: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 17:58:00.629954: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m252s[0m 8ms/step - accuracy: 0.6083 - loss: 1.2449 - val_accuracy: 0.5510 - val_loss: 1.5504
Epoch 9/10


2025-03-29 17:58:07.525901: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175
2025-03-29 17:58:07.525924: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m31243/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.6088 - loss: 1.2425

2025-03-29 18:02:11.873289: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 18:02:11.873321: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 18:02:11.873337: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 8ms/step - accuracy: 0.6088 - loss: 1.2425 - val_accuracy: 0.5489 - val_loss: 1.5532
Epoch 10/10


2025-03-29 18:02:18.695307: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 8134076999784538461
2025-03-29 18:02:18.695327: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175
2025-03-29 18:02:18.695332: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 2437941258184784760


[1m31245/31247[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 7ms/step - accuracy: 0.6092 - loss: 1.2409

2025-03-29 18:06:07.478462: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 12104237751170906778
2025-03-29 18:06:07.478488: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 9621422523188769149
2025-03-29 18:06:07.478496: I tensorflow/core/framework/local_rendezvous.cc:430] Local rendezvous send item cancelled. Key hash: 6478238555696045443


[1m31247/31247[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m235s[0m 7ms/step - accuracy: 0.6092 - loss: 1.2409 - val_accuracy: 0.5496 - val_loss: 1.5515


2025-03-29 18:06:13.775263: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 8134076999784538461
2025-03-29 18:06:13.775295: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 14415402829844293175


### Predict

In [11]:
# Wrap with preprocessing 
model_shakespeare = keras.Sequential([
    text_vec_layer, 
    keras.layers.Lambda(lambda X: X - 2),  # skip <PAD> or <UNK> tokens
    model
])

In [13]:
sentence = tf.constant(["To be or not to b"])
y_proba = model_shakespeare.predict(sentence)[0, -1]
#y_proba = model_shakespeare.predict(["To be or not to b"])[0, -1]
y_pred = tf.argmax(y_proba)  # Pick the most probable character
text_vec_layer.get_vocabulary()[y_pred + 2]

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 124ms/step


np.str_('e')

In [14]:
sentence = tf.constant(["Romeo and Julie"])
y_proba = model_shakespeare.predict(sentence)[0, -1]
#y_proba = model_shakespeare.predict(["To be or not to b"])[0, -1]
y_pred = tf.argmax(y_proba)  # Pick the most probable character
text_vec_layer.get_vocabulary()[y_pred + 2]

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step


np.str_('t')