In [None]:
import os
import time

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

tf.distribute.OneDeviceStrategy(device="/gpu:0")

In [None]:
# Even with large batch sizes, there isn't much to speed up with mixed_float due to training inputs batch data. Slows down training on older GPUs

# policy = tf.keras.mixed_precision.Policy("mixed_float16")
# tf.keras.mixed_precision.experimental.set_policy(policy)

In [None]:
path_to_file = tf.keras.utils.get_file(
    "shakespeare.txt",
    "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt",
)

In [None]:
text = open(path_to_file, "rb").read().decode(encoding="utf-8")
print(f"Length of test: {len(text)} characters")

In [None]:
print(text[:250])

In [None]:
vocab = sorted(set(text))
print(vocab)
print(f"{len(vocab)} unique characters")

In [None]:
example_texts = ["abcdefg", "xyz"]

chars = tf.strings.unicode_split(example_texts, input_encoding="UTF-8")
chars

In [None]:
ids_from_chars = preprocessing.StringLookup(vocabulary=list(vocab))

In [None]:
ids = ids_from_chars(chars)
ids

In [None]:
chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True
)

In [None]:
chars = chars_from_ids(ids)
chars

In [None]:
tf.strings.reduce_join(chars, axis=-1).numpy()

In [None]:
def text_from_ids(ids):
    return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

In [None]:
all_ids = ids_from_chars(tf.strings.unicode_split(text, "UTF-8"))
all_ids.numpy()

In [None]:
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

In [None]:
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode("utf-8"))

In [None]:
seq_length = 100
examples_per_epoch = len(text) // (seq_length + 1)

In [None]:
sequences = ids_dataset.batch(seq_length + 1, drop_remainder=True)

for seq in sequences.take(1):
    print(chars_from_ids(seq))

In [None]:
for seq in sequences.take(5):
    print(text_from_ids(seq).numpy())

In [None]:
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

In [None]:
split_input_target(list("Tensorflow"))

In [None]:
dataset = sequences.map(split_input_target)

In [None]:
for input_example, target_example in dataset.take(1):
    print("Input Text:", text_from_ids(input_example).numpy())
    print("Input ID's:", input_example)
    print("Target Text:", text_from_ids(target_example).numpy())
    print("Target ID's:", target_example)

In [None]:
# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset.cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset

In [None]:
vocab_size = len(vocab)

embedding_dim = 256

rnn_units = 1024

In [None]:
class MyModel(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, rnn_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            rnn_units, return_sequences=True, return_state=True
        )
        self.bidirectional = tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(64, return_sequences=True)
        )
#         self.dropout = tf.keras.layers.Dropout(.2)
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x, training=training)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.bidirectional(x)
#         x = self.dropout(x)
        x = self.dense(x, training=training)

        if return_state:
            return x, states
        else:
            return x

In [None]:
model = MyModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
)

In [None]:
for input_example_batch, target_example_batch in dataset.take(1):
    #     print(text_from_ids(input_example_batch))
    #     print(text_from_ids(target_example_batch))
    example_batch_predictions = model(input_example_batch)
    print(
        example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)"
    )

In [None]:
model.summary()

In [None]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

In [None]:
sampled_indices

In [None]:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

In [None]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

In [None]:
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print(
    "Prediction shape: ",
    example_batch_predictions.shape,
    " # (batch_size, sequence_length, vocab_size)",
)
print("Mean loss:", mean_loss)

In [None]:
# Newly created models shouldn't have weights created that are too certain in it's result without training, the exponent of the mean loss should be close to the number of inputs 
print(len(vocab))
print(tf.exp(mean_loss).numpy())

In [None]:
model.compile(optimizer="adam", loss=loss)

In [None]:
# checkpoints directory
checkpoint_dir = "./training_checkpoints"

# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix, save_weights_only=True
)

In [None]:
EPOCHS = 30

# earlystopping = tf.keras.callbacks.EarlyStopping(
#     monitor="loss", mode="auto", patience=3, restore_best_weights=True
# )

In [None]:
history = model.fit(
    dataset, epochs=EPOCHS, callbacks=[checkpoint_callback]
)

In [None]:
class OneStep(tf.keras.Model):
    def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
        super().__init__()
        self.temperature = temperature
        self.model = model
        self.chars_from_ids = chars_from_ids
        self.ids_from_chars = ids_from_chars

        # Create a mask to prevent '' or [UNK] from being Generated
        skip_ids = self.ids_from_chars(["", "[UNK]"])[:, None]
        sparse_mask = tf.SparseTensor(
            # Put an -inf at each bad index
            values=[-float("inf")] * len(skip_ids),
            indices=skip_ids,
            # Match the shape to the vocabulary
            dense_shape=[len(ids_from_chars.get_vocabulary())],
        )
        self.prediction_mask = tf.sparse.to_dense(sparse_mask)

    @tf.function
    def generate_one_step(self, inputs, states=None):
        # Convert strings to token IDs.
        input_chars = tf.strings.unicode_split(inputs, "UTF-8")
        input_ids = self.ids_from_chars(input_chars).to_tensor()

        # Run the model.
        # predicted_logits.shape is [batch, char, next_char_logits]
        predicted_logits, states = self.model(
            inputs=input_ids, states=states, return_state=True
        )

        # Only use the last prediction.
        predicted_logits = predicted_logits[:, -1, :]
        predicted_logits = predicted_logits / self.temperature
        # Apply the prediction mask: prevent "" or [UNK] from being generated
        predicted_logits = predicted_logits + self.prediction_mask

        # Sample the output logits to generate token IDs.
        predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
        predicted_ids = tf.squeeze(predicted_ids, axis=-1)

        # Convert from token ids to characters
        predicted_chars = self.chars_from_ids(predicted_ids)

        # Return the characters and the model state.
        return predicted_chars, states

In [None]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [None]:
# rand_int = np.random.randint(low = 1, high= len(vocab))
# print(rand_int)
# start_string = sequences.take(2)
# for seq in start_string:
#     print(text_from_ids(seq).numpy())

In [None]:
start = time.time()
states = None
next_char = tf.constant(["JULIET:"])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode("utf-8"), "\n\n" + "_" * 80)
print("\nRun time:", end - start)