<a href="https://colab.research.google.com/github/domschl/tensor-poet/blob/master/eager_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Eager Tensor Poet (Tensorflow 2.0)

In [1]:
!pip install -U ml-indie-tools

[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m
Collecting ml-indie-tools
  Using cached ml_indie_tools-0.0.25-py3-none-any.whl (25 kB)
Installing collected packages: ml-indie-tools
  Attempting uninstall: ml-indie-tools
    Found existing installation: ml-indie-tools 0.0.24
    Uninstalling ml-indie-tools-0.0.24:
      Successfully uninstalled ml-indie-tools-0.0.24
[33m  DEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m
[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If 

In [2]:
import logging
import os
import sys
import json
import time
import datetime

import numpy as np
import tensorflow as tf

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

## Preliminary

A tensorflow deep LSTM model for text generation

This code can use either CPU, GPU, TPU when running on Google Colab.

Select the corresponding runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [4]:
ml_env = MLEnv(platform='tf', accelerator='fastest')
ml_env.describe()

'OS: Darwin, Python: 3.9.9 (Conda), Jupyter Notebook Tensorflow: 2.7.0, GPU: METAL'

In [5]:
project_name='women_writers'
model_name='lstm_v1'
root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [6]:
use_dark_mode=False  # Set to false for white background

In [7]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [8]:
# sample searches
search_spec= {"author": ["brontë","Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

INFO:GutenbergLib:Gutenberg index read from local cache: ./data/gutenberg_cache/gutenberg_index
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/64457.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/63022.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/54066.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/54012.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/54011.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/54010.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/53747.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/42671.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/42078.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/31100.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/9182.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/5670.txt
INFO:G

25 matching books found with search {'author': ['brontë', 'Jane Austen', 'Virginia Woolf'], 'language': ['english']}.


INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/158.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/144.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/141.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/121.txt
INFO:GutenbergLib:Book read from cache at ./data/gutenberg_cache/105.txt


In [9]:
td = Text_Dataset(book_list)

INFO:Datasets:Loaded 25 texts


In [16]:
def get_random_sample_batch(td, batch_size, length):
    for i in range(batch_size):
        Xi, yi = td.get_random_char_tokenized_sample_pair(length)
        if i==0:
            smpX=np.array(Xi, dtype=np.float32)
            smpy=np.array(yi, dtype=np.float32)
        else:
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.float32)))
    return np.array(smpX), np.array(smpy)

def get_random_onehot_sample_batch(td, batch_size, length):
    X, y = get_random_sample_batch(td, batch_size, length)
    xoh = tf.keras.backend.one_hot(X, len(td.i2c))
    yk = tf.keras.backend.constant(y)
    return xoh, yk

## 2. Use tf.data for texts

In [17]:
SEQUENCE_LEN = 96
iNumBatches = 0
if ml_env.is_tpu is True:
    BATCH_SIZE=256
    use_tpu_model_for_tpu=True
    STATEFUL=False
    LSTM_UNITS = 512
    LSTM_LAYERS = 4

else:
    BATCH_SIZE = 128
    STATEFUL = True
    LSTM_UNITS = 512
    LSTM_LAYERS = 4

if iNumBatches==0:
    NUM_BATCHES=BATCH_SIZE  # int(textlib.total_size/BATCH_SIZE/SEQUENCE_LEN)
else:
    NUM_BATCHES=iNumBatches

In [18]:
dx=[]
dy=[]
for i in range(NUM_BATCHES):
    x,y=get_random_onehot_sample_batch(td, BATCH_SIZE,SEQUENCE_LEN)
    dx.append(x)
    dy.append(y)

In [42]:
data_xy=(dx,dy) # tf.keras.backend.constant(np.array([dx,dy]))

In [43]:
textlib_dataset=tf.data.Dataset.from_tensor_slices(data_xy)

In [44]:
shuffle_buffer=10000
dataset=textlib_dataset.shuffle(shuffle_buffer)  
dataset.take(1)

<TakeDataset shapes: ((128, 96, 188), (128, 96)), types: (tf.float32, tf.float32)>

In [45]:
def build_model(vocab_size, steps, lstm_units, lstm_layers, batch_size, stateful=True):
    model = tf.keras.Sequential([
        # tf.keras.layers.Embedding(vocab_size, embedding_dim,
        #                          batch_input_shape=[batch_size, None]),
        # tf.keras.layers.Flatten(),
        *[tf.keras.layers.LSTM(lstm_units,
                            # input_shape=(timesteps, data_dim)
                            batch_input_shape=[batch_size, steps, vocab_size],
                            return_sequences=True,
                            stateful=stateful,
                            recurrent_initializer='glorot_uniform')  for _ in range(lstm_layers)],
        # *[tf.keras.layers.LSTM(lstm_units,
        #                     return_sequences=True,
        #                     stateful=stateful,
        #                     recurrent_initializer='glorot_uniform') for _ in range(lstm_layers-1)],
        tf.keras.layers.Dense(vocab_size)
        ])
    return model

def build_tpu_model(vocab_size, steps, lstm_units, lstm_layers, batch_size, stateful=True):
    # print("NOT ADAPTED!")
    # with tf.device('/job:localhost/replica:0/task:0/device:CPU:0'):
    #     embedded = tf.keras.layers.Embedding(vocab_size, embedding_dim, embeddings_initializer='uniform', batch_input_shape=[batch_size, None, SEQUENCE_LEN])
    with tpu_strategy.scope():
        lstm = [tf.keras.layers.LSTM(lstm_units,
                        batch_input_shape=[batch_size, steps, vocab_size],
                        return_sequences=True,
                        stateful=stateful,
                        recurrent_initializer='glorot_uniform', unroll=True) for _ in range(lstm_layers)]
#     tf.keras.layers.LSTM(lstm_units,
#                          return_sequences=True,
#                          stateful=stateful,
#                          # recurrent_initializer='glorot_uniform',
#                         unroll=True)
    dense = tf.keras.layers.Dense(vocab_size)
    
    model = tf.keras.Sequential([
        # embedded,
        *lstm,
        dense
        ])
    return model

In [46]:
if ml_env.is_tpu:
    print(TPU_ADDRESS)
    os.environ['COLAB_TPU_ADDR']

In [47]:
if ml_env.is_tpu is True and not tpu_is_init:
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_ADDRESS)
    # tf.config.experimental_connect_to_cluster(cluster_resolver) # host(cluster_resolver.master())
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)    
    tpu_is_init=True

In [48]:
if ml_env.is_tpu is True:
    if use_tpu_model_for_tpu is True:
        print("tpu, simple model")
        # with tpu_strategy.scope():
        model = build_tpu_model(
          vocab_size = len(td.i2c),
          # embedding_dim=EMBEDDING_DIM,
          steps=SEQUENCE_LEN,
          lstm_units=LSTM_UNITS,
          lstm_layers=LSTM_LAYERS,
          batch_size=BATCH_SIZE,
          stateful=STATEFUL)
    else:
        print("tpu, default model")
        with tpu_strategy.scope():
            model = build_model(
              vocab_size = len(td.i2c),
              steps=SEQUENCE_LEN,
              # embedding_dim=EMBEDDING_DIM,
              lstm_units=LSTM_UNITS,
              lstm_layers=LSTM_LAYERS,
              batch_size=BATCH_SIZE,
              stateful=STATEFUL)        
else:
    print("non-tpu mode")
    model = build_model(
        vocab_size = len(td.i2c),
        # embedding_dim=EMBEDDING_DIM,
        steps=SEQUENCE_LEN,
        lstm_units=LSTM_UNITS,
        lstm_layers=LSTM_LAYERS,
        batch_size=BATCH_SIZE,
        stateful=STATEFUL)

non-tpu mode


### Some sanity checks of the (untrained!) model

In [49]:
dataset.take(1)

<TakeDataset shapes: ((128, 96, 188), (128, 96)), types: (tf.float32, tf.float32)>

In [50]:
if ml_env.is_tpu is False:  # no sanity for TPU, since eager not supported:
    for input_example_batch, target_example_batch in dataset.take(1):
        model.reset_states()
        example_batch_predictions = model.predict(input_example_batch, batch_size=256)
        print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

2022-01-01 17:40:17.528395: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:40:17.675733: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:40:17.900605: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:40:18.319305: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:40:18.502319: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


(128, 96, 188) # (batch_size, sequence_length, vocab_size)


In [51]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_4 (LSTM)                (128, 96, 512)            1435648   
_________________________________________________________________
lstm_5 (LSTM)                (128, 96, 512)            2099200   
_________________________________________________________________
lstm_6 (LSTM)                (128, 96, 512)            2099200   
_________________________________________________________________
lstm_7 (LSTM)                (128, 96, 512)            2099200   
_________________________________________________________________
dense_1 (Dense)              (128, 96, 188)            96444     
Total params: 7,829,692
Trainable params: 7,829,692
Non-trainable params: 0
_________________________________________________________________


In [52]:
dataset.take(1)

<TakeDataset shapes: ((128, 96, 188), (128, 96)), types: (tf.float32, tf.float32)>

In [38]:
if ml_env.is_tpu is False:
    sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
    sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()
    print(sampled_indices)

[ 84  39   5 108 136  46  33 137  44  83 127 148 138  48  19 185  24  93
  28   6  80 181 123  93 131 164  21   2  32 154  84 162 139  44  48   9
  81  96   3  50  83 151  36 130  42  20  30 101 172 149 120 182  51  58
  33 170  66 109  65 106  45 175  18 167  79 172  72  26 173  77  98  74
 152  45 169 106 107 178 105 122  36  44 147  56 155  87 145 123  43   5
 164 136  50  96  12  73]


In [53]:
if ml_env.is_tpu is False:
    print(td.decode(sampled_indices, tokenizer='char'))

"-Kνuἐ[αι?e&wπχὑὢiμlTῶάiᾀ’F<eos>é^"çëιπ18φ<sos>n?jRàhmό7äἔὗόἂδ[Ê,έvDἈËk“4äoV#γεθ/ἈöD_ώ2QRικX}9;άἀK’unφ'.


### Loss function, optimizer, tensorboard output

In [55]:
def loss(labels, logits):
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

if ml_env.is_tpu is False:
    example_batch_loss  = loss(target_example_batch, example_batch_predictions)
    print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
    print("scalar_loss:      ", example_batch_loss.numpy().mean())

Prediction shape:  (128, 96, 188)  # (batch_size, sequence_length, vocab_size)
scalar_loss:       5.236963


In [56]:
opti = tf.keras.optimizers.Adam(learning_rate=0.001, clipvalue=0.3)
# opti = tf.keras.optimizers.Adam(clipvalue=0.5)
# opti=tf.keras.optimizers.SGD(lr=0.003)

def scalar_loss(labels, logits):
    bl=loss(labels, logits)
    return tf.reduce_mean(bl)

model.compile(optimizer=opti, loss=loss, metrics=[scalar_loss])

In [57]:
# Directory where the checkpoints will be saved
checkpoint_dir = os.path.join(model_path, 'training_checkpoints')
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join(log_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch') # , histogram_freq=1) # update_freq='epoch', 

2022-01-01 17:42:32.753617: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.
2022-01-01 17:42:32.753643: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.
2022-01-01 17:42:32.754126: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.


In [58]:
%tensorboard --logdir logs

UsageError: Line magic function `%tensorboard` not found.


## The actual training

In [59]:
EPOCHS=20

In [60]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback, tensorboard_callback])

Epoch 1/20


2022-01-01 17:42:46.845285: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:42:47.188861: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:42:47.408739: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:42:47.776498: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:42:47.991612: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:42:48.912379: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2022-01-01 17:42:50.271436: I tensorflow/core/grappler/optimizers/cust

  1/128 [..............................] - ETA: 11:47 - loss: 5.2370 - scalar_loss: 5.2370

2022-01-01 17:42:51.436017: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.
2022-01-01 17:42:51.436032: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.


  2/128 [..............................] - ETA: 1:57 - loss: 5.2282 - scalar_loss: 5.2282 

2022-01-01 17:42:52.298814: I tensorflow/core/profiler/lib/profiler_session.cc:67] Profiler session collecting data.
2022-01-01 17:42:52.302534: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.
2022-01-01 17:42:52.313331: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: ./logs/20220101-174232/train/plugins/profile/2022_01_01_17_42_52

2022-01-01 17:42:52.314437: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to ./logs/20220101-174232/train/plugins/profile/2022_01_01_17_42_52/m1air.fritz.box.trace.json.gz
2022-01-01 17:42:52.318166: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: ./logs/20220101-174232/train/plugins/profile/2022_01_01_17_42_52

2022-01-01 17:42:52.318345: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to ./logs/20220101-174232/train/plugins/profile/2022_01_01

Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20

KeyboardInterrupt: 

In [None]:
# Generate

In [None]:
use_tpu_for_generation=False

In [None]:
if not use_tpu_for_generation:
    gen_model = build_model(vocab_size = len(textlib.i2c),
        # embedding_dim=EMBEDDING_DIM,
        steps=SEQUENCE_LEN,
        lstm_units=LSTM_UNITS,
        lstm_layers=LSTM_LAYERS,
        batch_size=1)
else:
    gen_model = build_tpu_model(
          vocab_size = len(textlib.i2c),
          #embedding_dim=EMBEDDING_DIM,
          steps=SEQUENCE_LEN,
          lstm_units=LSTM_UNITS,
          lstm_layers=LSTM_LAYERS,
          batch_size=1,
          stateful=STATEFUL)  # TPUs can't handle stateful=True, and that's deadly for text generation.

In [None]:
tf.train.latest_checkpoint(checkpoint_dir)

In [None]:
gen_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
gen_model.build(tf.TensorShape([1, None]))

In [None]:
gen_model.summary()

In [None]:
def generate_text_with_tpu(model, start_string, temp=0.6):
  # Evaluation step (generating text using the learned model)

  # Number of characters to generate
  num_generate = 128

  # Converting our start string to numbers (vectorizing)
  cutstr=start_string[-SEQUENCE_LEN:]  # Tpus need the whole history of exactly secuence_len chars, not less, not more.
  input_eval = [textlib.c2i[s] for s in cutstr]
  input_eval = tf.expand_dims(input_eval, 0)

  # Empty string to store our results
  text_generated = []
  ids=[]

  # Low temperatures results in more predictable text.
  # Higher temperatures results in more surprising text.
  # Experiment to find the best setting.
  temperature = temp

  # Here batch size == 1
  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_tensor = tf.random.categorical(predictions, num_samples=1)[-1,0]
      if not use_tpu:
          predicted_id=predicted_tensor.numpy()
      else:
          predicted_id=predicted_tensor.eval()
      ids.append(predicted_id)

      # We pass the predicted word as the next input to the model
      # along with the previous hidden state
      input_eval = tf.expand_dims([predicted_id], 0)

      text_generated.append(textlib.i2c[predicted_id])
      print("out:"+''.join(text_generated))

      cutstr=(start_string+''.join(text_generated))[-SEQUENCE_LEN:]  # Restore the entire history if last SEQUENCE_LEN chars, to be "stateless"
      input_eval = [textlib.c2i[s] for s in cutstr]
      input_eval = tf.expand_dims(input_eval, 0)

  return (start_string + ''.join(text_generated), ids)

def generate_text(model, start_string, temp=0.6):
  # Evaluation step (generating text using the learned model)

  # Number of characters to generate
  num_generate = 128

  # Converting our start string to numbers (vectorizing)
  cutstr=start_string[0:SEQUENCE_LEN] # 
  input_eval_0 = [textlib.c2i[s] for s in cutstr]
  input_eval_1 = tf.expand_dims(input_eval_0, 0)

  input_eval = tf.keras.backend.one_hot(input_eval_1, len(textlib.i2c))

  # Empty string to store our results
  text_generated = cutstr # []
  ids=input_eval_0 # []

  # Low temperatures results in more predictable text.
  # Higher temperatures results in more surprising text.
  # Experiment to find the best setting.
  temperature = temp

  # Here batch size == 1
  model.reset_states()
  for i in range(num_generate):
      predictions = model.predict(input_eval, steps=1, batch_size=1)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_tensor = tf.random.categorical(predictions, num_samples=1)[-1,0]
      if use_eager is True:
          predicted_id=predicted_tensor.numpy()
      else:
          predicted_id=tf.keras.backend.eval(predicted_tensor)
          print(predicted_id)
      ids.append(predicted_id)

      text_generated +=textlib.i2c[predicted_id]
      text_generated = text_generated[-SEQUENCE_LEN:]
      print(text_generated)

      # input_eval = tf.keras.backend.one_hot(input_eval_1, len(textlib.i2c))
      # We pass the predicted word as the next input to the model
      # along with the previous hidden state
      input_eval_1 = tf.expand_dims(ids[-SEQUENCE_LEN:], 0)
      input_eval = tf.keras.backend.one_hot(input_eval_1, len(textlib.i2c))    
  return (''.join(text_generated), ids)

In [None]:
start_string="With the clarity of thought of an artificial life form, the discussion went on:"
len(start_string[0:SEQUENCE_LEN])

In [None]:
if use_tpu_for_generation:
    sess=tf.compat.v1.keras.backend.get_session() # tf.compat.v1.get_default_session()
    with sess.as_default():
        tx,id=generate_text(gen_model, start_string="With the clarity of thought of an artificial life form, the discussion went on:", temp=0.8)
else:
    if use_eager:
        tf.compat.v1.enable_eager_execution()
        if not tf.executing_eagerly():
            print("Eager engine stall.")
        else:
            sess=tf.compat.v1.keras.backend.get_session()
    # with tf.device('/job:localhost/replica:0/task:0/device:CPU:0'):  # Speed is about same gpu/cpu
    tx,id=generate_text(gen_model, start_string="With the clarity of thought of an artificial life form, the discussion went on:", temp=0.8)
    print(tx)

In [None]:
def detectPlagiarism(tx, textlibrary, minQuoteLength=10):
    textlibrary.source_highlight(tx, minQuoteLength)

In [None]:
txt=textlib.decode(id)
txti=txt.split('\r\n')
for t in txti:
    print(t)

In [None]:
detectPlagiarism(tx, textlib)

## References:
* <https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/text/text_generation.ipynb>
* <https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb>

## 6. A dialog with the trained model [not ported yet]

In [None]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):


def doDialog():
    # 0.1 (frozen character) - 1.3 (creative/chaotic character)
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer

    with tf.Session() as sess:
        print("Please enter some dialog.")
        print("The net will answer according to your input.")
        print("'bye' for end,")
        print("'reset' to reset the conversation context,")
        print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
        print("    to change character of the dialog.")
        print("    Current temperature={}.".format(temperature))
        print()
        xso = None
        bye = False
        model.init.run()

        tflogdir = os.path.realpath(model.logdir)
        if not os.path.exists(tflogdir):
            print("You haven't trained a model, no data found at: {}".format(
                trainParams["logdir"]))
            return

        # Used for saving the training parameters periodically
        saver = tf.train.Saver()
        checkpoint_file = os.path.join(tflogdir, model.checkpoint)

        lastSave = tf.train.latest_checkpoint(tflogdir, latest_filename=None)
        if lastSave is not None:
            pt = lastSave.rfind('-')
            if pt != -1:
                pt += 1
                start_iter = int(lastSave[pt:])
            # print("Restoring checkpoint at {}: {}".format(start_iter, lastSave))
            saver.restore(sess, lastSave)
        else:
            print("No checkpoints have been saved at:{}".format(tflogdir))
            return

        # g_state = sess.run([model.init_state_0], feed_dict={model.batch_size: 1})
        doini = True

        bye = False
        while not bye:
            print("> ", end="")
            prompt = input()
            if prompt == 'bye':
                bye = True
                print("Good bye!")
                continue
            if prompt == 'reset':
                doini = True
                # g_state = sess.run([model.init_state_0], feed_dict={model.batch_size: 1})
                print("(conversation context marked for reset)")
                continue
            if prompt[:len("temperature=")] == "temperature=":
                t = float(prompt[len("temperature="):])
                if t > 0.05 and t < 1.4:
                    temperature = t
                    print("(generator temperature now {})".format(t))
                    print()
                    continue
                print("Invalid temperature-value ignored! [0.1-1.0]")
                continue
            xs = ' ' * model.steps
            xso = ''
            for rep in range(1):
                for i in range(len(prompt)):
                    xs = xs[1:]+prompt[i]
                    X_new = np.transpose([[textlib.c2i[sj]] for sj in xs])
                    if doini:
                        doini = False
                        g_state = sess.run(
                            [model.init_state_0], feed_dict={model.X: X_new})
                    g_state, y_pred = sess.run([model.final_state, model.output_softmax_temp],
                                               feed_dict={model.X: X_new, model.init_state: g_state,
                                                          model.temperature: temperature})
            ans = 0
            numEndPrompts = 0
            while (ans < maxAnswerSize and numEndPrompts < maxEndPrompts) or ans < minAnswerSize:

                X_new = np.transpose([[textlib.c2i[sj]] for sj in xs])
                g_state, y_pred = sess.run([model.final_state, model.output_softmax_temp],
                                           feed_dict={model.X: X_new, model.init_state: g_state,
                                                      model.temperature: temperature})
                inds = list(range(model.vocab_size))
                ind = np.random.choice(inds, p=y_pred[0, -1].ravel())
                nc = textlib.i2c[ind]
                if nc == endPrompt:
                    numEndPrompts += 1
                xso += nc
                xs = xs[1:]+nc
                ans += 1
            print(xso.replace("\\n", "\n"))
            textlib.source_highlight(xso, 13)
    return

In [None]:
# Talk to the net!
doDialog()