<a href="https://colab.research.google.com/github/domschl/transformer-poet/blob/main/transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
!pip install -U ml-indie-tools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, regularizers

import tensorflow_datasets as tfds

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.keras_custom_layers import MultiHeadSelfAttention, PositionalEncoding

Using TF-Keras version: 2.8.0


## Preliminary

A tensorflow deep multi-head attention model for text generation

This code can use either CPU, GPU, TPU when running on Google Colab.
Select the corresponding runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [4]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='tf', accelerator='fastest')
ml_env.describe()

INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Initializing the TPU system: grpc://10.38.124.146:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.38.124.146:8470


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


'OS: Linux, Python: 3.7.13, Colab Jupyter Notebook Tensorflow: 2.8.2, TPU: TPU, 8 nodes v2 (8GB)'

In [5]:
use_eager=tf.executing_eagerly()
if ml_env.is_tpu is True:
    tpu_strategy = ml_env.tpu_strategy
    tpu_is_init=True
    if use_eager is True:
        tf.config.run_functions_eagerly(False)
    use_eager=False

In [6]:
project_name='women_writers'
model_name='mhsa_v1_tf'

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : /content/drive/My Drive (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : /content/drive/My Drive/Colab Notebooks/women_writers (Changes to the file system happen only below this project path
Model path (snapshots)   : /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf (Model weights and snapshots are stored here)
Data path (training data): /content/drive/My Drive/Colab Notebooks/women_writers/data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [7]:
use_dark_mode=True # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [8]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [9]:
# sample searches
search_spec= {"author": ["Emily Brontë","Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

21 matching books found with search {'author': ['Emily Brontë', 'Jane Austen', 'Virginia Woolf'], 'language': ['english']}.


In [10]:
for i in range(len(book_list)):
    print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

0: The Common Reader - Virginia Woolf, 64457
1: Mr. Bennett and Mrs. Brown - Virginia Woolf, 63022
2: The Younger Sister, Volumes 1-3 - Catherine Anne Austen Hubback and Jane Austen, 54066
3: The Younger Sister, Vol. 3 - Catherine Anne Austen Hubback and Jane Austen, 54012
4: The Younger Sister, Vol. 2 - Catherine Anne Austen Hubback and Jane Austen, 54011
5: The Younger Sister, Vol. 1 - Catherine Anne Austen Hubback and Jane Austen, 54010
6: Pride and Prejudice - Jane Austen, 42671
7: The Letters of Jane Austen - Jane Austen, 42078
8: The Complete Project Gutenberg Works of Jane Austen - Jane Austen, 31100
9: Jacob's Room - Virginia Woolf, 5670
10: Pride and Prejudice - Jane Austen, 1342
11: Night and Day - Virginia Woolf, 1245
12: Love And Friendship And Other Early Works - Jane Austen, 1212
13: Lady Susan - Jane Austen, 946
14: Wuthering Heights - Emily Brontë, 768
15: Sense and Sensibility - Jane Austen, 161
16: Emma - Jane Austen, 158
17: The Voyage Out - Virginia Woolf, 144
18: M

In [11]:
select = ("Bennett", "1342", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]

print("Using:")
for i in range(len(sub_book_list)):
    print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

textlib_dataset = None  # Forces re-caching
td = Text_Dataset(sub_book_list)
td.init_tokenizer(tokenizer='ngram', max_ngrams=3, max_tokens=400)


Using:
1: Mr. Bennett and Mrs. Brown - Virginia Woolf
2: Pride and Prejudice - Jane Austen
3: Lady Susan - Jane Austen
4: Wuthering Heights - Emily Brontë
5: Emma - Jane Austen
6: The Voyage Out - Virginia Woolf


In [12]:
SEQUENCE_LEN = 80
SUB_PROBABILITY = 0.15  # like BERT

td.init_getitem(sample_type='encoded', sample_length=SEQUENCE_LEN, content_stepping=1)

num_records = len(td)

print(f"{num_records} records")

1627305 records


In [13]:
def OLD_get_sample_batch(td, batch_size, length, random_index=True, SUB_probability=0.0):
    for i in range(batch_size):
        if random_index is True:
            ind = random.randint(0, num_records-1)
        else:
            ind = i * td.getitem_content_stepping
        Xi = td[ind]
        yi = [Xi[-1]]
        if SUB_probability==0.0:
            Xi[-1]=td.c2i['␚']  # use 'SUB'-stitut glyph to mark last char of input
        else:
            l=int(len(Xi)*SUB_probability)
            for li in range(l):
                pos=random.randint(0,len(Xi)-1)
                Xi[pos]=td.c2i['␚']
        if i==0:
            smpX=np.array(Xi, dtype=np.float32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

# def get_random_onehot_sample_batch(td, batch_size, length):
#     X, y = get_random_sample_batch(td, batch_size, length)
#     xoh = tf.keras.backend.one_hot(X, len(td.i2c))
#     yk = tf.keras.backend.constant(y)
#     return xoh, yk

def get_sample_batch(td, batch_size, length, SUB_probability=0.15):
    for i in range(batch_size):
        Xi = td.get_random_item()
        yi = Xi.copy()
        l=int(len(Xi)*SUB_probability)
        for li in range(l):
            pos=random.randint(0,len(Xi)-1)
            if td.tokenizer_type=='char':
                Xi[pos]=td.c2i['␚']
            elif td.tokenizer_type=='word':
                Xi[pos]=td.w2i['<subst>']
            elif td.tokenizer_type=='ngram':
                Xi[pos]=td.t2i['<subst>']
            else:
                print(f"Unexpected tokenizer_type {td.tokenizer_type}")
        if i==0:
            smpX=np.array(Xi, dtype=np.float32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

In [14]:
test_x, test_y = get_sample_batch(td, 2, 40, SUB_probability=SUB_PROBABILITY)
for i in range(len(test_x)):
    xi=[int(x) for x in test_x[i]]
    print(f"[{i}](l={len(xi)}): X=>{td.decode(xi)}<, y=>{td.decode(test_y[i])}<")

[0](l=80): X=>ct of the<subst>quie<subst>drive home which was to c<subst><subst> the very ques<subst>n<subst><subst>le en<subst><subst>yme<subst><subst><subst>
this day of pleasure. Such another scheme, composed of so many
ill-<, y=>ct of the
quiet drive home which was to close the very questionable enjoyments of
this day of pleasure. Such another scheme, composed of so many
ill-<
[1](l=80): X=>ke
increas<subst>her bitter<subst><subst><subst>At last she broke o<subst><subst>

“Thank Go<subst>Hel<subst>, I’m not like you! I sometimes think you don’<subst>hink
or fee<subst>or <subst>are t<subst>do anything <, y=>ke
increased her bitterness. At last she broke out—

“Thank God, Helen, I’m not like you! I sometimes think you don’t think
or feel or care to do anything <


In [15]:
test_x.shape, test_y.shape

((2, 80), (2, 80))

## 2. Use tf.data for texts

In [16]:
def expand_name_template(template, params):
    exp=copy.copy(template)
    for key in params:
        src="{"+key+"}"
        dst=f"{params[key]}"
        exp=exp.replace(src,dst).replace('[','(').replace(']',')')
    return exp

def save_model_metadata(epoch, suffix='std'):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    params['current_epoch'] = epoch
    try:
        with open(meta_file, 'w') as f:
            f.write(json.dumps(params))
    except Exception as e:
        print(f"Failed to store model metadata at {model_path}: {e}")
        return False
    return True

def read_model_metadata(suffix="std"):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    try:
        with open(meta_file, 'r') as f:
            meta = json.load(f)
    except Exception as e:
        print(f"Cannot access project meta-data at {meta_file}: {e}, starting anew.")
        return None
    return meta

def is_metadata_compatible(params, meta):
    is_valid=True
    keys=set(list(params.keys())+list(meta.keys()))
    for key in keys:
        if key in updatable_keys:
            continue
        if key not in meta:
            print(f"Key {key} not available in last checkpoint model_meta, params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif key not in params:
            print(f"Key {key} not available in params, last checkpoint model_meta[{key}]: {meta[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif meta[key]!=params[key]:
            print(f"Last checkpoint model_meta[{key}]: {meta[key]} != params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
    if is_valid is False:
        print("Aborting import.")
        return False
    return True

In [17]:
vocabulary_size = td.get_unique_token_count()  # vocabulary-size

params = { # Multi-head self-attention
    'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

    'mhsa_layers': 4, 
    'heads': [4,4,4,4],
    'units': [512,512,512,512],  # 0 inserts an LSTM for memory-states :-)
    'norm': 'softmax', # this is for within each head
    'mh_normalize': True,  # use layer-norm after concatenation (or additiona) of the heads
    'l2_regularizer': 1e-9,
    'dropout': 0.0,       # no dropout: 0.0
    'join_heads_by_add': False,  # stragegy how multi-heads are joined: False: concat (as in all-you-need), True: relu+add of all the heads
    'vocab_size': vocabulary_size,
    'sequence_len': SEQUENCE_LEN,

    'batch_size': 1024,
    'learning_rate': 0.005,
    'clipvalue': None,
    'sample_every_n_epochs': 2500,
}

if len(params['heads'])!=params['mhsa_layers'] or len(params['units'])!=params['mhsa_layers']:
    print("ERROR: lenght of 'heads' and 'units' must be equal to mhsa_layers!")
    
if ml_env.is_tpu is True:
    lr = params['learning_rate']*1.0
else:
    lr = params['learning_rate']

model_suffix = expand_name_template(params['name'], params)
# Put 'important' params in checkpoint-pathname to separate model-data:
checkpoint_dir = os.path.join(model_path, f"training_checkpoints_{model_suffix}")
if os.path.exists(checkpoint_dir) is False:
    os.makedirs(checkpoint_dir)

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'dropout', 
             'sample_every_n_epochs']

# These values are taking from saved checkpoint:
keep_keys=['current_epoch']

continue_last = True

meta = read_model_metadata(suffix=model_suffix)
if meta is not None and is_metadata_compatible(params, meta) is True and continue_last is True:
    for key in keep_keys:
        if key in meta:
            params[key]=meta[key]
    if params is not None:
        print(f"Continuing last session from epoch {params['current_epoch']}")
    else:
        print(f"No previous data, starting new model")
else:
    print("Starting new model")

print(params)

Cannot access project meta-data at /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/model_meta_4x(4, 4, 4, 4)x(512, 512, 512, 512)x400.json: [Errno 2] No such file or directory: '/content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/model_meta_4x(4, 4, 4, 4)x(512, 512, 512, 512)x400.json', starting anew.
Starting new model
{'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}', 'mhsa_layers': 4, 'heads': [4, 4, 4, 4], 'units': [512, 512, 512, 512], 'norm': 'softmax', 'mh_normalize': True, 'l2_regularizer': 1e-09, 'dropout': 0.0, 'join_heads_by_add': False, 'vocab_size': 400, 'sequence_len': 80, 'batch_size': 1024, 'learning_rate': 0.001, 'clipvalue': None, 'sample_every_n_epochs': 250}


In [18]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 1589


In [19]:
# @tf.function   (only slows things down [considerably!])
def make_tf_dataset(num, random_index=False, SUB_probability=0.0):
    dx=[]
    dy=[]
    num_batches_active = num
    for i in range(num_batches_active):
        x,y=get_sample_batch(td, params['batch_size'], params['sequence_len'], SUB_probability=SUB_probability)
        if i<1:
            print(f"[{num} x]: {x.shape} -> {y.shape}")
        dx.append(x)
        dy.append(y)
    dx=np.array(dx)
    dy=np.array(dy)
    print(f"dx.shape={dx.shape}, dy.shape={dy.shape}")
    data_xy = (dx, dy)
    tf_dataset=tf.data.Dataset.from_tensor_slices(data_xy)
    return tf_dataset

In [20]:
MAX_NUM_BATCHES = 50000

if num_batches>MAX_NUM_BATCHES:
    restricted_batches=MAX_NUM_BATCHES
    print(f"Restrictinig {num_batches} to max of {restricted_batches}")
else:
    restricted_batches=num_batches
    print(f"{restricted_batches} batches")
if cached_batch_data == restricted_batches and textlib_dataset is not None:
    print("Reusing cached training-data")
else:
    print("Creating dataset, this is slow. Be patient...")
    textlib_dataset = make_tf_dataset(restricted_batches, SUB_probability=SUB_PROBABILITY)
    cached_batch_data = restricted_batches
    print("Dataset done and cached.")

1589 batches
Creating dataset, this is slow. Be patient...
[1589 x]: (1024, 80) -> (1024, 80)
dx.shape=(1589, 1024, 80), dy.shape=(1589, 1024, 80)
Dataset done and cached.


In [21]:
shuffle_buffer=10000
if ml_env.is_tpu is True:
    dataset=textlib_dataset.shuffle(shuffle_buffer).repeat()  # Otherwise TPU may run dry
else:
    dataset=textlib_dataset.shuffle(shuffle_buffer)  
dataset.take(1)

<TakeDataset element_spec=(TensorSpec(shape=(1024, 80), dtype=tf.float32, name=None), TensorSpec(shape=(1024, 80), dtype=tf.int32, name=None))>

In [22]:
if ml_env.is_tpu is False:
    validation_dataset = make_tf_dataset(10, random_index=True, SUB_probability=SUB_PROBABILITY)

In [23]:
def model_mhsa(inputs, params):
    dense = layers.Dense(params['vocab_size'], kernel_regularizer=regularizers.l2(params['l2_regularizer']))  # using softmax here prevents temperature adjust, affects 'from_logits' param in sparse_categorical loss 
    fl = layers.Flatten()
    dr = layers.Dropout(params['dropout'])
    pe = PositionalEncoding(amplitude=0.3)
    rs_up = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    lstm1 = layers.LSTM(units=vocabulary_size, return_sequences=True)
#    lstm2 = layers.LSTM(units=vocabulary_size, return_sequences=True)
    rs_down = layers.Reshape(target_shape=(SEQUENCE_LEN,vocabulary_size))
    mhsa=[]
    residuals=[]

    for i in range(params['mhsa_layers']):
        if params['units'][i]==0:
            mhsa.append(None)
            residuals.append(i)
        else:
            mhsa.append(MultiHeadSelfAttention(params['heads'][i], units=params['units'][i], norm=params['norm'], mh_normalize=params['mh_normalize'], join_heads_by_add=params['join_heads_by_add']))
    x = tf.one_hot(tf.cast(inputs,dtype=tf.int32), params['vocab_size'], axis=-1)
    x = pe(x)
    for i in range(len(mhsa)):
        if i in residuals:
            x = rs_down(lstm1(rs_up(x)))+x
            print(f"Residual at layer {i} added.")
        else:
            x = mhsa[i](x)
        # x = mhsa[i](x,x)
    if params['dropout']>0.0:
        x = dr(x)
    # x = dense(fl(x))
    return x 

In [24]:
def mhsa_generate(model, text, gen_len=64, temperature=0.9, argmax=False, verbose=False):
    if verbose is True:
        full=text[:-1]
    gen_text=""
    lf=0
    input = np.array(td.encode(text))
    while len(input) < params['sequence_len']:
        input = np.concatenate([td.encode('<pad>'),input])
    for i in range(gen_len):
        input = np.concatenate([input[1:],td.encode('<subst>')])
        if len(input)!=params['sequence_len']:
            print('assertion failure')
            return None
        pred = model(input)
        pred /= temperature
        pred = tf.keras.layers.Softmax()(pred)
        if tf.executing_eagerly() is True and ml_env.is_tpu is False:
            pred=pred.numpy()
        else:
            pred=tf.keras.backend.eval(pred)  # this is a cheat, it internaly used Numpy() too.
        if argmax is True:
            pred=np.argmax(pred[0],axis=1)
        else:
            pred = [np.random.choice(list(range(len(pred[0][-1]))), p=pred[0][-1])]
        input = np.concatenate([input[1:],[pred[-1]]])
        c = td.decode([pred[-1]])
        if verbose is True:
            print(c, end='')
            if c=='\n':
                lf=0
            else:
                lf += 1
                if (lf>80 and c==' ') or lf>120:
                    print()
                    lf=0
            full+=c
        gen_text+=c
    if verbose is True:
        print()
    return gen_text


In [25]:
if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        print("Creating TPU-scope model")
        inputs = keras.Input(shape=(params['sequence_len'],))
        outputs = model_mhsa(inputs, params)
        model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    print("Creating Default-scope model")
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model_cpu = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
else:
    inputs = keras.Input(shape=(params['sequence_len'],))
    outputs = model_mhsa(inputs, params)
    model = keras.Model(inputs=inputs, outputs=outputs, name="mhsa_v1_tf")
    model_cpu = model

Creating TPU-scope model
Creating Default-scope model


In [26]:
def get_newest_checkpoint(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    paths = [os.path.join(checkpoint_dir, basename) for basename in files]
    return max(paths, key=os.path.getctime)

def import_previous_compatible_checkpoint(model, force_import=False):
    meta = read_model_metadata(suffix=model_suffix)
    if meta is None:
        print("No previous checkpoint found")
        return False
    if is_metadata_compatible(params, meta) is not True and force_import is False:
        print("No useable import found.")
        return False
    try:
        last_checkpoint = get_newest_checkpoint(checkpoint_dir) # Doesn't do anything: tf.train.latest_checkpoint(checkpoint_dir)
    except Exception as e:
        print(f"Cannot determine last checkpoint in {checkpoint_dir}, cannot import due to: {e}")
        return False
    print(f"Last checkpoint: {last_checkpoint}")
    try:
        model.load_weights(last_checkpoint)
    except Exception as e:
        print(f"Failed to import model {last_checkpoint}: {e}")
        return False
    if 'current_epoch' in meta:
        params['current_epoch'] = meta['current_epoch']
    print(f"Successful import of epoch {params['current_epoch']} from {last_checkpoint}, continuing from there...")
    return True

### Loss function, optimizer, tensorboard output

In [27]:
kscc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

def loss(labels, logits):
  vl=kscc(labels, logits)
  return vl

In [28]:
if params['clipvalue'] is not None:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr, clip_value=params['clipvalue'])
else:
    if ml_env.is_tpu is True:
        with tpu_strategy.scope():
            opti = tf.keras.optimizers.Adam(learning_rate=lr)
    else:
        opti = tf.keras.optimizers.Adam(learning_rate=lr)

if ml_env.is_tpu is True:
    with tpu_strategy.scope():
        model.compile(optimizer=opti, loss=loss, metrics=[], run_eagerly=False, jit_compile=True)
else:
    model.compile(optimizer=opti, loss=loss, metrics=['accuracy'])

In [29]:
import_checkpoint = True
force_import = False   # True: ignore metadata and try import anyway. This will of course crash, if the new model doesn't fit the checkpoint-data...

if import_checkpoint is True:
    import_previous_compatible_checkpoint(model, force_import=force_import)

Cannot access project meta-data at /content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/model_meta_4x(4, 4, 4, 4)x(512, 512, 512, 512)x400.json: [Errno 2] No such file or directory: '/content/drive/My Drive/Colab Notebooks/women_writers/model/mhsa_v1_tf/model_meta_4x(4, 4, 4, 4)x(512, 512, 512, 512)x400.json', starting anew.
No previous checkpoint found


In [30]:
model.summary()

Model: "mhsa_v1_tf"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 80)]              0         
                                                                 
 tf.cast (TFOpLambda)        (None, 80)                0         
                                                                 
 tf.one_hot (TFOpLambda)     (None, 80, 400)           0         
                                                                 
 positional_encoding (Positi  (None, 80, 400)          0         
 onalEncoding)                                                   
                                                                 
 multi_head_self_attention (  (None, 80, 400)          4080800   
 MultiHeadSelfAttention)                                         
                                                                 
 multi_head_self_attention_1  (None, 80, 400)          4

In [31]:
TPU_GENERATE_ON_CPU = False  # The thing is: both options are slow on TPU :-/

class ServiceCallback(keras.callbacks.Callback):
#    def on_test_end(self, logs=None):
    # @tf.function
    def on_epoch_end(self, epoch, logs=None):
        save_model_metadata(epoch, suffix=model_suffix)
        if (epoch+1) % params['sample_every_n_epochs'] == 0:
            idx=random.randint(0,len(td)-1)
            text=td.decode(td[idx])
            print()
            if ml_env.is_tpu is True:
                temp_list=[0.6,0.7,0.8,0.0]
                gen_len=64
                with tpu_strategy.scope():
                    weights=model.get_weights()
                model_cpu.set_weights(weights)
                # HDF5 is required for saving weights that originate from TPU
                # otherwise this just silently fails...
                checkpoint_path = os.path.join(checkpoint_dir, "cp-{epoch:04d}.h5")
                chkpt_dest=checkpoint_path.format(epoch=epoch)
                print(f"Checkpoint: {chkpt_dest}")
                model_cpu.save_weights(chkpt_dest)
            else:
                temp_list=[0.5, 0.7, 0.9]
                gen_len=192
            print(f"prompt: {text}")
            for temp in temp_list:
                print(f"---------------- T={temp} ---------------")
                if ml_env.is_tpu is True and TPU_GENERATE_ON_CPU is True:
                    with tf.device('/cpu:0'):
                        if temp==0.0:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                        else:
                            reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                else:
                    if temp==0.0:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=1.0, argmax=True, verbose=False)
                    else:
                        reply=mhsa_generate(model_cpu, text, gen_len=gen_len, temperature=temp, verbose=False)
                td.source_highlight(reply, min_quote_size=10, dark_mode=use_dark_mode, display_ref_anchor=False)
            print("--------------------------------------")

service_callback=ServiceCallback()

In [32]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

logdir = os.path.join(log_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
if ml_env.is_tpu:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='epoch', write_graph=False)
else:
    tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch')


In [33]:
# Dont try:
#    # use the python variable log_path:
#   get_ipython().run_line_magic('tensorboard', '--logdir "{log_path}"')
#except:
#   pass

# The following throws errors on non-colab, but the guarding above is too bug-ridden.
if ml_env.is_tpu is False:
    %tensorboard --logdir logs

## The actual training

In [34]:
EPOCHS=50000
if 'current_epoch' in params:
    initial_epoch=params['current_epoch']
else:
    initial_epoch=0

In [35]:
if ml_env.is_tpu is True:
    steps_per_epoch=restricted_batches//params['batch_size']
    if steps_per_epoch < 1:
        steps_per_epoch = 1
    history = model.fit(dataset, epochs=EPOCHS, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, callbacks=[service_callback]) # for TPU we need to role our own checkpointer since we need to transfer the weights
else:
    history = model.fit(dataset, validation_data=validation_dataset, epochs=EPOCHS, initial_epoch=initial_epoch, callbacks=[checkpoint_callback, tensorboard_callback, service_callback])

Epoch 1/5000
Epoch 2/5000
Epoch 3/5000
Epoch 4/5000
Epoch 5/5000
Epoch 6/5000
Epoch 7/5000
Epoch 8/5000
Epoch 9/5000
Epoch 10/5000
Epoch 11/5000
Epoch 12/5000
Epoch 13/5000
Epoch 14/5000
Epoch 15/5000
Epoch 16/5000
Epoch 17/5000
Epoch 18/5000
Epoch 19/5000
Epoch 20/5000
Epoch 21/5000
Epoch 22/5000
Epoch 23/5000
Epoch 24/5000
Epoch 25/5000
Epoch 26/5000
Epoch 27/5000
Epoch 28/5000
Epoch 29/5000
Epoch 30/5000
Epoch 31/5000
Epoch 32/5000
Epoch 33/5000
Epoch 34/5000
Epoch 35/5000
Epoch 36/5000
Epoch 37/5000
Epoch 38/5000
Epoch 39/5000
Epoch 40/5000
Epoch 41/5000
Epoch 42/5000
Epoch 43/5000
Epoch 44/5000
Epoch 45/5000
Epoch 46/5000
Epoch 47/5000
Epoch 48/5000
Epoch 49/5000
Epoch 50/5000
Epoch 51/5000
Epoch 52/5000
Epoch 53/5000
Epoch 54/5000
Epoch 55/5000
Epoch 56/5000
Epoch 57/5000
Epoch 58/5000
Epoch 59/5000
Epoch 60/5000
Epoch 61/5000
Epoch 62/5000
Epoch 63/5000
Epoch 64/5000
Epoch 65/5000
Epoch 66/5000
Epoch 67/5000
Epoch 68/5000
Epoch 69/5000
Epoch 70/5000
Epoch 71/5000
Epoch 72/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 251/5000
Epoch 252/5000
Epoch 253/5000
Epoch 254/5000
Epoch 255/5000
Epoch 256/5000
Epoch 257/5000
Epoch 258/5000
Epoch 259/5000
Epoch 260/5000
Epoch 261/5000
Epoch 262/5000
Epoch 263/5000
Epoch 264/5000
Epoch 265/5000
Epoch 266/5000
Epoch 267/5000
Epoch 268/5000
Epoch 269/5000
Epoch 270/5000
Epoch 271/5000
Epoch 272/5000
Epoch 273/5000
Epoch 274/5000
Epoch 275/5000
Epoch 276/5000
Epoch 277/5000
Epoch 278/5000
Epoch 279/5000
Epoch 280/5000
Epoch 281/5000
Epoch 282/5000
Epoch 283/5000
Epoch 284/5000
Epoch 285/5000
Epoch 286/5000
Epoch 287/5000
Epoch 288/5000
Epoch 289/5000
Epoch 290/5000
Epoch 291/5000
Epoch 292/5000
Epoch 293/5000
Epoch 294/5000
Epoch 295/5000
Epoch 296/5000
Epoch 297/5000
Epoch 298/5000
Epoch 299/5000
Epoch 300/5000
Epoch 301/5000
Epoch 302/5000
Epoch 303/5000
Epoch 304/5000
Epoch 305/5000
Epoch 306/5000
Epoch 307/5000
Epoch 308/5000
Epoch 309/5000
Epoch 310/5000
Epoch 311/5000
Epoch 312/5000
Epoch 313/5000
Epoch 314/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 501/5000
Epoch 502/5000
Epoch 503/5000
Epoch 504/5000
Epoch 505/5000
Epoch 506/5000
Epoch 507/5000
Epoch 508/5000
Epoch 509/5000
Epoch 510/5000
Epoch 511/5000
Epoch 512/5000
Epoch 513/5000
Epoch 514/5000
Epoch 515/5000
Epoch 516/5000
Epoch 517/5000
Epoch 518/5000
Epoch 519/5000
Epoch 520/5000
Epoch 521/5000
Epoch 522/5000
Epoch 523/5000
Epoch 524/5000
Epoch 525/5000
Epoch 526/5000
Epoch 527/5000
Epoch 528/5000
Epoch 529/5000
Epoch 530/5000
Epoch 531/5000
Epoch 532/5000
Epoch 533/5000
Epoch 534/5000
Epoch 535/5000
Epoch 536/5000
Epoch 537/5000
Epoch 538/5000
Epoch 539/5000
Epoch 540/5000
Epoch 541/5000
Epoch 542/5000
Epoch 543/5000
Epoch 544/5000
Epoch 545/5000
Epoch 546/5000
Epoch 547/5000
Epoch 548/5000
Epoch 549/5000
Epoch 550/5000
Epoch 551/5000
Epoch 552/5000
Epoch 553/5000
Epoch 554/5000
Epoch 555/5000
Epoch 556/5000
Epoch 557/5000
Epoch 558/5000
Epoch 559/5000
Epoch 560/5000
Epoch 561/5000
Epoch 562/5000
Epoch 563/5000
Epoch 564/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 751/5000
Epoch 752/5000
Epoch 753/5000
Epoch 754/5000
Epoch 755/5000
Epoch 756/5000
Epoch 757/5000
Epoch 758/5000
Epoch 759/5000
Epoch 760/5000
Epoch 761/5000
Epoch 762/5000
Epoch 763/5000
Epoch 764/5000
Epoch 765/5000
Epoch 766/5000
Epoch 767/5000
Epoch 768/5000
Epoch 769/5000
Epoch 770/5000
Epoch 771/5000
Epoch 772/5000
Epoch 773/5000
Epoch 774/5000
Epoch 775/5000
Epoch 776/5000
Epoch 777/5000
Epoch 778/5000
Epoch 779/5000
Epoch 780/5000
Epoch 781/5000
Epoch 782/5000
Epoch 783/5000
Epoch 784/5000
Epoch 785/5000
Epoch 786/5000
Epoch 787/5000
Epoch 788/5000
Epoch 789/5000
Epoch 790/5000
Epoch 791/5000
Epoch 792/5000
Epoch 793/5000
Epoch 794/5000
Epoch 795/5000
Epoch 796/5000
Epoch 797/5000
Epoch 798/5000
Epoch 799/5000
Epoch 800/5000
Epoch 801/5000
Epoch 802/5000
Epoch 803/5000
Epoch 804/5000
Epoch 805/5000
Epoch 806/5000
Epoch 807/5000
Epoch 808/5000
Epoch 809/5000
Epoch 810/5000
Epoch 811/5000
Epoch 812/5000
Epoch 813/5000
Epoch 814/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 1001/5000
Epoch 1002/5000
Epoch 1003/5000
Epoch 1004/5000
Epoch 1005/5000
Epoch 1006/5000
Epoch 1007/5000
Epoch 1008/5000
Epoch 1009/5000
Epoch 1010/5000
Epoch 1011/5000
Epoch 1012/5000
Epoch 1013/5000
Epoch 1014/5000
Epoch 1015/5000
Epoch 1016/5000
Epoch 1017/5000
Epoch 1018/5000
Epoch 1019/5000
Epoch 1020/5000
Epoch 1021/5000
Epoch 1022/5000
Epoch 1023/5000
Epoch 1024/5000
Epoch 1025/5000
Epoch 1026/5000
Epoch 1027/5000
Epoch 1028/5000
Epoch 1029/5000
Epoch 1030/5000
Epoch 1031/5000
Epoch 1032/5000
Epoch 1033/5000
Epoch 1034/5000
Epoch 1035/5000
Epoch 1036/5000
Epoch 1037/5000
Epoch 1038/5000
Epoch 1039/5000
Epoch 1040/5000
Epoch 1041/5000
Epoch 1042/5000
Epoch 1043/5000
Epoch 1044/5000
Epoch 1045/5000
Epoch 1046/5000
Epoch 1047/5000
Epoch 1048/5000
Epoch 1049/5000
Epoch 1050/5000
Epoch 1051/5000
Epoch 1052/5000
Epoch 1053/5000
Epoch 1054/5000
Epoch 1055/5000
Epoch 1056/5000
Epoch 1057/5000
Epoch 1058/5000
Epoch 1059/5000
Epoch 1060/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 1251/5000
Epoch 1252/5000
Epoch 1253/5000
Epoch 1254/5000
Epoch 1255/5000
Epoch 1256/5000
Epoch 1257/5000
Epoch 1258/5000
Epoch 1259/5000
Epoch 1260/5000
Epoch 1261/5000
Epoch 1262/5000
Epoch 1263/5000
Epoch 1264/5000
Epoch 1265/5000
Epoch 1266/5000
Epoch 1267/5000
Epoch 1268/5000
Epoch 1269/5000
Epoch 1270/5000
Epoch 1271/5000
Epoch 1272/5000
Epoch 1273/5000
Epoch 1274/5000
Epoch 1275/5000
Epoch 1276/5000
Epoch 1277/5000
Epoch 1278/5000
Epoch 1279/5000
Epoch 1280/5000
Epoch 1281/5000
Epoch 1282/5000
Epoch 1283/5000
Epoch 1284/5000
Epoch 1285/5000
Epoch 1286/5000
Epoch 1287/5000
Epoch 1288/5000
Epoch 1289/5000
Epoch 1290/5000
Epoch 1291/5000
Epoch 1292/5000
Epoch 1293/5000
Epoch 1294/5000
Epoch 1295/5000
Epoch 1296/5000
Epoch 1297/5000
Epoch 1298/5000
Epoch 1299/5000
Epoch 1300/5000
Epoch 1301/5000
Epoch 1302/5000
Epoch 1303/5000
Epoch 1304/5000
Epoch 1305/5000
Epoch 1306/5000
Epoch 1307/5000
Epoch 1308/5000
Epoch 1309/5000
Epoch 1310/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 1501/5000
Epoch 1502/5000
Epoch 1503/5000
Epoch 1504/5000
Epoch 1505/5000
Epoch 1506/5000
Epoch 1507/5000
Epoch 1508/5000
Epoch 1509/5000
Epoch 1510/5000
Epoch 1511/5000
Epoch 1512/5000
Epoch 1513/5000
Epoch 1514/5000
Epoch 1515/5000
Epoch 1516/5000
Epoch 1517/5000
Epoch 1518/5000
Epoch 1519/5000
Epoch 1520/5000
Epoch 1521/5000
Epoch 1522/5000
Epoch 1523/5000
Epoch 1524/5000
Epoch 1525/5000
Epoch 1526/5000
Epoch 1527/5000
Epoch 1528/5000
Epoch 1529/5000
Epoch 1530/5000
Epoch 1531/5000
Epoch 1532/5000
Epoch 1533/5000
Epoch 1534/5000
Epoch 1535/5000
Epoch 1536/5000
Epoch 1537/5000
Epoch 1538/5000
Epoch 1539/5000
Epoch 1540/5000
Epoch 1541/5000
Epoch 1542/5000
Epoch 1543/5000
Epoch 1544/5000
Epoch 1545/5000
Epoch 1546/5000
Epoch 1547/5000
Epoch 1548/5000
Epoch 1549/5000
Epoch 1550/5000
Epoch 1551/5000
Epoch 1552/5000
Epoch 1553/5000
Epoch 1554/5000
Epoch 1555/5000
Epoch 1556/5000
Epoch 1557/5000
Epoch 1558/5000
Epoch 1559/5000
Epoch 1560/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


---------------- T=0.0 ---------------


--------------------------------------
Epoch 1751/5000
Epoch 1752/5000
Epoch 1753/5000
Epoch 1754/5000
Epoch 1755/5000
Epoch 1756/5000
Epoch 1757/5000
Epoch 1758/5000
Epoch 1759/5000
Epoch 1760/5000
Epoch 1761/5000
Epoch 1762/5000
Epoch 1763/5000
Epoch 1764/5000
Epoch 1765/5000
Epoch 1766/5000
Epoch 1767/5000
Epoch 1768/5000
Epoch 1769/5000
Epoch 1770/5000
Epoch 1771/5000
Epoch 1772/5000
Epoch 1773/5000
Epoch 1774/5000
Epoch 1775/5000
Epoch 1776/5000
Epoch 1777/5000
Epoch 1778/5000
Epoch 1779/5000
Epoch 1780/5000
Epoch 1781/5000
Epoch 1782/5000
Epoch 1783/5000
Epoch 1784/5000
Epoch 1785/5000
Epoch 1786/5000
Epoch 1787/5000
Epoch 1788/5000
Epoch 1789/5000
Epoch 1790/5000
Epoch 1791/5000
Epoch 1792/5000
Epoch 1793/5000
Epoch 1794/5000
Epoch 1795/5000
Epoch 1796/5000
Epoch 1797/5000
Epoch 1798/5000
Epoch 1799/5000
Epoch 1800/5000
Epoch 1801/5000
Epoch 1802/5000
Epoch 1803/5000
Epoch 1804/5000
Epoch 1805/5000
Epoch 1806/5000
Epoch 1807/5000
Epoch 1808/5000
Epoch 1809/5000
Epoch 1810/5000
E

---------------- T=0.7 ---------------


---------------- T=0.8 ---------------


KeyboardInterrupt: ignored

## A dialog with the trained model

In [None]:
model_cpu.set_weights(model.get_weights())

In [None]:
# Do a dialog with the recursive neural net trained above:
# def genDialogAnswer(prompt, g_state=None, endPrompt='.', maxEndPrompts=2,
# maxAnswerSize=512, temperature=1.0):

def doDialog(model):
    temperature = 0.6
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    # look for number of maxEndPrompts until answer is finished.
    maxEndPrompts = 4
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
    doini = True
    bye = False
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        if prompt[:len("temperature=")] == "temperature=":
            t = float(prompt[len("temperature="):])
            if t > 0.05 and t < 1.4:
                temperature = t
                print("(generator temperature now {})".format(t))
                print()
                continue
            print("Invalid temperature-value ignored! [0.1-1.0]")
            continue
        reply=mhsa_generate(model, prompt, gen_len=256, temperature=temperature, verbose=True)
        td.source_highlight(reply, min_quote_size=13, dark_mode=use_dark_mode)

In [None]:
# Talk to the net!
doDialog(model_cpu)