<a href="https://colab.research.google.com/github/domschl/tensor-poet/blob/master/eager_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install TF 2.0, if necessary. This currently needs to be done when running from Colab.

In [None]:
!pip install tf-nightly-gpu-2.0-preview

# [WIP] Eager Tensor Poet (tf 2.0)

**THIS IS UNFINISHED WORK IN PROGRESS**

A tensorflow deep LSTM model for text generation

This code can use either CPU, GPU or TPU when running on Google Colab.

Select the corresponding runtime (menu: Runtime / Change runtime type)

Note: TPU support is not yet working.

In [None]:
%load_ext tensorboard

In [None]:
import numpy as np
import os
import json
import time
import datetime
import random
import tensorflow as tf
from IPython.core.display import display, HTML

from urllib.request import urlopen  # Py3

## 0. Check system

### Tensorflow api version check

Temporary note: currently, this is tested against the master build of tensorflow, which still has a version tag 1.14.x at the time of this writing. the version check below is preliminary.

In [None]:
try:
    if 'api.v2' in tf.version.__name__:
        print("Tensorflow api v2 active.")
    else:
        print("Tensorflow api v2 not found. This will not work.")
except:
    print("Failed to check for Tensorflow api v2. This will not work.")

### GPU/TPU check

In [None]:
from tensorflow.python.client import device_lib

use_tpu = False
use_gpu = False

try:
    TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
    use_tpu = True
    tf.config.experimental_connect_to_host(TPU_ADDRESS)
    print("TPU available at {}".format(TPU_ADDRESS))
except:
    print("No TPU available")

for hw in ["CPU", "GPU", "TPU"]:
    hwlist=tf.config.experimental.list_logical_devices(hw)
    print("{} -> {}".format(hw,hwlist))


if use_tpu is False:
    def get_available_devs_of_type(type):
        local_device_protos = device_lib.list_local_devices()
        return [x.name for x in local_device_protos if type in x.name]

    def get_dev_desc():
        local_device_protos = device_lib.list_local_devices()
        return [(x.name, x.physical_device_desc) for x in local_device_protos]

    def get_available_gpus():
        return get_available_devs_of_type('GPU')

    dl = get_available_gpus()
    if len(dl)==0:
        print("WARNING: You have neither TPU nor GPU, this is going to be very slow!")
        print("         Hint: If using Google Colab, set runtime type to TPU.")
        print(get_available_devs_of_type(''))
    else:
        use_gpu = True
        print(f"GPUs: {dl}")
        print(get_dev_desc())


##  1. Text library

In [None]:
# TextLibrary class: text library for training, encoding, batch generation,
# and formatted source display


class TextLibrary:
    def __init__(self, descriptors, max=100000000):
        self.descriptors = descriptors
        self.data = ''
        self.files = []
        self.c2i = {}
        self.i2c = {}
        index = 1
        for descriptor, name in descriptors:
            fd = {}
            if descriptor[:4] == 'http':
                try:
                    dat = urlopen(descriptor).read().decode('utf-8')
                    if dat[0]=='\ufeff':  # Ignore BOM
                        dat=dat[1:]
                    self.data += dat
                    fd["name"] = name
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                except Exception as e:
                    print(f"Can't download {descriptor}: {e}")
            else:
                fd["name"] = name
                try:
                    f = open(descriptor)
                    dat = f.read(max)
                    self.data += dat
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                    f.close()
                except Exception as e:
                    print(f"ERROR: Cannot read: {filename}: {e}")
        ind = 0
        for c in self.data:  # sets are not deterministic
            if c not in self.c2i:
                self.c2i[c] = ind
                self.i2c[ind] = c
                ind += 1
        self.ptr = 0

    def display_colored_html(self, textlist, pre='', post=''):
        bgcolors = ['#d4e6e1', '#d8daef', '#ebdef0', '#eadbd8', '#e2d7d5', '#edebd0',
                    '#ecf3cf', '#d4efdf', '#d0ece7', '#d6eaf8', '#d4e6f1', '#d6dbdf',
                    '#f6ddcc', '#fae5d3', '#fdebd0', '#e5e8e8', '#eaeded', '#A9CCE3']
        out = ''
        for txt, ind in textlist:
            txt = txt.replace('\n', '<br>')
            if ind == 0:
                out += txt
            else:
                out += "<span style=\"background-color:"+bgcolors[ind % 16]+";\">" + \
                       txt + "</span>"+"<sup>[" + str(ind) + "]</sup>"
        display(HTML(pre+out+post))

    def source_highlight(self, txt, minQuoteSize=10):
        tx = txt
        out = []
        qts = []
        txsrc = [("Sources: ", 0)]
        sc = False
        noquote = ''
        while len(tx) > 0:  # search all library files for quote 'txt'
            mxQ = 0
            mxI = 0
            mxN = ''
            found = False
            for f in self.files:  # find longest quote in all texts
                p = minQuoteSize
                if p <= len(tx) and tx[:p] in f["data"]:
                    p = minQuoteSize + 1
                    while p <= len(tx) and tx[:p] in f["data"]:
                        p += 1
                    if p-1 > mxQ:
                        mxQ = p-1
                        mxI = f["index"]
                        mxN = f["name"]
                        found = True
            if found:  # save longest quote for colorizing
                if len(noquote) > 0:
                    out.append((noquote, 0))
                    noquote = ''
                out.append((tx[:mxQ], mxI))
                tx = tx[mxQ:]
                if mxI not in qts:  # create a new reference, if first occurence
                    qts.append(mxI)
                    if sc:
                        txsrc.append((", ", 0))
                    sc = True
                    txsrc.append((mxN, mxI))
            else:
                noquote += tx[0]
                tx = tx[1:]
        if len(noquote) > 0:
            out.append((noquote, 0))
            noquote = ''
        self.display_colored_html(out)
        if len(qts) > 0:  # print references, if there is at least one source
            self.display_colored_html(txsrc, pre="<small><p style=\"text-align:right;\">",
                                     post="</p></small>")

    def get_slice(self, length):
        if (self.ptr + length >= len(self.data)):
            self.ptr = 0
        if self.ptr == 0:
            rst = True
        else:
            rst = False
        sl = self.data[self.ptr:self.ptr+length]
        self.ptr += length
        return sl, rst

    def decode(self, ar):
        return ''.join([self.i2c[ic] for ic in ar])

    def get_random_slice(self, length):
        p = random.randrange(0, len(self.data)-length)
        sl = self.data[p:p+length]
        return sl

    def get_slice_array(self, length):
        ar = np.array([c for c in self.get_slice(length)[0]])
        return ar

    def get_encoded_slice(self, length):
        s, rst = self.get_slice(length)
        X = [self.c2i[c] for c in s]
        return X
        
    def get_encoded_slice_array(self, length):
        return np.array(self.get_encoded_slice(length))

    def get_sample(self, length):
        s, rst = self.get_slice(length+1)
        X = [self.c2i[c] for c in s[:-1]]
        y = [self.c2i[c] for c in s[1:]]
        return (X, y, rst)

    def get_random_sample(self, length):
        s = self.get_random_slice(length+1)
        X = [self.c2i[c] for c in s[:-1]]
        y = [self.c2i[c] for c in s[1:]]
        return (X, y)

    def get_sample_batch(self, batch_size, length):
        smpX = []
        smpy = []
        for i in range(batch_size):
            Xi, yi, rst = self.get_sample(length)
            smpX.append(Xi)
            smpy.append(yi)
        return smpX, smpy, rst

    def get_random_sample_batch(self, batch_size, length):
        smpX = []
        smpy = []
        for i in range(batch_size):
            Xi, yi = self.get_random_sample(length)
            smpX.append(Xi)
            smpy.append(yi)
        return smpX, smpy


### Read text data

In [None]:
libdesc = {
    "name": "Woman Writers",
    "description": "A collection of works of Woolf, Austen and Brontë",
    "lib": [
        # 'data/tiny-shakespeare.txt',
        # since project gutenberg blocks the entire country of Germany, we use a mirror:
        # ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/0/100/100-0.txt', "Shakespeare: Collected Works"
        #  Project Gutenberg: Pride and Prejudice_ by Jane Austen, Wuthering Heights by Emily Brontë, The Voyage Out by Virginia Woolf and Emma_by Jane Austen
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/3/7/4/3/37431/37431.txt', "Jane Austen: Pride and Prejudice"),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/7/6/768/768.txt', "Emily Brontë: Wuthering Heights"),         
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/4/144/144.txt', "Virginia Wolf: Voyage out"),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/5/158/158.txt', "Jane Austen: Emma")
    ]
}

textlib = TextLibrary(libdesc["lib"])


## 2. Use tf.data for texts

In [None]:
SEQUENCE_LEN = 60
if use_tpu is True:
    BATCH_SIZE=256
    use_simple_model_for_tpu=True
else:
    BATCH_SIZE = 256
LSTM_UNITS = 768
EMBEDDING_DIM = 120
LSTM_LAYERS = 4
NUM_BATCHES=30

In [None]:
dx=[]
dy=[]
for i in range(NUM_BATCHES):
    x,y=textlib.get_random_sample_batch(BATCH_SIZE,SEQUENCE_LEN)
    dx.append(x)
    dy.append(y)

In [None]:
data_xy=(dx,dy)


In [None]:
textlib_dataset=tf.data.Dataset.from_tensor_slices(data_xy)

In [None]:
shuffle_buffer=10000
dataset=textlib_dataset.shuffle(shuffle_buffer)
dataset.take(1)

In [None]:
def build_model(vocab_size, embedding_dim, lstm_units, lstm_layers, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    *[tf.keras.layers.LSTM(lstm_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform') for _ in range(lstm_layers)],
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

@tf.function
def build_simple_model(vocab_size, embedding_dim, lstm_units, lstm_layers, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, SEQUENCE_LEN]),
    tf.keras.layers.LSTM(lstm_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform',
                        unroll=True),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model



In [None]:

# dev_strings=[]
# for log_dev in tf.config.experimental.list_logical_devices('TPU'):
#     dev_strings.append(log_dev.name)
# print(dev_strings)

# for i in range(8):
#     dev_strings.append('/TPU:{}'.format(i))
# print(dev_strings)

if use_tpu:
    print(TPU_ADDRESS)
    os.environ['COLAB_TPU_ADDR']

In [None]:
if use_tpu is True:
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_ADDRESS)
    tf.config.experimental_connect_to_host(cluster_resolver.master())
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)  # <-- this currently fails with colab/TPU
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)    
    
    if use_simple_model_for_tpu is True:
        with tpu_strategy.scope():
            model = build_simple_model(
              vocab_size = len(textlib.i2c),
              embedding_dim=EMBEDDING_DIM,
              lstm_units=LSTM_UNITS,
              lstm_layers=LSTM_LAYERS,
              batch_size=BATCH_SIZE)
    else:
        with tpu_strategy.scope():
            model = build_model(
              vocab_size = len(textlib.i2c),
              embedding_dim=EMBEDDING_DIM,
              lstm_units=LSTM_UNITS,
              lstm_layers=LSTM_LAYERS,
              batch_size=BATCH_SIZE)        
else:
    model = build_model(
      vocab_size = len(textlib.i2c),
      embedding_dim=EMBEDDING_DIM,
      lstm_units=LSTM_UNITS,
      lstm_layers=LSTM_LAYERS,
      batch_size=BATCH_SIZE)

In [None]:
for input_example_batch, target_example_batch in dataset.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

In [None]:
model.summary()

In [None]:
dataset.take(1)

In [None]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()

In [None]:
sampled_indices

In [None]:
textlib.decode(sampled_indices)

In [None]:
def loss(labels, logits):
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

example_batch_loss  = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss:      ", example_batch_loss.numpy().mean())

In [None]:
# adam_clipped = tf.keras.optimizers.Adam(lr=0.003, clipvalue=1.0)
adam_clipped = tf.keras.optimizers.Adam(clipvalue=0.5)

def scalar_loss(labels, logits):
    bl=loss(labels, logits)
    return tf.reduce_mean(bl)

model.compile(optimizer=adam_clipped, loss=loss, metrics=[scalar_loss])

In [None]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch', histogram_freq=1)

In [None]:
%tensorboard --logdir logs

In [None]:
EPOCHS=100

In [None]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback, tensorboard_callback])

In [None]:
# Generate

In [None]:
tf.train.latest_checkpoint(checkpoint_dir)

In [None]:
gen_model = build_model(vocab_size = len(textlib.i2c),
  embedding_dim=EMBEDDING_DIM,
  lstm_units=LSTM_UNITS,
  lstm_layers=LSTM_LAYERS,
  batch_size=1)
gen_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

gen_model.build(tf.TensorShape([1, None]))

In [None]:
gen_model.summary()

In [None]:
def generate_text(model, start_string, temp=0.6):
  # Evaluation step (generating text using the learned model)

  # Number of characters to generate
  num_generate = 1000

  # Converting our start string to numbers (vectorizing)
  input_eval = [textlib.c2i[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)

  # Empty string to store our results
  text_generated = []
  ids=[]

  # Low temperatures results in more predictable text.
  # Higher temperatures results in more surprising text.
  # Experiment to find the best setting.
  temperature = temp

  # Here batch size == 1
  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
      ids.append(predicted_id)

      # We pass the predicted word as the next input to the model
      # along with the previous hidden state
      input_eval = tf.expand_dims([predicted_id], 0)

      text_generated.append(textlib.i2c[predicted_id])

  return (start_string + ''.join(text_generated), ids)

In [None]:
tx,id=generate_text(gen_model, start_string="With the clarity of thought of an artificial life form, the discussion went on:", temp=0.8)

In [None]:
def detectPlagiarism(tx, textlibrary, minQuoteLength=10):
    textlibrary.source_highlight(tx, minQuoteLength)

In [None]:
txt=textlib.decode(id)
txti=txt.split('\r\n')
for t in txti:
    print(t)

In [None]:
detectPlagiarism(tx, textlib)

## References:
* <https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/text/text_generation.ipynb>
* <https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb>

## 6. A dialog with the trained model [not ported yet]

In [None]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):


def doDialog():
    # 0.1 (frozen character) - 1.3 (creative/chaotic character)
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer

    with tf.Session() as sess:
        print("Please enter some dialog.")
        print("The net will answer according to your input.")
        print("'bye' for end,")
        print("'reset' to reset the conversation context,")
        print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
        print("    to change character of the dialog.")
        print("    Current temperature={}.".format(temperature))
        print()
        xso = None
        bye = False
        model.init.run()

        tflogdir = os.path.realpath(model.logdir)
        if not os.path.exists(tflogdir):
            print("You haven't trained a model, no data found at: {}".format(
                trainParams["logdir"]))
            return

        # Used for saving the training parameters periodically
        saver = tf.train.Saver()
        checkpoint_file = os.path.join(tflogdir, model.checkpoint)

        lastSave = tf.train.latest_checkpoint(tflogdir, latest_filename=None)
        if lastSave is not None:
            pt = lastSave.rfind('-')
            if pt != -1:
                pt += 1
                start_iter = int(lastSave[pt:])
            # print("Restoring checkpoint at {}: {}".format(start_iter, lastSave))
            saver.restore(sess, lastSave)
        else:
            print("No checkpoints have been saved at:{}".format(tflogdir))
            return

        # g_state = sess.run([model.init_state_0], feed_dict={model.batch_size: 1})
        doini = True

        bye = False
        while not bye:
            print("> ", end="")
            prompt = input()
            if prompt == 'bye':
                bye = True
                print("Good bye!")
                continue
            if prompt == 'reset':
                doini = True
                # g_state = sess.run([model.init_state_0], feed_dict={model.batch_size: 1})
                print("(conversation context marked for reset)")
                continue
            if prompt[:len("temperature=")] == "temperature=":
                t = float(prompt[len("temperature="):])
                if t > 0.05 and t < 1.4:
                    temperature = t
                    print("(generator temperature now {})".format(t))
                    print()
                    continue
                print("Invalid temperature-value ignored! [0.1-1.0]")
                continue
            xs = ' ' * model.steps
            xso = ''
            for rep in range(1):
                for i in range(len(prompt)):
                    xs = xs[1:]+prompt[i]
                    X_new = np.transpose([[textlib.c2i[sj]] for sj in xs])
                    if doini:
                        doini = False
                        g_state = sess.run(
                            [model.init_state_0], feed_dict={model.X: X_new})
                    g_state, y_pred = sess.run([model.final_state, model.output_softmax_temp],
                                               feed_dict={model.X: X_new, model.init_state: g_state,
                                                          model.temperature: temperature})
            ans = 0
            numEndPrompts = 0
            while (ans < maxAnswerSize and numEndPrompts < maxEndPrompts) or ans < minAnswerSize:

                X_new = np.transpose([[textlib.c2i[sj]] for sj in xs])
                g_state, y_pred = sess.run([model.final_state, model.output_softmax_temp],
                                           feed_dict={model.X: X_new, model.init_state: g_state,
                                                      model.temperature: temperature})
                inds = list(range(model.vocab_size))
                ind = np.random.choice(inds, p=y_pred[0, -1].ravel())
                nc = textlib.i2c[ind]
                if nc == endPrompt:
                    numEndPrompts += 1
                xso += nc
                xs = xs[1:]+nc
                ans += 1
            print(xso.replace("\\n", "\n"))
            textlib.source_highlight(xso, 13)
    return

In [None]:
# Talk to the net!
doDialog()