<a href="https://colab.research.google.com/github/domschl/torch-transformer-poet/blob/main/torch_transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Torch-Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
!pip install -U ml-indie-tools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random
import numpy as np

import torch

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.pytorch_custom_layers import MultiHeadSelfAttention

## Preliminary

A pytorch deep multi-head attention model for text generation following Andrej Karpathy's [video-lecture-ng](https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py)

This code can use either CPU, GPU, or Apple Silicon. Google Colab is supported too, select the corresponding Colab runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [4]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='pt', accelerator='fastest')
ml_env.describe()

'OS: Linux, Python: 3.8.10, Colab Jupyter Notebook Pytorch: 1.13.1+cu116, GPU: NVIDIA A100-SXM4-40GB (3MiB / 40960MiB), CPU'

In [5]:
# project_name = 'women_writers'
project_name='philosophers'
model_name=f'ngpt_{project_name}_v1_pt'

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : /content/drive/My Drive (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : /content/drive/My Drive/Colab Notebooks/philosophers (Changes to the file system happen only below this project path
Model path (snapshots)   : /content/drive/My Drive/Colab Notebooks/philosophers/model/ngpt_philosophers_v1_pt (Model weights and snapshots are stored here)
Data path (training data): /content/drive/My Drive/Colab Notebooks/philosophers/data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  1. Text library

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [6]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [7]:
logging.basicConfig(level=logging.INFO)
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [8]:
if project_name == 'women_writers':  # sample searches
    search_spec= {
        "author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"], 
        "language": ["english"]
    }
elif project_name == 'philosophers':
    search_spec = {
        "author": ["Immanuel Kant", "Friedrich Nietzsche", "Wilhelm Hegel"],
        "language": ["english"]
    }
    
book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

20 matching books found with search {'author': ['Immanuel Kant', 'Friedrich Nietzsche', 'Wilhelm Hegel'], 'language': ['english']}.


In [9]:
for i in range(len(book_list)):
    print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

0: The History of Philosophy: Volume 3 of 3 - Georg Wilhelm Hegel, 58169
1: The Will to Power, Books III and IV - Friedrich Nietzsche, 52915
2: The Will to Power, Books I and II - Friedrich Nietzsche, 52914
3: The Joyful Wisdom - Friedrich Nietzsche, 52881
4: Kant's Prolegomena - Immanuel Kant, 52821
5: Hegel's Lectures on the History of Philosophy: Vol. 2 of 3 - Georg Wilhelm Hegel, 51636
6: Hegel's Lectures on the History of Philosophy: Vol. 1 of 3 - Georg Wilhelm Hegel, 51635
7: Early Greek Philosophy & Other Essays - Friedrich Nietzsche, 51548
8: Perpetual Peace - Immanuel Kant, 50922
9: Kant's Critique of Judgement - Immanuel Kant, 48433
10: Thoughts Out of Season, Part 2 - Friedrich Nietzsche, 38226
11: Human, All Too Human - Friedrich Nietzsche, 38145
12: We Philologists, Volume 8 of 18 - Friedrich Nietzsche, 18267
13: The Metaphysical Elements of Ethics - Immanuel Kant, 5684
14: The Critique of Practical Reason - Immanuel Kant, 5683
15: Fundamental Principles of the Metaphysic 

In [11]:
MAX_TOKENS = 25000  # This becomes vocab_size
MAX_NGRAM_LEN = 12   # Max length of a token

if project_name == 'women_writers':
    select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
    sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]
else:
    sub_book_list = book_list
    
print("Using:")
for i in range(len(sub_book_list)):
    print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

textlib_dataset = None  # Forces re-caching
td = Text_Dataset(sub_book_list)
print("")
print(f"Starting NGRAM tokinizer with token length from 1..{MAX_NGRAM_LEN} with a max of {MAX_TOKENS} unique tokens,")
print("this can take considerable time...")
td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS)


Using:
1: The History of Philosophy: Volume 3 of 3 - Georg Wilhelm Hegel
2: The Will to Power, Books III and IV - Friedrich Nietzsche
3: The Will to Power, Books I and II - Friedrich Nietzsche
4: The Joyful Wisdom - Friedrich Nietzsche
5: Kant's Prolegomena - Immanuel Kant
6: Hegel's Lectures on the History of Philosophy: Vol. 2 of 3 - Georg Wilhelm Hegel
7: Hegel's Lectures on the History of Philosophy: Vol. 1 of 3 - Georg Wilhelm Hegel
8: Early Greek Philosophy & Other Essays - Friedrich Nietzsche
9: Perpetual Peace - Immanuel Kant
10: Kant's Critique of Judgement - Immanuel Kant
11: Thoughts Out of Season, Part 2 - Friedrich Nietzsche
12: Human, All Too Human - Friedrich Nietzsche
13: We Philologists, Volume 8 of 18 - Friedrich Nietzsche
14: The Metaphysical Elements of Ethics - Immanuel Kant
15: The Critique of Practical Reason - Immanuel Kant
16: Fundamental Principles of the Metaphysic of Morals - Immanuel Kant
17: Thoughts out of Season, Part One - Friedrich Nietzsche
18: Beyond

In [12]:
SEQUENCE_LEN = 192
# SUB_PROBABILITY = 0.15  # like BERT

td.init_getitem(sample_type='encoded', sample_length=SEQUENCE_LEN+1, content_stepping=1)

num_records = len(td)

print(f"{num_records} records")

2240694 records


In [13]:
def get_sample_batch(td, batch_size):
    # generate a small batch of data of inputs x and targets y
    # ix = torch.randint(len(data) - block_size, (batch_size,))
    # x = torch.stack([data[i : i + block_size] for i in ix])
    # y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    # x, y = x.to(device), y.to(device)
    # return x, y
    for i in range(batch_size):
        data = td.get_random_item()
        Xi = data[:-1]
        yi = data[1:]
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

In [14]:
test_x, test_y = get_sample_batch(td, 2)
for i in range(len(test_x)):
    xi=[int(x) for x in test_x[i]]
    print(f"[{i}](l={len(xi)}): X=>{td.decode(xi)}<,\ny=>{td.decode(test_y[i])}<")

[0](l=192): X=>ing the old man straight in the eye.

“Let him go, he is gone. And though it honoureth thee that thou speakest
only in praise of this dead one, yet thou knowest as well as I WHO he
was, and that he went curious ways.”

“To speak before three eyes,” said the old pope cheerfully (he was blind
of one eye), “in divine matters I am more enlightened than Zarathustra
himself--and may well be so.

My love served him long years, my will followed all his will. A good
servant, however, knoweth everything, and many a thing even which a
master hideth from himself.

He was a hidden God, full of secrecy. Verily, he did not come by his
son otherwise than by secret ways. At the door of his faith standeth
adultery.

Whoever extolleth him as a God of love<,
y=>old man straight in the eye.

“Let him go, he is gone. And though it honoureth thee that thou speakest
only in praise of this dead one, yet thou knowest as well as I WHO he
was, and that he went curious ways.”

“To speak before three

In [15]:
test_x.shape, test_y.shape

((2, 192), (2, 192))

## 2. data for texts

In [16]:
def expand_name_template(template, params):
    exp=copy.copy(template)
    for key in params:
        src="{"+key+"}"
        dst=f"{params[key]}"
        exp=exp.replace(src,dst).replace('[','(').replace(']',')')
    return exp

def save_model_metadata(epoch, suffix='std'):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    params['current_epoch'] = epoch
    try:
        with open(meta_file, 'w') as f:
            f.write(json.dumps(params))
    except Exception as e:
        print(f"Failed to store model metadata at {model_path}: {e}")
        return False
    return True

def read_model_metadata(suffix="std"):
    meta_file = os.path.join(model_path, f'model_meta_{suffix}.json')
    try:
        with open(meta_file, 'r') as f:
            meta = json.load(f)
    except Exception as e:
        print(f"Cannot access project meta-data at {meta_file}: {e}, starting anew.")
        return None
    return meta

def is_metadata_compatible(params, meta):
    is_valid=True
    keys=set(list(params.keys())+list(meta.keys()))
    for key in keys:
        if key in updatable_keys:
            continue
        if key not in meta:
            print(f"Key {key} not available in last checkpoint model_meta, params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif key not in params:
            print(f"Key {key} not available in params, last checkpoint model_meta[{key}]: {meta[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
        elif meta[key]!=params[key]:
            print(f"Last checkpoint model_meta[{key}]: {meta[key]} != params[{key}]: {params[key]}, cannot import incompatible model. Put key in `updatable_keys` list, if irrelevant.")
            is_valid = False
    if is_valid is False:
        print("Aborting import.")
        return False
    return True

In [32]:
vocabulary_size = td.get_unique_token_count()  # vocabulary-size

attn_layers = 48;

params = { # Multi-head self-attention
    'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

    'mhsa_layers': attn_layers, 
    'heads': 16,
    'causal': True,  # Use causal self-attention
    'dropout': 0.1,       # no dropout: 0.0
    'vocab_size': vocabulary_size,
    'sequence_len': SEQUENCE_LEN,
    'embedding_size': 384, 
    'test_iterations': 10,  # number of epocs for loss estimation

    'batch_size': 64,
    'learning_rate': 0.0004,
    'sample_every_n_iterations': 250,
    
    'max_iterations': 1000000  # maximum number of training iterations
}

# if len(params['heads'])!=params['mhsa_layers']: # or len(params['units'])!=params['mhsa_layers']:
#     print("ERROR: lenght of 'heads' and 'units' must be equal to mhsa_layers!")
    
model_suffix = expand_name_template(params['name'], params)
# Put 'important' params in checkpoint-pathname to separate model-data:
checkpoint_dir = os.path.join(model_path, f"training_checkpoints_{model_suffix}")
if os.path.exists(checkpoint_dir) is False:
    os.makedirs(checkpoint_dir)

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'dropout', 
             'sample_every_n_epochs']

# These values are taking from saved checkpoint:
keep_keys=['current_epoch']

continue_last = True
if continue_last is False:
    print("NOT continuing based on existing training! New start.")

meta = read_model_metadata(suffix=model_suffix)
if meta is not None and is_metadata_compatible(params, meta) is True and continue_last is True:
    for key in keep_keys:
        if key in meta:
            params[key]=meta[key]
    if params is not None:
        print(f"Continuing last session from epoch {params['current_epoch']}")
    else:
        print(f"No previous data, starting new model")
else:
    print("Starting new model")

print(params)

Cannot access project meta-data at /content/drive/My Drive/Colab Notebooks/philosophers/model/ngpt_philosophers_v1_pt/model_meta_48x16x{units}x25000.json: [Errno 2] No such file or directory: '/content/drive/My Drive/Colab Notebooks/philosophers/model/ngpt_philosophers_v1_pt/model_meta_48x16x{units}x25000.json', starting anew.
Starting new model
{'name': '{mhsa_layers}x{heads}x{units}x{vocab_size}', 'mhsa_layers': 48, 'heads': 16, 'causal': True, 'dropout': 0.1, 'vocab_size': 25000, 'sequence_len': 192, 'embedding_size': 384, 'test_iterations': 10, 'batch_size': 64, 'learning_rate': 0.0004, 'sample_every_n_iterations': 250, 'max_iterations': 1000000}


In [33]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 35010


In [34]:
def get_torch_batch(td, batch_size, device, split=None):
    x, y = get_sample_batch(td, batch_size)
    return torch.tensor(x, dtype=torch.long).to(device), torch.tensor(y, dtype=torch.long).to(device)

In [35]:
# get_torch_batch(td, 2, 'cpu')

In [36]:
@torch.no_grad()
def estimate_loss(device):
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(params['test_iterations'])
        for k in range(params['test_iterations']):
            print(".", end="", flush=True)
            X, Y = get_torch_batch(td, params['batch_size'], device, split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    print("\r", end="", flush=True)
    return out


def generate_sample(td, device, toks=100, temperature=1.0):
    # generate from the model
    # context = torch.zeros((1, 1), dtype=torch.long, device=device)
    context = torch.tensor([td.encode(' ')]).to(device)
    txt = td.decode(model.generate(context, max_new_tokens=toks, temperature=temperature)[0].tolist())
    td.source_highlight(txt, min_quote_size=10, dark_mode=False, display_ref_anchor=False)
    return txt
    # open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))


# @torch.compile
def do_train_step(xb, yb):
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [37]:
# XXX!
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("mps") if torch.backends.mps.is_available() else device

In [38]:
print("creating model...")
model_cpu = MultiHeadSelfAttention(params['vocab_size'], params['embedding_size'], 
                                   params['sequence_len'], params['dropout'], 
                                   params['heads'], params['mhsa_layers'], params['causal'], device)
model = model_cpu.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'])
start_iter = 0

creating model...
104.418472 M parameters


In [None]:
dt0 = time.time()
print("training...")
for iter in range(start_iter, params['max_iterations']):
    print(f"\rIteration: {iter+1:5d}/{((iter+1)//params['sample_every_n_iterations']+1)*params['sample_every_n_iterations']}/{params['max_iterations']}", end="", flush=True)
    # every once in a while evaluate the loss on train and val sets
    if (iter + 1) % params['sample_every_n_iterations'] == 0 or iter == params['max_iterations'] - 1:
        dt = time.time()
        print(f"\rloss eval", end="", flush=True)
        losses = estimate_loss(device)
        print(
            f"step {iter+1}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time {(dt-dt0)/params['sample_every_n_iterations']:.3f} sec/iter"
        )
        print("Sample: ", end="", flush=True)
        for temperature in [1.1, 1.2]:
            print(f"--------temperature: {temperature} ---------")
            generate_sample(td, device, toks=50, temperature=temperature)
        print("-------------------------------------------")
        dt0 = time.time()
    # sample a batch of data
    xb, yb = get_torch_batch(td, params['batch_size'], device, "train")
    # evaluate the loss
    do_train_step(xb, yb)
    start_iter = iter

training...
Iteration:    73/250/1000000

In [None]:
txt = generate_sample(td, device, toks=100, temperature=1.2)
print(txt)
td.source_highlight(txt)