<a href="https://colab.research.google.com/github/domschl/transformer-poet/blob/main/transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
!pip install -U ml-indie-tools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ml-indie-tools
  Downloading ml_indie_tools-0.3.17-py3-none-any.whl (37 kB)
Installing collected packages: ml-indie-tools
Successfully installed ml-indie-tools-0.3.17


In [2]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, regularizers

import tensorflow_datasets as tfds

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.keras_custom_layers import MultiHeadSelfAttention, PositionalEncoding

Using TF-Keras version: 2.9.0


## Preliminary

A tensorflow deep multi-head attention model for text generation

This code can use either CPU, GPU, TPU when running on Google Colab.
Select the corresponding runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [4]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='tf', accelerator='fastest')
ml_env.describe()

'OS: Linux, Python: 3.8.16, Colab Jupyter Notebook Tensorflow: 2.9.2, TPU: TPU, 8 nodes v2 (8GB)'

In [5]:
use_eager=tf.executing_eagerly()
if ml_env.is_tpu is True:
    tpu_strategy = ml_env.tpu_strategy
    tpu_is_init=True
    if use_eager is True:
        tf.config.run_functions_eagerly(False)
    use_eager=False

In [6]:
project_name='women_writers'
model_name='mhsa_v1_tf'

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Mounted at /content/drive
Root path (all projects) : /content/drive/My Drive (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : /content/drive/My Drive/Colab Notebooks/women_writers (Changes to the file system happen only below this project path
Model path (snapshots)   : /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf (Model weights and snapshots are stored here)
Data path (training data): /content/drive/My Drive/Colab Notebooks/women_writers/data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [7]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [8]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [9]:
# sample searches
search_spec= {"author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

21 matching books found with search {'author': ['Emily Brontë', 'Jane Austen', 'Virginia Woolf'], 'language': ['english']}.


In [10]:
for i in range(len(book_list)):
    print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

0: The Common Reader - Virginia Woolf, 64457
1: Mr. Bennett and Mrs. Brown - Virginia Woolf, 63022
2: The Younger Sister, Volumes 1-3 - Catherine Anne Austen Hubback and Jane Austen, 54066
3: The Younger Sister, Vol. 3 - Catherine Anne Austen Hubback and Jane Austen, 54012
4: The Younger Sister, Vol. 2 - Catherine Anne Austen Hubback and Jane Austen, 54011
5: The Younger Sister, Vol. 1 - Catherine Anne Austen Hubback and Jane Austen, 54010
6: Pride and Prejudice - Jane Austen, 42671
7: The Letters of Jane Austen - Jane Austen, 42078
8: The Complete Project Gutenberg Works of Jane Austen - Jane Austen, 31100
9: Jacob's Room - Virginia Woolf, 5670
10: Pride and Prejudice - Jane Austen, 1342
11: Night and Day - Virginia Woolf, 1245
12: Love And Friendship And Other Early Works - Jane Austen, 1212
13: Lady Susan - Jane Austen, 946
14: Wuthering Heights - Emily Brontë, 768
15: Sense and Sensibility - Jane Austen, 161
16: Emma - Jane Austen, 158
17: The Voyage Out - Virginia Woolf, 144
18: M

In [11]:
MAX_TOKENS = 20000  # This becomes vocab_size
MAX_NGRAM_LEN = 8   # Max length of a token

select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]

print("Using:")
for i in range(len(sub_book_list)):
    print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

textlib_dataset = None  # Forces re-caching
td = Text_Dataset(sub_book_list)
td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS)


Using:
1: Mr. Bennett and Mrs. Brown - Virginia Woolf
2: Jacob's Room - Virginia Woolf
3: Pride and Prejudice - Jane Austen
4: Night and Day - Virginia Woolf
5: Lady Susan - Jane Austen
6: Wuthering Heights - Emily Brontë
7: Sense and Sensibility - Jane Austen
8: Emma - Jane Austen
9: The Voyage Out - Virginia Woolf
10: Mansfield Park - Jane Austen
11: Northanger Abbey - Jane Austen
12: Persuasion - Jane Austen


In [12]:
SEQUENCE_LEN = 80
SUB_PROBABILITY = 0.15  # like BERT

td.init_getitem(sample_type='encoded', sample_length=SEQUENCE_LEN, content_stepping=1)

num_records = len(td)

print(f"{num_records} records")

1529756 records


In [13]:
def get_sample_batch(td, batch_size, length, SUB_probability=0.15):
    for i in range(batch_size):
        Xi = td.get_random_item()
        yi = Xi.copy()
        l=int(len(Xi)*SUB_probability)
        for li in range(l):
            pos=random.randint(0,len(Xi)-1)
            if td.tokenizer_type=='char':
                Xi[pos]=td.c2i['␚']
            elif td.tokenizer_type=='word':
                Xi[pos]=td.w2i['<subst>']
            elif td.tokenizer_type=='ngram':
                Xi[pos]=td.t2i['<subst>']
            else:
                print(f"Unexpected tokenizer_type {td.tokenizer_type}")
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

In [14]:
test_x, test_y = get_sample_batch(td, 2, 40, SUB_probability=SUB_PROBABILITY)
for i in range(len(test_x)):
    xi=[int(x) for x in test_x[i]]
    print(f"[{i}](l={len(xi)}): X=>{td.decode(xi)}<, y=>{td.decode(test_y[i])}<")

[0](l=80): X=>--"

<subst><subst>ped, regretting with a deep blush that she had implied so much;
but less would ha<subst>have been sufficient.  Mrs Smith would hardly
have believed so soon in Mr Elliot's failure, <subst>from the perception
of there being a some<subst>else.  As it <subst>she i<subst> submitted,
and with all<subst>mb<subst> of se<subst>nothing beyond; and Anne, eage<subst>
e<, y=>--"

She stopped, regretting with a deep blush that she had implied so much;
but less would hardly have been sufficient.  Mrs Smith would hardly
have believed so soon in Mr Elliot's failure, but from the perception
of there being a somebody else.  As it was, she instantly submitted,
and with all the semblance of seeing nothing beyond; and Anne, eager to
e<
[1](l=80): X=>, the cousin who taught the young ladies of
Bungay to play upon the violin, was the <subst>one in whom she could
confide, and as she walked up and down benea<subst><subst>ps <subst>pergola, she <subst><subst> a little speech to h

In [15]:
test_x.shape, test_y.shape

((2, 80), (2, 80))

## 2. Use tf.data for texts

In [16]:
def expand_name_template(template, params):
    exp=copy.copy(template)
    for key in params:
        src="{"+key+"}"
        dst=f"{params[key]}"
        exp=exp.replace(src,dst).replace('[','(').replace(']',')')
    return exp

def save_model_metadata(epoch, suffix='std'):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    params['current_epoch'] = epoch
    try:
        with open(meta_file, 'w') as f:
            f.write(json.dumps(params))
    except Exception as e:
        print(f"Failed to store model metadata at {model_path}: {e}")
        return False
    return True

def read_model_metadata(suffix="std"):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    try:
        with open(meta_file, 'r') as f:
            meta = json.load(f)
    except Exception as e:
        print(f"Cannot access project meta-data at {meta_file}: {e}, starting anew.")
        return None
    return meta

def is_metadata_compatible(params, meta):
    is_valid=True
    keys=set(list(params.keys())+list(meta.keys()))
    for key in keys:
        if key in updatable_keys:
            continue
        if key not in meta:
            print(f"Key {key} not available in last checkpoint model_meta, params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif key not in params:
            print(f"Key {key} not available in params, last checkpoint model_meta[{key}]: {meta[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif meta[key]!=params[key]:
            print(f"Last checkpoint model_meta[{key}]: {meta[key]} != params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
    if is_valid is False:
        print("Aborting import.")
        return False
    return True

In [45]:
vocabulary_size = td.get_unique_token_count()  # vocabulary-size

lyrs = 8;

params = { # Multi-head self-attention
    'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

    'mhsa_layers': lyrs, 
    'heads': [4]*lyrs,
    'units': [512]*lyrs,  # 0 inserts an LSTM for memory-states :-)
    'norm': 'softmax', # this is for within each head
    'mh_normalize': True,  # use layer-norm after concatenation (or additiona) of the heads
    'l2_regularizer': 1e-9,
    'dropout': 0.0,       # no dropout: 0.0
    'join_heads_by_add': True,  # stragegy how multi-heads are joined: False: concat (as in all-you-need), True: relu+add of all the heads
    'recurrent': True,
    'vocab_size': vocabulary_size,
    'sequence_len': SEQUENCE_LEN,
    'embedding_size': 128,

    'batch_size': 256,
    'learning_rate': 0.00002,
    'clipvalue': None,
    'sample_every_n_epochs': 100,
}

if len(params['heads'])!=params['mhsa_layers'] or len(params['units'])!=params['mhsa_layers']:
    print("ERROR: lenght of 'heads' and 'units' must be equal to mhsa_layers!")
    
if ml_env.is_tpu is True:
    lr = params['learning_rate']*1.0
else:
    lr = params['learning_rate']

model_suffix = expand_name_template(params['name'], params)
# Put 'important' params in checkpoint-pathname to separate model-data:
checkpoint_dir = os.path.join(model_path, f"training_checkpoints_{model_suffix}")
if os.path.exists(checkpoint_dir) is False:
    os.makedirs(checkpoint_dir)

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'dropout', 
             'sample_every_n_epochs']

# These values are taking from saved checkpoint:
keep_keys=['current_epoch']

continue_last = True
if continue_last is False:
    print("NOT continuing based on existing training! New start.")

meta = read_model_metadata(suffix=model_suffix)
if meta is not None and is_metadata_compatible(params, meta) is True and continue_last is True:
    for key in keep_keys:
        if key in meta:
            params[key]=meta[key]
    if params is not None:
        print(f"Continuing last session from epoch {params['current_epoch']}")
    else:
        print(f"No previous data, starting new model")
else:
    print("Starting new model")

print(params)

Continuing last session from epoch 10948
{'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}', 'mhsa_layers': 8, 'heads': [4, 4, 4, 4, 4, 4, 4, 4], 'units': [512, 512, 512, 512, 512, 512, 512, 512], 'norm': 'softmax', 'mh_normalize': True, 'l2_regularizer': 1e-09, 'dropout': 0.0, 'join_heads_by_add': True, 'recurrent': True, 'vocab_size': 20000, 'sequence_len': 80, 'embedding_size': 128, 'batch_size': 256, 'learning_rate': 2e-05, 'clipvalue': None, 'sample_every_n_epochs': 100, 'current_epoch': 10948}


In [46]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 5975


In [19]:
# @tf.function   (only slows things down [considerably!])
def make_tf_dataset(num, random_index=False, SUB_probability=0.0):
    dx=[]
    dy=[]
    num_batches_active = num
    for i in range(num_batches_active):
        x,y=get_sample_batch(td, params['batch_size'], params['sequence_len'], SUB_probability=SUB_probability)
        if i<1:
            print(f"[{num} x]: {x.shape} -> {y.shape}")
        dx.append(x)
        dy.append(y)
    dx=np.array(dx)
    dy=np.array(dy)
    print(f"dx.shape={dx.shape}, dy.shape={dy.shape}")
    data_xy = (dx, dy)
    tf_dataset=tf.data.Dataset.from_tensor_slices(data_xy)
    return tf_dataset

In [20]:
MAX_NUM_BATCHES = 50000

if num_batches>MAX_NUM_BATCHES:
    restricted_batches=MAX_NUM_BATCHES
    print(f"Restrictinig {num_batches} to max of {restricted_batches}")
else:
    restricted_batches=num_batches
    print(f"{restricted_batches} batches")
if cached_batch_data == restricted_batches and textlib_dataset is not None:
    print("Reusing cached training-data")
else:
    print("Creating dataset, this is slow. Be patient...")
    textlib_dataset = make_tf_dataset(restricted_batches, SUB_probability=SUB_PROBABILITY)
    cached_batch_data = restricted_batches
    print("Dataset done and cached.")

5975 batches
Creating dataset, this is slow. Be patient...
[5975 x]: (256, 80) -> (256, 80)
dx.shape=(5975, 256, 80), dy.shape=(5975, 256, 80)
Dataset done and cached.


In [21]:
shuffle_buffer=10000
if ml_env.is_tpu is True:
    dataset=textlib_dataset.shuffle(shuffle_buffer).repeat()  # Otherwise TPU may run dry
else:
    dataset=textlib_dataset.shuffle(shuffle_buffer)  
dataset.take(1)

<TakeDataset element_spec=(TensorSpec(shape=(256, 80), dtype=tf.int32, name=None), TensorSpec(shape=(256, 80), dtype=tf.int32, name=None))>

In [22]:
if ml_env.is_tpu is False:
    validation_dataset = make_tf_dataset(10, random_index=True, SUB_probability=SUB_PROBABILITY)

In [23]:
def model_mhsa(inputs, params):
    dense = layers.Dense(params['vocab_size'], kernel_regularizer=regularizers.l2(params['l2_regularizer']))  # using softmax here prevents temperature adjust, affects 'from_logits' param in sparse_categorical loss 
    fl = layers.Flatten()
    dr = layers.Dropout(params['dropout'])
    pe = PositionalEncoding(amplitude=0.3)
    rs_up = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    if 0 in params['units']:  # XXX remove!
        lstm1 = layers.LSTM(units=vocabulary_size, return_sequences=True)
    if vocabulary_size>=300:
        emb=layers.Embedding(vocabulary_size,params['embedding_size'],input_length=params['sequence_len'])
    rs_down = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    mhsa=[]
    residuals=[]

    for i in range(params['mhsa_layers']):
        if params['units'][i]==0:  # XXX remove!
            mhsa.append(None)
            residuals.append(i)
        else:
            mhsa.append(MultiHeadSelfAttention(params['heads'][i], units=params['units'][i], norm=params['norm'], mh_normalize=params['mh_normalize'], join_heads_by_add=params['join_heads_by_add'], recurrent=params['recurrent']))
    xint = tf.cast(inputs,dtype=tf.int32)
    if vocabulary_size<300:
        x = tf.one_hot(xint, params['vocab_size'], axis=-1)
    else:
        x = emb(xint)
    x = pe(x)
    # if params['recurrent'] is True:
    #     mem = x
    for i in range(len(mhsa)):
        if i in residuals:  # XXX remove!
            x = rs_down(lstm1(rs_up(x)))+x
            print(f"Residual at layer {i} added.")
        else:
            # if params['recurrent'] is True:
            #     x, mem = mhsa[i](x, mem)
            x = mhsa[i](x)
        # x = mhsa[i](x,x)
    if params['dropout']>0.0:
        x = dr(x)
    # x = dense(fl(x))
    x = dense(x)
    return x 

In [24]:
def mhsa_generate(model, text, gen_len=64, temperature=0.9, argmax=False, verbose=False):
    if verbose is True:
        full=text[:-1]
    gen_text=""
    lf=0
    input = np.array(td.encode(text))
    while len(input) < params['sequence_len']:
        input = np.concatenate([td.encode('<pad>'),input])
    for i in range(gen_len):
        input = np.concatenate([input[1:],td.encode('<subst>')])
        if len(input)!=params['sequence_len']:
            print('assertion failure')
            return None
        pred = model(input)
        pred /= temperature
        pred = tf.keras.layers.Softmax()(pred)
        if tf.executing_eagerly() is True and ml_env.is_tpu is False:
            pred=pred.numpy()
        else:
            pred=tf.keras.backend.eval(pred)  # this is a cheat, it internaly used Numpy() too.
        if argmax is True:
            pred=np.argmax(pred[0],axis=1)
        else:
            pred = [np.random.choice(list(range(len(pred[0][-1]))), p=pred[0][-1])]
        input = np.concatenate([input[1:],[pred[-1]]])
        c = td.decode([pred[-1]])
        if verbose is True:
            print(c, end='')
            if c=='\n':
                lf=0
            else:
                lf += 1
                if (lf>80 and c==' ') or lf>120:
                    print()
                    lf=0
            full+=c
        gen_text+=c
    if verbose is True:
        print()
    return gen_text


In [25]:
if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        print("Creating TPU-scope model")
        inputs = keras.Input(shape=(params['sequence_len'],))
        outputs = model_mhsa(inputs, params)
        model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    print("Creating Default-scope model")
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model_cpu = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
else:
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    model_cpu = model

Creating TPU-scope model
Creating Default-scope model


In [26]:
def get_newest_checkpoint(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    paths = [os.path.join(checkpoint_dir, basename) for basename in files]
    return max(paths, key=os.path.getctime)

def import_previous_compatible_checkpoint(model, force_import=False):
    meta = read_model_metadata(suffix=model_suffix)
    if meta is None:
        print("No previous checkpoint found")
        return False
    if is_metadata_compatible(params, meta) is not True and force_import is False:
        print("No useable import found.")
        return False
    try:
        last_checkpoint = get_newest_checkpoint(checkpoint_dir) # Doesn't do anything: tf.train.latest_checkpoint(checkpoint_dir)
    except Exception as e:
        print(f"Cannot determine last checkpoint in {checkpoint_dir}, cannot import due to: {e}")
        return False
    print(f"Last checkpoint: {last_checkpoint}")
    try:
        model.load_weights(last_checkpoint)
    except Exception as e:
        print(f"Failed to import model {last_checkpoint}: {e}")
        return False
    if 'current_epoch' in meta:
        params['current_epoch'] = meta['current_epoch']
    print(f"Successful import of epoch {params['current_epoch']} from {last_checkpoint}, continuing from there...")
    return True

### Loss function, optimizer, tensorboard output

In [27]:
kscc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

def loss(labels, logits):
  vl=kscc(labels, logits)
  return vl

In [47]:
if params['clipvalue'] is not None:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
else:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr)
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr)

if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        model.compile(optimizer=opti, loss=loss, metrics=[], run_eagerly=False, jit_compile=True)
else:
    model.compile(optimizer=opti, loss=loss, metrics=['accuracy'])

In [48]:
import_checkpoint = True
force_import = False   # True: ignore metadata and try import anyway. This will of course crash, if the new model doesn't fit the checkpoint-data...

if import_checkpoint is True:
    import_previous_compatible_checkpoint(model, force_import=force_import)

Last checkpoint: /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/training_checkpoints_8x(4, 4, 4, 4, 4, 4, 4, 4)x(512, 512, 512, 512, 512, 512, 512, 512)x20000/cp-10799.h5
Successful import of epoch 10948 from /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/training_checkpoints_8x(4, 4, 4, 4, 4, 4, 4, 4)x(512, 512, 512, 512, 512, 512, 512, 512)x20000/cp-10799.h5, continuing from there...


In [49]:
model.summary()

Model: "mhsa_v1_tf"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 80)]              0         
                                                                 
 tf.cast (TFOpLambda)        (None, 80)                0         
                                                                 
 embedding (Embedding)       (None, 80, 128)           2560000   
                                                                 
 positional_encoding (Positi  (None, 80, 128)          0         
 onalEncoding)                                                   
                                                                 
 multi_head_self_attention (  (None, 80, 128)          1344000   
 MultiHeadSelfAttention)                                         
                                                                 
 multi_head_self_attention_1  (None, 80, 128)          1

In [50]:
TPU_GENERATE_ON_CPU = False  # The thing is: both options are slow on TPU :-/

class ServiceCallback(keras.callbacks.Callback):
#    def on_test_end(self, logs=None):
    # @tf.function
    def on_epoch_end(self, epoch, logs=None):
        save_model_metadata(epoch, suffix=model_suffix)
        if (epoch+1) % params['sample_every_n_epochs'] == 0:
            idx=random.randint(0,len(td)-1)
            text=td.decode(td[idx])
            print()
            if ml_env.is_tpu is True:
                temp_list=[0.7] # [0.6,0.7,0.8]
                gen_len=50
                with tpu_strategy.scope():
                    weights=model.get_weights()
                model_cpu.set_weights(weights)
                # HDF5 is required for saving weights that originate from TPU
                # otherwise this just silently fails...
                checkpoint_path = os.path.join(checkpoint_dir, "cp-{epoch:04d}.h5")
                chkpt_dest=checkpoint_path.format(epoch=epoch)
                print(f"Checkpoint: {chkpt_dest}")
                model_cpu.save_weights(chkpt_dest)
            else:
                temp_list=[0.6, 0.7, 0.8]
                gen_len=192
            print(f"prompt: {text}")
            for temp in temp_list:
                print(f"---------------- T={temp} ---------------")
                if ml_env.is_tpu is True and TPU_GENERATE_ON_CPU is True:
                    with tf.device('/cpu:0'):
                        if temp==0.0:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                        else:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                else:
                    if temp==0.0:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                    else:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                td.source_highlight(reply, min_quote_size=10, dark_mode=use_dark_mode, display_ref_anchor=False)
            print("--------------------------------------")

service_callback=ServiceCallback()

In [51]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join(log_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
if ml_env.is_tpu:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch', write_graph=False)
else:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch')


In [52]:
# Dont try:
#    # use the python variable log_path:
#   get_ipython().run_line_magic('tensorboard', '--logdir "{log_path}"')
#except:
#   pass

# The following throws errors on non-colab, but the guarding above is too bug-ridden.
# if ml_env.is_tpu is False:
#    %tensorboard --logdir logs

## The actual training

In [53]:
EPOCHS=500000
if 'current_epoch' in params:
    initial_epoch=params['current_epoch']
else:
    initial_epoch=0

override=200
print(f"WARNING override of sample_every_n_epochs sample-generation to: {override}")
params['sample_every_n_epochs']=override



In [54]:
if ml_env.is_tpu is True:
    steps_per_epoch=restricted_batches//params['batch_size']
    if steps_per_epoch < 1:
        steps_per_epoch = 1
    history = model.fit(dataset, epochs=EPOCHS, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, callbacks=[service_callback]) # for TPU we need to role our own checkpointer since we need to transfer the weights
else:
    history = model.fit(dataset, validation_data=validation_dataset, epochs=EPOCHS, initial_epoch=initial_epoch, callbacks=[checkpoint_callback, tensorboard_callback, service_callback])

Epoch 10949/500000
 5/23 [=====>........................] - ETA: 0s - loss: 0.2416



Epoch 10950/500000
Epoch 10951/500000
Epoch 10952/500000
Epoch 10953/500000
Epoch 10954/500000
Epoch 10955/500000
Epoch 10956/500000
Epoch 10957/500000
Epoch 10958/500000
Epoch 10959/500000
Epoch 10960/500000
Epoch 10961/500000
Epoch 10962/500000
Epoch 10963/500000
Epoch 10964/500000
Epoch 10965/500000
Epoch 10966/500000
Epoch 10967/500000
Epoch 10968/500000
Epoch 10969/500000
Epoch 10970/500000
Epoch 10971/500000
Epoch 10972/500000
Epoch 10973/500000
Epoch 10974/500000
Epoch 10975/500000
Epoch 10976/500000
Epoch 10977/500000
Epoch 10978/500000
Epoch 10979/500000
Epoch 10980/500000
Epoch 10981/500000
Epoch 10982/500000
Epoch 10983/500000
Epoch 10984/500000
Epoch 10985/500000
Epoch 10986/500000
Epoch 10987/500000
Epoch 10988/500000
Epoch 10989/500000
Epoch 10990/500000
Epoch 10991/500000
Epoch 10992/500000
Epoch 10993/500000
Epoch 10994/500000
Epoch 10995/500000
Epoch 10996/500000
Epoch 10997/500000
Epoch 10998/500000
Epoch 10999/500000
Epoch 11000/500000
Checkpoint: /content/drive/My D

--------------------------------------
Epoch 11001/500000
Epoch 11002/500000
Epoch 11003/500000
Epoch 11004/500000
Epoch 11005/500000
Epoch 11006/500000
Epoch 11007/500000
Epoch 11008/500000
Epoch 11009/500000
Epoch 11010/500000
Epoch 11011/500000
Epoch 11012/500000
Epoch 11013/500000
Epoch 11014/500000
Epoch 11015/500000
Epoch 11016/500000
Epoch 11017/500000
Epoch 11018/500000
Epoch 11019/500000
Epoch 11020/500000
Epoch 11021/500000
Epoch 11022/500000
Epoch 11023/500000
Epoch 11024/500000
Epoch 11025/500000
Epoch 11026/500000
Epoch 11027/500000
Epoch 11028/500000
Epoch 11029/500000
Epoch 11030/500000
Epoch 11031/500000
Epoch 11032/500000
Epoch 11033/500000
Epoch 11034/500000
Epoch 11035/500000
Epoch 11036/500000
Epoch 11037/500000
Epoch 11038/500000
Epoch 11039/500000
Epoch 11040/500000
Epoch 11041/500000
Epoch 11042/500000
Epoch 11043/500000
Epoch 11044/500000
Epoch 11045/500000
Epoch 11046/500000
Epoch 11047/500000
Epoch 11048/500000
Epoch 11049/500000
Epoch 11050/500000
Epoch 11051

--------------------------------------
Epoch 11201/500000
Epoch 11202/500000
Epoch 11203/500000
Epoch 11204/500000
Epoch 11205/500000
Epoch 11206/500000
Epoch 11207/500000
Epoch 11208/500000
Epoch 11209/500000
Epoch 11210/500000
Epoch 11211/500000
Epoch 11212/500000
Epoch 11213/500000
Epoch 11214/500000
Epoch 11215/500000
Epoch 11216/500000
Epoch 11217/500000
Epoch 11218/500000
Epoch 11219/500000
Epoch 11220/500000
Epoch 11221/500000
Epoch 11222/500000
Epoch 11223/500000
Epoch 11224/500000
Epoch 11225/500000
Epoch 11226/500000
Epoch 11227/500000
Epoch 11228/500000
Epoch 11229/500000
Epoch 11230/500000
Epoch 11231/500000
Epoch 11232/500000
Epoch 11233/500000
Epoch 11234/500000
Epoch 11235/500000
Epoch 11236/500000
Epoch 11237/500000
Epoch 11238/500000
Epoch 11239/500000
Epoch 11240/500000
Epoch 11241/500000
Epoch 11242/500000
Epoch 11243/500000
Epoch 11244/500000
Epoch 11245/500000
Epoch 11246/500000
Epoch 11247/500000
Epoch 11248/500000
Epoch 11249/500000
Epoch 11250/500000
Epoch 11251

--------------------------------------
Epoch 11401/500000
Epoch 11402/500000
Epoch 11403/500000
Epoch 11404/500000
Epoch 11405/500000
Epoch 11406/500000
Epoch 11407/500000
Epoch 11408/500000
Epoch 11409/500000
Epoch 11410/500000
Epoch 11411/500000
Epoch 11412/500000
Epoch 11413/500000
Epoch 11414/500000
Epoch 11415/500000
Epoch 11416/500000
Epoch 11417/500000
Epoch 11418/500000
Epoch 11419/500000
Epoch 11420/500000
Epoch 11421/500000
Epoch 11422/500000
Epoch 11423/500000
Epoch 11424/500000
Epoch 11425/500000
Epoch 11426/500000
Epoch 11427/500000
Epoch 11428/500000
Epoch 11429/500000
Epoch 11430/500000
Epoch 11431/500000
Epoch 11432/500000
Epoch 11433/500000
Epoch 11434/500000
Epoch 11435/500000
Epoch 11436/500000
Epoch 11437/500000
Epoch 11438/500000
Epoch 11439/500000
Epoch 11440/500000
Epoch 11441/500000
Epoch 11442/500000
Epoch 11443/500000
Epoch 11444/500000
Epoch 11445/500000
Epoch 11446/500000
Epoch 11447/500000
Epoch 11448/500000
Epoch 11449/500000
Epoch 11450/500000
Epoch 11451

--------------------------------------
Epoch 11601/500000
Epoch 11602/500000
Epoch 11603/500000
Epoch 11604/500000
Epoch 11605/500000
Epoch 11606/500000
Epoch 11607/500000
Epoch 11608/500000
Epoch 11609/500000
Epoch 11610/500000
Epoch 11611/500000
Epoch 11612/500000
Epoch 11613/500000
Epoch 11614/500000
Epoch 11615/500000
Epoch 11616/500000
Epoch 11617/500000
Epoch 11618/500000
Epoch 11619/500000
Epoch 11620/500000
Epoch 11621/500000
Epoch 11622/500000
Epoch 11623/500000
Epoch 11624/500000
Epoch 11625/500000
Epoch 11626/500000
Epoch 11627/500000
Epoch 11628/500000
Epoch 11629/500000
Epoch 11630/500000
Epoch 11631/500000
Epoch 11632/500000
Epoch 11633/500000
Epoch 11634/500000
Epoch 11635/500000
Epoch 11636/500000
Epoch 11637/500000
Epoch 11638/500000
Epoch 11639/500000
Epoch 11640/500000
Epoch 11641/500000
Epoch 11642/500000
Epoch 11643/500000
Epoch 11644/500000
Epoch 11645/500000
Epoch 11646/500000
Epoch 11647/500000
Epoch 11648/500000
Epoch 11649/500000
Epoch 11650/500000
Epoch 11651

--------------------------------------
Epoch 11801/500000
Epoch 11802/500000
Epoch 11803/500000
Epoch 11804/500000
Epoch 11805/500000
Epoch 11806/500000
Epoch 11807/500000
Epoch 11808/500000
Epoch 11809/500000
Epoch 11810/500000
Epoch 11811/500000
Epoch 11812/500000
Epoch 11813/500000
Epoch 11814/500000
Epoch 11815/500000
Epoch 11816/500000
Epoch 11817/500000
Epoch 11818/500000
Epoch 11819/500000
Epoch 11820/500000
Epoch 11821/500000
Epoch 11822/500000
Epoch 11823/500000
Epoch 11824/500000
Epoch 11825/500000
Epoch 11826/500000
Epoch 11827/500000
Epoch 11828/500000
Epoch 11829/500000
Epoch 11830/500000
Epoch 11831/500000
Epoch 11832/500000
Epoch 11833/500000
Epoch 11834/500000
Epoch 11835/500000
Epoch 11836/500000
Epoch 11837/500000
Epoch 11838/500000
Epoch 11839/500000
Epoch 11840/500000
Epoch 11841/500000
Epoch 11842/500000
Epoch 11843/500000
Epoch 11844/500000
Epoch 11845/500000
Epoch 11846/500000
Epoch 11847/500000
Epoch 11848/500000
Epoch 11849/500000
Epoch 11850/500000
Epoch 11851

--------------------------------------
Epoch 12001/500000
Epoch 12002/500000
Epoch 12003/500000
Epoch 12004/500000
Epoch 12005/500000
Epoch 12006/500000
Epoch 12007/500000
Epoch 12008/500000
Epoch 12009/500000
Epoch 12010/500000
Epoch 12011/500000
Epoch 12012/500000
Epoch 12013/500000
Epoch 12014/500000
Epoch 12015/500000
Epoch 12016/500000
Epoch 12017/500000
Epoch 12018/500000
Epoch 12019/500000
Epoch 12020/500000
Epoch 12021/500000
Epoch 12022/500000
Epoch 12023/500000
Epoch 12024/500000
Epoch 12025/500000
Epoch 12026/500000
Epoch 12027/500000
Epoch 12028/500000
Epoch 12029/500000
Epoch 12030/500000
Epoch 12031/500000
Epoch 12032/500000
Epoch 12033/500000
Epoch 12034/500000
Epoch 12035/500000
Epoch 12036/500000
Epoch 12037/500000
Epoch 12038/500000
Epoch 12039/500000
Epoch 12040/500000
Epoch 12041/500000
Epoch 12042/500000
Epoch 12043/500000
Epoch 12044/500000
Epoch 12045/500000
Epoch 12046/500000
Epoch 12047/500000
Epoch 12048/500000
Epoch 12049/500000
Epoch 12050/500000
Epoch 12051

--------------------------------------
Epoch 12201/500000
Epoch 12202/500000
Epoch 12203/500000
Epoch 12204/500000
Epoch 12205/500000
Epoch 12206/500000
Epoch 12207/500000
Epoch 12208/500000
Epoch 12209/500000
Epoch 12210/500000
Epoch 12211/500000
Epoch 12212/500000
Epoch 12213/500000
Epoch 12214/500000
Epoch 12215/500000
Epoch 12216/500000
Epoch 12217/500000
Epoch 12218/500000
Epoch 12219/500000
Epoch 12220/500000
Epoch 12221/500000
Epoch 12222/500000
Epoch 12223/500000
Epoch 12224/500000
Epoch 12225/500000
Epoch 12226/500000
Epoch 12227/500000
Epoch 12228/500000
Epoch 12229/500000
Epoch 12230/500000
Epoch 12231/500000
Epoch 12232/500000
Epoch 12233/500000
Epoch 12234/500000
Epoch 12235/500000
Epoch 12236/500000
Epoch 12237/500000
Epoch 12238/500000
Epoch 12239/500000
Epoch 12240/500000
Epoch 12241/500000
Epoch 12242/500000
Epoch 12243/500000
Epoch 12244/500000
Epoch 12245/500000
Epoch 12246/500000
Epoch 12247/500000
Epoch 12248/500000
Epoch 12249/500000
Epoch 12250/500000
Epoch 12251

KeyboardInterrupt: ignored

## A dialog with the trained model

In [55]:
model_cpu.set_weights(model.get_weights())

In [60]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):

def doDialog(model):
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 128 # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
    doini = True
    bye = False
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        if prompt[:len("temperature=")] == "temperature=":
            t = float(prompt[len("temperature="):])
            if t > 0.05 and t < 1.4:
                temperature = t
                print("(generator temperature now {})".format(t))
                print()
                continue
            print("Invalid temperature-value ignored! [0.1-1.0]")
            continue
        reply=mhsa_generate(model, prompt, gen_len=256, temperature=temperature, verbose=True)
        td.source_highlight(reply, min_quote_size=13, dark_mode=use_dark_mode)

In [61]:
# Talk to the net!
doDialog(model_cpu)

Please enter some dialog.
The net will answer according to your input.
'bye' for end,
'reset' to reset the conversation context,
'temperature=<float>' [0.1(frozen)-1.0(creative)]
    to change character of the dialog.
    Current temperature=0.6.

> temperature=1.2
(generator temperature now 1.2)

> Happy day! Glorious! Wonderful sunshine, luck abound!
Ellen, I convinced
Catherines, an ever sudden drawing.”

In ble turned off, is in the country; and
where rown time to look to the man who,
trank was cle-solges,
that se of low el, and spoke, to all
 know, was not a very dull, no
he saw By you every use, which excite attentivers the famor, which was,
but Harley setting his knight:

“This witters came merce Mrs. Puched
weth their rooms, so commat he
should that question is at the tree
prospectantat her doubt, been an engaged of Servant any the door of the
arches, as one Waterlood out over the difference I wanted on, and 


"You think particularly I do not want to have hin, your speculby
is

> temperature=0.8
(generator temperature now 0.8)

> Happy day, she felt so lucky to finally meet him.

He just with the with unexpectedly and was sohow
that an answer, that looked in his nonsense,
thought herself so decided among themselves
nothing, and what banklace minute
and commencing, there wasted to her young of
every roof; and she seasonably result of
last night why their books that
more delicitous, disable remone
much for Sir, on being equal to her him. Emma was not sorry for him. He cures, makel to
living; and as nobody anything on
his eye, I know how is
perfectly too very greatest indeed
being lay for vacity as Mrs. Briting
and been d
eepasy, theghts again, of cautions, while every cargance; but I could do doubt Hern every other wretched something else
seemed to that, and then I was not the same
windows to Miss Allening out in Lincontalims
Bimately he would meet a new
like of liferk knife in
the d-post, besides, their engagements,--he small, interested wild women all Don;
th

> temperature=0.7
(generator temperature now 0.7)

> And on and on she went, it was a very long voyage into uncertainty.
 English-grass realial request
which he had the desire to keep them all unssion for
my temper, I mean that her termen
the little ity of the sea-blood bars and words with a emotion. Susan had not toonxious knowledge of them, to understanding,
“Lady Seal’s sense
rather of one gentleman opportunity of the firstzed these reality else, as she down to
disappoint without any glove
without his daughter that Perrys,” said Lady Susan.

“Oh! I must tell you one can be
so expressed his serious. His great were
so possibility of needed that
E
dmund asked her to recollecting your visit.
He saw it garden—I drew up in her eyes
to have had her opposite surprised speediment and had become shape of
consolation, why had now should not have display an opportunity of being together.

She called 
her court opportunity of speaking, her feet
part of an admiration; and by every thing’s
sity, a

> bye
Good bye!
