<a href="https://colab.research.google.com/github/domschl/transformer-poet/blob/main/transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
!pip install -U ml-indie-tools

Collecting ml-indie-tools
  Downloading ml_indie_tools-0.1.2-py3-none-any.whl (33 kB)
Installing collected packages: ml-indie-tools
Successfully installed ml-indie-tools-0.1.2


In [2]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, regularizers

import tensorflow_datasets as tfds

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.keras_custom_layers import MultiHeadSelfAttention, PositionalEncoding

## Preliminary

A tensorflow deep multi-head attention model for text generation

This code can use either CPU, GPU, TPU when running on Google Colab.
Select the corresponding runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [4]:
ml_env = MLEnv(platform='tf', accelerator='fastest')
ml_env.describe()

'OS: Linux, Python: 3.7.12, Colab Jupyter Notebook Tensorflow: 2.8.0, GPU: Tesla T4 (3MiB / 15109MiB)'

In [5]:
if ml_env.is_tpu is True:
    tpu_strategy = ml_env.tpu_strategy
    tpu_is_init=True
    use_eager=False
else:
    use_eager=True

In [6]:
project_name='women_writers'
model_name='mhsa_v1_tf'

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Mounted at /content/drive
Root path (all projects) : /content/drive/My Drive (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : /content/drive/My Drive/Colab Notebooks/women_writers (Changes to the file system happen only below this project path
Model path (snapshots)   : /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf (Model weights and snapshots are stored here)
Data path (training data): /content/drive/My Drive/Colab Notebooks/women_writers/data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [7]:
use_dark_mode=True # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [8]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [9]:
# sample searches
search_spec= {"author": ["Emily Brontë","Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

21 matching books found with search {'author': ['Emily Brontë', 'Jane Austen', 'Virginia Woolf'], 'language': ['english']}.


In [10]:
for i in range(len(book_list)):
    print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

0: The Common Reader - Virginia Woolf, 64457
1: Mr. Bennett and Mrs. Brown - Virginia Woolf, 63022
2: The Younger Sister, Volumes 1-3 - Catherine Anne Austen Hubback and Jane Austen, 54066
3: The Younger Sister, Vol. 3 - Catherine Anne Austen Hubback and Jane Austen, 54012
4: The Younger Sister, Vol. 2 - Catherine Anne Austen Hubback and Jane Austen, 54011
5: The Younger Sister, Vol. 1 - Catherine Anne Austen Hubback and Jane Austen, 54010
6: Pride and Prejudice - Jane Austen, 42671
7: The Letters of Jane Austen - Jane Austen, 42078
8: The Complete Project Gutenberg Works of Jane Austen - Jane Austen, 31100
9: Jacob's Room - Virginia Woolf, 5670
10: Pride and Prejudice - Jane Austen, 1342
11: Night and Day - Virginia Woolf, 1245
12: Love And Friendship And Other Early Works - Jane Austen, 1212
13: Lady Susan - Jane Austen, 946
14: Wuthering Heights - Emily Brontë, 768
15: Sense and Sensibility - Jane Austen, 161
16: Emma - Jane Austen, 158
17: The Voyage Out - Virginia Woolf, 144
18: M

In [11]:
select = (14,1,6,13,16,17)
sub_book_list = [book_list[i] for i in range(len(book_list)) if i in select]

print("Using:")
for i in range(len(sub_book_list)):
    print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

td = Text_Dataset(sub_book_list)
td.init_tokenizer(tokenizer='char')

INFO:Datasets:Loaded 6 texts


Using:
1: Mr. Bennett and Mrs. Brown - Virginia Woolf
2: Pride and Prejudice - Jane Austen
3: Lady Susan - Jane Austen
4: Wuthering Heights - Emily Brontë
5: Emma - Jane Austen
6: The Voyage Out - Virginia Woolf


In [12]:
SEQUENCE_LEN = 96
SUB_PROBABILITY = 0.15  # like BERT

td.init_getitem(sample_type='chargen_single_encoded', sample_length=SEQUENCE_LEN, content_stepping=1)

num_records = len(td)

print(f"{num_records} records")

def get_sample_batch(td, batch_size, length, random_index=True, SUB_probability=0.0):
    for i in range(batch_size):
        if random_index is True:
            ind = random.randint(0, num_records-1)
        else:
            ind = i * td.getitem_content_stepping
        Xi = td[ind]
        yi = [Xi[-1]]
        if SUB_probability==0.0:
            Xi[-1]=td.c2i['␚']  # use 'SUB'-stitut glyph to mark last char of input
        else:
            l=int(len(Xi)*SUB_probability)
            for li in range(l):
                pos=random.randint(0,len(Xi)-1)
                Xi[pos]=td.c2i['␚']
        if i==0:
            smpX=np.array(Xi, dtype=np.float32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

def get_random_onehot_sample_batch(td, batch_size, length):
    X, y = get_random_sample_batch(td, batch_size, length)
    xoh = tf.keras.backend.one_hot(X, len(td.i2c))
    yk = tf.keras.backend.constant(y)
    return xoh, yk

3153130 records


In [13]:
test_x, test_y = get_sample_batch(td, 2, 40, random_index=True, SUB_probability=SUB_PROBABILITY)
for i in range(len(test_x)):
    print(f"[{i}]: X=>{td.decode(test_x[i])}<, y=>{td.decode(test_y[i])}<")

[0]: X=>␚efore,␚or an
␚our b␚␚ore, its␚inte␚est soon␚increa␚ed; ␚nd before their fi␚st␚conv␚rsation was <, y=> <
[1]: X=>ed by a␚s␚rt of fa␚c␚n␚tion, it would be ri␚ic␚lous in␚me to repeat
th␚ instances ␚f gre␚t ␚␚s␚o<, y=>o<


## 2. Use tf.data for texts

In [14]:
def expand_name_template(template, params):
    exp=copy.copy(template)
    for key in params:
        src="{"+key+"}"
        dst=f"{params[key]}"
        exp=exp.replace(src,dst)
    return exp

def save_model_metadata(epoch, suffix='std'):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    params['current_epoch'] = epoch
    try:
        with open(meta_file, 'w') as f:
            f.write(json.dumps(params))
    except Exception as e:
        print(f"Failed to store model metadata at {model_path}: {e}")
        return False
    return True

def read_model_metadata(suffix="std"):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    try:
        with open(meta_file, 'r') as f:
            meta = json.load(f)
    except Exception as e:
        print(f"Cannot access project meta-data at {meta_file}: {e}, starting anew.")
        return None
    return meta

def is_metadata_compatible(params, meta):
    is_valid=True
    keys=set(list(params.keys())+list(meta.keys()))
    for key in keys:
        if key in updatable_keys:
            continue
        if key not in meta:
            print(f"Key {key} not available in last checkpoint model_meta, params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif key not in params:
            print(f"Key {key} not available in params, last checkpoint model_meta[{key}]: {meta[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif meta[key]!=params[key]:
            print(f"Last checkpoint model_meta[{key}]: {meta[key]} != params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
    if is_valid is False:
        print("Aborting import.")
        return False
    return True

In [15]:
params = { # Multi-head self-attention
    'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

    'mhsa_layers': 6,
    'heads': 4,
    'units': 192, # len(td.i2c),
    'norm': 'softmax', # this is for within each head
    'mh_normalize': True,  # use layer-norm after concatenation (or additiona) of the heads
    'l2_regularizer': 1e-9,
    'dropout': 0.9,       # no dropout: 0.0
    'join_heads_by_add': False,  # stragegy how multi-heads are joined: False: concat (as in all-you-need), True: relu+add of all the heads
    'vocab_size': len(td.i2c),
    'sequence_len': SEQUENCE_LEN,

    'batch_size': 128,
    'learning_rate': 0.001,
    'clipvalue': None,
    'sample_every_n_epochs': 5,
}

if ml_env.is_tpu is True:
    lr = params['learning_rate']*1.0
else:
    lr = params['learning_rate']

model_suffix = expand_name_template(params['name'], params)
# Put 'important' params in checkpoint-pathname to separate model-data:
checkpoint_dir = os.path.join(model_path, f"training_checkpoints_{model_suffix}")

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'dropout', 
             'sample_every_n_epochs']

# These values are taking from saved checkpoint:
keep_keys=['current_epoch']

continue_last = True

meta = read_model_metadata(suffix=model_suffix)
if meta is not None and is_metadata_compatible(params, meta) is True and continue_last is True:
    for key in keep_keys:
        if key in meta:
            params[key]=meta[key]
    if params is not None:
        print(f"Continuing last session from epoch {params['current_epoch']}")
    else:
        print(f"No previous data, starting new model")
else:
    print("Starting new model")

print(params)

Cannot access project meta-data at /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/model_meta_6x4x192x132.json: [Errno 2] No such file or directory: '/content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/model_meta_6x4x192x132.json', starting anew.
Starting new model
{'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}', 'mhsa_layers': 6, 'heads': 4, 'units': 192, 'norm': 'softmax', 'mh_normalize': True, 'l2_regularizer': 1e-09, 'dropout': 0.9, 'join_heads_by_add': False, 'vocab_size': 132, 'sequence_len': 96, 'batch_size': 128, 'learning_rate': 0.001, 'clipvalue': None, 'sample_every_n_epochs': 5}


In [16]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 24633


In [17]:
# @tf.function
def make_tf_dataset(num, random_index=False, SUB_probability=0.0):
    dx=[]
    dy=[]
    num_batches_active = num
    for i in range(num_batches_active):
        x,y=get_sample_batch(td, params['batch_size'], params['sequence_len'], random_index=random_index, SUB_probability=SUB_probability)
        if i<1:
            print(f"[{num} x]: {x.shape} -> {y.shape}")
        dx.append(x)
        dy.append(y)
    dx=np.array(dx)
    dy=np.array(dy)
    data_xy = (dx, dy)
    tf_dataset=tf.data.Dataset.from_tensor_slices(data_xy)
    return tf_dataset

In [18]:
MAX_NUM_BATCHES = 40000

if num_batches>MAX_NUM_BATCHES:
    restricted_batches=MAX_NUM_BATCHES
else:
    restricted_batches=num_batches
textlib_dataset = make_tf_dataset(restricted_batches, random_index=True, SUB_probability=SUB_PROBABILITY)

[24633 x]: (128, 96) -> (128, 1)


NameError: ignored

In [19]:
shuffle_buffer=10000
if ml_env.is_tpu is True:
    dataset=textlib_dataset.shuffle(shuffle_buffer).repeat()  # Otherwise TPU may run dry
else:
    dataset=textlib_dataset.shuffle(shuffle_buffer)  
dataset.take(1)

<TakeDataset element_spec=(TensorSpec(shape=(128, 96), dtype=tf.float32, name=None), TensorSpec(shape=(128, 1), dtype=tf.int32, name=None))>

In [20]:
if ml_env.is_tpu is False:
    validation_dataset = make_tf_dataset(10, random_index=True, SUB_probability=SUB_PROBABILITY)

[10 x]: (128, 96) -> (128, 1)


In [21]:
def model_mhsa(inputs, params):
    dense = layers.Dense(params['vocab_size'], kernel_regularizer=regularizers.l2(params['l2_regularizer']))  # using softmax here prevents temperature adjust, affects 'from_logits' param in sparse_categorical loss 
    fl = layers.Flatten()
    dr = layers.Dropout(params['dropout'])
    pe = PositionalEncoding(amplitude=0.3)
    mhsa=[]
    for i in range(params['mhsa_layers']):
        mhsa.append(MultiHeadSelfAttention(params['heads'], units=params['units'], norm=params['norm'], mh_normalize=params['mh_normalize'], join_heads_by_add=params['join_heads_by_add']))
    x = tf.one_hot(tf.cast(inputs,dtype=tf.int32), params['vocab_size'], axis=-1)
    x = pe(x)
    for i in range(params['mhsa_layers']):
        x = mhsa[i](x)
        # x = mhsa[i](x,x)
    if params['dropout']>0.0:
        x = dr(x)
    x = dense(fl(x))
    return x 

In [None]:
def mhsa_generate(text, gen_len=64, temperature=0.9, verbose=False):
    if verbose is True:
        full=text[:-1]
    gen_text=""
    lf=0
    if verbose is True:
        print(f"[{full}]", end='')
    tex=copy.copy(text)
    while len(tex) < params['sequence_len']:
        tex=' '+tex
    tex=tex[-params['sequence_len']+1:]
    tex=tex+'␚'
    for i in range(gen_len):
        input = np.array([td.encode(tex)])
        pred = model.predict(input, batch_size=1)
        pred /= temperature
        pred = tf.keras.layers.Softmax()(pred)
        if use_eager is True:
            pred=pred.numpy()
        else:
            pred=tf.keras.backend.eval(pred)
        ci=np.random.choice(list(range(len(pred[0]))), p=pred[0]) # np.argmax(pred[0])
        c=td.i2c[ci]
        if verbose is True:
            print(c, end='')
            if c=='\n':
                lf=0
            else:
                lf += 1
                if (lf>80 and c==' ') or lf>120:
                    print()
                    lf=0
            full+=c
        gen_text+=c
        tex=tex[:-1]+c+'␚'
        tex=tex[-params['sequence_len']:]
    if verbose is True:
        print()
    return gen_text

In [None]:
inputs = keras.Input(shape=(params['sequence_len'],))
outputs = model_mhsa(inputs, params)
model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")

In [None]:
def import_previous_compatible_checkpoint(force_import=False):
    meta = read_model_metadata(suffix=model_suffix)
    if meta is None:
        print("No previous checkpoint found")
        return False
    if is_metadata_compatible(params, meta) is not True and force_import is False:
        print("No useable import found.")
        return False
    try:
        last_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    except Exception as e:
        print(f"Cannot determine last checkpoint in {checkpoint_dir}, cannot import due to: {e}")
        return False
    print(f"Last checkpoint: {last_checkpoint}")
    try:
        model.load_weights(last_checkpoint)
    except Exception as e:
        print(f"Failed to import model {last_checkpoint}: {e}")
        return False
    if 'current_epoch' in meta:
        params['current_epoch'] = meta['current_epoch']
    print(f"Successful import of epoch {params['current_epoch']} from {last_checkpoint}, continuing from there...")
    return True

In [None]:
import_checkpoint = False
force_import = False   # True: ignore metadata and try import anyway. This will of course crash, if the new model doesn't fit the checkpoint-data...

if import_checkpoint is True:
    import_previous_compatible_checkpoint(force_import=force_import)

### Loss function, optimizer, tensorboard output

In [None]:
kscc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

def loss(labels, logits):
  vl=kscc(labels, logits)
  return vl

In [None]:
if params['clipvalue'] is not None:
    opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
else:
    opti = tf.keras.optimizers.Adam(learning_rate=lr)

if ml_env.is_tpu is True:
    model.compile(optimizer=opti, loss=loss, metrics=[])
else:
    model.compile(optimizer=opti, loss=loss, metrics=['accuracy'])

In [None]:
 model.summary()

In [None]:
TPU_GENERATE_ON_CPU = False  # The thing is: both options are slow on TPU :-/

class GeneratorCallback(keras.callbacks.Callback):
#    def on_test_end(self, logs=None):
    def on_epoch_end(self, epoch, logs=None):
        save_model_metadata(epoch, suffix=model_suffix)
        if (epoch+1) % params['sample_every_n_epochs'] == 0:
            idx=random.randint(0,len(td)-1)
            text=td.decode(td[idx])
            print()
            if ml_env.is_tpu is True:
                temp_list=[0.7]
                gen_len=64
            else:
                temp_list=[0.5, 0.7, 0.9]
                gen_len=192
            print(f"prompt: {text}")
            for temp in temp_list:
                print(f"---------------- T={temp} ---------------")
                if ml_env.is_tpu is True and TPU_GENERATE_ON_CPU is True:
                    with tf.device('/cpu:0'):
                        reply=mhsa_generate(text, gen_len=gen_len, temperature=temp, verbose=False)
                else:
                    reply=mhsa_generate(text, gen_len=gen_len, temperature=temp, verbose=False)
                td.source_highlight(reply, min_quote_size=10, dark_mode=use_dark_mode, display_ref_anchor=False)
            print("--------------------------------------")

generator_callback=GeneratorCallback()

In [None]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join(log_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch', histogram_freq=1) # update_freq='epoch', 

In [None]:
# This sometimes completely freezes?! Simply skip this cells, if it seems
# to block. Bug.
# if ml_env.is_colab:
#     try:
#         # use the python variable log_path:
#         get_ipython().run_line_magic('tensorboard', '--logdir "{log_path}"')
#     except:
#         pass

# The following throws errors on non-colab, but the guarding above is too bug-ridden.
%tensorboard --logdir logs

## The actual training

In [None]:
EPOCHS=1000
if 'current_epoch' in params:
    initial_epoch=params['current_epoch']
else:
    initial_epoch=0

In [None]:
if ml_env.is_tpu is True:
    steps_per_epoch=restricted_batches//params['batch_size']
    history = model.fit(dataset, epochs=EPOCHS, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, callbacks=[checkpoint_callback, tensorboard_callback, generator_callback])
else:
    history = model.fit(dataset, validation_data=validation_dataset, epochs=EPOCHS, initial_epoch=initial_epoch, callbacks=[checkpoint_callback, tensorboard_callback, generator_callback])

## A dialog with the trained model

In [None]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):

def doDialog():
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
    doini = True
    bye = False
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        if prompt[:len("temperature=")] == "temperature=":
            t = float(prompt[len("temperature="):])
            if t > 0.05 and t < 1.4:
                temperature = t
                print("(generator temperature now {})".format(t))
                print()
                continue
            print("Invalid temperature-value ignored! [0.1-1.0]")
            continue
        reply=mhsa_generate(prompt, gen_len=256, temperature=temperature, verbose=False)
        td.source_highlight(reply, min_quote_size=13, dark_mode=use_dark_mode)

In [None]:
# Talk to the net!
doDialog()