Note: part of the training code is based off tensorflows text generation example here: https://www.tensorflow.org/text/tutorials/text_generation. However we make drastic changes to make it compatible for our data and for word generation instead of character generation and rhyme scheme implementation.

In [2]:
import tensorflow as tf
import keras
import numpy as np
import string
from keras.preprocessing.text import Tokenizer
import pickle

In [4]:
# File path for training data
training_file_path = "poems.txt"

f = open(training_file_path, "r", encoding="utf-8")
content = f.read()
f.close()

# lower case and remove punctuation
content = content.lower()
content_no_punctuation = ""
i = 0

while i < (len(content)):
  if content[i] not in string.punctuation:
    content_no_punctuation += content[i]
  i += 1

content = content_no_punctuation


# Split data into array
lines = content.splitlines()
data = []
for l in lines:
  sentence = l.split()
  sentence.append('\n')
  data.append(sentence)
  
# Train tokenizer and encode data
tokenizer = Tokenizer()
tokenizer.fit_on_texts(data)
encoded = tokenizer.texts_to_sequences(data)

flatten_list = [j for sub in encoded for j in sub]
vocab = sorted(set(tokenizer.word_docs))

In [3]:
ids_dataset = tf.data.Dataset.from_tensor_slices(flatten_list)

# Create batches of sequence size 200
seq_length = 200

sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

2022-12-09 11:56:49.514913: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-12-09 11:56:49.518040: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-12-09 11:56:49.518136: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-12-09 11:56:49.518432: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operati

In [4]:
# Here we do the n gram split to get gold labels from our training data
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

In [5]:
# Batch size
BATCH_SIZE = 32

dataset = (
    dataset
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

vocab_size = len(tokenizer.word_index) + 1

# Embedding size
embedding_dim = 512

# RNN size
rnn_units = 1024

In [6]:
class Model(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

In [7]:
model = Model(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

In [8]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

In [9]:
model.compile(optimizer='adam', loss=loss)

In [10]:
EPOCHS = 100
# Train here
history = model.fit(dataset, epochs=EPOCHS)

Epoch 1/100


2022-12-09 11:56:51.574413: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8100
2022-12-09 11:56:52.056972: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2022-12-09 11:56:52.125641: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x7fed293b97a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2022-12-09 11:56:52.125668: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): NVIDIA GeForce RTX 3050 Ti Laptop GPU, Compute Capability 8.6
2022-12-09 11:56:52.128596: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2022-12-09 11:56:52.172753: W tensorflow/compiler/xla/str

 1/27 [>.............................] - ETA: 1:07 - loss: 9.4206


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


 2/27 [=>............................] - ETA: 8s - loss: 9.4169  


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


 3/27 [==>...........................] - ETA: 8s - loss: 9.4110


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


 4/27 [===>..........................] - ETA: 8s - loss: 9.3986


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


 5/27 [====>.........................] - ETA: 8s - loss: 9.3411


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


 6/27 [=====>........................] - ETA: 7s - loss: 9.0039


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.





You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7

In [16]:
model.save('model_poems')

with open('tokenizer', 'wb') as tokenizer_file:
  pickle.dump(tokenizer, tokenizer_file)

INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 512]), tf.float32, tf.int32, tf.int32, TensorShape([None])), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fefac6fdf70>, 140667696630352), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512, 3072), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fefac711670>, 140667696631120), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1024, 3072), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fefac6c0ee0>, 140667696095568), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(2, 3072), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fefac6f6e80>, 140667696240480), {}).
INFO:tensorflow:Unsupported signature for ser



INFO:tensorflow:Assets written to: model_poems/assets


INFO:tensorflow:Assets written to: model_poems/assets
