In [2]:
# based on https://www.tensorflow.org/text/tutorials/text_generation

import tensorflow as tf

import numpy as np
import os
import time

print(tf.__version__)

2.5.0


In [3]:
full_text = open('words_alpha.txt', 'rb').read().decode(encoding='utf-8')
text = ''
seq_length = 30
for word in full_text.split():
    if len(word) < 4:
        continue
    else:
        word = '_' + word
        for i in range(0, seq_length-len(word)):
            word = word + '*'
    text = text + word

print(f'Length of text: {len(text)} characters')
print(text[:300])

Length of text: 11025662 characters
_aahed************************_aahing***********************_aahs*************************_aalii************************_aaliis***********************_aals*************************_aani*************************_aardvark*********************_aardvarks********************_aardwolf*********************


In [4]:
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')
print(vocab)

28 unique characters
['*', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [5]:
ids_from_chars = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=list(vocab), mask_token=None)

chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

Metal device set to: Apple M1


2022-01-27 12:24:28.843977: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-01-27 12:24:28.844231: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [6]:
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

for element in ids_dataset.take(seq_length):
    print(element, chars_from_ids(element).numpy().decode('utf-8'))

tf.Tensor(2, shape=(), dtype=int64) _
tf.Tensor(3, shape=(), dtype=int64) a
tf.Tensor(3, shape=(), dtype=int64) a
tf.Tensor(10, shape=(), dtype=int64) h
tf.Tensor(7, shape=(), dtype=int64) e
tf.Tensor(6, shape=(), dtype=int64) d
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1, shape=(), dtype=int64) *
tf.Tensor(1

In [7]:
# TODO read karpathi to see how he chooses sequence length and if it's important
# maybe longer is better? but we don't want to keep context between words

# try truncating and padding after the stop token
# sequence length = shortest word length, change the start character into _

sequences = ids_dataset.batch(seq_length, drop_remainder=True)

for seq in sequences.take(5):
  print(seq, text_from_ids(seq))

tf.Tensor(
[ 2  3  3 10  7  6  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1], shape=(30,), dtype=int64) tf.Tensor(b'_aahed************************', shape=(), dtype=string)
tf.Tensor(
[ 2  3  3 10 11 16  9  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1], shape=(30,), dtype=int64) tf.Tensor(b'_aahing***********************', shape=(), dtype=string)
tf.Tensor(
[ 2  3  3 10 21  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1], shape=(30,), dtype=int64) tf.Tensor(b'_aahs*************************', shape=(), dtype=string)
tf.Tensor(
[ 2  3  3 14 11 11  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1], shape=(30,), dtype=int64) tf.Tensor(b'_aalii************************', shape=(), dtype=string)
tf.Tensor(
[ 2  3  3 14 11 11 21  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1], shape=(30,), dtype=int64) tf.Tensor(b'_aaliis***********************', shape=(), dtype=string)


In [8]:
def split_input_target(sequence):
    print(sequence)
    input_text = sequence
    target_text = tf.concat([sequence[1:], [1]], 0) # pad with extra *
    return input_text, target_text
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())

Tensor("args_0:0", shape=(30,), dtype=int64)
Input : b'_aahed************************'
Target: b'aahed*************************'


2022-01-27 12:24:36.949568: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2022-01-27 12:24:36.949926: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [9]:
BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

### Define the model
We use an RNN with GRUs defined in rnn_gru_model.py

In [10]:
from rnn_gru_model import RnnGRUModel
model = RnnGRUModel()

### Try the untrained model on the first sequence

In [11]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
    
# vocab size is 29 because of [UNK]
# TODO we don't actually need [UNK] since input involves all possible characters already
model.summary()

(64, 30, 29) # (batch_size, sequence_length, vocab_size)
Model: "rnn_gru_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        multiple                  7424      
_________________________________________________________________
gru (GRU)                    multiple                  3938304   
_________________________________________________________________
dense (Dense)                multiple                  29725     
_________________________________________________________________
dense_1 (Dense)              multiple                  870       
Total params: 3,976,323
Trainable params: 3,976,323
Non-trainable params: 0
_________________________________________________________________


In [12]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
sampled_indices

array([ 5,  4,  2, 15, 23,  4, 26, 16,  6, 17, 23, 23, 19, 15,  9, 10, 20,
       26,  0, 10, 17,  0, 16,  9, 28, 18, 21, 25, 16, 16])

In [13]:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy().decode('utf-8'))
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy().decode('utf-8'))

Input:
 _accidents********************
Next Char Predictions:
 cb_mubxndouuqmghrx[UNK]ho[UNK]ngzpswnn


In [14]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", mean_loss)

Prediction shape:  (64, 30, 29)  # (batch_size, sequence_length, vocab_size)
Mean loss:         3.3777168


In [15]:
# A newly initialized model shouldn't be too sure of itself, the output logits should
# all have similar magnitudes. To confirm this you can check that the exponential of the mean
# loss is approximately equal to the vocabulary size. A much higher loss means the model is
# sure of its wrong answers, and is badly initialized:

tf.exp(mean_loss).numpy()

29.303791

### Train the model

In [16]:
model.compile(optimizer='adam', loss=loss)

In [17]:
# Directory where the checkpoints will be saved
checkpoint_dir = 'training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [18]:
EPOCHS = 3

In [19]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

Epoch 1/3


2022-01-27 12:25:09.237448: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-27 12:25:09.502257: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-27 12:25:09.589565: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


Epoch 2/3
Epoch 3/3


### Create wrapper which generates and evaluates word plausiblity

In [100]:
from rnn_plausiblewords import RnnWordPlausibilityEvaluator
import logging
model_wrapper = RnnWordPlausibilityEvaluator(logging,
                                             model=model,
                                             ids_from_chars=ids_from_chars,
                                             chars_from_ids=chars_from_ids,
                                             temperature=0.8)

In [103]:
# TODO generate some words using Spolling Bree letters and see how the models agree or disagree
def create_words(given_model, num_words):
    for m in range(num_words):
      states = None
      next_char = tf.constant(['_'])
      result = [next_char]

      for n in range(100):
        next_char, states = given_model.generate_one_step(next_char, states=states)
        result.append(next_char)
        if next_char == '*':
          break

      result = tf.strings.join(result)
      end = time.time()
      print(result[0].numpy().decode('utf-8'))

create_words(model_wrapper, 10)

_udex*
_coller*
_elvies*
_agmanies*
_commonied*
_compic*
_cands*
_candological*
_dgorm*
_oxyphyl*


In [105]:
print(model_wrapper.evaluate_word('test'))
print(model_wrapper.evaluate_word('testestest'))

4.7846328020095825
24.958137571811676


### Save and load the model

In [92]:
## Save the model weights only
model.save_weights("base_model_saved_weights")

In [93]:
## Save the layers needed to regenerate ModelWrapper
tf.saved_model.save(chars_from_ids, 'chars_from_ids')
tf.saved_model.save(ids_from_chars, 'ids_from_chars')


FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: chars_from_ids/assets

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: ids_from_chars/assets


2022-01-18 14:41:24.907000: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


In [None]:
# loaded_model.load_weights("base_model_saved_weights")