<a href="https://colab.research.google.com/github/domschl/tensor-poet/blob/master/eager_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install TF 2.0, if necessary. This currently needs to be done when running from Colab.

In [1]:
!pip install tf-nightly-gpu-2.0-preview



# [WIP] Eager Tensor Poet (tf 2.0)

**THIS IS UNFINISHED WORK IN PROGRESS**

A tensorflow deep LSTM model for text generation

This code can use either CPU, GPU or TPU when running on Google Colab.

Select the corresponding runtime (menu: Runtime / Change runtime type)

Note: TPU support is not yet working.

In [0]:
import numpy as np
import os
import json
import time
import random
import tensorflow as tf
from IPython.core.display import display, HTML

from urllib.request import urlopen  # Py3

## 0. Check system

### Tensorflow api version check

Temporary note: currently, this is tested against the master build of tensorflow, which still has a version tag 1.14.x at the time of this writing. the version check below is preliminary.

In [2]:
try:
    if 'api.v2' in tf.version.__name__:
        print("Tensorflow api v2 active.")
    else:
        print("Tensorflow api v2 not found. This will not work.")
except:
    print("Failed to check for Tensorflow api v2. This will not work.")

Tensorflow api v2 active.


### GPU/TPU check

In [3]:
from tensorflow.python.client import device_lib

use_tpu = False
use_gpu = False

try:
    TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
    use_tpu = True
    tf.config.experimental_connect_to_host(TPU_ADDRESS)
    print("TPU available at {}".format(TPU_ADDRESS))
except:
    print("No TPU available")

for hw in ["CPU", "GPU", "TPU"]:
    hwlist=tf.config.experimental.list_logical_devices(hw)
    print("{} -> {}".format(hw,hwlist))


if use_tpu is False:
    def get_available_devs_of_type(type):
        local_device_protos = device_lib.list_local_devices()
        return [x.name for x in local_device_protos if type in x.name]

    def get_dev_desc():
        local_device_protos = device_lib.list_local_devices()
        return [(x.name, x.physical_device_desc) for x in local_device_protos]

    def get_available_gpus():
        return get_available_devs_of_type('GPU')

    dl = get_available_gpus()
    if len(dl)==0:
        print("WARNING: You have neither TPU nor GPU, this is going to be very slow!")
        print("         Hint: If using Google Colab, set runtime type to TPU.")
        print(get_available_devs_of_type(''))
    else:
        use_gpu = True
        print(f"GPUs: {dl}")
        print(get_dev_desc())


No TPU available
CPU -> [LogicalDevice(name='/job:localhost/replica:0/task:0/device:CPU:0', device_type='CPU')]
GPU -> [LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:0', device_type='GPU')]
TPU -> []
GPUs: ['/device:XLA_GPU:0', '/device:GPU:0']
[('/device:CPU:0', ''), ('/device:XLA_GPU:0', 'device: XLA_GPU device'), ('/device:XLA_CPU:0', 'device: XLA_CPU device'), ('/device:GPU:0', 'device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7')]


##  1. Text library

In [0]:
# TextLibrary class: text library for training, encoding, batch generation,
# and formatted source display


class TextLibrary:
    def __init__(self, descriptors, max=100000000):
        self.descriptors = descriptors
        self.data = ''
        self.files = []
        self.c2i = {}
        self.i2c = {}
        index = 1
        for descriptor, name in descriptors:
            fd = {}
            if descriptor[:4] == 'http':
                try:
                    dat = urlopen(descriptor).read().decode('utf-8')
                    if dat[0]=='\ufeff':  # Ignore BOM
                        dat=dat[1:]
                    self.data += dat
                    fd["name"] = name
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                except Exception as e:
                    print(f"Can't download {descriptor}: {e}")
            else:
                fd["name"] = name
                try:
                    f = open(descriptor)
                    dat = f.read(max)
                    self.data += dat
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                    f.close()
                except Exception as e:
                    print(f"ERROR: Cannot read: {filename}: {e}")
        ind = 0
        for c in self.data:  # sets are not deterministic
            if c not in self.c2i:
                self.c2i[c] = ind
                self.i2c[ind] = c
                ind += 1
        self.ptr = 0

    def display_colored_html(self, textlist, pre='', post=''):
        bgcolors = ['#d4e6f1', '#d8daef', '#ebdef0', '#eadbd8', '#e2d7d5', '#edebd0',
                    '#ecf3cf', '#d4efdf', '#d0ece7', '#d6eaf8', '#d4e6f1', '#d6dbdf',
                    '#f6ddcc', '#fae5d3', '#fdebd0', '#e5e8e8', '#eaeded', '#A9CCE3']
        out = ''
        for txt, ind in textlist:
            txt = txt.replace('\n', '<br>')
            if ind == 0:
                out += txt
            else:
                out += "<span style=\"background-color:"+bgcolors[ind % 16]+";\">" + \
                       txt + "</span>"+"<sup>[" + str(ind) + "]</sup>"
        display(HTML(pre+out+post))

    def source_highlight(self, txt, minQuoteSize=10):
        tx = txt
        out = []
        qts = []
        txsrc = [("Sources: ", 0)]
        sc = False
        noquote = ''
        while len(tx) > 0:  # search all library files for quote 'txt'
            mxQ = 0
            mxI = 0
            mxN = ''
            found = False
            for f in self.files:  # find longest quote in all texts
                p = minQuoteSize
                if p <= len(tx) and tx[:p] in f["data"]:
                    p = minQuoteSize + 1
                    while p <= len(tx) and tx[:p] in f["data"]:
                        p += 1
                    if p-1 > mxQ:
                        mxQ = p-1
                        mxI = f["index"]
                        mxN = f["name"]
                        found = True
            if found:  # save longest quote for colorizing
                if len(noquote) > 0:
                    out.append((noquote, 0))
                    noquote = ''
                out.append((tx[:mxQ], mxI))
                tx = tx[mxQ:]
                if mxI not in qts:  # create a new reference, if first occurence
                    qts.append(mxI)
                    if sc:
                        txsrc.append((", ", 0))
                    sc = True
                    txsrc.append((mxN, mxI))
            else:
                noquote += tx[0]
                tx = tx[1:]
        if len(noquote) > 0:
            out.append((noquote, 0))
            noquote = ''
        self.display_colored_html(out)
        if len(qts) > 0:  # print references, if there is at least one source
            self.display_colored_html(txsrc, pre="<small><p style=\"text-align:right;\">",
                                     post="</p></small>")

    def get_slice(self, length):
        if (self.ptr + length >= len(self.data)):
            self.ptr = 0
        if self.ptr == 0:
            rst = True
        else:
            rst = False
        sl = self.data[self.ptr:self.ptr+length]
        self.ptr += length
        return sl, rst

    def decode(self, ar):
        return ''.join([self.i2c[ic] for ic in ar])

    def get_random_slice(self, length):
        p = random.randrange(0, len(self.data)-length)
        sl = self.data[p:p+length]
        return sl

    def get_slice_array(self, length):
        ar = np.array([c for c in self.get_slice(length)[0]])
        return ar

    def get_encoded_slice(self, length):
        s, rst = self.get_slice(length)
        X = [self.c2i[c] for c in s]
        return X
        
    def get_encoded_slice_array(self, length):
        return np.array(self.get_encoded_slice(length))

    def get_sample(self, length):
        s, rst = self.get_slice(length+1)
        X = [self.c2i[c] for c in s[:-1]]
        y = [self.c2i[c] for c in s[1:]]
        return (X, y, rst)

    def get_random_sample(self, length):
        s = self.get_random_slice(length+1)
        X = [self.c2i[c] for c in s[:-1]]
        y = [self.c2i[c] for c in s[1:]]
        return (X, y)

    def get_sample_batch(self, batch_size, length):
        smpX = []
        smpy = []
        for i in range(batch_size):
            Xi, yi, rst = self.get_sample(length)
            smpX.append(Xi)
            smpy.append(yi)
        return smpX, smpy, rst

    def get_random_sample_batch(self, batch_size, length):
        smpX = []
        smpy = []
        for i in range(batch_size):
            Xi, yi = self.get_random_sample(length)
            smpX.append(Xi)
            smpy.append(yi)
        return smpX, smpy


### Read text data

In [0]:
libdesc = {
    "name": "Woman Writers",
    "description": "A collection of works of Woolf, Austen and Brontë",
    "lib": [
        # 'data/tiny-shakespeare.txt',
        # since project gutenberg blocks the entire country of Germany, we use a mirror:
        # ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/0/100/100-0.txt', "Shakespeare: Collected Works"
        #  Project Gutenberg: Pride and Prejudice_ by Jane Austen, Wuthering Heights by Emily Brontë, The Voyage Out by Virginia Woolf and Emma_by Jane Austen
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/3/7/4/3/37431/37431.txt', "Jane Austen: Pride and Prejudice"),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/7/6/768/768.txt', "Emily Brontë: Wuthering Heights"),         
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/4/144/144.txt', "Virginia Wolf: Voyage out"),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/5/158/158.txt', "Jane Austen: Emma")
    ]
}

textlib = TextLibrary(libdesc["lib"])


## 2. Use tf.data for texts

In [0]:
SEQUENCE_LEN = 80
if use_tpu is True:
    BATCH_SIZE=256
    use_simple_model_for_tpu=True
else:
    BATCH_SIZE = 256
LSTM_UNITS = 768
EMBEDDING_DIM = 120
LSTM_LAYERS = 6
NUM_BATCHES=30

In [0]:
dx=[]
dy=[]
for i in range(NUM_BATCHES):
    x,y=textlib.get_random_sample_batch(BATCH_SIZE,SEQUENCE_LEN)
    dx.append(x)
    dy.append(y)

In [0]:
data_xy=(dx,dy)


In [0]:
textlib_dataset=tf.data.Dataset.from_tensor_slices(data_xy)

In [10]:
shuffle_buffer=10000
dataset=textlib_dataset.shuffle(shuffle_buffer)
dataset.take(1)

<TakeDataset shapes: ((256, 80), (256, 80)), types: (tf.int32, tf.int32)>

In [0]:
def build_model(vocab_size, embedding_dim, lstm_units, lstm_layers, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    *[tf.keras.layers.LSTM(lstm_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform') for _ in range(lstm_layers)],
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

@tf.function
def build_simple_model(vocab_size, embedding_dim, lstm_units, lstm_layers, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, SEQUENCE_LEN]),
    tf.keras.layers.LSTM(lstm_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform',
                        unroll=True),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model



In [0]:

# dev_strings=[]
# for log_dev in tf.config.experimental.list_logical_devices('TPU'):
#     dev_strings.append(log_dev.name)
# print(dev_strings)

# for i in range(8):
#     dev_strings.append('/TPU:{}'.format(i))
# print(dev_strings)

if use_tpu:
    print(TPU_ADDRESS)
    os.environ['COLAB_TPU_ADDR']

In [0]:
if use_tpu is True:
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_ADDRESS)
    tf.config.experimental_connect_to_host(cluster_resolver.master())
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)  # <-- this currently fails with colab/TPU
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)    
    
    if use_simple_model_for_tpu is True:
        with tpu_strategy.scope():
            model = build_simple_model(
              vocab_size = len(textlib.i2c),
              embedding_dim=EMBEDDING_DIM,
              lstm_units=LSTM_UNITS,
              lstm_layers=LSTM_LAYERS,
              batch_size=BATCH_SIZE)
    else:
        with tpu_strategy.scope():
            model = build_model(
              vocab_size = len(textlib.i2c),
              embedding_dim=EMBEDDING_DIM,
              lstm_units=LSTM_UNITS,
              lstm_layers=LSTM_LAYERS,
              batch_size=BATCH_SIZE)        
else:
    model = build_model(
      vocab_size = len(textlib.i2c),
      embedding_dim=EMBEDDING_DIM,
      lstm_units=LSTM_UNITS,
      lstm_layers=LSTM_LAYERS,
      batch_size=BATCH_SIZE)

In [14]:
for input_example_batch, target_example_batch in dataset.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

(256, 80, 89) # (batch_size, sequence_length, vocab_size)


In [15]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (256, None, 120)          10680     
_________________________________________________________________
lstm (LSTM)                  (256, None, 768)          2731008   
_________________________________________________________________
lstm_1 (LSTM)                (256, None, 768)          4721664   
_________________________________________________________________
lstm_2 (LSTM)                (256, None, 768)          4721664   
_________________________________________________________________
lstm_3 (LSTM)                (256, None, 768)          4721664   
_________________________________________________________________
lstm_4 (LSTM)                (256, None, 768)          4721664   
_________________________________________________________________
lstm_5 (LSTM)                (256, None, 768)          4

In [16]:
dataset.take(1)

<TakeDataset shapes: ((256, 80), (256, 80)), types: (tf.int32, tf.int32)>

In [0]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()

In [18]:
sampled_indices

array([87, 62,  9, 14, 71, 60, 37, 58, 35, 44, 48, 17,  0, 39, 61, 29, 75,
       85, 39, 74, 41, 21, 61, 43, 46, 20, 40, 20,  9, 63, 78, 46, 70, 49,
       33, 19, 64, 44, 88, 42, 18, 85,  8, 36, 80, 47, 19, 40, 68, 62, 20,
        0, 50, 74,  1, 27, 58, 15, 46, 10, 32, 65, 65, 40, 62, 53, 51, 87,
        2, 54, 51, 49, 26, 62, 18, 80, 13, 84, 57, 38])

In [19]:
textlib.decode(sampled_indices)

'}/tg6ULHY5#kTANK;$AxDaN10dRdt(X093vi)5`Sf$c-?[iRV/dT7xh\nHE0Gm__R/C4}eI43\r/f?b@F:'

In [20]:
def loss(labels, logits):
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

example_batch_loss  = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss:      ", example_batch_loss.numpy().mean())

Prediction shape:  (256, 80, 89)  # (batch_size, sequence_length, vocab_size)
scalar_loss:       4.4883966


In [0]:
# adam_clipped = tf.keras.optimizers.Adam(lr=0.003, clipvalue=1.0)
adam_clipped = tf.keras.optimizers.Adam(clipvalue=0.5)

model.compile(optimizer=adam_clipped, loss=loss)

In [0]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [0]:
EPOCHS=100

In [24]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

Epoch 1/100


W0731 15:08:14.861467 140320682334080 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:457: BaseResourceVariable.constraint (from tensorflow.python.ops.resource_variable_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Apply a constraint manually following the optimizer update step.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7

In [0]:
# Generate

In [26]:
tf.train.latest_checkpoint(checkpoint_dir)

'./training_checkpoints/ckpt_100'

In [0]:
gen_model = build_model(vocab_size = len(textlib.i2c),
  embedding_dim=EMBEDDING_DIM,
  lstm_units=LSTM_UNITS,
  lstm_layers=LSTM_LAYERS,
  batch_size=1)
gen_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

gen_model.build(tf.TensorShape([1, None]))

In [28]:
gen_model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (1, None, 120)            10680     
_________________________________________________________________
lstm_6 (LSTM)                (1, None, 768)            2731008   
_________________________________________________________________
lstm_7 (LSTM)                (1, None, 768)            4721664   
_________________________________________________________________
lstm_8 (LSTM)                (1, None, 768)            4721664   
_________________________________________________________________
lstm_9 (LSTM)                (1, None, 768)            4721664   
_________________________________________________________________
lstm_10 (LSTM)               (1, None, 768)            4721664   
_________________________________________________________________
lstm_11 (LSTM)               (1, None, 768)           

In [0]:
def generate_text(model, start_string, temp=0.6):
  # Evaluation step (generating text using the learned model)

  # Number of characters to generate
  num_generate = 1000

  # Converting our start string to numbers (vectorizing)
  input_eval = [textlib.c2i[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)

  # Empty string to store our results
  text_generated = []
  ids=[]

  # Low temperatures results in more predictable text.
  # Higher temperatures results in more surprising text.
  # Experiment to find the best setting.
  temperature = temp

  # Here batch size == 1
  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
      ids.append(predicted_id)

      # We pass the predicted word as the next input to the model
      # along with the previous hidden state
      input_eval = tf.expand_dims([predicted_id], 0)

      text_generated.append(textlib.i2c[predicted_id])

  return (start_string + ''.join(text_generated), ids)

In [0]:
tx,id=generate_text(gen_model, start_string="With the clarity of thought of an artificial life form, the discussion went on:", temp=0.4)

In [0]:
def detectPlagiarism(tx, textlibrary, minQuoteLength=10):
    textlibrary.source_highlight(tx, minQuoteLength)

In [118]:
txt=textlib.decode(id)
txti=txt.split('\r\n')
for t in txti:
    print(t)

 Ellen who had seen them
in the conservatory to a pretty terrible place than she had taken up the study
of botany since her daughter married, and it was to be talked of beauty of her resemblence
in her kyes, and to the love of the place,
the streets, the people who said the long time time to me.  It had no sooner beneath here before them.

"Who writes the best Latin verse in your college, Mr. Woodhouse's spirits,
which terred to the time, it was not very likely, from Richmond Catherine letter too. In the stone on the same spot now. It was a laughtouble protector, which he may be
recognised to love each other. The poke floom was too polite not to continue the conversation of the stairs of the subject.--The
bad no more formerly inviting a chind to eat
attempt to write like a man. Every other woman does not attempt to the place and half a stone compliment
in her knew to be settled and perfectlious still, only to a some people sitting by the
tentation of a statesman in Section 


In [119]:
detectPlagiarism(tx, textlib)

## References:
* <https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/text/text_generation.ipynb>
* <https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb>

## 6. A dialog with the trained model [not ported yet]

In [0]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):


def doDialog():
    # 0.1 (frozen character) - 1.3 (creative/chaotic character)
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer

    with tf.Session() as sess:
        print("Please enter some dialog.")
        print("The net will answer according to your input.")
        print("'bye' for end,")
        print("'reset' to reset the conversation context,")
        print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
        print("    to change character of the dialog.")
        print("    Current temperature={}.".format(temperature))
        print()
        xso = None
        bye = False
        model.init.run()

        tflogdir = os.path.realpath(model.logdir)
        if not os.path.exists(tflogdir):
            print("You haven't trained a model, no data found at: {}".format(
                trainParams["logdir"]))
            return

        # Used for saving the training parameters periodically
        saver = tf.train.Saver()
        checkpoint_file = os.path.join(tflogdir, model.checkpoint)

        lastSave = tf.train.latest_checkpoint(tflogdir, latest_filename=None)
        if lastSave is not None:
            pt = lastSave.rfind('-')
            if pt != -1:
                pt += 1
                start_iter = int(lastSave[pt:])
            # print("Restoring checkpoint at {}: {}".format(start_iter, lastSave))
            saver.restore(sess, lastSave)
        else:
            print("No checkpoints have been saved at:{}".format(tflogdir))
            return

        # g_state = sess.run([model.init_state_0], feed_dict={model.batch_size: 1})
        doini = True

        bye = False
        while not bye:
            print("> ", end="")
            prompt = input()
            if prompt == 'bye':
                bye = True
                print("Good bye!")
                continue
            if prompt == 'reset':
                doini = True
                # g_state = sess.run([model.init_state_0], feed_dict={model.batch_size: 1})
                print("(conversation context marked for reset)")
                continue
            if prompt[:len("temperature=")] == "temperature=":
                t = float(prompt[len("temperature="):])
                if t > 0.05 and t < 1.4:
                    temperature = t
                    print("(generator temperature now {})".format(t))
                    print()
                    continue
                print("Invalid temperature-value ignored! [0.1-1.0]")
                continue
            xs = ' ' * model.steps
            xso = ''
            for rep in range(1):
                for i in range(len(prompt)):
                    xs = xs[1:]+prompt[i]
                    X_new = np.transpose([[textlib.c2i[sj]] for sj in xs])
                    if doini:
                        doini = False
                        g_state = sess.run(
                            [model.init_state_0], feed_dict={model.X: X_new})
                    g_state, y_pred = sess.run([model.final_state, model.output_softmax_temp],
                                               feed_dict={model.X: X_new, model.init_state: g_state,
                                                          model.temperature: temperature})
            ans = 0
            numEndPrompts = 0
            while (ans < maxAnswerSize and numEndPrompts < maxEndPrompts) or ans < minAnswerSize:

                X_new = np.transpose([[textlib.c2i[sj]] for sj in xs])
                g_state, y_pred = sess.run([model.final_state, model.output_softmax_temp],
                                           feed_dict={model.X: X_new, model.init_state: g_state,
                                                      model.temperature: temperature})
                inds = list(range(model.vocab_size))
                ind = np.random.choice(inds, p=y_pred[0, -1].ravel())
                nc = textlib.i2c[ind]
                if nc == endPrompt:
                    numEndPrompts += 1
                xso += nc
                xs = xs[1:]+nc
                ans += 1
            print(xso.replace("\\n", "\n"))
            textlib.source_highlight(xso, 13)
    return

In [0]:
# Talk to the net!
doDialog()