<a href="https://colab.research.google.com/github/domschl/transformer-poet/blob/main/transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [None]:
!pip install -U ml-indie-tools

In [1]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, regularizers

import tensorflow_datasets as tfds

In [2]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.keras_custom_layers import MultiHeadSelfAttention, PositionalEncoding

Using TF-Keras version: 2.9.0


## Preliminary

A tensorflow deep multi-head attention model for text generation

This code can use either CPU, GPU, TPU when running on Google Colab.
Select the corresponding runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [3]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='tf', accelerator='fastest')
ml_env.describe()

'OS: Darwin, Python: 3.10.5 (Conda), Jupyter Notebook Tensorflow: 2.9.1, GPU: METAL'

In [4]:
use_eager=tf.executing_eagerly()
if ml_env.is_tpu is True:
    tpu_strategy = ml_env.tpu_strategy
    tpu_is_init=True
    if use_eager is True:
        tf.config.run_functions_eagerly(False)
    use_eager=False

In [5]:
project_name='women_writers'
model_name='mhsa_v1_tf'

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : . (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : . (Changes to the file system happen only below this project path
Model path (snapshots)   : ./model/mhsa_v1_tf (Model weights and snapshots are stored here)
Data path (training data): ./data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [6]:
use_dark_mode=True # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [7]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [8]:
# sample searches
search_spec= {"author": ["Emily Brontë","Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

21 matching books found with search {'author': ['Emily Brontë', 'Jane Austen', 'Virginia Woolf'], 'language': ['english']}.


In [9]:
for i in range(len(book_list)):
    print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

0: The Common Reader - Virginia Woolf, 64457
1: Mr. Bennett and Mrs. Brown - Virginia Woolf, 63022
2: The Younger Sister, Volumes 1-3 - Catherine Anne Austen Hubback and Jane Austen, 54066
3: The Younger Sister, Vol. 3 - Catherine Anne Austen Hubback and Jane Austen, 54012
4: The Younger Sister, Vol. 2 - Catherine Anne Austen Hubback and Jane Austen, 54011
5: The Younger Sister, Vol. 1 - Catherine Anne Austen Hubback and Jane Austen, 54010
6: Pride and Prejudice - Jane Austen, 42671
7: The Letters of Jane Austen - Jane Austen, 42078
8: The Complete Project Gutenberg Works of Jane Austen - Jane Austen, 31100
9: Jacob's Room - Virginia Woolf, 5670
10: Pride and Prejudice - Jane Austen, 1342
11: Night and Day - Virginia Woolf, 1245
12: Love And Friendship And Other Early Works - Jane Austen, 1212
13: Lady Susan - Jane Austen, 946
14: Wuthering Heights - Emily Brontë, 768
15: Sense and Sensibility - Jane Austen, 161
16: Emma - Jane Austen, 158
17: The Voyage Out - Virginia Woolf, 144
18: M

In [10]:
select = ("Bennett", "1342", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]

print("Using:")
for i in range(len(sub_book_list)):
    print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

textlib_dataset = None  # Forces re-caching
td = Text_Dataset(sub_book_list)
td.init_tokenizer(tokenizer='ngram', max_ngrams=6, max_tokens=5000)


INFO:Datasets:Loaded 6 texts
INFO:Datasets:Extracting ngrams of length 1..6 from text_list, selecting 5000 most used ngrams.


Using:
1: Mr. Bennett and Mrs. Brown - Virginia Woolf
2: Pride and Prejudice - Jane Austen
3: Lady Susan - Jane Austen
4: Wuthering Heights - Emily Brontë
5: Emma - Jane Austen
6: The Voyage Out - Virginia Woolf


INFO:Datasets:Encoding text corpora as ngrams.
INFO:Datasets:Encoding text Mr. Bennett and Mrs. Brown...
INFO:Datasets:Encoding text Pride and Prejudice...
INFO:Datasets:Encoding text Lady Susan...
INFO:Datasets:Encoding text Wuthering Heights...
INFO:Datasets:Encoding text Emma...
INFO:Datasets:Encoding text The Voyage Out...
INFO:Datasets:Encoding text corpora as ngrams done.


In [11]:
SEQUENCE_LEN = 80
SUB_PROBABILITY = 0.15  # like BERT

td.init_getitem(sample_type='encoded', sample_length=SEQUENCE_LEN, content_stepping=1)

num_records = len(td)

print(f"{num_records} records")

917005 records


In [12]:
def get_sample_batch(td, batch_size, length, SUB_probability=0.15):
    for i in range(batch_size):
        Xi = td.get_random_item()
        yi = Xi.copy()
        l=int(len(Xi)*SUB_probability)
        for li in range(l):
            pos=random.randint(0,len(Xi)-1)
            if td.tokenizer_type=='char':
                Xi[pos]=td.c2i['␚']
            elif td.tokenizer_type=='word':
                Xi[pos]=td.w2i['<subst>']
            elif td.tokenizer_type=='ngram':
                Xi[pos]=td.t2i['<subst>']
            else:
                print(f"Unexpected tokenizer_type {td.tokenizer_type}")
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

In [13]:
test_x, test_y = get_sample_batch(td, 2, 40, SUB_probability=SUB_PROBABILITY)
for i in range(len(test_x)):
    xi=[int(x) for x in test_x[i]]
    print(f"[{i}](l={len(xi)}): X=>{td.decode(xi)}<, y=>{td.decode(test_y[i])}<")

[0](l=80): X=>, <subst>lately se<subst>d there, and
encumbered with many low conne<subst>ions, but giv<subst>em<subst>ves<subst>mense
airs, and expecting to be <subst> footing with the old established<subst>am<subst>ies. A year <subst><subst>lf is the very utmost that they can have lived
at West Hall; and how they got the<subst>fortune no<, y=>, very lately settled there, and
encumbered with many low connexions, but giving themselves immense
airs, and expecting to be on a footing with the old established
families. A year and a half is the very utmost that they can have lived
at West Hall; and how they got their fortune no<
[1](l=80): X=>wards reas<subst>believe that the beginning <subst>already made, and <subst>not but hope that the gip<subst>y, though she had
_told_ no fortune, might be pro<subst>to have ma<subst><subst>t’s.—About a
fortnight after the a<subst>m, they <subst>to a sufficient explanation, and
quite undesig<subst>ly. Emma was no<subst>nk<, y=>wards reason to believe that

In [14]:
test_x.shape, test_y.shape

((2, 80), (2, 80))

## 2. Use tf.data for texts

In [15]:
def expand_name_template(template, params):
    exp=copy.copy(template)
    for key in params:
        src="{"+key+"}"
        dst=f"{params[key]}"
        exp=exp.replace(src,dst).replace('[','(').replace(']',')')
    return exp

def save_model_metadata(epoch, suffix='std'):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    params['current_epoch'] = epoch
    try:
        with open(meta_file, 'w') as f:
            f.write(json.dumps(params))
    except Exception as e:
        print(f"Failed to store model metadata at {model_path}: {e}")
        return False
    return True

def read_model_metadata(suffix="std"):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    try:
        with open(meta_file, 'r') as f:
            meta = json.load(f)
    except Exception as e:
        print(f"Cannot access project meta-data at {meta_file}: {e}, starting anew.")
        return None
    return meta

def is_metadata_compatible(params, meta):
    is_valid=True
    keys=set(list(params.keys())+list(meta.keys()))
    for key in keys:
        if key in updatable_keys:
            continue
        if key not in meta:
            print(f"Key {key} not available in last checkpoint model_meta, params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif key not in params:
            print(f"Key {key} not available in params, last checkpoint model_meta[{key}]: {meta[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif meta[key]!=params[key]:
            print(f"Last checkpoint model_meta[{key}]: {meta[key]} != params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
    if is_valid is False:
        print("Aborting import.")
        return False
    return True

In [57]:
vocabulary_size = td.get_unique_token_count()  # vocabulary-size

lyrs = 2;

params = { # Multi-head self-attention
    'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

    'mhsa_layers': lyrs, 
    'heads': [1]*lyrs,
    'units': [128]*lyrs,  # 0 inserts an LSTM for memory-states :-)
    'norm': 'softmax', # this is for within each head
    'mh_normalize': True,  # use layer-norm after concatenation (or additiona) of the heads
    'l2_regularizer': 1e-9,
    'dropout': 0.0,       # no dropout: 0.0
    'join_heads_by_add': False,  # stragegy how multi-heads are joined: False: concat (as in all-you-need), True: relu+add of all the heads
    'vocab_size': vocabulary_size,
    'sequence_len': SEQUENCE_LEN,

    'batch_size': 768,
    'learning_rate': 0.0005,
    'clipvalue': None,
    'sample_every_n_epochs': 2,
}

if len(params['heads'])!=params['mhsa_layers'] or len(params['units'])!=params['mhsa_layers']:
    print("ERROR: lenght of 'heads' and 'units' must be equal to mhsa_layers!")
    
if ml_env.is_tpu is True:
    lr = params['learning_rate']*1.0
else:
    lr = params['learning_rate']

model_suffix = expand_name_template(params['name'], params)
# Put 'important' params in checkpoint-pathname to separate model-data:
checkpoint_dir = os.path.join(model_path, f"training_checkpoints_{model_suffix}")
if os.path.exists(checkpoint_dir) is False:
    os.makedirs(checkpoint_dir)

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'dropout', 
             'sample_every_n_epochs']

# These values are taking from saved checkpoint:
keep_keys=['current_epoch']

continue_last = True

meta = read_model_metadata(suffix=model_suffix)
if meta is not None and is_metadata_compatible(params, meta) is True and continue_last is True:
    for key in keep_keys:
        if key in meta:
            params[key]=meta[key]
    if params is not None:
        print(f"Continuing last session from epoch {params['current_epoch']}")
    else:
        print(f"No previous data, starting new model")
else:
    print("Starting new model")

print(params)

Cannot access project meta-data at ./model/mhsa_v1_tf/model_meta_2x(1, 1)x(128, 128)x5000.json: [Errno 2] No such file or directory: './model/mhsa_v1_tf/model_meta_2x(1, 1)x(128, 128)x5000.json', starting anew.
Starting new model
{'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}', 'mhsa_layers': 2, 'heads': [1, 1], 'units': [128, 128], 'norm': 'softmax', 'mh_normalize': True, 'l2_regularizer': 1e-09, 'dropout': 0.0, 'join_heads_by_add': False, 'vocab_size': 5000, 'sequence_len': 80, 'batch_size': 768, 'learning_rate': 0.0005, 'clipvalue': None, 'sample_every_n_epochs': 2}


In [58]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 1194


In [59]:
# @tf.function   (only slows things down [considerably!])
def make_tf_dataset(num, random_index=False, SUB_probability=0.0):
    dx=[]
    dy=[]
    num_batches_active = num
    for i in range(num_batches_active):
        x,y=get_sample_batch(td, params['batch_size'], params['sequence_len'], SUB_probability=SUB_probability)
        if i<1:
            print(f"[{num} x]: {x.shape} -> {y.shape}")
        dx.append(x)
        dy.append(y)
    dx=np.array(dx)
    dy=np.array(dy)
    print(f"dx.shape={dx.shape}, dy.shape={dy.shape}")
    data_xy = (dx, dy)
    tf_dataset=tf.data.Dataset.from_tensor_slices(data_xy)
    return tf_dataset

In [60]:
MAX_NUM_BATCHES = 50000

if num_batches>MAX_NUM_BATCHES:
    restricted_batches=MAX_NUM_BATCHES
    print(f"Restrictinig {num_batches} to max of {restricted_batches}")
else:
    restricted_batches=num_batches
    print(f"{restricted_batches} batches")
if cached_batch_data == restricted_batches and textlib_dataset is not None:
    print("Reusing cached training-data")
else:
    print("Creating dataset, this is slow. Be patient...")
    textlib_dataset = make_tf_dataset(restricted_batches, SUB_probability=SUB_PROBABILITY)
    cached_batch_data = restricted_batches
    print("Dataset done and cached.")

1194 batches
Creating dataset, this is slow. Be patient...
[1194 x]: (768, 80) -> (768, 80)
dx.shape=(1194, 768, 80), dy.shape=(1194, 768, 80)
Dataset done and cached.


In [61]:
shuffle_buffer=10000
if ml_env.is_tpu is True:
    dataset=textlib_dataset.shuffle(shuffle_buffer).repeat()  # Otherwise TPU may run dry
else:
    dataset=textlib_dataset.shuffle(shuffle_buffer)  
dataset.take(1)

<TakeDataset element_spec=(TensorSpec(shape=(768, 80), dtype=tf.int32, name=None), TensorSpec(shape=(768, 80), dtype=tf.int32, name=None))>

In [62]:
if ml_env.is_tpu is False:
    validation_dataset = make_tf_dataset(10, random_index=True, SUB_probability=SUB_PROBABILITY)

[10 x]: (768, 80) -> (768, 80)
dx.shape=(10, 768, 80), dy.shape=(10, 768, 80)


In [63]:
def model_mhsa(inputs, params):
    dense = layers.Dense(params['vocab_size'], kernel_regularizer=regularizers.l2(params['l2_regularizer']))  # using softmax here prevents temperature adjust, affects 'from_logits' param in sparse_categorical loss 
    fl = layers.Flatten()
    dr = layers.Dropout(params['dropout'])
    pe = PositionalEncoding(amplitude=0.3)
    rs_up = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    if 0 in params['units']:
        lstm1 = layers.LSTM(units=vocabulary_size, return_sequences=True)
    if vocabulary_size>=300:
        emb=layers.Embedding(vocabulary_size,64,input_length=SEQUENCE_LEN)
    rs_down = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    mhsa=[]
    residuals=[]

    for i in range(params['mhsa_layers']):
        if params['units'][i]==0:
            mhsa.append(None)
            residuals.append(i)
        else:
            mhsa.append(MultiHeadSelfAttention(params['heads'][i], units=params['units'][i], norm=params['norm'], mh_normalize=params['mh_normalize'], join_heads_by_add=params['join_heads_by_add']))
    xint = tf.cast(inputs,dtype=tf.int32)
    if vocabulary_size<300:
        x = tf.one_hot(xint, params['vocab_size'], axis=-1)
    else:
        x = emb(xint)
    x = pe(x)
    for i in range(len(mhsa)):
        if i in residuals:
            x = rs_down(lstm1(rs_up(x)))+x
            print(f"Residual at layer {i} added.")
        else:
            x = mhsa[i](x)
        # x = mhsa[i](x,x)
    if params['dropout']>0.0:
        x = dr(x)
    # x = dense(fl(x))
    x = dense(x)
    return x 

In [64]:
def mhsa_generate(model, text, gen_len=64, temperature=0.9, argmax=False, verbose=False):
    if verbose is True:
        full=text[:-1]
    gen_text=""
    lf=0
    input = np.array(td.encode(text))
    while len(input) < params['sequence_len']:
        input = np.concatenate([td.encode('<pad>'),input])
    for i in range(gen_len):
        input = np.concatenate([input[1:],td.encode('<subst>')])
        if len(input)!=params['sequence_len']:
            print('assertion failure')
            return None
        pred = model(input)
        pred /= temperature
        pred = tf.keras.layers.Softmax()(pred)
        if tf.executing_eagerly() is True and ml_env.is_tpu is False:
            pred=pred.numpy()
        else:
            pred=tf.keras.backend.eval(pred)  # this is a cheat, it internaly used Numpy() too.
        if argmax is True:
            pred=np.argmax(pred[0],axis=1)
        else:
            pred = [np.random.choice(list(range(len(pred[0][-1]))), p=pred[0][-1])]
        input = np.concatenate([input[1:],[pred[-1]]])
        c = td.decode([pred[-1]])
        if verbose is True:
            print(c, end='')
            if c=='\n':
                lf=0
            else:
                lf += 1
                if (lf>80 and c==' ') or lf>120:
                    print()
                    lf=0
            full+=c
        gen_text+=c
    if verbose is True:
        print()
    return gen_text


In [65]:
if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        print("Creating TPU-scope model")
        inputs = keras.Input(shape=(params['sequence_len'],))
        outputs = model_mhsa(inputs, params)
        model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    print("Creating Default-scope model")
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model_cpu = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
else:
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    model_cpu = model

In [66]:
def get_newest_checkpoint(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    paths = [os.path.join(checkpoint_dir, basename) for basename in files]
    return max(paths, key=os.path.getctime)

def import_previous_compatible_checkpoint(model, force_import=False):
    meta = read_model_metadata(suffix=model_suffix)
    if meta is None:
        print("No previous checkpoint found")
        return False
    if is_metadata_compatible(params, meta) is not True and force_import is False:
        print("No useable import found.")
        return False
    try:
        last_checkpoint = get_newest_checkpoint(checkpoint_dir) # Doesn't do anything: tf.train.latest_checkpoint(checkpoint_dir)
    except Exception as e:
        print(f"Cannot determine last checkpoint in {checkpoint_dir}, cannot import due to: {e}")
        return False
    print(f"Last checkpoint: {last_checkpoint}")
    try:
        model.load_weights(last_checkpoint)
    except Exception as e:
        print(f"Failed to import model {last_checkpoint}: {e}")
        return False
    if 'current_epoch' in meta:
        params['current_epoch'] = meta['current_epoch']
    print(f"Successful import of epoch {params['current_epoch']} from {last_checkpoint}, continuing from there...")
    return True

### Loss function, optimizer, tensorboard output

In [67]:
kscc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

def loss(labels, logits):
  vl=kscc(labels, logits)
  return vl

In [68]:
if params['clipvalue'] is not None:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
else:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr)
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr)

if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        model.compile(optimizer=opti, loss=loss, metrics=[], run_eagerly=False, jit_compile=True)
else:
    model.compile(optimizer=opti, loss=loss, metrics=['accuracy'])

In [69]:
import_checkpoint = True
force_import = False   # True: ignore metadata and try import anyway. This will of course crash, if the new model doesn't fit the checkpoint-data...

if import_checkpoint is True:
    import_previous_compatible_checkpoint(model, force_import=force_import)

Cannot access project meta-data at ./model/mhsa_v1_tf/model_meta_2x(1, 1)x(128, 128)x5000.json: [Errno 2] No such file or directory: './model/mhsa_v1_tf/model_meta_2x(1, 1)x(128, 128)x5000.json', starting anew.
No previous checkpoint found


In [70]:
model.summary()

Model: "mhsa_v1_tf"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 80)]              0         
                                                                 
 tf.cast_2 (TFOpLambda)      (None, 80)                0         
                                                                 
 embedding_3 (Embedding)     (None, 80, 64)            320000    
                                                                 
 positional_encoding_3 (Posi  (None, 80, 64)           0         
 tionalEncoding)                                                 
                                                                 
 multi_head_self_attention_6  (None, 80, 64)           41216     
  (MultiHeadSelfAttention)                                       
                                                                 
 multi_head_self_attention_7  (None, 80, 64)           4

In [71]:
TPU_GENERATE_ON_CPU = False  # The thing is: both options are slow on TPU :-/

class ServiceCallback(keras.callbacks.Callback):
#    def on_test_end(self, logs=None):
    # @tf.function
    def on_epoch_end(self, epoch, logs=None):
        save_model_metadata(epoch, suffix=model_suffix)
        if (epoch+1) % params['sample_every_n_epochs'] == 0:
            idx=random.randint(0,len(td)-1)
            text=td.decode(td[idx])
            print()
            if ml_env.is_tpu is True:
                temp_list=[0.6,0.7,0.8,0.0]
                gen_len=64
                with tpu_strategy.scope():
                    weights=model.get_weights()
                model_cpu.set_weights(weights)
                # HDF5 is required for saving weights that originate from TPU
                # otherwise this just silently fails...
                checkpoint_path = os.path.join(checkpoint_dir, "cp-{epoch:04d}.h5")
                chkpt_dest=checkpoint_path.format(epoch=epoch)
                print(f"Checkpoint: {chkpt_dest}")
                model_cpu.save_weights(chkpt_dest)
            else:
                temp_list=[0.5, 0.7, 0.9]
                gen_len=192
            print(f"prompt: {text}")
            for temp in temp_list:
                print(f"---------------- T={temp} ---------------")
                if ml_env.is_tpu is True and TPU_GENERATE_ON_CPU is True:
                    with tf.device('/cpu:0'):
                        if temp==0.0:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                        else:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                else:
                    if temp==0.0:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                    else:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                td.source_highlight(reply, min_quote_size=10, dark_mode=use_dark_mode, display_ref_anchor=False)
            print("--------------------------------------")

service_callback=ServiceCallback()

In [72]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join(log_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
if ml_env.is_tpu:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch', write_graph=False)
else:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch')


In [73]:
# Dont try:
#    # use the python variable log_path:
#   get_ipython().run_line_magic('tensorboard', '--logdir "{log_path}"')
#except:
#   pass

# The following throws errors on non-colab, but the guarding above is too bug-ridden.
if ml_env.is_tpu is False:
    %tensorboard --logdir logs

UsageError: Line magic function `%tensorboard` not found.


## The actual training

In [74]:
EPOCHS=50000
if 'current_epoch' in params:
    initial_epoch=params['current_epoch']
else:
    initial_epoch=0

In [None]:
if ml_env.is_tpu is True:
    steps_per_epoch=restricted_batches//params['batch_size']
    if steps_per_epoch < 1:
        steps_per_epoch = 1
    history = model.fit(dataset, epochs=EPOCHS, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, callbacks=[service_callback]) # for TPU we need to role our own checkpointer since we need to transfer the weights
else:
    history = model.fit(dataset, validation_data=validation_dataset, epochs=EPOCHS, initial_epoch=initial_epoch, callbacks=[checkpoint_callback, tensorboard_callback, service_callback])

Epoch 1/50000


2022-06-20 09:55:38.771795: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


  32/1194 [..............................] - ETA: 15:07 - loss: 8.3783 - accuracy: 0.0021

## A dialog with the trained model

In [None]:
model_cpu.set_weights(model.get_weights())

In [None]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):

def doDialog(model):
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
    doini = True
    bye = False
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        if prompt[:len("temperature=")] == "temperature=":
            t = float(prompt[len("temperature="):])
            if t > 0.05 and t < 1.4:
                temperature = t
                print("(generator temperature now {})".format(t))
                print()
                continue
            print("Invalid temperature-value ignored! [0.1-1.0]")
            continue
        reply=mhsa_generate(model, prompt, gen_len=256, temperature=temperature, verbose=True)
        td.source_highlight(reply, min_quote_size=13, dark_mode=use_dark_mode)

In [None]:
# Talk to the net!
doDialog(model_cpu)