In [None]:
import tensorflow as tf

import os
import time
import numpy as np

In [None]:
path_to_file=tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

In [None]:
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
print(len(text))

In [None]:
print(text[:250])

In [None]:
vocab = sorted(set(text))
print(len(vocab))
print(vocab)

In [None]:
example_texts=['abcdefg', 'xyz']

chars=tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars

In [None]:
ids_from_chars=tf.keras.layers.experimental.preprocessing.StringLookup(vocabulary=list(vocab))

In [None]:
ids_from_chars.get_vocabulary()

In [None]:
ids=ids_from_chars(chars)
ids

In [None]:
chars_from_ids=tf.keras.layers.experimental.preprocessing.StringLookup(vocabulary=ids_from_chars.get_vocabulary(), invert=True)

In [None]:
chars=chars_from_ids(ids)
chars

In [None]:
tf.strings.reduce_join(chars, axis=-1)

In [None]:
def text_from_ids(ids):
    return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

In [None]:
all_ids=ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids

In [None]:
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

In [None]:
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))

In [None]:
seq_length=100
examples_per_epoch=len(text) // (seq_length+1)

In [None]:
sequences = ids_dataset.batch(seq_length+1, True)

for seq in sequences.take(1):
    print(chars_from_ids(seq).numpy())

In [None]:
for seq in sequences.take(5):
    print(text_from_ids(seq).numpy().decode('utf-8'))

In [None]:
def split_input_target(sequence):
    input_text=sequence[:-1]
    target_text=sequence[1:]
    return input_text, target_text

In [None]:
split_input_target(list('Tensorflow'))

In [None]:
dataset = sequences.map(split_input_target)

In [None]:
for input_example, target_example in dataset.take(1):
    print("Input: ", input_example)
    print("Target: ", target_example)

In [None]:
BATCH_SIZE=64
BUFFER_SIZE=10000

dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
dataset

In [None]:
vocab_size=len(vocab)
embedding_dim = 256
rnn_units=1024

In [None]:
class MyModel(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, rnn_units):
        super().__init__(self)

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocab_size)
    
    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x, training=training)
        
        if states == None:
            states = self.gru.get_initial_state(x)
        
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)

        if return_state:
            return x, states
        else:
            return x

In [None]:
model = MyModel(len(ids_from_chars.get_vocabulary()), embedding_dim, rnn_units)

In [None]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(text_from_ids(tf.argmax(example_batch_predictions[0], axis=-1)))

In [None]:
model.summary()

In [None]:
sample_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sample_indices = tf.squeeze(sample_indices, axis=-1).numpy()
sample_indices

In [None]:
text_from_ids(sample_indices).numpy()

In [None]:
loss=tf.keras.losses.SparseCategoricalCrossentropy(True)

example_batch_loss=loss(target_example_batch, example_batch_predictions)
mean_loss=example_batch_loss.numpy().mean()

print("Prediction shape: ", example_batch_predictions.shape)
print("Mean loss: ", mean_loss)

In [None]:
tf.exp(mean_loss).numpy()

In [None]:
model.compile('adam', loss=loss)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True)

In [None]:
EPOCHS=40

history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

In [None]:
class OneStep(tf.keras.Model):
    def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
        super().__init__()
        self.model=model
        self.chars_from_ids=chars_from_ids
        self.ids_from_chars=ids_from_chars
        self.temperature=1.0

        skip_ids=self.ids_from_chars(['', '[UNK]'])[:, None]
        sparse_mask = tf.SparseTensor(
            values=[-float('inf')]*len(skip_ids),
            indices=skip_ids,
            dense_shape=[len(ids_from_chars.get_vocabulary())]
        )

        self.prediction_mask = tf.sparse.to_dense(sparse_mask)
    
    @tf.function
    def generate_one_step(self, inputs, states=None):
        input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
        input_ids = self.ids_from_chars(input_chars).to_tensor()

        predicted_logits, states = self.model(inputs=input_ids, states=states, return_state=True)

        predicted_logits = predicted_logits[:, -1, :]
        predicted_logits = predicted_logits / self.temperature

        predicted_logits = predicted_logits + self.prediction_mask

        predicted_ids = tf.random.categorical(predicted_logits, 1)
        predicted_ids = tf.squeeze(predicted_ids, axis=-1)

        predicted_chars = self.chars_from_ids(predicted_ids)

        return predicted_chars, states

In [None]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [None]:
start = time.time()
states = None
next_char = tf.constant(['ROMEO'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()

print(result[0].numpy().decode('utf-8'), '\n\n' + '-'*80)
print('\nRun time: {}', end - start)

In [None]:
start = time.time()
states = None
next_char = tf.constant(['ROMEO', 'ROMEO', 'ROMEO', 'ROMEO', 'ROMEO'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()

for i in range(5):
    print(result[i].numpy().decode('utf-8'))
    print('\n\n' + '-'*80)

print('\nRun time: ', end - start)

In [None]:
tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')

In [None]:
states = None
next_char = tf.constant(['ROMEO'])
result = [next_char]

for n in range(100):
    next_char, states = one_step_reloaded.generate_one_step(next_char, states)
    result.append(next_char)

print(tf.strings.join(result)[0].numpy().decode('utf-8'))

In [None]:
class CustomTraining(MyModel):
    @tf.function
    def train_step(self, inputs):
        inputs, labels = inputs

        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = self.loss(labels, predictions)
        
        grads = tape.gradient(loss, model.trainable_variables)

        self.optimizer.apply_gradients(zip(grads, model.trainable_variables))

        return {'loss': loss}

In [None]:
model = CustomTraining(len(ids_from_chars.get_vocabulary()), embedding_dim=embedding_dim, rnn_units=rnn_units)

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(True))

In [None]:
model.fit(dataset, epochs=1)

In [None]:
EPOCHS=10

mean = tf.metrics.Mean()

for epoch in range(EPOCHS):
    start = time.time()
    mean.reset_states()

    for (batch_n, (inputs, labels)) in enumerate(dataset):
        logs = model.train_step([inputs, labels])
        mean.update_state(logs['loss'])

        if batch_n % 50 == 0:
            print(f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}")
    
    if (epoch+1)%5==0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch))
    
    print()
    print(f'Epoch {epoch+1} Loss: {mean.result().numpy():.4f}')
    print(f'Time taken for 1 epoch {time.time() - start:.2f} sec')
    print('_'*80)