<a href="https://colab.research.google.com/github/domschl/transformer-poet/blob/main/transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [2]:
!pip install -U ml-indie-tools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ml-indie-tools
  Downloading ml_indie_tools-0.3.8-py3-none-any.whl (36 kB)
Installing collected packages: ml-indie-tools
Successfully installed ml-indie-tools-0.3.8


In [3]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, regularizers

import tensorflow_datasets as tfds

In [4]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.keras_custom_layers import MultiHeadSelfAttention, PositionalEncoding

Using TF-Keras version: 2.9.0


## Preliminary

A tensorflow deep multi-head attention model for text generation

This code can use either CPU, GPU, TPU when running on Google Colab.
Select the corresponding runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [5]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='tf', accelerator='fastest')
ml_env.describe()

'OS: Linux, Python: 3.7.15, Colab Jupyter Notebook Tensorflow: 2.9.2, TPU: TPU, 8 nodes v2 (8GB)'

In [6]:
use_eager=tf.executing_eagerly()
if ml_env.is_tpu is True:
    tpu_strategy = ml_env.tpu_strategy
    tpu_is_init=True
    if use_eager is True:
        tf.config.run_functions_eagerly(False)
    use_eager=False

In [7]:
project_name='women_writers'
model_name='mhsa_v1_tf'

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Mounted at /content/drive
Root path (all projects) : /content/drive/My Drive (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : /content/drive/My Drive/Colab Notebooks/women_writers (Changes to the file system happen only below this project path
Model path (snapshots)   : /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf (Model weights and snapshots are stored here)
Data path (training data): /content/drive/My Drive/Colab Notebooks/women_writers/data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [8]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [9]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [10]:
# sample searches
search_spec= {"author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

21 matching books found with search {'author': ['Emily Brontë', 'Jane Austen', 'Virginia Woolf'], 'language': ['english']}.


In [11]:
for i in range(len(book_list)):
    print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

0: The Common Reader - Virginia Woolf, 64457
1: Mr. Bennett and Mrs. Brown - Virginia Woolf, 63022
2: The Younger Sister, Volumes 1-3 - Catherine Anne Austen Hubback and Jane Austen, 54066
3: The Younger Sister, Vol. 3 - Catherine Anne Austen Hubback and Jane Austen, 54012
4: The Younger Sister, Vol. 2 - Catherine Anne Austen Hubback and Jane Austen, 54011
5: The Younger Sister, Vol. 1 - Catherine Anne Austen Hubback and Jane Austen, 54010
6: Pride and Prejudice - Jane Austen, 42671
7: The Letters of Jane Austen - Jane Austen, 42078
8: The Complete Project Gutenberg Works of Jane Austen - Jane Austen, 31100
9: Jacob's Room - Virginia Woolf, 5670
10: Pride and Prejudice - Jane Austen, 1342
11: Night and Day - Virginia Woolf, 1245
12: Love And Friendship And Other Early Works - Jane Austen, 1212
13: Lady Susan - Jane Austen, 946
14: Wuthering Heights - Emily Brontë, 768
15: Sense and Sensibility - Jane Austen, 161
16: Emma - Jane Austen, 158
17: The Voyage Out - Virginia Woolf, 144
18: M

In [12]:
select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]

print("Using:")
for i in range(len(sub_book_list)):
    print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

textlib_dataset = None  # Forces re-caching
td = Text_Dataset(sub_book_list)
td.init_tokenizer(tokenizer='ngram', max_ngrams=6, max_tokens=5000)


Using:
1: Mr. Bennett and Mrs. Brown - Virginia Woolf
2: Jacob's Room - Virginia Woolf
3: Pride and Prejudice - Jane Austen
4: Night and Day - Virginia Woolf
5: Lady Susan - Jane Austen
6: Wuthering Heights - Emily Brontë
7: Sense and Sensibility - Jane Austen
8: Emma - Jane Austen
9: The Voyage Out - Virginia Woolf
10: Mansfield Park - Jane Austen
11: Northanger Abbey - Jane Austen


In [13]:
SEQUENCE_LEN = 80
SUB_PROBABILITY = 0.15  # like BERT

td.init_getitem(sample_type='encoded', sample_length=SEQUENCE_LEN, content_stepping=1)

num_records = len(td)

print(f"{num_records} records")

1854539 records


In [14]:
def get_sample_batch(td, batch_size, length, SUB_probability=0.15):
    for i in range(batch_size):
        Xi = td.get_random_item()
        yi = Xi.copy()
        l=int(len(Xi)*SUB_probability)
        for li in range(l):
            pos=random.randint(0,len(Xi)-1)
            if td.tokenizer_type=='char':
                Xi[pos]=td.c2i['␚']
            elif td.tokenizer_type=='word':
                Xi[pos]=td.w2i['<subst>']
            elif td.tokenizer_type=='ngram':
                Xi[pos]=td.t2i['<subst>']
            else:
                print(f"Unexpected tokenizer_type {td.tokenizer_type}")
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

In [15]:
test_x, test_y = get_sample_batch(td, 2, 40, SUB_probability=SUB_PROBABILITY)
for i in range(len(test_x)):
    xi=[int(x) for x in test_x[i]]
    print(f"[{i}](l={len(xi)}): X=>{td.decode(xi)}<, y=>{td.decode(test_y[i])}<")

[0](l=80): X=>both
lighthouse and bird; he was steadfast and brilli<subst><subst>at the same
time he was whirled, with all other things, senseless <subst>st t<subst>gl<subst><subst>got up, left his <subst><subst>e of <subst>lver, and pressed on, with the
wind against him. The <subst>age of the lighthouse and t<subst>orm full of
birds <, y=>both
lighthouse and bird; he was steadfast and brilliant; and at the same
time he was whirled, with all other things, senseless against the
glass. He got up, left his tribute of silver, and pressed on, with the
wind against him. The image of the lighthouse and the storm full of
birds <
[1](l=80): X=>uncle and aunt<subst>  <subst>he was in town; and why not to me<subst>If he fears me, why come
      hi<subst>? If he<subst>longer cares for <subst>why silent? T<subst>ing,
     <subst>as<subst>man! I will think no more about him.”

  <subst>Her r<subst>olution was for a short time involuntarily <subst>pt by th<, y=>uncle and aunt,
      when he was in to

In [16]:
test_x.shape, test_y.shape

((2, 80), (2, 80))

## 2. Use tf.data for texts

In [17]:
def expand_name_template(template, params):
    exp=copy.copy(template)
    for key in params:
        src="{"+key+"}"
        dst=f"{params[key]}"
        exp=exp.replace(src,dst).replace('[','(').replace(']',')')
    return exp

def save_model_metadata(epoch, suffix='std'):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    params['current_epoch'] = epoch
    try:
        with open(meta_file, 'w') as f:
            f.write(json.dumps(params))
    except Exception as e:
        print(f"Failed to store model metadata at {model_path}: {e}")
        return False
    return True

def read_model_metadata(suffix="std"):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    try:
        with open(meta_file, 'r') as f:
            meta = json.load(f)
    except Exception as e:
        print(f"Cannot access project meta-data at {meta_file}: {e}, starting anew.")
        return None
    return meta

def is_metadata_compatible(params, meta):
    is_valid=True
    keys=set(list(params.keys())+list(meta.keys()))
    for key in keys:
        if key in updatable_keys:
            continue
        if key not in meta:
            print(f"Key {key} not available in last checkpoint model_meta, params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif key not in params:
            print(f"Key {key} not available in params, last checkpoint model_meta[{key}]: {meta[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif meta[key]!=params[key]:
            print(f"Last checkpoint model_meta[{key}]: {meta[key]} != params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
    if is_valid is False:
        print("Aborting import.")
        return False
    return True

In [47]:
vocabulary_size = td.get_unique_token_count()  # vocabulary-size

lyrs = 6;

params = { # Multi-head self-attention
    'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

    'mhsa_layers': lyrs, 
    'heads': [6]*lyrs,
    'units': [512]*lyrs,  # 0 inserts an LSTM for memory-states :-)
    'norm': 'softmax', # this is for within each head
    'mh_normalize': True,  # use layer-norm after concatenation (or additiona) of the heads
    'l2_regularizer': 1e-9,
    'dropout': 0.0,       # no dropout: 0.0
    'join_heads_by_add': True,  # stragegy how multi-heads are joined: False: concat (as in all-you-need), True: relu+add of all the heads
    'vocab_size': vocabulary_size,
    'sequence_len': SEQUENCE_LEN,
    'embedding_size': 128,

    'batch_size': 256,
    'learning_rate': 0.0002,
    'clipvalue': None,
    'sample_every_n_epochs': 100,
}

if len(params['heads'])!=params['mhsa_layers'] or len(params['units'])!=params['mhsa_layers']:
    print("ERROR: lenght of 'heads' and 'units' must be equal to mhsa_layers!")
    
if ml_env.is_tpu is True:
    lr = params['learning_rate']*1.0
else:
    lr = params['learning_rate']

model_suffix = expand_name_template(params['name'], params)
# Put 'important' params in checkpoint-pathname to separate model-data:
checkpoint_dir = os.path.join(model_path, f"training_checkpoints_{model_suffix}")
if os.path.exists(checkpoint_dir) is False:
    os.makedirs(checkpoint_dir)

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'dropout', 
             'sample_every_n_epochs']

# These values are taking from saved checkpoint:
keep_keys=['current_epoch']

continue_last = True
if continue_last is False:
    print("NOT continuing based on existing training! New start.")

meta = read_model_metadata(suffix=model_suffix)
if meta is not None and is_metadata_compatible(params, meta) is True and continue_last is True:
    for key in keep_keys:
        if key in meta:
            params[key]=meta[key]
    if params is not None:
        print(f"Continuing last session from epoch {params['current_epoch']}")
    else:
        print(f"No previous data, starting new model")
else:
    print("Starting new model")

print(params)

Continuing last session from epoch 6873
{'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}', 'mhsa_layers': 6, 'heads': [6, 6, 6, 6, 6, 6], 'units': [512, 512, 512, 512, 512, 512], 'norm': 'softmax', 'mh_normalize': True, 'l2_regularizer': 1e-09, 'dropout': 0.0, 'join_heads_by_add': True, 'vocab_size': 5000, 'sequence_len': 80, 'embedding_size': 128, 'batch_size': 256, 'learning_rate': 0.0002, 'clipvalue': None, 'sample_every_n_epochs': 100, 'current_epoch': 6873}


In [19]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 7244


In [20]:
# @tf.function   (only slows things down [considerably!])
def make_tf_dataset(num, random_index=False, SUB_probability=0.0):
    dx=[]
    dy=[]
    num_batches_active = num
    for i in range(num_batches_active):
        x,y=get_sample_batch(td, params['batch_size'], params['sequence_len'], SUB_probability=SUB_probability)
        if i<1:
            print(f"[{num} x]: {x.shape} -> {y.shape}")
        dx.append(x)
        dy.append(y)
    dx=np.array(dx)
    dy=np.array(dy)
    print(f"dx.shape={dx.shape}, dy.shape={dy.shape}")
    data_xy = (dx, dy)
    tf_dataset=tf.data.Dataset.from_tensor_slices(data_xy)
    return tf_dataset

In [21]:
MAX_NUM_BATCHES = 50000

if num_batches>MAX_NUM_BATCHES:
    restricted_batches=MAX_NUM_BATCHES
    print(f"Restrictinig {num_batches} to max of {restricted_batches}")
else:
    restricted_batches=num_batches
    print(f"{restricted_batches} batches")
if cached_batch_data == restricted_batches and textlib_dataset is not None:
    print("Reusing cached training-data")
else:
    print("Creating dataset, this is slow. Be patient...")
    textlib_dataset = make_tf_dataset(restricted_batches, SUB_probability=SUB_PROBABILITY)
    cached_batch_data = restricted_batches
    print("Dataset done and cached.")

7244 batches
Creating dataset, this is slow. Be patient...
[7244 x]: (256, 80) -> (256, 80)
dx.shape=(7244, 256, 80), dy.shape=(7244, 256, 80)
Dataset done and cached.


In [22]:
shuffle_buffer=10000
if ml_env.is_tpu is True:
    dataset=textlib_dataset.shuffle(shuffle_buffer).repeat()  # Otherwise TPU may run dry
else:
    dataset=textlib_dataset.shuffle(shuffle_buffer)  
dataset.take(1)

<TakeDataset element_spec=(TensorSpec(shape=(256, 80), dtype=tf.int32, name=None), TensorSpec(shape=(256, 80), dtype=tf.int32, name=None))>

In [23]:
if ml_env.is_tpu is False:
    validation_dataset = make_tf_dataset(10, random_index=True, SUB_probability=SUB_PROBABILITY)

In [24]:
def model_mhsa(inputs, params):
    dense = layers.Dense(params['vocab_size'], kernel_regularizer=regularizers.l2(params['l2_regularizer']))  # using softmax here prevents temperature adjust, affects 'from_logits' param in sparse_categorical loss 
    fl = layers.Flatten()
    dr = layers.Dropout(params['dropout'])
    pe = PositionalEncoding(amplitude=0.3)
    rs_up = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    if 0 in params['units']:
        lstm1 = layers.LSTM(units=vocabulary_size, return_sequences=True)
    if vocabulary_size>=300:
        emb=layers.Embedding(vocabulary_size,params['embedding_size'],input_length=params['sequence_len'])
    rs_down = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    mhsa=[]
    residuals=[]

    for i in range(params['mhsa_layers']):
        if params['units'][i]==0:
            mhsa.append(None)
            residuals.append(i)
        else:
            mhsa.append(MultiHeadSelfAttention(params['heads'][i], units=params['units'][i], norm=params['norm'], mh_normalize=params['mh_normalize'], join_heads_by_add=params['join_heads_by_add']))
    xint = tf.cast(inputs,dtype=tf.int32)
    if vocabulary_size<300:
        x = tf.one_hot(xint, params['vocab_size'], axis=-1)
    else:
        x = emb(xint)
    x = pe(x)
    for i in range(len(mhsa)):
        if i in residuals:
            x = rs_down(lstm1(rs_up(x)))+x
            print(f"Residual at layer {i} added.")
        else:
            x = mhsa[i](x)
        # x = mhsa[i](x,x)
    if params['dropout']>0.0:
        x = dr(x)
    # x = dense(fl(x))
    x = dense(x)
    return x 

In [25]:
def mhsa_generate(model, text, gen_len=64, temperature=0.9, argmax=False, verbose=False):
    if verbose is True:
        full=text[:-1]
    gen_text=""
    lf=0
    input = np.array(td.encode(text))
    while len(input) < params['sequence_len']:
        input = np.concatenate([td.encode('<pad>'),input])
    for i in range(gen_len):
        input = np.concatenate([input[1:],td.encode('<subst>')])
        if len(input)!=params['sequence_len']:
            print('assertion failure')
            return None
        pred = model(input)
        pred /= temperature
        pred = tf.keras.layers.Softmax()(pred)
        if tf.executing_eagerly() is True and ml_env.is_tpu is False:
            pred=pred.numpy()
        else:
            pred=tf.keras.backend.eval(pred)  # this is a cheat, it internaly used Numpy() too.
        if argmax is True:
            pred=np.argmax(pred[0],axis=1)
        else:
            pred = [np.random.choice(list(range(len(pred[0][-1]))), p=pred[0][-1])]
        input = np.concatenate([input[1:],[pred[-1]]])
        c = td.decode([pred[-1]])
        if verbose is True:
            print(c, end='')
            if c=='\n':
                lf=0
            else:
                lf += 1
                if (lf>80 and c==' ') or lf>120:
                    print()
                    lf=0
            full+=c
        gen_text+=c
    if verbose is True:
        print()
    return gen_text


In [26]:
if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        print("Creating TPU-scope model")
        inputs = keras.Input(shape=(params['sequence_len'],))
        outputs = model_mhsa(inputs, params)
        model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    print("Creating Default-scope model")
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model_cpu = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
else:
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    model_cpu = model

Creating TPU-scope model
Creating Default-scope model


In [27]:
def get_newest_checkpoint(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    paths = [os.path.join(checkpoint_dir, basename) for basename in files]
    return max(paths, key=os.path.getctime)

def import_previous_compatible_checkpoint(model, force_import=False):
    meta = read_model_metadata(suffix=model_suffix)
    if meta is None:
        print("No previous checkpoint found")
        return False
    if is_metadata_compatible(params, meta) is not True and force_import is False:
        print("No useable import found.")
        return False
    try:
        last_checkpoint = get_newest_checkpoint(checkpoint_dir) # Doesn't do anything: tf.train.latest_checkpoint(checkpoint_dir)
    except Exception as e:
        print(f"Cannot determine last checkpoint in {checkpoint_dir}, cannot import due to: {e}")
        return False
    print(f"Last checkpoint: {last_checkpoint}")
    try:
        model.load_weights(last_checkpoint)
    except Exception as e:
        print(f"Failed to import model {last_checkpoint}: {e}")
        return False
    if 'current_epoch' in meta:
        params['current_epoch'] = meta['current_epoch']
    print(f"Successful import of epoch {params['current_epoch']} from {last_checkpoint}, continuing from there...")
    return True

### Loss function, optimizer, tensorboard output

In [48]:
kscc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

def loss(labels, logits):
  vl=kscc(labels, logits)
  return vl

In [49]:
if params['clipvalue'] is not None:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
else:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr)
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr)

if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        model.compile(optimizer=opti, loss=loss, metrics=[], run_eagerly=False, jit_compile=True)
else:
    model.compile(optimizer=opti, loss=loss, metrics=['accuracy'])

In [50]:
import_checkpoint = False
force_import = False   # True: ignore metadata and try import anyway. This will of course crash, if the new model doesn't fit the checkpoint-data...

if import_checkpoint is True:
    import_previous_compatible_checkpoint(model, force_import=force_import)

In [51]:
model.summary()

Model: "mhsa_v1_tf"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 80)]              0         
                                                                 
 tf.cast (TFOpLambda)        (None, 80)                0         
                                                                 
 embedding (Embedding)       (None, 80, 128)           640000    
                                                                 
 positional_encoding (Positi  (None, 80, 128)          0         
 onalEncoding)                                                   
                                                                 
 multi_head_self_attention (  (None, 80, 128)          1606144   
 MultiHeadSelfAttention)                                         
                                                                 
 multi_head_self_attention_1  (None, 80, 128)          1

In [52]:
TPU_GENERATE_ON_CPU = False  # The thing is: both options are slow on TPU :-/

class ServiceCallback(keras.callbacks.Callback):
#    def on_test_end(self, logs=None):
    # @tf.function
    def on_epoch_end(self, epoch, logs=None):
        save_model_metadata(epoch, suffix=model_suffix)
        if (epoch+1) % params['sample_every_n_epochs'] == 0:
            idx=random.randint(0,len(td)-1)
            text=td.decode(td[idx])
            print()
            if ml_env.is_tpu is True:
                temp_list=[0.7] # [0.6,0.7,0.8]
                gen_len=50
                with tpu_strategy.scope():
                    weights=model.get_weights()
                model_cpu.set_weights(weights)
                # HDF5 is required for saving weights that originate from TPU
                # otherwise this just silently fails...
                checkpoint_path = os.path.join(checkpoint_dir, "cp-{epoch:04d}.h5")
                chkpt_dest=checkpoint_path.format(epoch=epoch)
                print(f"Checkpoint: {chkpt_dest}")
                model_cpu.save_weights(chkpt_dest)
            else:
                temp_list=[0.6, 0.7, 0.8]
                gen_len=192
            print(f"prompt: {text}")
            for temp in temp_list:
                print(f"---------------- T={temp} ---------------")
                if ml_env.is_tpu is True and TPU_GENERATE_ON_CPU is True:
                    with tf.device('/cpu:0'):
                        if temp==0.0:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                        else:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                else:
                    if temp==0.0:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                    else:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                td.source_highlight(reply, min_quote_size=10, dark_mode=use_dark_mode, display_ref_anchor=False)
            print("--------------------------------------")

service_callback=ServiceCallback()

In [53]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join(log_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
if ml_env.is_tpu:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch', write_graph=False)
else:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch')


In [54]:
# Dont try:
#    # use the python variable log_path:
#   get_ipython().run_line_magic('tensorboard', '--logdir "{log_path}"')
#except:
#   pass

# The following throws errors on non-colab, but the guarding above is too bug-ridden.
# if ml_env.is_tpu is False:
#    %tensorboard --logdir logs

## The actual training

In [55]:
EPOCHS=500000
if 'current_epoch' in params:
    initial_epoch=params['current_epoch']
else:
    initial_epoch=0

override=200
print(f"WARNING override of sample_every_n_epochs sample-generation to: {override}")
params['sample_every_n_epochs']=override



In [None]:
if ml_env.is_tpu is True:
    steps_per_epoch=restricted_batches//params['batch_size']
    if steps_per_epoch < 1:
        steps_per_epoch = 1
    history = model.fit(dataset, epochs=EPOCHS, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, callbacks=[service_callback]) # for TPU we need to role our own checkpointer since we need to transfer the weights
else:
    history = model.fit(dataset, validation_data=validation_dataset, epochs=EPOCHS, initial_epoch=initial_epoch, callbacks=[checkpoint_callback, tensorboard_callback, service_callback])

Epoch 6874/500000
 5/28 [====>.........................] - ETA: 0s - loss: 5.8706



Epoch 6875/500000
Epoch 6876/500000
Epoch 6877/500000
Epoch 6878/500000
Epoch 6879/500000
Epoch 6880/500000
Epoch 6881/500000
Epoch 6882/500000
Epoch 6883/500000
Epoch 6884/500000
Epoch 6885/500000
Epoch 6886/500000
Epoch 6887/500000
Epoch 6888/500000
Epoch 6889/500000
Epoch 6890/500000
Epoch 6891/500000
Epoch 6892/500000
Epoch 6893/500000
Epoch 6894/500000
Epoch 6895/500000
Epoch 6896/500000
Epoch 6897/500000
Epoch 6898/500000
Epoch 6899/500000
Epoch 6900/500000
Epoch 6901/500000
Epoch 6902/500000
Epoch 6903/500000
Epoch 6904/500000
Epoch 6905/500000
Epoch 6906/500000
Epoch 6907/500000
Epoch 6908/500000
Epoch 6909/500000
Epoch 6910/500000
Epoch 6911/500000
Epoch 6912/500000
Epoch 6913/500000
Epoch 6914/500000
Epoch 6915/500000
Epoch 6916/500000
Epoch 6917/500000
Epoch 6918/500000
Epoch 6919/500000
Epoch 6920/500000
Epoch 6921/500000
Epoch 6922/500000
Epoch 6923/500000
Epoch 6924/500000
Epoch 6925/500000
Epoch 6926/500000
Epoch 6927/500000
Epoch 6928/500000
Epoch 6929/500000
Epoch 6930

--------------------------------------
Epoch 7001/500000
Epoch 7002/500000
Epoch 7003/500000
Epoch 7004/500000
Epoch 7005/500000
Epoch 7006/500000
Epoch 7007/500000
Epoch 7008/500000
Epoch 7009/500000
Epoch 7010/500000
Epoch 7011/500000
Epoch 7012/500000
Epoch 7013/500000
Epoch 7014/500000
Epoch 7015/500000
Epoch 7016/500000
Epoch 7017/500000
Epoch 7018/500000
Epoch 7019/500000
Epoch 7020/500000
Epoch 7021/500000
Epoch 7022/500000
Epoch 7023/500000
Epoch 7024/500000
Epoch 7025/500000
Epoch 7026/500000
Epoch 7027/500000
Epoch 7028/500000
Epoch 7029/500000
Epoch 7030/500000
Epoch 7031/500000
Epoch 7032/500000
Epoch 7033/500000
Epoch 7034/500000
Epoch 7035/500000
Epoch 7036/500000
Epoch 7037/500000
Epoch 7038/500000
Epoch 7039/500000
Epoch 7040/500000
Epoch 7041/500000
Epoch 7042/500000
Epoch 7043/500000
Epoch 7044/500000
Epoch 7045/500000
Epoch 7046/500000
Epoch 7047/500000
Epoch 7048/500000
Epoch 7049/500000
Epoch 7050/500000
Epoch 7051/500000
Epoch 7052/500000
Epoch 7053/500000
Epoch 7

--------------------------------------
Epoch 7201/500000
Epoch 7202/500000
Epoch 7203/500000
Epoch 7204/500000
Epoch 7205/500000
Epoch 7206/500000
Epoch 7207/500000
Epoch 7208/500000
Epoch 7209/500000
Epoch 7210/500000
Epoch 7211/500000
Epoch 7212/500000
Epoch 7213/500000
Epoch 7214/500000
Epoch 7215/500000
Epoch 7216/500000
Epoch 7217/500000
Epoch 7218/500000
Epoch 7219/500000
Epoch 7220/500000
Epoch 7221/500000
Epoch 7222/500000
Epoch 7223/500000
Epoch 7224/500000
Epoch 7225/500000
Epoch 7226/500000
Epoch 7227/500000
Epoch 7228/500000
Epoch 7229/500000
Epoch 7230/500000
Epoch 7231/500000
Epoch 7232/500000
Epoch 7233/500000
Epoch 7234/500000
Epoch 7235/500000
Epoch 7236/500000
Epoch 7237/500000
Epoch 7238/500000
Epoch 7239/500000
Epoch 7240/500000
Epoch 7241/500000
Epoch 7242/500000
Epoch 7243/500000
Epoch 7244/500000
Epoch 7245/500000
Epoch 7246/500000
Epoch 7247/500000
Epoch 7248/500000
Epoch 7249/500000
Epoch 7250/500000
Epoch 7251/500000
Epoch 7252/500000
Epoch 7253/500000
Epoch 7

--------------------------------------
Epoch 7401/500000
Epoch 7402/500000
Epoch 7403/500000
Epoch 7404/500000
Epoch 7405/500000
Epoch 7406/500000
Epoch 7407/500000
Epoch 7408/500000
Epoch 7409/500000
Epoch 7410/500000
Epoch 7411/500000
Epoch 7412/500000
Epoch 7413/500000
Epoch 7414/500000
Epoch 7415/500000
Epoch 7416/500000
Epoch 7417/500000
Epoch 7418/500000
Epoch 7419/500000
Epoch 7420/500000
Epoch 7421/500000
Epoch 7422/500000
Epoch 7423/500000
Epoch 7424/500000
Epoch 7425/500000
Epoch 7426/500000
Epoch 7427/500000
Epoch 7428/500000
Epoch 7429/500000
Epoch 7430/500000
Epoch 7431/500000
Epoch 7432/500000
Epoch 7433/500000
Epoch 7434/500000
Epoch 7435/500000
Epoch 7436/500000
Epoch 7437/500000
Epoch 7438/500000
Epoch 7439/500000
Epoch 7440/500000
Epoch 7441/500000
Epoch 7442/500000
Epoch 7443/500000
Epoch 7444/500000
Epoch 7445/500000
Epoch 7446/500000
Epoch 7447/500000
Epoch 7448/500000
Epoch 7449/500000
Epoch 7450/500000
Epoch 7451/500000
Epoch 7452/500000
Epoch 7453/500000
Epoch 7

--------------------------------------
Epoch 7601/500000
Epoch 7602/500000
Epoch 7603/500000
Epoch 7604/500000
Epoch 7605/500000
Epoch 7606/500000
Epoch 7607/500000
Epoch 7608/500000
Epoch 7609/500000
Epoch 7610/500000
Epoch 7611/500000
Epoch 7612/500000
Epoch 7613/500000
Epoch 7614/500000
Epoch 7615/500000
Epoch 7616/500000
Epoch 7617/500000
Epoch 7618/500000
Epoch 7619/500000
Epoch 7620/500000
Epoch 7621/500000
Epoch 7622/500000
Epoch 7623/500000
Epoch 7624/500000
Epoch 7625/500000
Epoch 7626/500000
Epoch 7627/500000
Epoch 7628/500000
Epoch 7629/500000
Epoch 7630/500000
Epoch 7631/500000
Epoch 7632/500000
Epoch 7633/500000
Epoch 7634/500000
Epoch 7635/500000
Epoch 7636/500000
Epoch 7637/500000
Epoch 7638/500000
Epoch 7639/500000
Epoch 7640/500000
Epoch 7641/500000
Epoch 7642/500000
Epoch 7643/500000
Epoch 7644/500000
Epoch 7645/500000
Epoch 7646/500000
Epoch 7647/500000
Epoch 7648/500000
Epoch 7649/500000
Epoch 7650/500000
Epoch 7651/500000
Epoch 7652/500000
Epoch 7653/500000
Epoch 7

--------------------------------------
Epoch 7801/500000
Epoch 7802/500000
Epoch 7803/500000
Epoch 7804/500000
Epoch 7805/500000
Epoch 7806/500000
Epoch 7807/500000
Epoch 7808/500000
Epoch 7809/500000
Epoch 7810/500000
Epoch 7811/500000
Epoch 7812/500000
Epoch 7813/500000
Epoch 7814/500000
Epoch 7815/500000
Epoch 7816/500000
Epoch 7817/500000
Epoch 7818/500000
Epoch 7819/500000
Epoch 7820/500000
Epoch 7821/500000
Epoch 7822/500000
Epoch 7823/500000
Epoch 7824/500000
Epoch 7825/500000
Epoch 7826/500000
Epoch 7827/500000
Epoch 7828/500000
Epoch 7829/500000
Epoch 7830/500000
Epoch 7831/500000
Epoch 7832/500000
Epoch 7833/500000
Epoch 7834/500000
Epoch 7835/500000
Epoch 7836/500000
Epoch 7837/500000
Epoch 7838/500000
Epoch 7839/500000
Epoch 7840/500000
Epoch 7841/500000
Epoch 7842/500000
Epoch 7843/500000
Epoch 7844/500000
Epoch 7845/500000
Epoch 7846/500000
Epoch 7847/500000
Epoch 7848/500000
Epoch 7849/500000
Epoch 7850/500000
Epoch 7851/500000
Epoch 7852/500000
Epoch 7853/500000
Epoch 7

--------------------------------------
Epoch 8001/500000
Epoch 8002/500000
Epoch 8003/500000
Epoch 8004/500000
Epoch 8005/500000
Epoch 8006/500000
Epoch 8007/500000
Epoch 8008/500000
Epoch 8009/500000
Epoch 8010/500000
Epoch 8011/500000
Epoch 8012/500000
Epoch 8013/500000
Epoch 8014/500000
Epoch 8015/500000
Epoch 8016/500000
Epoch 8017/500000
Epoch 8018/500000
Epoch 8019/500000
Epoch 8020/500000
Epoch 8021/500000
Epoch 8022/500000
Epoch 8023/500000
Epoch 8024/500000
Epoch 8025/500000
Epoch 8026/500000
Epoch 8027/500000
Epoch 8028/500000
Epoch 8029/500000
Epoch 8030/500000
Epoch 8031/500000
Epoch 8032/500000
Epoch 8033/500000
Epoch 8034/500000
Epoch 8035/500000
Epoch 8036/500000
Epoch 8037/500000
Epoch 8038/500000
Epoch 8039/500000
Epoch 8040/500000
Epoch 8041/500000
Epoch 8042/500000
Epoch 8043/500000
Epoch 8044/500000
Epoch 8045/500000
Epoch 8046/500000
Epoch 8047/500000
Epoch 8048/500000
Epoch 8049/500000
Epoch 8050/500000
Epoch 8051/500000
Epoch 8052/500000
Epoch 8053/500000
Epoch 8

--------------------------------------
Epoch 8201/500000
Epoch 8202/500000
Epoch 8203/500000
Epoch 8204/500000
Epoch 8205/500000
Epoch 8206/500000
Epoch 8207/500000
Epoch 8208/500000
Epoch 8209/500000
Epoch 8210/500000
Epoch 8211/500000
Epoch 8212/500000
Epoch 8213/500000
Epoch 8214/500000
Epoch 8215/500000
Epoch 8216/500000
Epoch 8217/500000
Epoch 8218/500000
Epoch 8219/500000
Epoch 8220/500000
Epoch 8221/500000
Epoch 8222/500000
Epoch 8223/500000
Epoch 8224/500000
Epoch 8225/500000
Epoch 8226/500000
Epoch 8227/500000
Epoch 8228/500000
Epoch 8229/500000
Epoch 8230/500000
Epoch 8231/500000
Epoch 8232/500000
Epoch 8233/500000
Epoch 8234/500000
Epoch 8235/500000
Epoch 8236/500000
Epoch 8237/500000
Epoch 8238/500000
Epoch 8239/500000
Epoch 8240/500000
Epoch 8241/500000
Epoch 8242/500000
Epoch 8243/500000
Epoch 8244/500000
Epoch 8245/500000
Epoch 8246/500000
Epoch 8247/500000
Epoch 8248/500000
Epoch 8249/500000
Epoch 8250/500000
Epoch 8251/500000
Epoch 8252/500000
Epoch 8253/500000
Epoch 8

--------------------------------------
Epoch 8401/500000
Epoch 8402/500000
Epoch 8403/500000
Epoch 8404/500000
Epoch 8405/500000
Epoch 8406/500000
Epoch 8407/500000
Epoch 8408/500000
Epoch 8409/500000
Epoch 8410/500000
Epoch 8411/500000
Epoch 8412/500000
Epoch 8413/500000
Epoch 8414/500000
Epoch 8415/500000
Epoch 8416/500000
Epoch 8417/500000
Epoch 8418/500000
Epoch 8419/500000
Epoch 8420/500000
Epoch 8421/500000
Epoch 8422/500000
Epoch 8423/500000
Epoch 8424/500000
Epoch 8425/500000
Epoch 8426/500000
Epoch 8427/500000
Epoch 8428/500000
Epoch 8429/500000
Epoch 8430/500000
Epoch 8431/500000
Epoch 8432/500000
Epoch 8433/500000
Epoch 8434/500000
Epoch 8435/500000
Epoch 8436/500000
Epoch 8437/500000
Epoch 8438/500000
Epoch 8439/500000
Epoch 8440/500000
Epoch 8441/500000
Epoch 8442/500000
Epoch 8443/500000
Epoch 8444/500000
Epoch 8445/500000
Epoch 8446/500000
Epoch 8447/500000
Epoch 8448/500000
Epoch 8449/500000
Epoch 8450/500000
Epoch 8451/500000
Epoch 8452/500000
Epoch 8453/500000
Epoch 8

--------------------------------------
Epoch 8601/500000
Epoch 8602/500000
Epoch 8603/500000
Epoch 8604/500000
Epoch 8605/500000
Epoch 8606/500000
Epoch 8607/500000
Epoch 8608/500000
Epoch 8609/500000
Epoch 8610/500000
Epoch 8611/500000
Epoch 8612/500000
Epoch 8613/500000
Epoch 8614/500000
Epoch 8615/500000
Epoch 8616/500000
Epoch 8617/500000
Epoch 8618/500000
Epoch 8619/500000
Epoch 8620/500000
Epoch 8621/500000
Epoch 8622/500000
Epoch 8623/500000
Epoch 8624/500000
Epoch 8625/500000
Epoch 8626/500000
Epoch 8627/500000
Epoch 8628/500000
Epoch 8629/500000
Epoch 8630/500000
Epoch 8631/500000
Epoch 8632/500000
Epoch 8633/500000
Epoch 8634/500000
Epoch 8635/500000
Epoch 8636/500000
Epoch 8637/500000
Epoch 8638/500000
Epoch 8639/500000
Epoch 8640/500000
Epoch 8641/500000
Epoch 8642/500000
Epoch 8643/500000
Epoch 8644/500000
Epoch 8645/500000
Epoch 8646/500000
Epoch 8647/500000
Epoch 8648/500000
Epoch 8649/500000
Epoch 8650/500000
Epoch 8651/500000
Epoch 8652/500000
Epoch 8653/500000
Epoch 8

--------------------------------------
Epoch 8801/500000
Epoch 8802/500000
Epoch 8803/500000
Epoch 8804/500000
Epoch 8805/500000
Epoch 8806/500000
Epoch 8807/500000
Epoch 8808/500000
Epoch 8809/500000
Epoch 8810/500000
Epoch 8811/500000
Epoch 8812/500000
Epoch 8813/500000
Epoch 8814/500000
Epoch 8815/500000
Epoch 8816/500000
Epoch 8817/500000
Epoch 8818/500000
Epoch 8819/500000
Epoch 8820/500000
Epoch 8821/500000
Epoch 8822/500000
Epoch 8823/500000
Epoch 8824/500000
Epoch 8825/500000
Epoch 8826/500000
Epoch 8827/500000
Epoch 8828/500000
Epoch 8829/500000
Epoch 8830/500000
Epoch 8831/500000
Epoch 8832/500000
Epoch 8833/500000
Epoch 8834/500000
Epoch 8835/500000
Epoch 8836/500000
Epoch 8837/500000
Epoch 8838/500000
Epoch 8839/500000
Epoch 8840/500000
Epoch 8841/500000
Epoch 8842/500000
Epoch 8843/500000
Epoch 8844/500000
Epoch 8845/500000
Epoch 8846/500000
Epoch 8847/500000
Epoch 8848/500000
Epoch 8849/500000
Epoch 8850/500000
Epoch 8851/500000
Epoch 8852/500000
Epoch 8853/500000
Epoch 8

--------------------------------------
Epoch 9001/500000
Epoch 9002/500000
Epoch 9003/500000
Epoch 9004/500000
Epoch 9005/500000
Epoch 9006/500000
Epoch 9007/500000
Epoch 9008/500000
Epoch 9009/500000
Epoch 9010/500000
Epoch 9011/500000
Epoch 9012/500000
Epoch 9013/500000
Epoch 9014/500000
Epoch 9015/500000
Epoch 9016/500000
Epoch 9017/500000
Epoch 9018/500000
Epoch 9019/500000
Epoch 9020/500000
Epoch 9021/500000
Epoch 9022/500000
Epoch 9023/500000
Epoch 9024/500000
Epoch 9025/500000
Epoch 9026/500000
Epoch 9027/500000
Epoch 9028/500000
Epoch 9029/500000
Epoch 9030/500000
Epoch 9031/500000
Epoch 9032/500000
Epoch 9033/500000
Epoch 9034/500000
Epoch 9035/500000
Epoch 9036/500000
Epoch 9037/500000
Epoch 9038/500000
Epoch 9039/500000
Epoch 9040/500000
Epoch 9041/500000
Epoch 9042/500000
Epoch 9043/500000
Epoch 9044/500000
Epoch 9045/500000
Epoch 9046/500000
Epoch 9047/500000
Epoch 9048/500000
Epoch 9049/500000
Epoch 9050/500000
Epoch 9051/500000
Epoch 9052/500000
Epoch 9053/500000
Epoch 9

## A dialog with the trained model

In [None]:
model_cpu.set_weights(model.get_weights())

In [None]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):

def doDialog(model):
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
    doini = True
    bye = False
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        if prompt[:len("temperature=")] == "temperature=":
            t = float(prompt[len("temperature="):])
            if t > 0.05 and t < 1.4:
                temperature = t
                print("(generator temperature now {})".format(t))
                print()
                continue
            print("Invalid temperature-value ignored! [0.1-1.0]")
            continue
        reply=mhsa_generate(model, prompt, gen_len=256, temperature=temperature, verbose=True)
        td.source_highlight(reply, min_quote_size=13, dark_mode=use_dark_mode)

In [None]:
# Talk to the net!
doDialog(model_cpu)