<a href="https://colab.research.google.com/github/domschl/torch-transformer-poet/blob/main/torch_transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Torch-Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
import sys
if 'google.colab' in sys.modules:
    !pip install ml-indie-tools

In [2]:
import logging
import os
import json
import time
import datetime
import math
import numpy as np
from zoneinfo import ZoneInfo

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.Calibre_Dataset import Calibre_Dataset
from ml_indie_tools.Folder_Dataset import Folder_Dataset

import ml_indie_tools.pytorch_meta_tools as MJ
from ml_indie_tools.train_utils import TrainUtils

In [4]:
# Optional experimental event server to record and propagate training progress, not (yet) recommended!
# Functionality is ignored by default.
try:
    import indralib
    indra_avail = True
except Exception as e:
    indra_avail = False
if indra_avail is True:
    print("Indralib is available, trying to connect to Indrajala server for training progress reports...")

In [5]:
logging.basicConfig(level=logging.INFO)
log = logging.Logger("Main")
log.setLevel(logging.INFO)

## Preliminary

A pytorch deep multi-head attention model for text generation following Andrej Karpathy's [video-lecture-ng](https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py)

This code can use either CPU, GPU, or Apple Silicon. Google Colab is supported too, select the corresponding Colab runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [6]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='pt', accelerator='fastest')
ml_env.describe()

'OS: Darwin, Python: 3.13.2, Jupyter Notebook Pytorch: 2.6.0, GPU: MPS Metal accelerator (system memory)'

## 1. Project configuration

In [7]:
# project_name = 'women_writers'
# project_name='research'
project_name='neo_philosophers'
model_cpu = None
model_name=f'tr_{project_name}_v3_pt'

use_preprocessed_data = True                      # Use already tokenized data
use_existing_model_from_checkpoint = False         # Try to load checkpoint of training
use_torch_compile = True                           # Requires a modern graphics card with torch compile backend support
skip_additional_texts = False                       # Don't look for other data sources in `additional_texts.json`

if 'google.colab' in sys.modules:  # Google colab notebooks run on server that provide UTC time, we adapt logs to local time:
    local_timezone = ZoneInfo('Europe/Berlin')
else:
    local_timezone = None

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : . (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : . (Changes to the file system happen only below this project path
Model path (snapshots)   : ./model/tr_neo_philosophers_v3_pt (Model weights and snapshots are stored here)
Data path (training data): ./data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  2.1 Text data from Project Gutenberg

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training,
encoding, batch generation, and formatted source display. It read some
books from Project Gutenberg and supports creation of training batches.
The output functions support highlighting to allow to compare generated
texts with the actual sources to help to identify identical (memorized)
parts.

In [8]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [9]:
token_file = os.path.join(data_path,f"{project_name}_tokens.json")
if use_preprocessed_data is True:
    if os.path.exists(token_file):
        td = Text_Dataset()
        td.load_tokenizer(token_file)
    else:
        use_preprocessed_data = False

INFO:Datasets:Loading tokenizer from ./data/neo_philosophers_tokens.json
INFO:Datasets:Loading tokenizer done.


In [10]:
if use_preprocessed_data is False:
    cache_dir = os.path.join(data_path, 'gutenberg_cache')
    gd = Gutenberg_Dataset(cache_dir=cache_dir)

    if project_name == 'women_writers':  # sample searches
        search_spec= {
            "author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"],
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
    elif project_name == 'neo_philosophers':
        search_spec = {
            "author": ["Immanuel Kant", "Friedrich Nietzsche", "Wilhelm Hegel", "Arthur Schopenhauer"],
            "language": ["english", "german"]
        }
        book_list=gd.search(search_spec)
        search_spec = {
            "author": ["Plato", "Platon"],
            "title": ["Timaeus", "Critias", "Symposium"],
            "language": ["english", "german"]
        }
        book_list+=gd.search(search_spec)
        search_spec = {
            "title": ["Buddh", "Sutra"],
            "language": ["english", "german"]
        }
        book_list+=gd.search(search_spec)
    else:
        search_spec = {}
        book_list = []

    book_cnt = len(book_list)
    print(f"{book_cnt} matching books found with search {search_spec}.")

    if book_cnt > 0:
        if book_cnt<80:
            # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts,
            # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
            book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)
        else:
            logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

        for i in range(len(book_list)):
            if 'author' not in book_list[i]:
                book_list[i]['author']='unknown'
            print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

        if project_name == 'women_writers':
            select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
            sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]
        else:
            sub_book_list = book_list

        print("Using:")
        for i in range(len(sub_book_list)):
            if 'author' not in sub_book_list[i]:
                sub_book_list[i]['author']='unknown'
            print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

        td = Text_Dataset(sub_book_list)
    else:
        td = Text_Dataset()()

## 2.2 Additional training material from folders or Calibre library

This looks for a file `additional_texts.json` in the `project_path` as shown above.

```json
{
  "local_texts": [["/some/directory/that/contains/texts", [".txt", ".md", ".org", ".py"]]],
  "calibre": ["/home/myuser/Calibre Library", []]
}
```

If the folder(s) defined in `local_texts` contain text files with default endings `.txt`, `.md`, `.org`, or `.py` (can be configured), they are added to the training data. Folders are searched recursively.

If the path defined in `calibre` contains a Calibre database, all text files (`.txt` only) within that library are added to the training data. The list argument can contain search-specs (see ml-indie-tools, calibre_dataset) to qualify which books to import, e.g. [{"tags": ["philosophy"]}] would import all books that are tagged with "Philosophy" within calibre.

In [11]:
if use_preprocessed_data is False and skip_additional_texts is False:
    additional = os.path.join(project_path, "additional_texts.json")
    print(f"Looking for description of additional sources in {additional}")
    if os.path.exists(additional) is True:
        with open(additional, 'r') as f:
            add_desc = json.load(f)
            if 'local_texts' in add_desc:
                fd = Folder_Dataset()
                for text_path, qualifier in add_desc['local_texts']:
                    if not isinstance(qualifier, list):
                        qualifier = [qualifier]
                    print(f"Loading texts from {text_path} using extension restriction {qualifier}")
                    fd.load_index(text_path, use_aliases=False, min_file_size=2048, file_extensions=qualifier)
                td.load_texts(fd.records[:10000])
            if 'calibre' in add_desc:
                cal_path, specs = add_desc['calibre']
                if os.path.exists(cal_path):
                    print(f"Loading text from calibre at {cal_path}")
                    cd = Calibre_Dataset(cal_path)
                    cd.load_index(max_file_size=100000000)
                    if specs is not None and len(specs)!=0:
                        ls = cd.search(specs)
                        td.load_texts(ls)
                    else:
                        td.load_texts(cd.records[:1000])

## 2.3 Tokenize data

In [12]:
if use_preprocessed_data is False:
    MAX_TOKENS = 32768  # This becomes vocab_size
    MAX_NGRAM_LEN = 5   # Max length of a token
    CHUNK_SIZE = 500000 # Split larger texts in chunks, if not None

    print("")
    print(f"Starting tokenizer with token length from 1..{MAX_NGRAM_LEN} with a max of {MAX_TOKENS} unique tokens,")
    print("this can take considerable time...")

    # Better tested NGRAM tokenizer:
    # td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS)
    # or alternative 'BYTEGRAM' (more experimental, can encode arbitrary UTF-8)
    # td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.save_tokenizer(token_file)

In [13]:
tok_tests = ["Good morning, this is a simple test sentence for tokenization",
             "Guten Morgen, dies is ein einfach Testsatz zur Aufteilung in Satzbestandteile",
             "སེམས་ཉིད་ངལ་བསོ་རྒྱུད་",
             "སྟོང་ཉིད་སྙིང་རྗེའི་སྙིང་པོ་ཅན།"]
for test in tok_tests:
    enc = td.encode(test)
    dec = td.decode(enc)
    if dec != test:
        print(f"Tokenizer failed for: {test} -> {dec}")
    else:
        r = len(enc)/len(test)*100.0
        print(f"Tokenizer: {test}({len(test)}) -> {enc}({len(enc)}) OK, compressed size: {r:.2f}%")

Tokenizer: Good morning, this is a simple test sentence for tokenization(61) -> [5505, 111, 18314, 110, 1806, 116, 3012, 115, 8985, 109, 25713, 101, 20947, 110, 8513, 32, 1675, 111, 2243, 105, 6663, 110](22) OK, compressed size: 36.07%
Tokenizer: Guten Morgen, dies is ein einfach Testsatz zur Aufteilung in Satzbestandteile(77) -> [17441, 116, 31991, 114, 3903, 100, 15655, 115, 1682, 101, 24959, 99, 349, 84, 13338, 97, 15817, 122, 805, 65, 23813, 101, 484, 117, 2708, 32, 4708, 116, 122, 6077, 97, 32391, 101, 1898](34) OK, compressed size: 44.16%
Tokenizer: སེམས་ཉིད་ངལ་བསོ་རྒྱུད་(22) -> [11060, 186, 5962, 166, 497, 137, 4740, 145, 497, 132, 7015, 139, 4897, 166, 10124, 139, 8666, 146, 4577, 180, 5980, 139](22) OK, compressed size: 100.00%
Tokenizer: སྟོང་ཉིད་སྙིང་རྗེའི་སྙིང་པོ་ཅན།(31) -> [6416, 159, 3901, 132, 497, 137, 4740, 145, 497, 166, 3615, 153, 4740, 132, 497, 162, 3615, 151, 6103, 160, 3957, 139, 6416, 153, 4740, 132, 497, 148, 10124, 139, 21563, 147, 5954](33) OK, compressed siz

## 3. Model metadata

In [14]:
params = None
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'current_loss',
                 'sample_every_n_iterations', 'sample_size', 'save_every_n_iterations', 'max_iterations']
model_dimensions = 128
context_length = 64

params = { # Multi-head self-attention
        'meta_name_template': '{prelude_layers}-{recurrent_layers}/{recurrence_steps}-{coda_layers}x{heads}x{units}x{vocab_size}',

        'prelude_layers': 2,
        'recurrent_layers': 0,
        'coda_layers': 2,
        'recurrence_steps': 2,
        'heads': 8,
        'vocab_size': td.get_unique_token_count(),
        'context_length': context_length,
        'dropout': 0.1,
        'non_linearity': nn.Mish,  # Default nn.ReLU
        'model_dimensions': model_dimensions,
        'test_iterations': 100,  # number of epocs for loss estimation

        'batch_size': 256,
    
        'learning_rate': 4e-4,  # Only used, if lr_schedule is False
        'lr_schedule': True,
        'lr_min': 2e-4,
        'lr_max': 8e-4,
        'warmup': 2000,
        'decay': 50000,
    
        'grad_clip': 0.8,

        'sample_every_n_iterations': 4096,
        'sample_size': 192,
        'save_every_n_iterations': 4096,

        'max_iterations': 100000000  # maximum number of training iterations
    }

model_file_path = MJ.get_model_filename(model_path)
if use_existing_model_from_checkpoint is True:
    params = MJ.loamodel_dimensions_metadata_from_checkpoint(params, updatable_keys, model_file_path, device=device, log=log) # torch.device('cpu'))
if params == None or use_existing_model_from_checkpoint is False:
    use_existing_model_from_checkpoint = False
# print(params)

## 4. Batch handling

In [15]:
joint_training=0
td.init_getitem(sample_type='encoded', sample_length=params['context_length']+1+joint_training, content_stepping=64)
num_records = len(td)
print(f"{num_records} records")

4843433 records


In [16]:
def get_sample_sub_batch(sample_batch, batch_size, sub_index=0):
    joint_training=0
    for i in range(batch_size):
        Xi = sample_batch[sub_index:-1-joint_training+sub_index]
        yi = sample_batch[sub_index+1:]
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

def get_sample_batch(td, batch_size):
    sample_batch = td.get_random_item()
    return get_sample_sub_batch(sample_batch, batch_size)

In [17]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 18919


In [18]:
x, y = get_sample_batch(td, 2)
x.shape, y.shape

((2, 64), (2, 64))

In [19]:
sample_data = None

def get_torch_subbatch(td, batch_size, device, split=None, sub_index=0):
    global sample_data
    if sub_index==0:
        sample_data = td.get_random_item()
    x, y = get_sample_sub_batch(sample_data, batch_size, sub_index)
    tx = torch.tensor(x, dtype=torch.long).to(device)
    tx.requires_grad = False
    ty = torch.tensor(y, dtype=torch.long).to(device)
    ty.requires_grad = False
    return tx, ty

def get_torch_batch(td, batch_size, device, split=None):
    x, y = get_sample_batch(td, batch_size)
    tx = torch.tensor(x, dtype=torch.long).to(device)
    tx.requires_grad = False
    ty = torch.tensor(y, dtype=torch.long).to(device)
    ty.requires_grad = False
    return tx, ty

def get_zero_state(batch_size, context_length, hidden_size, device):
    zstate = torch.zeros(batch_size, context_length, hidden_size, device=device)
    zstate.requires_grad = False
    return zstate

## 5. Loss and training helpers

In [20]:
class PositionalEncoding(nn.Module):
    def __init__(self, model_dimensions, max_len=5000):
        super().__init__()
        # Precompute positional encodings
        pe = torch.zeros(max_len, model_dimensions)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dimensions, 2).float() * (-math.log(10000.0) / model_dimensions))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # [1, max_len, model_dimensions]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [seq_len, batch_size, model_dimensions]
        seq_len = x.size(0)
        pe = self.pe[:, :seq_len, :].expand(-1, x.size(1), -1)
        return x

In [21]:
class TransformerBlock(nn.Module):
    def __init__(self, model_dimensions, heads, projection_dimension, dropout=0.1, non_linearity=nn.ReLU):
        super(TransformerBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(model_dimensions, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(model_dimensions)
        self.dropout1 = nn.Dropout(dropout)
        self.ff = nn.Sequential(
            nn.Linear(model_dimensions, projection_dimension),
            non_linearity(),
            nn.Dropout(dropout),
            nn.Linear(projection_dimension, model_dimensions)
        )
        self.norm2 = nn.LayerNorm(model_dimensions)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if attn_mask is None and x.size(0) > 1:
            seq_len = x.size(0)
            attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            attn_mask = attn_mask.to(x.device)  # [seq_len, seq_len], upper triangle = True (masked)

        attn_output, _ = self.self_attn(x, x, x,
                                      attn_mask=attn_mask,
                                      key_padding_mask=key_padding_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_output))
        return x

class LatentRecurrentBlock(nn.Module):
    def __init__(self, model_dimensions, heads, projection_dimension, recurrence_steps=3, dropout=0.1, non_linearity=nn.ReLU):
        super(LatentRecurrentBlock, self).__init__()
        self.recurrence_steps = recurrence_steps
        self.self_attn = nn.MultiheadAttention(model_dimensions, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(model_dimensions)
        self.dropout1 = nn.Dropout(dropout)
        self.recurrent = nn.LSTM(  # Swap GRU for LSTM
            input_size=model_dimensions,
            hidden_size=model_dimensions,
            num_layers=1,
            batch_first=True,
            bidirectional=False
        )
        self.ff = nn.Sequential(
            nn.Linear(model_dimensions, projection_dimension),
            non_linearity(),
            nn.Dropout(dropout),
            nn.Linear(projection_dimension, model_dimensions)
        )
        self.norm2 = nn.LayerNorm(model_dimensions)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        attn_output, _ = self.self_attn(x, x, x,
                                      attn_mask=attn_mask,
                                      key_padding_mask=key_padding_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        residual = x
        batch_size = x.size(1)
        latent = x.transpose(0, 1).contiguous()  # [batch, seq_len, model_dimensions]
        latent = latent.view(batch_size * x.size(0), 1, x.size(2))  # [batch*seq, 1, model_dimensions]
        h0 = torch.zeros(1, latent.size(0), x.size(2), device=x.device)
        c0 = torch.zeros(1, latent.size(0), x.size(2), device=x.device)  # Add cell state
        for _ in range(self.recurrence_steps):
            latent, (h0, c0) = self.recurrent(latent, (h0, c0))  # LSTM outputs hidden + cell
        latent = latent.view(x.size(1), x.size(0), -1).transpose(0, 1)
        latent = residual + latent
        ff_output = self.ff(latent)
        output = self.norm2(latent + self.dropout2(ff_output))
        return output

class LatentRecurrentDepthModel(nn.Module):
    def __init__(self, vocab_size, model_dimensions, heads, context_length, projection_dimension,
                 n1_prelude, n2_recurrent, n3_coda, recurrence_steps=3, dropout=0.1, non_linearity=nn.ReLU):
        """
        Args:
            vocab_size (int): Size of the vocabulary (for embedding and projection).
            model_dimensions (int): Transformer hidden size.
            heads (int): Number of attention heads.
            projection_dimension (int): Feedforward hidden size.
            n1_prelude, n2_recurrent, n3_coda (int): Number of blocks per stage.
            recurrence_steps (int): Recurrent steps per LRD block.
            dropout (float): Dropout rate.
        """
        super(LatentRecurrentDepthModel, self).__init__()

        self.context_length = context_length  # for generate
        self.model_dimensions = model_dimensions

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, model_dimensions)
        self.pos_encoding = PositionalEncoding(model_dimensions, max_len=context_length)
        # self.pos_encoding = PositionalEncoding(model_dimensions) # , max_len=context_length)

        # Prelude blocks
        self.prelude = nn.ModuleList([
            TransformerBlock(model_dimensions, heads, projection_dimension, dropout, non_linearity)
            for _ in range(n1_prelude)
        ])

        # Latent Recurrent blocks
        if n2_recurrent > 0:
            self.recurrent = nn.ModuleList([
                LatentRecurrentBlock(model_dimensions, heads, projection_dimension, recurrence_steps, dropout, non_linearity)
                for _ in range(n2_recurrent)
            ])
        else:
            self.recurrent = None

        # Coda blocks
        self.coda = nn.ModuleList([
            TransformerBlock(model_dimensions, heads, projection_dimension, dropout, non_linearity)
            for _ in range(n3_coda)
        ])

        # Final projection layer (e.g., to vocab size for generation)
        self.proj = nn.Linear(model_dimensions, vocab_size)

    def forward(self, input_ids, attn_mask=None, key_padding_mask=None):
        """
        Args:
            input_ids (torch.Tensor): Token IDs [batch_size, seq_len].
            attn_mask (torch.Tensor, optional): Attention mask [seq_len, seq_len].
            key_padding_mask (torch.Tensor, optional): Padding mask [batch_size, seq_len].
        Returns:
            torch.Tensor: Output logits [batch_size, seq_len, vocab_size].
        """
        # Embed input tokens
        x = self.embedding(input_ids) * math.sqrt(self.model_dimensions) # /2.0  # [batch_size, seq_len, model_dimensions]
        # x = self.pos_encoding(x)
        x = x.transpose(0, 1)  # [seq_len, batch_size, model_dimensions] for transformer
        x = self.pos_encoding(x)

        # Prelude: Entry to latent space
        for block in self.prelude:
            x = block(x, attn_mask, key_padding_mask)

        # Recurrent: Refine latents
        if self.recurrent is not None:
            for block in self.recurrent:
                x = block(x, attn_mask, key_padding_mask)

        # Coda: Exit from latent space
        for block in self.coda:
            x = block(x, attn_mask, key_padding_mask)

        # Project to output space
        x = x.transpose(0, 1)  # [batch_size, seq_len, model_dimensions]
        output = self.proj(x)  # [batch_size, seq_len, vocab_size]
        return output

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate new tokens given a context

        Note: for apple MPS, top_k is limited max 16 vor older torchs! ((01/2023) implementation limitation)
        See: https://github.com/pytorch/pytorch/issues/78915
        Solved in: https://github.com/pytorch/pytorch/pull/94639 (03/2023)

        :param idx: the context (B,T) tensor of indices
        :param max_new_tokens: the maximum number of tokens to generate
        :param temperature: the temperature to use for sampling
        :param top_k: the number of top tokens to consider
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last context_length tokens
            idx_cond = idx[:, -self.context_length :]
            # print(idx_cond.shape)
            # get the predictions
            logits = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply temperature
            if temperature != 1.0 and temperature > 0.0:
                logits = logits / temperature
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

    def generate_with_beam(self, model, tokenizer, prompt="The", max_len=50, temperature=1.0, top_k=30, beam_width=3):
    # def generate(model, tokenizer, prompt="The", max_len=50, temperature=1.0, top_k=30, beam_width=3):
        """
        Beam search generation with static abort condition.

        Args:
            model: LatentRecurrentDepthModel
            tokenizer: Your custom/botok tokenizer (no [EOS])
            prompt (str): Starting text
            max_len (int): Max output length
            temperature (float): Softmax temperature
            top_k (int): Sample from top k tokens
            beam_width (int): Number of beams
        """
        model.eval()
        device = next(model.parameters()).device
        input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)  # [1, seq_len]
        beams = [(input_ids, 0.0)]  # (sequence, log_prob)

        with torch.no_grad():
            for step in range(max_len):
                new_beams = []
                for seq, score in beams:
                    # Forward pass
                    logits = model(seq)  # [1, seq_len, vocab_size]
                    next_logits = logits[0, -1, :] / temperature

                    # Top-k sampling
                    top_k_logits, top_k_indices = torch.topk(next_logits, top_k)

                    # Repetition penality
                    for i, token in enumerate(seq[0][-5:]):
                        penalty = 1.0 + 0.2 * i
                        top_k_logits[top_k_indices == token] /= penalty

                    probs = F.softmax(top_k_logits, dim=-1)

                    # Sample beam_width candidates
                    next_tokens = torch.multinomial(probs, num_samples=beam_width)
                    for i in range(beam_width):
                        token_id = top_k_indices[next_tokens[i]].unsqueeze(0).unsqueeze(0)  # [1, 1]
                        log_prob = torch.log(probs[next_tokens[i]]).item()
                        new_seq = torch.cat([seq, token_id], dim=1)
                        # Repetition penalty
                        # penalty = 1.0 if new_seq[0, -1].item() not in new_seq[0, -5:-1] else 0.9
                        new_beams.append((new_seq, score + log_prob * penalty))

                # Sort and prune beams
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

                # Static abort: all beams at max_len or repeating last 5 tokens
                # all_max_len = all(len(seq[0]) >= max_len for seq, _ in beams)
                # all_repeating = all(
                #     len(seq[0]) > 5 and seq[0, -5:].tolist() == [seq[0, -1].item()] * 5
                #     for seq, _ in beams
                # )
                # if all_max_len or all_repeating:
                #     break
                if all(len(seq[0]) >= max_len for seq, _ in beams):
                    break

        best_seq, _ = beams[0]
        return tokenizer.decode(best_seq[0].tolist())

In [22]:
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.GRU, nn.LSTM)):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param, gain=1.0)
            elif 'bias' in name:
                nn.init.zeros_(param)

In [23]:
print("creating model...")
try:
    # Colab + torch 2 -> lots of garbage.
    if model is not None:
        del model
except:
    pass

model = LatentRecurrentDepthModel(
    vocab_size=params['vocab_size'],
    model_dimensions=params['model_dimensions'], heads=params['heads'], projection_dimension=params['model_dimensions']*4,
    context_length=params['context_length'],
    n1_prelude=params['prelude_layers'], n2_recurrent=params['recurrent_layers'], n3_coda=params['coda_layers'], 
    recurrence_steps=params['recurrence_steps'], dropout=params['dropout'], non_linearity=params['non_linearity']
)
model.apply(init_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'])

model = model.to(device)
if use_existing_model_from_checkpoint is True:
    params_load = MJ.load_checkpoint(params, model, optimizer, file_path=model_file_path, updatable_keys=updatable_keys, device=device, log=log) # torch.device("cpu"))
    if params_load is not None:
        params = params_load
model = model.to(device)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

if use_torch_compile is True:
    if device == 'cuda':
        print("Compiling...")
        model = torch.compile(model)
        print("Compile ok.")
        try:
            torch.set_float32_matmul_precision('high')
        except:
            print("Seems no tensor cores for that.")
    # elif str(device) == 'mps':
    #     print("Compiling...")
    #     model = torch.compile(model)
    #     print("Compile ok.")

if 'current_epoch' in params:
    ep = params['current_epoch']
else:
    ep=0
if 'current_loss' in params:
    ls = params['current_loss']
else:
    ls=0

if ep==0 and ls==0:
    start_iter = 0
else:
    start_iter = ep
    current_loss = ls

# print the number of parameters in the model
print(model)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")

creating model...
LatentRecurrentDepthModel(
  (embedding): Embedding(32768, 128)
  (pos_encoding): PositionalEncoding()
  (prelude): ModuleList(
    (0-1): 2 x TransformerBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): Mish()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=512, out_features=128, bias=True)
      )
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (coda): ModuleList(
    (0-1): 2 x TransformerBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
     

In [24]:
@torch.no_grad()
def estimate_loss(device):
    # XXX: this does take data for train and val from SAME pool!
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(params['test_iterations'])
        for k in range(params['test_iterations']):
            # if k % (params['test_iterations']/10 + 1) == 0:
            #     print(".", end="", flush=True)
            X, Y = get_torch_batch(td, params['batch_size'], device, split)
            logits = model(X)
            loss = get_loss(logits, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    print("\r", end="", flush=True)
    mloss = (out['train']+out['val'])/2.0
    return mloss

def generate_sample(td, device, prompt=' ', toks=100, state=None, temperature=1.0, top_k=None, pad=True, with_beam=True):
    if with_beam is True:
        txt = model.generate_with_beam(model,td,prompt,toks, temperature=temperature, top_k=top_k, beam_width=7)
    else:
        model.eval()
        if pad is True:
            while len(prompt)<params['context_length']*4:
                if len(prompt)==params['context_length']*4-1:
                    prompt = '\n' + prompt
                else:
                    prompt = ' ' + prompt
        context = torch.tensor([td.encode(prompt)]).to(device)
        answer = model.generate(context, max_new_tokens=toks, temperature=temperature, top_k=top_k)
        txt = td.decode(answer[0].tolist())
    # Identify memorisation of text by highlighting verbatim quotes from sources
    # that are longer than 10 chars. HTML colorcoded output for source identification:
    td.source_highlight(txt, min_quote_size=10, dark_mode=False, display_ref_anchor=False)
    model.train()
    return txt


In [25]:
# @torch.jit.script
# @torch.compile
criterion = nn.CrossEntropyLoss()

def get_loss(logits, yb):
    output_flat = logits.reshape(-1, params['vocab_size'])
    # output_flat = logits.view(-1, params['vocab_size'])
    # print(output_flat.shape)
    ybr = yb.reshape(-1)
    # print(ybr.shape)
    loss = criterion(output_flat, ybr)
    return loss

def do_train_step(xb, yb, device, state=None):
    model.train()
    logits = model(xb)
    loss = get_loss(logits, yb)

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), params['grad_clip']).cpu()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(), norm

In [26]:
def start_tu_session():
    if indra_avail is True:
        with open('indra_creds.json', 'r') as f:
            creds = json.load(f)
            tu = TrainUtils(indra_server_profile_name='default', username=creds['username'], password=creds['password'])
    else:
        tu = TrainUtils()
    tu.train_session_start(model_name=model_name, model_description="Torch-poet tests", model_version=1, model_params=params, indra_subdomain="torch_poet/first_tests/1", status_string_size=110)
    return tu

In [27]:
def lr_schedule(optim, n_iter, warmup, max_lr, decay, min_lr):
    if n_iter<warmup and warmup>0:
        lr = (n_iter+1)/warmup*max_lr
    elif n_iter<warmup+decay and decay>0:
        i = n_iter-warmup
        lr = (decay-i)/decay*(max_lr-min_lr)+min_lr
    else:
        lr = min_lr

    for g in optim.param_groups:
        g['lr'] = lr
    return lr


In [28]:
def train(train_utils):
    global start_iter
    dt0 = time.time()
    sdt = datetime.datetime.now(tz=local_timezone).strftime("%Y-%m-%d %H:%M:%S")
    print(f"training, start at {sdt}...")
    gen_id = 0
    last_print=0
    iter_bench = 1
    tu = train_utils
    lr = params['learning_rate']
    # current_loss = estimate_loss(device)
    inputs = ["What is the difference between good and evil? The difference ", "How did everything come into existence? The origin ", "What was at the beginning of time? Time itself ", "How are physics, quantum-mechanics and consciousness related? The relation between ", "How to attain complete self-awareness? Complete ", "What is the nature of reality? The nature ", "How be a good human being? A human "]
    for iter in range(start_iter, params['max_iterations']):
        # every once in a while evaluate the loss on train and val sets
        if (iter + 1) % params['sample_every_n_iterations'] == 0 or iter == params['max_iterations'] - 1:
            dt = time.time()
            print(f"\rloss eval", end="", flush=True)
            current_loss = estimate_loss(device)
            print(
                f"step {iter+1}: train loss {current_loss:.4f}, time {(dt-dt0)/iter_bench:.3f} sec/iter                       "
            )
            iter_bench = 1
            sdt = datetime.datetime.now(tz=local_timezone).strftime("%Y-%m-%d %H:%M:%S")
            print(f"Sample at {sdt}:", flush=True)
            for temperature in [1.0]: # 0.75, 1.1, 1.3, 1.5]:
                print(f"--------temperature: {temperature} ---------")
                prompt = inputs[gen_id%len(inputs)]
                print(f"Prompt: {prompt}")
                generate_sample(td=td, device=device, prompt=prompt, toks=params['sample_size'], temperature=temperature, top_k=10, with_beam=False)
                # print(f"Prompt: {prompt}")
                # generate_sample(td=td, device=device, prompt=prompt, toks=params['sample_size'], temperature=temperature, top_k=10, with_beam=True)
            print("-------------------------------------------")
            gen_id += 1
            dt0 = time.time()

        if params['lr_schedule'] is True:
            lr = lr_schedule(optimizer, iter, params['warmup'], params['lr_max'], params['decay'], params['lr_min'])

        xb, yb = get_torch_batch(td, params['batch_size'], device, "train")
        cur_loss, cur_norm = do_train_step(xb, yb, device=device)


        nt = time.time()
        if (nt-last_print)>1:
            rec = {
                'epoch': iter/num_batches,
                'batch': iter%params['sample_every_n_iterations'],
                'num_batches': params['sample_every_n_iterations'],
                'loss': cur_loss,
                'learning_rate': lr,
                'gradient_norm': cur_norm.item(),
            }
            status_string, record = train_utils.train_state(rec)
            print(status_string, end="\r")
            last_print=nt

        start_iter = iter
        iter_bench += 1
        if (iter+1)%params['save_every_n_iterations'] == 0:
            MJ.save_checkpoint(params, model, optimizer, iter, current_loss, file_path=model_file_path, log=log)

In [None]:
tu = start_tu_session()
try:
    train(train_utils = tu)
except KeyboardInterrupt:
    print(f"\nTraining interrupted.")
tu.train_session_end()

training, start at 2025-03-02 12:20:10...
Ep: 0.04 Bat: 706/4096       ⦊███▌                ⦉ loss: 7.2529 lr: 0.000283 grad_norm: 2.352 Sec/It: 0.235  