# Text Generation using LSTM for Sherlock Holmes

In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # avoid parallelism warning

# to do: try to use tensorflow-metal instead of base tf
import wandb
import tensorflow as tf
import matplotlib.pyplot as plt
import platform
import os
from collections import Counter
import re
import copy

print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)

Python version: 3.12.7
Tensorflow version: 2.20.0
Keras version: 3.11.3


In [2]:

from wandb.integration.keras import WandbMetricsLogger

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdupaya[0m ([33mdupaya-kth-royal-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
cache_dir = './tmp'
dataset_file_name = 'sherlockholmes.txt'

dataset_file_path = dataset_file_name

print(dataset_file_path)

sherlockholmes.txt


In [4]:
text = open(dataset_file_path, mode='r').read()
print(text[:250])

The Adventures of Sherlock Holmes

by Arthur Conan Doyle

I. A SCANDAL IN BOHEMIA


I.

To Sherlock Holmes she is always _the_ woman. I have seldom heard him
mention her under any other name. In his eyes she eclipses and
predominates the whole of her


In [5]:
from transformers import GPT2Tokenizer
from tokenizers import ByteLevelBPETokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
tokenizer = ByteLevelBPETokenizer()
save_dir="./tokenizer"

tokenizer.train(files=["sherlockholmes.txt"], vocab_size=30_000, min_frequency=2)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Save the tokenizer model
tokenizer.save_model(save_dir)

# Load the tokenizer using GPT2Tokenizer
custom_tokenizer = GPT2Tokenizer.from_pretrained(save_dir)








Tokenize the text file

In [7]:
input_ids = custom_tokenizer.encode(text)
print(f"Total tokens in text: {len(input_ids)}")

Total tokens in text: 144597


Prepping data

In [8]:
tokens_dataset = tf.data.Dataset.from_tensor_slices(input_ids)

In [9]:
sequence_length = 100
examples_per_epoch = len(input_ids) // (sequence_length + 1)

print(f'Examples per epoch: {examples_per_epoch}')


Examples per epoch: 1431


In [10]:
# Generate batched sequences out of the token dataset
sequences = tokens_dataset.batch(sequence_length + 1, drop_remainder=True)

In [11]:
# Split sequences into input and target
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

In [12]:
# Show some examples of input-target pairs
for input_example, target_example in dataset.take(1):
    for i in range(5):
        if i < len(input_example):
            input_token = input_example[i].numpy()
            target_token = target_example[i].numpy()
            print(f'Step {i:2d}')
            print(f'  input token: {input_token} ({custom_tokenizer.decode([input_token])})')
            print(f'  expected output token: {target_token} ({custom_tokenizer.decode([target_token])})')

Step  0
  input token: 364 (The)
  expected output token: 407 ( A)
Step  1
  input token: 407 ( A)
  expected output token: 67 (d)
Step  2
  input token: 67 (d)
  expected output token: 2428 (venture)
Step  3
  input token: 2428 (venture)
  expected output token: 82 (s)
Step  4
  input token: 82 (s)
  expected output token: 283 ( of)


2025-10-11 17:58:51.657301: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [13]:
# Batch size.
BATCH_SIZE = 64

# Buffer size to shuffle the dataset (TF data is designed to work
# with possibly infinite sequences, so it doesn't attempt to shuffle
# the entire sequence in memory. Instead, it maintains a buffer in
# which it shuffles elements).
BUFFER_SIZE = 10000

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

dataset

<_BatchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int32, name=None), TensorSpec(shape=(64, 100), dtype=tf.int32, name=None))>

In [14]:
print('Batched dataset size: {}'.format(len(list(dataset.as_numpy_iterator()))))

Batched dataset size: 22


2025-10-11 17:58:51.768663: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [15]:
# Get vocabulary size from tokenizer
vocab_size = custom_tokenizer.vocab_size + 1  # +1 for the padding token if added

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

In [16]:

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.models.Sequential()

    model.add(tf.keras.layers.InputLayer(batch_input_shape=[batch_size, None]))

    model.add(tf.keras.layers.Embedding(
        input_dim=vocab_size,
        output_dim=embedding_dim
    ))

    model.add(tf.keras.layers.LSTM(
        units=rnn_units,
        return_sequences=True,
        stateful=True,
        recurrent_initializer=tf.keras.initializers.GlorotNormal()
    ))

    model.add(tf.keras.layers.Dense(vocab_size))

    return model
model = build_model(vocab_size, embedding_dim, rnn_units, BATCH_SIZE)

In [17]:
model.summary()

In [18]:
tf.keras.utils.plot_model(
    model,
    show_shapes=True,
    show_layer_names=True,
)

You must install pydot (`pip install pydot`) for `plot_model` to work.


In [19]:
# Loss function
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(
        y_true=labels,
        y_pred=logits,
        from_logits=True
    )

# Compile the model
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(
    optimizer=adam_optimizer,
    loss=loss
)

# Directory for checkpoints
checkpoint_dir = cache_dir + '/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Checkpoint filename
checkpoint_prefix = os.path.join(checkpoint_dir, 'model_epoch_{epoch:02d}.weights.h5')

# checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
#     'model_checkpoint.h5', save_best_only=True
# )

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_best_only=True,
    save_weights_only=True,
    monitor='loss',
    mode='min',
    verbose=1
)


early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='loss', patience=3, restore_best_weights=True
)


In [20]:
run = wandb.init(project='my-first-project', group='lstm-experiment')

EPOCHS = 80
wandb_callback = WandbMetricsLogger(log_freq='epoch')
history = model.fit(
    x=dataset,
    epochs=EPOCHS,
    callbacks=[
        checkpoint_cb, early_stopping, wandb_callback
    ]
)

run.finish()

Epoch 1/80
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - loss: 7.9252
Epoch 1: loss improved from None to 7.21324, saving model to ./tmp/checkpoints/model_epoch_01.weights.h5
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 3s/step - loss: 7.2132
Epoch 2/80
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - loss: 6.4700
Epoch 2: loss improved from 7.21324 to 6.47345, saving model to ./tmp/checkpoints/model_epoch_02.weights.h5
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 3s/step - loss: 6.4735
Epoch 3/80
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - loss: 6.3639
Epoch 3: loss improved from 6.47345 to 6.35930, saving model to ./tmp/checkpoints/model_epoch_03.weights.h5
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 3s/step - loss: 6.3593
Epoch 4/80
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - loss: 6.1881
Epoch 4: loss improved 

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch/epoch,▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▇▆▆▆▅▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
epoch/epoch,79.0
epoch/learning_rate,0.001
epoch/loss,1.8403


In [None]:
def render_training_history(training_history):
    loss = training_history.history['loss']
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.plot(loss, label='Training set')
    plt.legend()
    plt.grid(linestyle='--', linewidth=1, alpha=0.5)
    plt.show()

In [None]:
render_training_history(history)

# Load best model

In [None]:
tf.train.latest_checkpoint(checkpoint_dir)
simplified_batch_size = 1

model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

model.load_weights('tmp/checkpoints/model_epoch_80.weights.h5')

model.build(tf.TensorShape([simplified_batch_size, None]))

In [None]:
tf.train.latest_checkpoint(checkpoint_dir)
import glob

# Get list of all weight files
weight_files = glob.glob(os.path.join(checkpoint_dir, '*.weights.h5'))

if weight_files:
    latest = max(weight_files, key=os.path.getctime)  # or sort by name, etc.
    model.load_weights(latest)
    print(f"Loaded weights from: {latest}")
else:
    print("No weights file found. Using randomly initialized model.")

In [None]:
model.summary()

# test model

In [None]:
model.summary()

In [None]:
def generate_text(model, start_string, num_generate=1000, temperature=1.0):
    # Tokenize the start string
    input_ids = custom_tokenizer.encode(start_string)
    input_ids = tf.expand_dims(input_ids, 0)

    # Empty list to store generated tokens
    generated_tokens = []

    # Reset states for RNN layers if they are stateful
    for layer in model.layers:
        if hasattr(layer, 'reset_states'):
            layer.reset_states()

    for _ in range(num_generate):
        predictions = model(input_ids)
        # Remove batch dimension
        predictions = tf.squeeze(predictions, 0)

        # Use the last prediction
        predictions = predictions[-1:, :] / temperature
        predicted_id = tf.random.categorical(
            predictions,
            num_samples=1
        )[-1, 0].numpy()

        # Append the predicted token
        generated_tokens.append(predicted_id)

        # Update input_ids for next prediction
        input_ids = tf.expand_dims([predicted_id], 0)

    # Decode the generated tokens
    generated_text = custom_tokenizer.decode(generated_tokens)
    
    # Return the full text
    return start_string + generated_text

# Generate text with the model
print(generate_text(model, start_string="Sherlock ", num_generate=500))

In [None]:
# Generate the text with default temperature (1.0).
print(generate_text(model, start_string=u"Sherlock "))

In [None]:
# Generate the text with higher temperature to get more unexpected results.
print(generate_text(model, start_string=u"Sherlock: ", temperature=1.5))

# Experiments

## Experiments on grammar


In [None]:
import language_tool_python
def evaluate_grammar_quality(generated_text):

    try:
        tool = language_tool_python.LanguageTool('en-GB')
    except:
        print("LanguageTool not available. Installing via pip...")
        import subprocess
        import sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", "language-tool-python"])
        tool = language_tool_python.LanguageTool('en-GB')
    
    # Find grammar and spelling errors
    matches = tool.check(generated_text)
    
    # Count errors by category
    error_categories = Counter()
    for match in matches:
        error_categories[match.category] = error_categories.get(match.category, 0) + 1
    
    # Calculate word count
    words = re.findall(r'\b\w+\b', generated_text)
    word_count = len(words)
    
    # Calculate errors per 100 words
    errors_per_100_words = (len(matches) / max(1, word_count)) * 100
    
    results = {
        'total_errors': len(matches),
        'error_categories': dict(error_categories),
        'errors_per_100_words': errors_per_100_words,
        'word_count': word_count
    }
    
    return results

def merge_error_categories(error_cat_list):
    merged = Counter()
    for error_cat in error_cat_list:
        for category, count in error_cat.items():
            merged[category] += count
    
    # Calculate averages
    result = {category: count / len(error_cat_list) for category, count in merged.items()}
    return result


def run_evaluations(model, start_string, num_runs=5, num_tokens=300, temperatures=[0.7, 1.0]):
    all_results = {}
    
    for temp in temperatures:
        print(f"\nEvaluating temperature: {temp}")
        run_results = []
        
        for run in range(1, num_runs + 1):
            print(f"  Run {run}/{num_runs}...")
            text = generate_text(model, start_string, num_tokens, temp)
            run_results.append(evaluate_grammar_quality(text))
        
        # Calculate averages across runs
        avg_results = {
            'total_errors': sum(r['total_errors'] for r in run_results) / num_runs,
            'errors_per_100_words': sum(r['errors_per_100_words'] for r in run_results) / num_runs,
            'word_count': sum(r['word_count'] for r in run_results) / num_runs,
            'error_categories': merge_error_categories([r['error_categories'] for r in run_results])
        }
        
        # Store individual runs for reference
        avg_results['individual_runs'] = run_results
        all_results[temp] = avg_results
    
    return all_results


In [None]:
# Define temperatures to test
temperatures = [1.0]
start_string = "Sherlock"

# Run the evaluations 5 times for each temperature
results = run_evaluations(model, start_string, num_runs=5, num_tokens=300, temperatures=temperatures)


In [None]:

# Print results
print("\n===== AVERAGED RESULTS =====")
for temp, result in results.items():
    print(f"\nTemperature: {temp}")
    print(f"Average total errors: {result['total_errors']:.2f}")
    print(f"Average errors per 100 words: {result['errors_per_100_words']:.2f}")
    print(f"Average word count: {result['word_count']:.2f}")
    print("Average error categories:")
    
    # Sort categories by frequency for better readability
    sorted_categories = sorted(
        result['error_categories'].items(), 
        key=lambda x: x[1], 
        reverse=True
    )
    
    for category, avg_count in sorted_categories:
        print(f"  - {category}: {avg_count:.2f}")