# MLX-Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
!pip install -U ml-indie-tools



In [53]:
import logging
import os
import sys
import copy
import json
import time
import datetime
import random
import numpy as np
from zoneinfo import ZoneInfo

from functools import partial
import mlx
import mlx.core as mx
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.Calibre_Dataset import Calibre_Dataset
from ml_indie_tools.Folder_Dataset import Folder_Dataset

In [4]:
logging.basicConfig(level=logging.INFO)
log = logging.Logger("Main")
log.setLevel(logging.INFO)

## 0. Environment

In [5]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='mlx', accelerator='fastest')
ml_env.describe()

'OS: Darwin, Python: 3.12.2, Jupyter Notebook MLX: 0.8.0, GPU: MLX GPU (system memory)'

## 1. Project configuration

In [6]:
# project_name = 'women_writers'
model_cpu = None
project_name='women_writers'
model_name=f'ng_COMP_{project_name}_v2_pt'

use_preprocessed_data = True                     # Use already tokenized data
use_existing_model_from_checkpoint = False        # Try to load checkpoint of training

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : . (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : . (Changes to the file system happen only below this project path
Model path (snapshots)   : ./model/ng_COMP_women_writers_v2_pt (Model weights and snapshots are stored here)
Data path (training data): ./data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  2.1 Text data from Project Gutenberg

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training,
encoding, batch generation, and formatted source display. It read some
books from Project Gutenberg and supports creation of training batches.
The output functions support highlighting to allow to compare generated
texts with the actual sources to help to identify identical (memorized)
parts.

In [7]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [8]:
token_file = os.path.join(data_path,f"{project_name}_tokens.json")
if use_preprocessed_data is True:
    if os.path.exists(token_file):
        td = Text_Dataset()
        td.load_tokenizer(token_file)
    else:
        use_preprocessed_data = False

INFO:Datasets:Loading tokenizer from ./data/women_writers_tokens.json
INFO:Datasets:Loading tokenizer done.


In [9]:
if use_preprocessed_data is False:
    cache_dir = os.path.join(data_path, 'gutenberg_cache')
    gd = Gutenberg_Dataset(cache_dir=cache_dir)

    if project_name == 'women_writers':  # sample searches
        search_spec= {
            "author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"],
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
    elif project_name == 'neo_philosophers':
        search_spec = {
            "author": ["Immanuel Kant", "Friedrich Nietzsche", "Wilhelm Hegel"],
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
        search_spec = {
            "author": ["Plato"],
            "title": ["Timaeus", "Critias", "Symposium"],
            "language": ["english"]
        }
        book_list+=gd.search(search_spec)
    else:
        search_spec = {}
        book_list = []

    book_cnt = len(book_list)
    print(f"{book_cnt} matching books found with search {search_spec}.")

    if book_cnt > 0:
        if book_cnt<40:
            # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts,
            # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
            book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)
        else:
            logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

        for i in range(len(book_list)):
            print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

        if project_name == 'women_writers':
            select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
            sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]
        else:
            sub_book_list = book_list

        print("Using:")
        for i in range(len(sub_book_list)):
            print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

        td = Text_Dataset(sub_book_list)
    else:
        td = Text_Dataset()

## 2.2 Additional training material from folders or Calibre library

This looks for a file `additional_texts.json` in the `project_path` as shown above.

```json
{
  "local_texts": ["/some/directory/that/contains/texts"],
  "calibre": "/home/myuser/Calibre Library"
}
```

If the folder(s) defined in `local_texts` contain text files with default endings `.txt`, `.md`, `.org`, or `.py` (can be configured), they are added to the training data. Folders are searched recursively.

If the path defined in `calibre` contains a Calibre database, all text files (`.txt` only) within that library are added to the training data.

In [10]:
if use_preprocessed_data is False:
    additional = os.path.join(project_path, "additional_texts.json")
    print(f"Looking for description of additional sources in {additional}")
    if os.path.exists(additional) is True:
        with open(additional, 'r') as f:
            add_desc = json.load(f)
            if 'local_texts' in add_desc:
                fd = Folder_Dataset()
                for text_path in add_desc['local_texts']:
                    print(f"Loading texts from {text_path}")
                    fd.load_index(text_path, use_aliases=False, max_file_size=100000)
                td.load_texts(fd.records[:10000])
            if 'calibre' in add_desc:
                cal_path = add_desc['calibre']
                if os.path.exists(cal_path):
                    print(f"Loading text from calibre at {cal_path}")
                    cd = Calibre_Dataset(cal_path)
                    cd.load_index(max_file_size=100000000)
                    td.load_texts(cd.records[:1000])

## 2.3 Tokenize data

In [11]:
if use_preprocessed_data is False:
    MAX_TOKENS = 50000  # This becomes vocab_size
    MAX_NGRAM_LEN = 6   # Max length of a token
    CHUNK_SIZE = 1000000 # Split larger texts in chunks, if not None

    print("")
    print(f"Starting tokenizer with token length from 1..{MAX_NGRAM_LEN} with a max of {MAX_TOKENS} unique tokens,")
    print("this can take considerable time...")

    # td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS) # or 'bytegram'
    td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.save_tokenizer(token_file)

## 3. Model metadata

In [66]:
params = None
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'current_loss',
                 'sample_every_n_iterations', 'sample_size', 'save_every_n_iterations', 'max_iterations']
attn_layers = 12
embs = 1024

params = { # Multi-head self-attention
        'meta_name_template': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

        'layers': attn_layers,
        'heads': 16,
        # 'causal': True,  # Use causal self-attention
        # 'dropout': 0.1,
        'vocab_size': td.get_unique_token_count(),
        'sequence_len': 1024,
        'embedding_size': embs,
        'test_iterations': 100,  # number of epocs for loss estimation
        'checkpoint': True,  # MLX gradient checkpointing

        'joint_state_training': 0,
        'batch_size': 4,
        'learning_rate': 3e-4,
        'weight_decay': 1e-5,
        'lr_warmup': 200,  # num iterations for lr warmup

        'sample_every_n_iterations': 64,
        # 'sample_size': 150,
        # 'save_every_n_iterations': 1024,

        'max_iterations': 100000000  # maximum number of training iterations
    }

## 4. Batch handling

In [67]:
td.init_getitem(sample_type='encoded', sample_length=params['sequence_len']+1, content_stepping=1)
num_records = len(td)
print(f"{num_records} records")

1530568 records


In [68]:
def get_sample_sub_batch(sample_batch, batch_size, sub_index=0):
    for i in range(batch_size):
        Xi = sample_batch[sub_index:-1-params['joint_state_training']+sub_index]
        if params['joint_state_training']+sub_index == 0:
            yi = sample_batch[sub_index+1:]
        else:
            yi = sample_batch[sub_index+1:-params['joint_state_training']+sub_index]
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.uint32)
            smpy=np.array(yi, dtype=np.uint32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.uint32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.uint32)))
    return np.array(smpX, dtype=np.uint32), np.array(smpy, dtype=np.uint32)

def get_sample_batch(td, batch_size):
    sample_batch = td.get_random_item()
    return get_sample_sub_batch(sample_batch, batch_size)

In [69]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 382642


In [70]:
a,b = get_sample_batch(td, 1)
a.itemsize

4

In [71]:
# Adapted from: https://raw.githubusercontent.com/ml-explore/mlx-examples/main/transformer_lm/main.py
class TransformerLM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_layers: int,
        dims: int,
        num_heads: int,
        checkpoint: bool,
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, dims)
        self.pe = nn.SinusoidalPositionalEncoding(dims)
        self.transformer = nn.TransformerEncoder(
            num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint
        )
        self.out_proj = nn.Linear(dims, vocab_size)

    def __call__(self, x):
        L = x.shape[1]
        mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
        x = self.embedding(x)
        x = x + self.pe(mx.arange(L))
        x = self.transformer(x, mask)
        return self.out_proj(x)

In [72]:
def to_samples_old(context_size, dataset):
    tokens = dataset.size
    window_size = context_size + 1  # include target
    samples = tokens - window_size + 1
    X = np.lib.stride_tricks.as_strided(
        dataset,
        shape=(samples, window_size),
        strides=(dataset.itemsize, dataset.itemsize),
    )
    return X[:, :-1], X[:, 1:]

In [73]:
def iterate_batches_old(batch_size, context_size, dataset):
    inputs, targets = to_samples(context_size, dataset)
    s = 0
    while True:
        if s == 0:
            # Reset permutation:
            perm = np.random.permutation(inputs.shape[0])
        ids = perm[s : s + batch_size]
        yield inputs[ids], targets[ids]
        s += batch_size
        if s >= inputs.shape[0]:
            s = 0

In [74]:
def iterate_batches(batch_size, context_size, dataset):
    while True:
        x,y = get_sample_batch(dataset, batch_size)
        yield x,y

In [75]:
def train():
    batch_size = params['batch_size']
    context_size = params['sequence_len']
    steps_per_eval = params['test_iterations']
    steps_per_report = params['sample_every_n_iterations']

    # Load vocab and dataset:
    # vocab, train, valid, test = datasets.load_dataset(args.dataset)

    # Initialize model:
    model = TransformerLM(
        params['vocab_size'], params['layers'], params['embedding_size'], params['heads'], params['checkpoint']
    )
    mx.eval(model.parameters())
    nparams = sum(
        x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
    )
    print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")

    def loss_fn(model, x, y, reduce=True):
        logits = model(x)
        losses = nn.losses.cross_entropy(logits, y)
        return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))

    optimizer = optim.AdamW(
        learning_rate=params['learning_rate'], weight_decay=params['weight_decay']
    )

    # def eval_fn(dataset):
    #     inputs, targets = map(mx.array, to_samples(context_size, dataset))
    #     loss = 0
    #     for s in range(0, targets.shape[0], batch_size):
    #         bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
    #         bx, by = map(mx.array, (bx, by))
    #         losses = loss_fn(model, bx, by, reduce=False)
    #         loss += mx.sum(losses).item()
    #     return loss / len(targets)

    state = [model.state, optimizer.state]

    @partial(mx.compile, inputs=state, outputs=state)
    def step(inputs, targets):
        loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
        loss, grads = loss_and_grad_fn(model, inputs, targets)
        optimizer.update(model, grads)
        return loss

    train_iterator = iterate_batches(batch_size, context_size, td)
    losses = []
    tic = time.perf_counter()
    for it, (inputs, targets) in zip(range(params['max_iterations']), train_iterator):
        inputs, targets = map(mx.array, (inputs, targets))
        optimizer.learning_rate = min(1, it / params['lr_warmup']) * params['learning_rate']
        loss = step(inputs, targets)
        mx.eval(state)
        losses.append(loss.item())
        if (it + 1) % steps_per_report == 0:
            train_loss = np.mean(losses)
            toc = time.perf_counter()
            print(
                f"Iter {it + 1}: Train loss {train_loss:.3f}, "
                f"It/sec {steps_per_report / (toc - tic):.3f}"
            )
            losses = []
            tic = time.perf_counter()
        # if (it + 1) % steps_per_eval == 0:
        #     val_loss = eval_fn(valid)
        #     toc = time.perf_counter()
        #     print(
        #         f"Iter {it + 1}: "
        #         f"Val loss {val_loss:.3f}, "
        #         f"Val ppl {math.exp(val_loss):.3f}, "
        #         f"Val took {(toc - tic):.3f}s, "
        #     )
            tic = time.perf_counter()

    # if args.eval_test:
    #     test_loss = eval_fn(test)
    #     test_ppl = math.exp(test_loss)
    #     print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")

In [None]:
train()

Training a transformer with 192.983 M parameters
Iter 64: Train loss 8.629, It/sec 0.764
Iter 128: Train loss 8.032, It/sec 0.747
Iter 192: Train loss 7.921, It/sec 0.744
Iter 256: Train loss 7.777, It/sec 0.746
Iter 320: Train loss 7.534, It/sec 0.747
Iter 384: Train loss 7.383, It/sec 0.719
Iter 448: Train loss 7.267, It/sec 0.718
Iter 512: Train loss 7.233, It/sec 0.733
Iter 576: Train loss 7.122, It/sec 0.729
Iter 640: Train loss 7.102, It/sec 0.739
Iter 704: Train loss 7.008, It/sec 0.728
