<a href="https://colab.research.google.com/github/domschl/torch-transformer-poet/blob/main/torch_transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Torch-Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [None]:
!pip install -U ml-indie-tools

In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install -U torch

In [None]:
import logging
import os
import copy
import json
import time
import datetime
import random
import numpy as np

import torch

In [None]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.Calibre_Dataset import Calibre_Dataset
from ml_indie_tools.Folder_Dataset import Folder_Dataset

from ml_indie_tools.pytorch_custom_layers import MultiHeadSelfAttention
import ml_indie_tools.pytorch_meta_tools as MJ

In [None]:
logging.basicConfig(level=logging.INFO)

## Preliminary

A pytorch deep multi-head attention model for text generation following Andrej Karpathy's [video-lecture-ng](https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py)

This code can use either CPU, GPU, or Apple Silicon. Google Colab is supported too, select the corresponding Colab runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [None]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='pt', accelerator='fastest')
ml_env.describe()

## 1. Project configuration

In [None]:
# project_name = 'women_writers'
model_cpu = None
project_name='philosophers'
model_name=f'ngpt_COMP_{project_name}_v2_pt'

use_preprocessed_data = True
use_existing_model_from_checkpoint = True

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools 
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

##  2.1 Text data from Project Gutenberg

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training, 
encoding, batch generation, and formatted source display. It read some 
books from Project Gutenberg and supports creation of training batches. 
The output functions support highlighting to allow to compare generated 
texts with the actual sources to help to identify identical (memorized) 
parts.

In [None]:
logging.basicConfig(level=logging.INFO)
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [None]:
token_file = os.path.join(data_path,f"{project_name}_tokens.json")
if use_preprocessed_data is True:
    if os.path.exists(token_file):
        td = Text_Dataset()
        td.load_tokenizer(token_file)
    else:
        use_preprocessed_data = False

In [None]:
if use_preprocessed_data is False:
    cache_dir = os.path.join(data_path, 'gutenberg_cache')
    gd = Gutenberg_Dataset(cache_dir=cache_dir)

    if project_name == 'women_writers':  # sample searches
        search_spec= {
            "author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"], 
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
    elif project_name == 'philosophers':
        search_spec = {
            "author": ["Immanuel Kant", "Friedrich Nietzsche", "Wilhelm Hegel"],
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
        search_spec = {
            "author": ["Plato"],
            "title": ["Timaeus", "Critias", "Symposium"],
            "language": ["english"]
        }
        book_list+=gd.search(search_spec)

    book_cnt = len(book_list)
    print(f"{book_cnt} matching books found with search {search_spec}.")
    if book_cnt<40:
        # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
        # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
        book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
    else:
        logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

    for i in range(len(book_list)):
        print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")
        
    if project_name == 'women_writers':
        select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
        sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]
    else:
        sub_book_list = book_list

    print("Using:")
    for i in range(len(sub_book_list)):
        print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

    td = Text_Dataset(sub_book_list)

## 2.2 Additional training material for folder `{data_path}/local_texts`

If the folder {data_path} (defined above) contains a sub-folder `local_texts`, and it contains
files of structure `<title> - <author> - <language>.txt`, then they are added to the training data.
Sample filename: `"./data/local_texts/works-of-shakespeare - William Shakespeare - English.txt"`.
The titles of those documents are referenced via numeric aliases to preserve privacy on non-public data.

In [None]:
if use_preprocessed_data is False:
    use_local_folder_data = True
    if use_local_folder_data:
        local_texts = os.path.join(data_path, 'local_texts')
        fd = Folder_Dataset(local_texts)
        fd.load_index(use_aliases=False)
        td.load_texts(fd.records)

## 2.3 Tokenize data

In [None]:
if use_preprocessed_data is False:
    MAX_TOKENS = 20000  # This becomes vocab_size
    MAX_NGRAM_LEN = 8   # Max length of a token

    print("")
    print(f"Starting NGRAM tokinizer with token length from 1..{MAX_NGRAM_LEN} with a max of {MAX_TOKENS} unique tokens,")
    print("this can take considerable time...")

    td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS)
    td.save_tokenizer(token_file)

## 3. Model metadata

In [None]:
params = None
if use_existing_model_from_checkpoint is True:
    model_file_path = MJ.get_model_filename(model_path)
    params = MJ.load_model_metadata_from_checkpoint(model_file_path, device=device) # torch.device('cpu'))
if params == None or use_existing_model_from_checkpoint is False:
    attn_layers = 24
    use_existing_model_from_checkpoint = False
    params = { # Multi-head self-attention
        'meta_name_template': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

        'mhsa_layers': attn_layers, 
        'heads': 16,
        'causal': True,  # Use causal self-attention
        'sigma_compressor': True,
        'dropout': 10,  # WARNING: dropout > 1.0 has nothing to do with dropout, but is a compressor-hack! (see ml_indie_tools v. >= 0.6)
        'vocab_size': td.get_unique_token_count(),
        'sequence_len': 384,
        'embedding_size': 384, 
        'test_iterations': 10,  # number of epocs for loss estimation

        'batch_size': 64,
        'learning_rate': 0.001,
        'sample_every_n_iterations': 250,
        'sample_size': 100,
        'save_every_n_iterations': 100,

        'max_iterations': 1000000  # maximum number of training iterations
    }

if params['dropout']>1.0:
    if params['sigma_compressor'] is True:
        params['meta_name_template'] += f"-SigmaCmp({params['dropout']})"
    else:
        params['meta_name_template'] += f"-Cmp({params['dropout']})"

# When comparing if training-data is compatible with new params set, 
# the following keys are updatable, they can be changed while continuing
# to use existing checkpoints and continue training with those values
# changed:
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'current_loss',  
                 'sample_every_n_iterations', 'sample_size', 'save_every_n_iterations']
print(params)

## 4. Batch handling

In [None]:
td.init_getitem(sample_type='encoded', sample_length=params['sequence_len']+1, content_stepping=1)
num_records = len(td)
print(f"{num_records} records")

In [None]:
def get_sample_batch(td, batch_size):
    # generate a small batch of data of inputs x and targets y
    # ix = torch.randint(len(data) - block_size, (batch_size,))
    # x = torch.stack([data[i : i + block_size] for i in ix])
    # y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    # x, y = x.to(device), y.to(device)
    # return x, y
    for i in range(batch_size):
        data = td.get_random_item()
        Xi = data[:-1]
        yi = data[1:]
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

In [None]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

In [None]:
def get_torch_batch(td, batch_size, device, split=None):
    x, y = get_sample_batch(td, batch_size)
    return torch.tensor(x, dtype=torch.long).to(device), torch.tensor(y, dtype=torch.long).to(device)

# get_torch_batch(td, 2, 'cpu')

## 5. Loss and training helpers

In [None]:
@torch.no_grad()
def estimate_loss(device):
    # XXX: this does take data for train and val from SAME pool!
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(params['test_iterations'])
        for k in range(params['test_iterations']):
            print(".", end="", flush=True)
            X, Y = get_torch_batch(td, params['batch_size'], device, split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    print("\r", end="", flush=True)
    mloss = (out['train']+out['val'])/2.0
    return mloss

def generate_sample(td, device, prompt=' ', toks=100, temperature=1.0, top_k=None, pad=False):
    # generate from the model
    # context = torch.zeros((1, 1), dtype=torch.long, device=device)
    model.eval()
    if pad is True:
        while len(prompt)<params['sequence_len']:
            if len(prompt)==params['sequence_len']-1:
                prompt = '\n' + prompt
            else:
                prompt = ' ' + prompt
    context = torch.tensor([td.encode(prompt)]).to(device)
    answer = model.generate(context, max_new_tokens=toks, temperature=temperature, top_k=top_k)
    txt = td.decode(answer[0].tolist())
    # Identify memorisation of text by highlighting verbatim quotes from sources
    # that are longer than 10 chars. HTML colorcoded output for source identification:
    td.source_highlight(txt, min_quote_size=10, dark_mode=False, display_ref_anchor=False)
    return txt
    # open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))

In [None]:
print("creating model...")
model = MultiHeadSelfAttention(params['vocab_size'], params['embedding_size'], 
                                   params['sequence_len'], params['dropout'], 
                                   params['heads'], params['mhsa_layers'], params['causal'], params['sigma_compressor'], device)
optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'])

if use_existing_model_from_checkpoint is True:
    params = MJ.load_checkpoint(params, model, optimizer, file_path=model_file_path, device=device) # torch.device("cpu"))
model = model.to(device)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

if device == 'cuda':
    print("Compiling...")
    model = torch.compile(model)
    print("Compile ok.")
    try:
        torch.set_float32_matmul_precision('high')
    except:
        print("Seems no tensor cores for that.")
if 'current_epoch' in params:
    ep = params['current_epoch']
else:
    ep=0
if 'current_loss' in params:
    ls = params['current_loss']
else:
    ls=0
    
if ep==0 and ls==0:
    start_iter = 0
else:
    start_iter = ep
    current_loss = ls
    
# print the number of parameters in the model
print(model)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")

In [None]:
# @torch.jit.script
# @torch.compile
def do_train_step(xb, yb):
    model.train()
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
dt0 = time.time()
print("training...")
gen_id = 0
iter_bench = 1
current_loss = estimate_loss(device)
inputs = ["what is the difference between good and evil? ", "How did everything come into existence? ", "What was at the beginning of time? ", "How are physics, quantum-mechanics and consciousness related? ", "How to attain complete self-awareness? ", "What is the nature of reality? ", "How be a good human being? "]
for iter in range(start_iter, params['max_iterations']):
    print(f"\rIteration: {iter+1:5d}/{((iter+1)//params['sample_every_n_iterations']+1)*params['sample_every_n_iterations']}/{params['max_iterations']}", end="", flush=True)
    # every once in a while evaluate the loss on train and val sets
    if (iter + 1) % params['sample_every_n_iterations'] == 0 or iter == params['max_iterations'] - 1:
        dt = time.time()
        print(f"\rloss eval", end="", flush=True)
        current_loss = estimate_loss(device)
        print(
            f"step {iter+1}: train loss {current_loss:.4f}, time {(dt-dt0)/iter_bench:.3f} sec/iter"
        )
        iter_bench = 1
        dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"Sample at {dt}:", flush=True)
        for temperature in [0.65, 0.75, 0.85]:
            print(f"--------temperature: {temperature} ---------")
            prompt = inputs[gen_id%len(inputs)]
            print(f"Prompt: {prompt}")
            generate_sample(td, device, prompt=prompt, toks=params['sample_size'], temperature=temperature, top_k=16)
        print("-------------------------------------------")
        gen_id += 1
        dt0 = time.time()
    # sample a batch of data
    xb, yb = get_torch_batch(td, params['batch_size'], device, "train")
    # evaluate the loss
    do_train_step(xb, yb)
    start_iter = iter
    iter_bench += 1
    if (iter+1)%params['save_every_n_iterations'] == 0:
        MJ.save_checkpoint(params, model, optimizer, iter, current_loss, file_path=model_file_path)
    

In [None]:
for t in [0.5, 1.5]:
    print(f"------Temperature {t}--------")
    generate_sample(td, device, prompt="How are consciousness and quantum mechanics related?", toks=150, temperature=t, top_k=16)