In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

import numpy as np
import seaborn as sns
from matplotlib import  pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm

# Helper functions
- We have to rejoin our character to form a string

In [2]:
def text_from_ids(ids, chars_from_ids):
  return tf.strings.reduce_join(
      chars_from_ids(ids), axis=-1
  )

# Lets load our dataset
- Since for text generation we will use the shakespeer dataset
- From there lets do a bit of eda

In [3]:
(train_ds, test_ds, valid_ds), info_ds = tfds.load(
    'tiny_shakespeare', split=['train', 'test', 'validation'], with_info=True,
    data_dir='data/'
)

In [4]:
'Total Length: {} characters'.format(len(next(iter(train_ds))['text'].numpy()))

'Total Length: 1003854 characters'

In [5]:
print(next(iter(train_ds))['text'].numpy()[:250].decode('UTF-8'))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



In [6]:
# Grab unique chracters - vocab
vocab = sorted(set(next(iter(train_ds))['text'].numpy().decode('UTF-8')))
vocab[:10], len(vocab)

(['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3'], 65)

In [7]:
ids_from_chars = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=list(vocab), mask_token=None
)
chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None
)

# Data preparation
- Lets prepare our data to be used for our model
- We first need to convert 

In [8]:
batch_size = 64
buffer_size = 10000
seq_len = 100 + 1

In [9]:
train_ds = train_ds.map(lambda x: ids_from_chars(tf.strings.unicode_split(x['text'], 'UTF-8')))
train_ds = train_ds.unbatch().batch(seq_len, drop_remainder=True)
train_ds = train_ds.map(lambda x: (x[:-1], x[1:]))
train_ds = train_ds.shuffle(buffer_size).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

In [10]:
valid_ds = valid_ds.map(lambda x: ids_from_chars(tf.strings.unicode_split(x['text'], 'UTF-8')))
valid_ds = valid_ds.unbatch().batch(seq_len, drop_remainder=True)
valid_ds = valid_ds.map(lambda x: (x[:-1], x[1:]))
valid_ds = valid_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

In [11]:
test_ds = test_ds.map(lambda x: ids_from_chars(tf.strings.unicode_split(x['text'], 'UTF-8')))
test_ds = test_ds.unbatch().batch(seq_len, drop_remainder=True)
test_ds = test_ds.map(lambda x: (x[:-1], x[1:]))
test_ds = test_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

# Build model
- We start building with the subclass api as a base test
- This will allow more control such as manipulating the state of each layer

In [12]:
class CharGenRNN(tf.keras.Model):
  def __init__(self, vocab_size=len(vocab), embedding_dim=256, rnn_units=1024):
    super(CharGenRNN, self).__init__()

    # Embedding Layer
    self.embedding = tf.keras.layers.Embedding(
        vocab_size, embedding_dim
    )
    self.lstm = tf.keras.layers.GRU(
        rnn_units, return_sequences=True, return_state=True
    )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, x, states=None, return_state=False, training=False):
    x = self.embedding(x, training=training)
    if states is None:
      states = self.lstm.get_initial_state(x)

    x, state = self.lstm(x, initial_state=states, training=training)
    x = self.dense(x)

    if return_state:
      return x, states
    else:
      return x

In [13]:
model = CharGenRNN(vocab_size=len(ids_from_chars.get_vocabulary()))

In [14]:
for input_example_batch, target_example_batch in train_ds.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

(64, 100, 66) # (batch_size, sequence_length, vocab_size)


In [15]:
model.summary()

Model: "char_gen_rnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        multiple                  16896     
_________________________________________________________________
gru (GRU)                    multiple                  3938304   
_________________________________________________________________
dense (Dense)                multiple                  67650     
Total params: 4,022,850
Trainable params: 4,022,850
Non-trainable params: 0
_________________________________________________________________


In [16]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

In [17]:
print("Input:\n", text_from_ids(input_example_batch[0], chars_from_ids).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices, chars_from_ids).numpy())

Input:
 b'y were\nThe common muck of the world: he covets less\nThan misery itself would give; rewards\nHis deeds'

Next Char Predictions:
 b'YlounnHxX[UNK]!LfilHDb m-oK:E.s-XgxmH-nWVM3r:s:LWiUl!BSQ fP3!XhVrTXyS?:aeq xPFJhKFKb,TPUKd[UNK]YBM-L:W&zDwlI'


# Training Loop
- Lets create a custom training loop to train our data

In [18]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)

In [19]:
@tf.function
def train_step(input_example_batch, target_example_batch):
  with tf.GradientTape() as tape:
    logits = model(input_example_batch, training=True)
    loss_val = loss(target_example_batch, logits)
  
  grads = tape.gradient(logits, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))

  return loss_val

@tf.function
def val_step(input_example_batch, target_example_batch):
  logits = model(input_example_batch)
  loss_val = loss(target_example_batch, logits)
  return loss_val

In [20]:
for epoch in tqdm(range(20)):
  for step, (input_example_batch, target_example_batch) in tqdm(enumerate(train_ds)):
    training_loss = train_step(input_example_batch, target_example_batch)
    if step % 200 == 0:
        print(
            "Training loss (for one batch) at step %d: %.4f"
            % (step, float(training_loss))
        )
        print("Seen so far: %s samples" % ((step + 1) * 64))

  for step, (val_image, val_label) in tqdm(enumerate(valid_ds)):
    test_loss = val_step(val_image, val_label)

    if step % 200 == 0:
        print(
            "Valid loss (for one batch) at step %d: %.4f"
            % (step, float(test_loss))
        )
        print("Seen so far: %s samples" % ((step + 1) * 64))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 4.1897
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.1770
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1671
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.1908
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1578
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.1981
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.2090
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2020
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.2076
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2036
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1946
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2054
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1955
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2075
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1628
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2093
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1775
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2105
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1687
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2111
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.2036
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2116
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1979
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2118
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1895
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2119
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.2077
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2121
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1882
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2121
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1914
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2122
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.2059
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2122
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.2079
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2123
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1614
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2123
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Training loss (for one batch) at step 0: 5.1880
Seen so far: 64 samples



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Valid loss (for one batch) at step 0: 5.2123
Seen so far: 64 samples




# Evaluate model

In [21]:
for input_example_batch, target_example_batch in test_ds.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

  sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
  sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

  print("Input:\n", text_from_ids(input_example_batch[0], chars_from_ids).numpy())
  print()
  print("Next Char Predictions:\n", text_from_ids(sampled_indices, chars_from_ids).numpy())

(64, 100, 66) # (batch_size, sequence_length, vocab_size)
Input:
 b"rance ta'en\nAs shall with either part's agreement stand?\n\nBAPTISTA:\nNot in my house, Lucentio; for, "

Next Char Predictions:
 b"AghrRC'jYCgDmtCLQAqY?jbNg?t?j!qjLLYKnjYCb!-qYQWNQgjYCKehvJAxiqCY!Cb!EgCtGLBMLj-CjpW[UNK]jHmCrCjD3qrp!Lq[UNK]"


In [45]:
class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(
        input_ids, states=states, return_state=True
    )
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states

In [46]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [52]:
%%time
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)

print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)

ROMEO:YiA?AgYqhjfVfmJeHxeYH'nC!urjjeCAmN&JYc!!ln-gh?CrNjZqg!'
L?aUp!QKgCD?qg!'pejqCijIChAq
n&jUmih?NBqYL
rbCvghgCDmGx'YrLqwjANrLY
p!mn!xqC!fPJvufrpjM?!q&;SnV!hqqQlMjqjQNJjcgqC
gPmGNYJb?CYErP
3xqYNCC
vgY-jjgBYYYLTCC';A;R;VjUnL;p?!NYbpKmqiE!Phec?Emq?pCrqpPn?jNMNgkZh
OCKNYQqjDjBN?D
jJ!YrCueq-ZIBBrqjngq-ctpfSkBJNu3TwqxB.p'!mqCFuqiYqqjA?qCYKIExuukBqnC'qCE!gUBq!EPqqTCYYuPzRHxK'!Iqg?ANqjijEtkpuiKqqqY;CgfpBq
ggCdQCmYp!rqliqgE?DvaNnp?E?jbrPt'rpfvDGChjNLYp?j'SAq!!W R?KOjYgf'TPbgqpUYnCNoBmC?pq&!YY?.YDKggQ!!PC'bbjYgmqdYivEgwr'FBr'qpffDg!NNjpg3Tmi
OfjfpBY?C-RVQSbHqAix-C!B'CcYCCjA!Pjeqf!!uE
!!qbYiO?KpPefggH?KoZVv
yxgwC!Cp?bYj?YcqJ.qghCYugQ?'gqY.EjCbp!qA!j'!NpDLpELAAAwKdAqrCu!mqmppSr!jVAggBuTgjCNrNC-rtxIAqjQcqEPNAqPKpq?rCcf'EEYemNWDmCA;PIBTmgY??CWjDVgH?iCqjQpAqh?m!CDTpHgBqYppN'jng!AhYYClCIjcgLPnKCv-q'gcCYYqngpCBqUmDuN?gxVmWK
HjYC
Cb?KrjCgqTPc!GJ'LYhq?EngDVqUm!LLYr-g?YSPCGgfCewCrpDgC
qn!N'EHqDqqCYN?!vxYfh'DqCjggCjHgYL$ZNLf!qDCEYrxbpH
i!!pSbP&Ej'!KWJN-!$AA'jrxYWQqQu!bYj'CjYj!OcrAyCL?D!AjVijqnTDvpCCB!Rj

# Save model

In [53]:
tf.saved_model.save(one_step_model, 'model/')




FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.



FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.


INFO:tensorflow:Assets written to: model/assets


INFO:tensorflow:Assets written to: model/assets
