https://www.tensorflow.org/text/tutorials/text_generation

In [1]:
import os
import time

import numpy as np
import tensorflow as tf

In [2]:
path_to_file = tf.keras.utils.get_file(
    'shakespeare.txt',
    'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt'
)

In [3]:
# Read, then decode for py2 compat
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
print(f'Length of text: {len(text)} characters')

Length of text: 1115394 characters


In [4]:
print(text[:250])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



In [5]:
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')

65 unique characters


In [6]:
example_texts = ['abcdefg', 'xyz']

chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars

<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

In [7]:
ids_from_chars = tf.keras.layers.StringLookup(
    vocabulary=list(vocab),
    mask_token=None
)

In [8]:
ids = ids_from_chars(chars)
ids

<tf.RaggedTensor [[40, 41, 42, 43, 44, 45, 46], [63, 64, 65]]>

In [9]:
chars_from_ids = tf.keras.layers.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(),
    invert=True,
    mask_token=None
)

In [10]:
chars = chars_from_ids(ids)
chars

<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

In [11]:
tf.strings.reduce_join(chars, axis=-1).numpy()

array([b'abcdefg', b'xyz'], dtype=object)

In [12]:
def text_from_ids(ids):
    return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

In [13]:
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([19, 48, 57, ..., 46,  9,  1], dtype=int64)>

In [14]:
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

In [15]:
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))

F
i
r
s
t
 
C
i
t
i


In [16]:
seq_length = 100

In [17]:
sequences = ids_dataset.batch(seq_length + 1, drop_remainder=True)

for seq in sequences.take(1):
    print(chars_from_ids(seq))

tf.Tensor(
[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'
 b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'
 b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'
 b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'
 b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e'
 b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i'
 b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y'
 b'o' b'u' b' '], shape=(101,), dtype=string)


In [18]:
for seq in sequences.take(5):
    print(text_from_ids(seq).numpy())

b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'


In [19]:
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

In [20]:
split_input_target(list('Tensorflow'))

(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'],
 ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])

In [21]:
dataset = sequences.map(split_input_target)

In [22]:
for input_example, target_example in dataset.take(1):
    print(f'Input: {text_from_ids(input_example).numpy()}')
    print(f'Target: {text_from_ids(target_example).numpy()}')

Input: b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '


In [23]:
BATCH_SIZE = 64

# Buffer size to shuffle the dataset (TF data is designed to work with possibly infinite sequences, so it doesn't attempt to shuffle the entire sequence in memory. Instead, it maintains a buffer in which it shuffles elements)
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset

<PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>

In [24]:
vocab_size = len(ids_from_chars.get_vocabulary())
embedding_dim = 256
rnn_units = 1024

In [25]:
class MyModel(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, rnn_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

        self.gru = tf.keras.layers.GRU(
            rnn_units,
            return_sequences=True,
            return_state=True
        )

        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x, training=training)

        if states is None:
            states = self.gru.get_initial_state(x)

        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)

        if return_state:
            return x, states

        else:
            return x

In [26]:
model = MyModel(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

In [27]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, '# (batch_size, sequence_length, vocab_size)')

(64, 100, 66) # (batch_size, sequence_length, vocab_size)


In [28]:
model.summary()

Model: "my_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  16896     
                                                                 
 gru (GRU)                   multiple                  3938304   
                                                                 
 dense (Dense)               multiple                  67650     
                                                                 
Total params: 4,022,850
Trainable params: 4,022,850
Non-trainable params: 0
_________________________________________________________________


In [29]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

In [30]:
sampled_indices

array([25, 29, 52, 20, 34, 22, 10, 39, 60, 26, 22, 61,  1, 29, 56,  9,  6,
       37, 47, 43, 53, 60, 65,  8, 64, 47, 63, 62, 24,  2, 23, 24, 62, 55,
       46, 11, 22, 44, 36, 15, 42,  0, 48, 60, 27, 49, 22, 24, 46, 39, 42,
       26, 30, 15, 62, 60, 37, 49,  9, 64, 44, 11, 11, 49, 31, 59,  3, 44,
       36, 52, 44, 40, 10,  2, 53, 20, 12, 59, 26,  7,  5, 20, 58, 63, 46,
       40, 29,  8, 50, 48, 14, 32,  0, 63, 10, 14, 56, 52, 42, 57],
      dtype=int64)

In [31]:
print(f'Input:\n{text_from_ids(input_example_batch[0]).numpy()}')
print(f'Next char predictions:\n{text_from_ids(sampled_indices).numpy()}')

Input:
b"OP:\nPeace have they made with him indeed, my lord.\n\nKING RICHARD II:\nO villains, vipers, damn'd with"
Next char predictions:
b"LPmGUI3ZuMIv\nPq.'Xhdnuz-yhxwK JKwpg:IeWBc[UNK]iuNjIKgZcMQBwuXj.ye::jRt!eWmea3 nG;tM,&GsxgaP-kiAS[UNK]x3Aqmcr"


In [32]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

In [33]:
model.compile(optimizer='adam', loss=loss)

In [34]:
checkpoint_dir = 'ckpts/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt{epoch}')

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

In [35]:
EPOCHS = 20

In [36]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [37]:
class OneStep(tf.keras.Model):
    def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
        super().__init__()
        self.temperature = temperature
        self.model = model
        self.chars_from_ids = chars_from_ids
        self.ids_from_chars = ids_from_chars

        # Create a mask to prevent '[UNK]' from being generated
        skip_ids = self.ids_from_chars(['[UNK]'])[:, None]

        sparse_mask = tf.SparseTensor(
            # Put a -inf at each bad index.
            values=[-float('inf')] * len(skip_ids),
            indices=skip_ids,

            # Match the shape to the vocabulary
            dense_shape=[len(ids_from_chars.get_vocabulary())])

        self.prediction_mask = tf.sparse.to_dense(sparse_mask)

    @tf.function
    def generate_one_step(self, inputs, states=None):
        # Convert strings to token IDs
        input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
        input_ids = self.ids_from_chars(input_chars).to_tensor()

        # Run the model. predicted_logits.shape is [batch, char, next_char_logits]
        predicted_logits, states = self.model(
            inputs=input_ids,
            states=states,
            return_state=True
        )

        # Only use the last prediction
        predicted_logits = predicted_logits[:, -1, :]
        predicted_logits = predicted_logits / self.temperature

        # Apply the prediction mask: prevent '[UNK]'from being generated
        predicted_logits = predicted_logits + self.prediction_mask

        # Sample the output logits to generate token IDs
        predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
        predicted_ids = tf.squeeze(predicted_ids, axis=-1)

        # Convert from token ids to characters
        predicted_chars = self.chars_from_ids(predicted_ids)

        # Return the characters and model state
        return predicted_chars, states

In [38]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [39]:
start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()

print(result[0].numpy().decode('utf-8'), '\n' + '_' * 80)
print(f'Run time: {end - start}')

ROMEO:
Good mistress, and perish can make them not,
But bid him should my use for vantage must rich:
Then, God defend that I be yet but great oppression
faith, she is an expire to take an impoes,
Small me a liebunation of thy stoumphatice,
That mocks about him.

FRIAR LAURENCE:
His absence,
He might have done thee how to cut off all the language.

OXF ROWE:
Undwere you be a purpose divited counsel.
Reednesser, you have done waste it; what I profess
theirs, offer me night. Hark you, spot that the vastles vast thereof.
Sound trumpets! strike my fall, stark and safe, and perush thou
The verge my boar's son, those though us thus
falling in his stocking: his horse drave deliver?

WARWICK:
I go; hopping you must see your brother's shadow,
Kings yet our battles with the sweets o' the mother,
Cry humords that Richmond should be.
Diret aly his follies is wither'd blood!

LORD WILLOUGHBY:
Base!' quoth he. Hence-ball'd and me!
Just Warwick's daughter, since I till compare
Her father's skining so 

In [40]:
start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()

print(result, '\n' + '_' * 80)
print(f'Run time: {end - start}')

tf.Tensor(
[b"ROMEO:\nThat little farewell.\n\nLORD WILLOUGHBY:\nThe most sufficitive toem of sensible,\nHath held your court-a-morrow'd truture,\nThat shake the law, cry whether had an emperier.\n\nDUKE OF YORK:\nWhere is thy poor crowns!\nDays nightly, cousin Destruetion!\nNow Romeo, will he mury this fair cozen?\n\nBAPTISTA:\nI am busied the law upon this bloody moon.\n\nTRANIO:\nDidst welcome, Bush, and damn'd up so dreadful man?\n\nPost:\nA jealous arrs at voices.\nO, no, no, every grey, or else\nI know not on this castle's torture,\nThan dwell is father. Luringly to the sun,\nSo far as life, for nothing but deadly sendet,\nRone of Such God, he been an old traves; and to be\ncolour'd fear, I do, and that have man;\nThese high powerful profession shall pardon me,\nNever take him here, how has your viratural companion?\n\nISABELLA:\nJut never\nYou that so broken fattle, whose death\nTake in abtance, with all speed to die:\nThere stands that seeking him to put a little.\nOne of thyse

In [41]:
tf.saved_model.save(one_step_model, 'models/one_step')
one_step_reloaded = tf.saved_model.load('models/one_step')





INFO:tensorflow:Assets written to: models/one_step\assets


INFO:tensorflow:Assets written to: models/one_step\assets


In [42]:
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(100):
    next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
    result.append(next_char)

print(tf.strings.join(result)[0].numpy().decode('utf-8'))

ROMEO:
If love be balm'd, proud, King Edward's drid,
One nature, my master named,--there's sendence,
Ere h


In [43]:
class CustomTraining(MyModel):
    @tf.function
    def train_step(self, inputs):
        inputs, labels = inputs

        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = self.loss(labels, predictions)
            
        grads = tape.gradient(loss, model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
        return {'loss': loss}

In [44]:
model = CustomTraining(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

In [45]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
)

In [46]:
model.fit(dataset, epochs=1)



<keras.callbacks.History at 0x185dd94fd90>

In [47]:
EPOCHS = 10

mean = tf.metrics.Mean()

for epoch in range(EPOCHS):
    start = time.time()
    mean.reset_states()

    for (batch_n, (inp, target)) in enumerate(dataset):
        logs = model.train_step([inp, target])
        mean.update_state(logs['loss'])

        if batch_n % 50 == 0:
            template = f'Epoch: {epoch + 1} | Batch: {batch_n} | Loss: {logs["loss"]:.4f}'
            print(template)

    # Saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch))

    print(f'Epoch: {epoch + 1} | Loss: {mean.result().numpy():.4f}')
    print(f'Time taken for 1 epoch: {time.time() - start:.2f} sec')
    print('_' * 80)

model.save_weights(checkpoint_prefix.format(epoch=epoch))

Epoch: 1 | Batch: 0 | Loss: 2.1749
Epoch: 1 | Batch: 50 | Loss: 2.0676
Epoch: 1 | Batch: 100 | Loss: 1.9495
Epoch: 1 | Batch: 150 | Loss: 1.8234
Epoch: 1 | Loss: 1.9935
Time taken for 1 epoch: 8.39 sec
________________________________________________________________________________
Epoch: 2 | Batch: 0 | Loss: 1.8705
Epoch: 2 | Batch: 50 | Loss: 1.7184
Epoch: 2 | Batch: 100 | Loss: 1.6541
Epoch: 2 | Batch: 150 | Loss: 1.6737
Epoch: 2 | Loss: 1.7201
Time taken for 1 epoch: 7.35 sec
________________________________________________________________________________
Epoch: 3 | Batch: 0 | Loss: 1.6355
Epoch: 3 | Batch: 50 | Loss: 1.5523
Epoch: 3 | Batch: 100 | Loss: 1.5101
Epoch: 3 | Batch: 150 | Loss: 1.5382
Epoch: 3 | Loss: 1.5593
Time taken for 1 epoch: 8.67 sec
________________________________________________________________________________
Epoch: 4 | Batch: 0 | Loss: 1.4343
Epoch: 4 | Batch: 50 | Loss: 1.4782
Epoch: 4 | Batch: 100 | Loss: 1.4549
Epoch: 4 | Batch: 150 | Loss: 1.4564
Epoch: