<a href="https://colab.research.google.com/github/domschl/torch-transformer-poet/blob/main/torch_transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Torch-Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
import sys
if 'google.colab' in sys.modules:
    # from: https://github.com/pytorch/pytorch/issues/107960  (libcuda not found)
    !export LC_ALL="en_US.UTF-8"
    !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
    !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
    !ldconfig /usr/lib64-nvidia

In [2]:
import logging
import os
import copy
import json
import time
import datetime
import random
import math
import numpy as np
from zoneinfo import ZoneInfo

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

In [3]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.Calibre_Dataset import Calibre_Dataset
from ml_indie_tools.Folder_Dataset import Folder_Dataset

import ml_indie_tools.pytorch_meta_tools as MJ
from ml_indie_tools.train_utils import TrainUtils

In [4]:
# Optional experimental event server to record and propagate training progress, not (yet) recommended!
# Functionality is ignored by default.
try:
    import indralib
    indra_avail = True
except Exception as e:
    indra_avail = False
if indra_avail is True:
    print("Indralib is available, trying to connect to Indrajala server for training progress reports...")

In [5]:
logging.basicConfig(level=logging.INFO)
log = logging.Logger("Main")
log.setLevel(logging.INFO)

## Preliminary

A pytorch deep multi-head attention model for text generation following Andrej Karpathy's [video-lecture-ng](https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py)

This code can use either CPU, GPU, or Apple Silicon. Google Colab is supported too, select the corresponding Colab runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [6]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='pt', accelerator='fastest')
ml_env.describe()

'OS: Linux, Python: 3.13.2, Jupyter Notebook Pytorch: 2.6.0+cu124, GPU: NVIDIA GeForce RTX 4070 Ti (/  285W |      18MiB), CPU'

## 1. Project configuration

In [7]:
# project_name = 'women_writers'
# project_name='research'
project_name='neo_philosophers'
model_cpu = None
model_name=f'tr_{project_name}_v3_pt'

use_preprocessed_data = True                      # Use already tokenized data
use_existing_model_from_checkpoint = False         # Try to load checkpoint of training
use_torch_compile = True                           # Requires a modern graphics card with torch compile backend support
skip_additional_texts = False                       # Don't look for other data sources in `additional_texts.json`

if 'google.colab' in sys.modules:  # Google colab notebooks run on server that provide UTC time, we adapt logs to local time:
    local_timezone = ZoneInfo('Europe/Berlin')
else:
    local_timezone = None

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : . (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : . (Changes to the file system happen only below this project path
Model path (snapshots)   : ./model/tr_neo_philosophers_v3_pt (Model weights and snapshots are stored here)
Data path (training data): ./data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  2.1 Text data from Project Gutenberg

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training,
encoding, batch generation, and formatted source display. It read some
books from Project Gutenberg and supports creation of training batches.
The output functions support highlighting to allow to compare generated
texts with the actual sources to help to identify identical (memorized)
parts.

In [8]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [9]:
token_file = os.path.join(data_path,f"{project_name}_tokens.json")
if use_preprocessed_data is True:
    if os.path.exists(token_file):
        td = Text_Dataset()
        td.load_tokenizer(token_file)
    else:
        use_preprocessed_data = False

INFO:Datasets:Loading tokenizer from ./data/neo_philosophers_tokens.json
INFO:Datasets:Loading tokenizer done.


In [10]:
if use_preprocessed_data is False:
    cache_dir = os.path.join(data_path, 'gutenberg_cache')
    gd = Gutenberg_Dataset(cache_dir=cache_dir)

    if project_name == 'women_writers':  # sample searches
        search_spec= {
            "author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"],
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
    elif project_name == 'neo_philosophers':
        search_spec = {
            "author": ["Immanuel Kant", "Friedrich Nietzsche", "Wilhelm Hegel", "Arthur Schopenhauer"],
            "language": ["english", "german"]
        }
        book_list=gd.search(search_spec)
        search_spec = {
            "author": ["Plato", "Platon"],
            "title": ["Timaeus", "Critias", "Symposium"],
            "language": ["english", "german"]
        }
        book_list+=gd.search(search_spec)
        search_spec = {
            "title": ["Buddh", "Sutra"],
            "language": ["english", "german"]
        }
        book_list+=gd.search(search_spec)
    else:
        search_spec = {}
        book_list = []

    book_cnt = len(book_list)
    print(f"{book_cnt} matching books found with search {search_spec}.")

    if book_cnt > 0:
        if book_cnt<80:
            # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts,
            # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
            book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)
        else:
            logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

        for i in range(len(book_list)):
            if 'author' not in book_list[i]:
                book_list[i]['author']='unknown'
            print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

        if project_name == 'women_writers':
            select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
            sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]
        else:
            sub_book_list = book_list

        print("Using:")
        for i in range(len(sub_book_list)):
            if 'author' not in sub_book_list[i]:
                sub_book_list[i]['author']='unknown'
            print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

        td = Text_Dataset(sub_book_list)
    else:
        td = Text_Dataset()()

## 2.2 Additional training material from folders or Calibre library

This looks for a file `additional_texts.json` in the `project_path` as shown above.

```json
{
  "local_texts": [["/some/directory/that/contains/texts", [".txt", ".md", ".org", ".py"]],
  "calibre": ["/home/myuser/Calibre Library", []]
}
```

If the folder(s) defined in `local_texts` contain text files with default endings `.txt`, `.md`, `.org`, or `.py` (can be configured), they are added to the training data. Folders are searched recursively.

If the path defined in `calibre` contains a Calibre database, all text files (`.txt` only) within that library are added to the training data. The list argument can contain search-specs (see ml-indie-tools, calibre_dataset) to qualify which books to import, e.g. [{"tags": ["philosophy"]}] would import all books that are tagged with "Philosophy" within calibre.

In [11]:
if use_preprocessed_data is False and skip_additional_texts is False:
    additional = os.path.join(project_path, "additional_texts.json")
    print(f"Looking for description of additional sources in {additional}")
    if os.path.exists(additional) is True:
        with open(additional, 'r') as f:
            add_desc = json.load(f)
            if 'local_texts' in add_desc:
                fd = Folder_Dataset()
                for text_path, qualifier in add_desc['local_texts']:
                    if not isinstance(qualifier, list):
                        qualifier = [qualifier]
                    print(f"Loading texts from {text_path} using extension restriction {qualifier}")
                    fd.load_index(text_path, use_aliases=False, min_file_size=2048, file_extensions=qualifier)
                td.load_texts(fd.records[:10000])
            if 'calibre' in add_desc:
                cal_path, specs = add_desc['calibre']
                if os.path.exists(cal_path):
                    print(f"Loading text from calibre at {cal_path}")
                    cd = Calibre_Dataset(cal_path)
                    cd.load_index(max_file_size=100000000) 
                    if specs is not None and len(specs)!=0:
                        ls = cd.search(specs)
                        td.load_texts(ls)
                    else:
                        td.load_texts(cd.records[:1000])

## 2.3 Tokenize data

In [12]:
if use_preprocessed_data is False:
    MAX_TOKENS = 32768  # This becomes vocab_size
    MAX_NGRAM_LEN = 5   # Max length of a token
    CHUNK_SIZE = 500000 # Split larger texts in chunks, if not None

    print("")
    print(f"Starting tokenizer with token length from 1..{MAX_NGRAM_LEN} with a max of {MAX_TOKENS} unique tokens,")
    print("this can take considerable time...")

    # Better tested NGRAM tokenizer:
    # td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS) 
    # or alternative 'BYTEGRAM' (more experimental, can encode arbitrary UTF-8)
    # td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.save_tokenizer(token_file)

In [13]:
tok_tests = ["Good morning, this is a simple test sentence for tokenization", "Guten Morgen, dies is ein einfach Testsatz zur Aufteilung in Satzbestandteile", "སེམས་ཉིད་ངལ་བསོ་རྒྱུད་"]
for test in tok_tests:
    enc = td.encode(test)
    dec = td.decode(enc)
    if dec != test:
        print(f"Tokenizer failed for: {test} -> {dec}")
    else:
        r = len(enc)/len(test)*100.0
        print(f"Tokenizer: {test}({len(test)}) -> {enc}({len(enc)}) OK, compressed size: {r:.2f}%")

Tokenizer: Good morning, this is a simple test sentence for tokenization(61) -> [5305, 111, 14443, 110, 1703, 116, 2759, 115, 15819, 109, 3184, 116, 25975, 101, 4283, 99, 1591, 32, 362, 107, 4337, 122, 444](23) OK, compressed size: 37.70%
Tokenizer: Guten Morgen, dies is ein einfach Testsatz zur Aufteilung in Satzbestandteile(77) -> [31731, 110, 16080, 103, 2087, 105, 11319, 32, 22115, 105, 16381, 99, 347, 84, 11107, 97, 15893, 122, 1043, 65, 1868, 116, 25161, 103, 22582, 97, 2730, 98, 6441, 110, 24565, 101, 3106](33) OK, compressed size: 42.86%
Tokenizer: སེམས་ཉིད་ངལ་བསོ་རྒྱུད་(22) -> [224, 189, 166, 224, 189, 186, 224, 189, 152, 224, 189, 166, 224, 188, 139, 224, 189, 137, 224, 189, 178, 224, 189, 145, 224, 188, 139, 224, 189, 132, 224, 189, 163, 224, 188, 139, 224, 189, 150, 224, 189, 166, 224, 189, 188, 224, 188, 139, 224, 189, 162, 224, 190, 146, 224, 190, 177, 224, 189, 180, 224, 189, 145, 224, 188, 139](66) OK, compressed size: 300.00%


## 3. Model metadata

In [14]:
params = None
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'current_loss',
                 'sample_every_n_iterations', 'sample_size', 'save_every_n_iterations', 'max_iterations']
attn_layers = 4
dims = 512
sequence_length = 192

params = { # Multi-head self-attention
        'meta_name_template': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

        'mhsa_layers': attn_layers,
        'heads': 16,
        'vocab_size': td.get_unique_token_count(),
        'sequence_len': sequence_length,
        'dropout': 0.1,
        'embedding_size': dims,
        'test_iterations': 100,  # number of epocs for loss estimation
        'use_recur': False,
        'share_weights': False,

        'batch_size': 48,      
        'learning_rate': 4e-4,   # None: Set in dependence of graphics hw
        'lr_schedule': True,
        'lr_min': 2e-5,
        'lr_max': 2e-4,
        'warmup': 4000,
        'decay': 100000,
        'grad_clip': 0.8,

        'sample_every_n_iterations': 1024,
        'sample_size': 192,
        'save_every_n_iterations': 4096,

        'max_iterations': 100000000  # maximum number of training iterations
    }

model_file_path = MJ.get_model_filename(model_path)
if use_existing_model_from_checkpoint is True:
    params = MJ.load_model_metadata_from_checkpoint(params, updatable_keys, model_file_path, device=device, log=log) # torch.device('cpu'))
if params == None or use_existing_model_from_checkpoint is False:
    use_existing_model_from_checkpoint = False
# print(params)

## 4. Batch handling

In [15]:
joint_training=0
td.init_getitem(sample_type='encoded', sample_length=params['sequence_len']+1+joint_training, content_stepping=1)
num_records = len(td)
print(f"{num_records} records")

14796641 records


In [16]:
def get_sample_sub_batch(sample_batch, batch_size, sub_index=0):
    joint_training=0
    for i in range(batch_size):
        Xi = sample_batch[sub_index:-1-joint_training+sub_index]
        yi = sample_batch[sub_index+1:]
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

def get_sample_batch(td, batch_size):
    sample_batch = td.get_random_item()
    return get_sample_sub_batch(sample_batch, batch_size)

In [17]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 308263


In [18]:
x, y = get_sample_batch(td, 2)
x.shape, y.shape

((2, 192), (2, 192))

In [19]:
sample_data = None

def get_torch_subbatch(td, batch_size, device, split=None, sub_index=0):
    global sample_data
    if sub_index==0:
        sample_data = td.get_random_item()
    x, y = get_sample_sub_batch(sample_data, batch_size, sub_index)
    tx = torch.tensor(x, dtype=torch.long).to(device)
    tx.requires_grad = False
    ty = torch.tensor(y, dtype=torch.long).to(device)
    ty.requires_grad = False
    return tx, ty

def get_torch_batch(td, batch_size, device, split=None):
    x, y = get_sample_batch(td, batch_size)
    tx = torch.tensor(x, dtype=torch.long).to(device)
    tx.requires_grad = False
    ty = torch.tensor(y, dtype=torch.long).to(device)
    ty.requires_grad = False
    return tx, ty

def get_zero_state(batch_size, sequence_len, hidden_size, device):
    zstate = torch.zeros(batch_size, sequence_len, hidden_size, device=device)
    zstate.requires_grad = False
    return zstate

## 5. Loss and training helpers

In [20]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, device=None):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model, device=device)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [21]:
class PositionalEncodingGrok(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Precompute positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [seq_len, batch_size, d_model]
        seq_len = x.size(0)
        pe = self.pe[:, :seq_len, :].expand(-1, x.size(1), -1)
        return x

In [22]:
# GROK-3 implementation of latent recurrence!

# Reusing previous block definitions with minor tweaks
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if attn_mask is None and x.size(0) > 1:
            seq_len = x.size(0)
            attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            attn_mask = attn_mask.to(x.device)  # [seq_len, seq_len], upper triangle = True (masked)
        
        attn_output, _ = self.self_attn(x, x, x, 
                                      attn_mask=attn_mask,
                                      key_padding_mask=key_padding_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_output))
        return x

class LatentRecurrentBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, num_steps=3, dropout=0.1, use_lstm=True):
        super(LatentRecurrentBlock, self).__init__()
        self.num_steps = num_steps
        self.use_lstm = use_lstm
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        if use_lstm is True:
            self.recurrent = nn.LSTM(  # Swap GRU for LSTM
                input_size=d_model,
                hidden_size=d_model,
                num_layers=1,
                batch_first=True,
                bidirectional=False
            )
        else:
            self.recurrent = nn.GRU(
                input_size=d_model,
                hidden_size=d_model,
                num_layers=1,
                batch_first=True,
                bidirectional=False
            )
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        attn_output, _ = self.self_attn(x, x, x, 
                                      attn_mask=attn_mask,
                                      key_padding_mask=key_padding_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        # residual = x
        batch_size = x.size(1)
        latent = x.transpose(0, 1).contiguous()  # [batch, seq_len, d_model]
        latent = latent.view(batch_size * x.size(0), 1, x.size(2))  # [batch*seq, 1, d_model]
        if self.use_lstm is False:
            h0 = torch.zeros(1, batch_size * x.size(0), x.size(2), device=x.device)
            for _ in range(self.num_steps):
                latent, h0 = self.recurrent(latent, h0)
            latent = latent.view(batch_size, x.size(0), x.size(2)).transpose(0, 1)
        else:
            h0 = torch.zeros(1, latent.size(0), x.size(2), device=x.device)
            c0 = torch.zeros(1, latent.size(0), x.size(2), device=x.device)  # Add cell state
            for _ in range(self.num_steps):
                latent, (h0, c0) = self.recurrent(latent, (h0, c0))  # LSTM outputs hidden + cell
            latent = latent.view(x.size(1), x.size(0), -1).transpose(0, 1)
            # latent = residual + latent
        ff_output = self.ff(latent)
        output = self.norm2(latent + self.dropout2(ff_output))
        return output

class LatentRecurrentDepthModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, sequence_len, d_ff, 
                 n1_prelude, n2_recurrent, n3_coda, num_steps=3, dropout=0.1, use_lstm=True):
        """
        Args:
            vocab_size (int): Size of the vocabulary (for embedding and projection).
            d_model (int): Transformer hidden size.
            nhead (int): Number of attention heads.
            d_ff (int): Feedforward hidden size.
            n1_prelude, n2_recurrent, n3_coda (int): Number of blocks per stage.
            num_steps (int): Recurrent steps per LRD block.
            dropout (float): Dropout rate.
        """
        super(LatentRecurrentDepthModel, self).__init__()

        self.sequence_len = sequence_len  # for generate
        self.d_model = d_model
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncodingGrok(d_model, max_len=sequence_len)
        # self.pos_encoding = PositionalEncoding(d_model) # , max_len=sequence_len)
        
        # Prelude blocks
        self.prelude = nn.ModuleList([
            TransformerBlock(d_model, nhead, d_ff, dropout)
            for _ in range(n1_prelude)
        ])
        
        # Latent Recurrent blocks
        if n2_recurrent > 0:
            self.recurrent = nn.ModuleList([
                LatentRecurrentBlock(d_model, nhead, d_ff, num_steps, dropout, use_lstm)
                for _ in range(n2_recurrent)
            ])
        else:
            self.recurrent = None
        
        # Coda blocks
        self.coda = nn.ModuleList([
            TransformerBlock(d_model, nhead, d_ff, dropout)
            for _ in range(n3_coda)
        ])
        
        # Final projection layer (e.g., to vocab size for generation)
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids, attn_mask=None, key_padding_mask=None):
        """
        Args:
            input_ids (torch.Tensor): Token IDs [batch_size, seq_len].
            attn_mask (torch.Tensor, optional): Attention mask [seq_len, seq_len].
            key_padding_mask (torch.Tensor, optional): Padding mask [batch_size, seq_len].
        Returns:
            torch.Tensor: Output logits [batch_size, seq_len, vocab_size].
        """
        # Embed input tokens
        x = self.embedding(input_ids) * math.sqrt(self.d_model) /2.0  # [batch_size, seq_len, d_model]
        # x = self.pos_encoding(x)
        x = x.transpose(0, 1)  # [seq_len, batch_size, d_model] for transformer
        x = self.pos_encoding(x)
        
        # Prelude: Entry to latent space
        for block in self.prelude:
            x = block(x, attn_mask, key_padding_mask)
        
        # Recurrent: Refine latents
        if self.recurrent is not None:
            for block in self.recurrent:
                x = block(x, attn_mask, key_padding_mask)
        
        # Coda: Exit from latent space
        for block in self.coda:
            x = block(x, attn_mask, key_padding_mask)
        
        # Project to output space
        x = x.transpose(0, 1)  # [batch_size, seq_len, d_model]
        output = self.proj(x)  # [batch_size, seq_len, vocab_size]
        return output

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate new tokens given a context

        Note: for apple MPS, top_k is limited max 16 vor older torchs! ((01/2023) implementation limitation)
        See: https://github.com/pytorch/pytorch/issues/78915
        Solved in: https://github.com/pytorch/pytorch/pull/94639 (03/2023)

        :param idx: the context (B,T) tensor of indices
        :param max_new_tokens: the maximum number of tokens to generate
        :param temperature: the temperature to use for sampling
        :param top_k: the number of top tokens to consider
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last sequence_len tokens
            idx_cond = idx[:, -self.sequence_len :]
            # print(idx_cond.shape)
            # get the predictions
            logits = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply temperature
            if temperature != 1.0 and temperature > 0.0:
                logits = logits / temperature
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

    def generate_from_prompt(self, model, tokenizer, prompt="The", max_len=50):
        model.eval()
        input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device)
        with torch.no_grad():
            for _ in range(max_len):
                logits = model(input_ids)
                next_token = torch.argmax(logits[:, -1, :], dim=-1)
                input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
        return tokenizer.decode(input_ids[0].tolist())

    def generate_with_beam(self, model, tokenizer, prompt="The", max_len=50, temperature=1.0, top_k=30, beam_width=3):
    # def generate(model, tokenizer, prompt="The", max_len=50, temperature=1.0, top_k=30, beam_width=3):
        """
        Beam search generation with static abort condition.
        
        Args:
            model: LatentRecurrentDepthModel
            tokenizer: Your custom/botok tokenizer (no [EOS])
            prompt (str): Starting text
            max_len (int): Max output length
            temperature (float): Softmax temperature
            top_k (int): Sample from top k tokens
            beam_width (int): Number of beams
        """
        model.eval()
        device = next(model.parameters()).device
        input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)  # [1, seq_len]
        beams = [(input_ids, 0.0)]  # (sequence, log_prob)
        
        with torch.no_grad():
            for step in range(max_len):
                new_beams = []
                for seq, score in beams:
                    # Forward pass
                    logits = model(seq)  # [1, seq_len, vocab_size]
                    next_logits = logits[0, -1, :] / temperature
                    
                    # Top-k sampling
                    top_k_logits, top_k_indices = torch.topk(next_logits, top_k)
                    probs = F.softmax(top_k_logits, dim=-1)
                    
                    # Sample beam_width candidates
                    next_tokens = torch.multinomial(probs, num_samples=beam_width)
                    for i in range(beam_width):
                        token_id = top_k_indices[next_tokens[i]].unsqueeze(0).unsqueeze(0)  # [1, 1]
                        log_prob = torch.log(probs[next_tokens[i]]).item()
                        new_seq = torch.cat([seq, token_id], dim=1)
                        # Repetition penalty
                        penalty = 1.0 if new_seq[0, -1].item() not in new_seq[0, -5:-1] else 0.9
                        new_beams.append((new_seq, score + log_prob * penalty))
                
                # Sort and prune beams
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
                
                # Static abort: all beams at max_len or repeating last 5 tokens
                all_max_len = all(len(seq[0]) >= max_len for seq, _ in beams)
                all_repeating = all(
                    len(seq[0]) > 5 and seq[0, -5:].tolist() == [seq[0, -1].item()] * 5 
                    for seq, _ in beams
                )
                if all_max_len or all_repeating:
                    break
        
        best_seq, _ = beams[0]
        return tokenizer.decode(best_seq[0].tolist())

In [23]:
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.GRU, nn.LSTM)):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param, gain=1.0)
            elif 'bias' in name:
                nn.init.zeros_(param)

In [24]:
print("creating model...")
try:
    # Colab + torch 2 -> lots of garbage.
    if model is not None:
        del model
except:
    pass

groks = True
special_init = True

if groks is False:
    model = MultiHeadSelfAttentionWithMemory(vocab_size=params['vocab_size'], embedding_size=params['embedding_size'],
                                           sequence_len=params['sequence_len'],
                                           num_heads=params['heads'], num_layers=params['mhsa_layers'], dropout=params['dropout'],
                                           use_recur=params['use_recur'], share_weights=params['share_weights'], device=device)
else:
    model = LatentRecurrentDepthModel(
        vocab_size=params['vocab_size'],
        d_model=params['embedding_size'], nhead=params['heads'], d_ff=params['embedding_size']*4,
        sequence_len=params['sequence_len'],
        n1_prelude=3, n2_recurrent=1, n3_coda=2, num_steps=3
    )
    if special_init is True:
        model.apply(init_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'])

model = model.to(device)
if use_existing_model_from_checkpoint is True:
    params_load = MJ.load_checkpoint(params, model, optimizer, file_path=model_file_path, updatable_keys=updatable_keys, device=device, log=log) # torch.device("cpu"))
    if params_load is not None:
        params = params_load
model = model.to(device)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

if use_torch_compile is True:
    if device == 'cuda':
        print("Compiling...")
        model = torch.compile(model)
        print("Compile ok.")
        try:
            torch.set_float32_matmul_precision('high')
        except:
            print("Seems no tensor cores for that.")
    # elif str(device) == 'mps':
    #     print("Compiling...")
    #     model = torch.compile(model)
    #     print("Compile ok.")

if 'current_epoch' in params:
    ep = params['current_epoch']
else:
    ep=0
if 'current_loss' in params:
    ls = params['current_loss']
else:
    ls=0

if ep==0 and ls==0:
    start_iter = 0
else:
    start_iter = ep
    current_loss = ls

# print the number of parameters in the model
print(model)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")

creating model...
Compiling...
Compile ok.
OptimizedModule(
  (_orig_mod): LatentRecurrentDepthModel(
    (embedding): Embedding(32768, 512)
    (pos_encoding): PositionalEncodingGrok()
    (prelude): ModuleList(
      (0-2): 3 x TransformerBlock(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (ff): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (recurrent): ModuleList(
      (0): LatentRecurrentBlock(
        (self_attn): MultiheadAttention(
 

In [25]:
        
class MultiHeadSelfAttentionWithMemory(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_size,
        sequence_len,
        num_heads,
        num_layers,
        dropout=0.1,
        use_recur=False,
        share_weights=True,
        device=None,
    ):
        super().__init__()
        if device is None:
            raise ValueError(
                "Device is None at MultiHeadSelfAttentionWithMemory"
            )
        self.device = device
        self.sequence_len = sequence_len
        self.use_recur = use_recur
        context_sub_layers = num_layers // 2
        self.context_sub_layers = context_sub_layers
        dims = embedding_size
        self.dims = dims
        # each token directly reads off the logits for the next token from a lookup table
        self.position_embedding_table = nn.Embedding(
            sequence_len, embedding_size, device=device
        )

        if share_weights is True:
            self.embedding = nn.Embedding(vocab_size, embedding_size)
            self.out_proj = nn.Linear(embedding_size, vocab_size, bias=False)
            self.embedding.weight = self.out_proj.weight  # torch.nn.Parameter(self.out_proj.weight.transpose(1,0))
            self.out_proj = self.out_proj.to(device)
            self.embdding = self.embedding.to(device)
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_size, device=device)
            self.out_proj = nn.Linear(embedding_size, vocab_size, device=device)


        if use_recur is True:
            self.rec=nn.RNN(dims, dims, num_layers=3, batch_first=True, device=device)
            self.lq=nn.Linear(dims, dims, device=device)
            self.lk=nn.Linear(dims, dims, device=device)
            self.lv=nn.Linear(dims, dims, device=device)
            self.lnorm=nn.BatchNorm1d(dims, device=device)
            self.sm = nn.Softmax(dim=2)
            self.proj1 = nn.Linear(dims, dims, device=device)
            self.sm2 = nn.Softmax(dim=2)
            encoder_layer = nn.TransformerEncoderLayer(d_model=dims, nhead=num_heads, dim_feedforward=dims*4, dropout=dropout, batch_first=True, device=device)
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=context_sub_layers)
            encoder_layer2 = nn.TransformerEncoderLayer(d_model=dims, nhead=num_heads, dim_feedforward=dims*4, dropout=dropout, batch_first=True, device=device)
            self.transformer2 = nn.TransformerEncoder(encoder_layer2, num_layers=num_layers - context_sub_layers)
        else:
            encoder_layer = nn.TransformerEncoderLayer(d_model=dims, nhead=num_heads, activation="gelu", dim_feedforward=dims*4, dropout=dropout, batch_first=True, device=device)
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.out_proj.bias.data.zero_()
        self.out_proj.weight.data.uniform_(-initrange, initrange)

    def forward(self, idx):
        B, D = idx.shape
        tok_emb = self.embedding(idx)
        # idx and targets are both (B,D) tensor of integers
        # tok_emb = self.token_embedding_table(idx)  # (B,D,C)

        # XXX: move to init, make not trainable:
        if self.device is None:
            pos_emb = self.position_embedding_table(torch.arange(self.sequence_len))
        else:
            pos_emb = self.position_embedding_table(
                torch.arange(D, device=self.device)
            )  # (D,C)

        x = tok_emb + pos_emb  # (B,D,C)
        
        # x = self.pos_encoder(x) 
        x_mask = nn.Transformer.generate_square_subsequent_mask(D).to(self.device)
        x = self.transformer(x, x_mask)

        if self.use_recur is True:
            skip = x
            x = self.rec(x)[0] + x
            xk = self.lk(x).permute((0,2,1))
            xv = self.lv(x)
            xq = self.lq(x)
            xqk = torch.matmul(xq, xk)
            sm = self.sm(xqk)/math.sqrt(D)
            att = torch.matmul(sm, xv)
            x = self.lnorm(att)
            x = self.sm2(self.proj1(x))
            x = x + skip    
            x = self.transformer2(x, x_mask)
            
        logits = self.out_proj(x)
        return logits

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate new tokens given a context

        Note: for apple MPS, top_k is limited max 16 vor older torchs! ((01/2023) implementation limitation)
        See: https://github.com/pytorch/pytorch/issues/78915
        Solved in: https://github.com/pytorch/pytorch/pull/94639 (03/2023)

        :param idx: the context (B,T) tensor of indices
        :param max_new_tokens: the maximum number of tokens to generate
        :param temperature: the temperature to use for sampling
        :param top_k: the number of top tokens to consider
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last sequence_len tokens
            idx_cond = idx[:, -self.sequence_len :]
            # print(idx_cond.shape)
            # get the predictions
            logits = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply temperature
            if temperature != 1.0 and temperature > 0.0:
                logits = logits / temperature
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

In [26]:
@torch.no_grad()
def estimate_loss(device):
    # XXX: this does take data for train and val from SAME pool!
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(params['test_iterations'])
        for k in range(params['test_iterations']):
            # if k % (params['test_iterations']/10 + 1) == 0:
            #     print(".", end="", flush=True)
            X, Y = get_torch_batch(td, params['batch_size'], device, split)
            logits = model(X)
            loss = get_loss(logits, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    print("\r", end="", flush=True)
    mloss = (out['train']+out['val'])/2.0
    return mloss

def generate_sample(td, device, prompt=' ', toks=100, state=None, temperature=1.0, top_k=None, pad=True, with_beam=True):
    if with_beam is True:
        txt = model.generate_with_beam(model,td,prompt,toks, temperature=temperature, top_k=top_k, beam_width=5)
    else:
        model.eval()
        if pad is True:
            while len(prompt)<params['sequence_len']*4:
                if len(prompt)==params['sequence_len']*4-1:
                    prompt = '\n' + prompt
                else:
                    prompt = ' ' + prompt
        context = torch.tensor([td.encode(prompt)]).to(device)
        answer = model.generate(context, max_new_tokens=toks, temperature=temperature, top_k=top_k)
        txt = td.decode(answer[0].tolist())
    # Identify memorisation of text by highlighting verbatim quotes from sources
    # that are longer than 10 chars. HTML colorcoded output for source identification:
    td.source_highlight(txt, min_quote_size=10, dark_mode=False, display_ref_anchor=False)
    model.train()
    return txt


In [27]:
# @torch.jit.script
# @torch.compile
criterion = nn.CrossEntropyLoss()

def get_loss(logits, yb):
    output_flat = logits.reshape(-1, params['vocab_size'])
    # output_flat = logits.view(-1, params['vocab_size'])
    # print(output_flat.shape)
    ybr = yb.reshape(-1)
    # print(ybr.shape)
    loss = criterion(output_flat, ybr)
    return loss
    
def do_train_step(xb, yb, device, state=None):
    model.train()
    logits = model(xb)
    loss = get_loss(logits, yb)
    
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), params['grad_clip']).cpu()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(), norm

In [28]:
def start_tu_session():
    if indra_avail is True:
        with open('indra_creds.json', 'r') as f:
            creds = json.load(f)
            tu = TrainUtils(indra_server_profile_name='default', username=creds['username'], password=creds['password'])
    else:
        tu = TrainUtils()
    tu.train_session_start(model_name=model_name, model_description="Torch-poet tests", model_version=1, model_params=params, indra_subdomain="torch_poet/first_tests/1", status_string_size=110)
    return tu

In [29]:
def lr_schedule(optim, n_iter, warmup, max_lr, decay, min_lr):
    if n_iter<warmup and warmup>0:
        lr = (n_iter+1)/warmup*max_lr
    elif n_iter<warmup+decay and decay>0:
        i = n_iter-warmup
        lr = (decay-i)/decay*(max_lr-min_lr)+min_lr
    else:
        lr = min_lr

    for g in optim.param_groups:
        g['lr'] = lr
    return lr


In [30]:
def train(train_utils):
    global start_iter
    dt0 = time.time()
    sdt = datetime.datetime.now(tz=local_timezone).strftime("%Y-%m-%d %H:%M:%S")
    print(f"training, start at {sdt}...")
    gen_id = 0
    last_print=0
    iter_bench = 1
    tu = train_utils
    lr = params['learning_rate']
    # current_loss = estimate_loss(device)
    inputs = ["What is the difference between good and evil? The difference ", "How did everything come into existence? The origin ", "What was at the beginning of time? Time itself ", "How are physics, quantum-mechanics and consciousness related? The relation between ", "How to attain complete self-awareness? Complete ", "What is the nature of reality? The nature ", "How be a good human being? A human "]
    for iter in range(start_iter, params['max_iterations']):
        # every once in a while evaluate the loss on train and val sets
        if (iter + 1) % params['sample_every_n_iterations'] == 0 or iter == params['max_iterations'] - 1:
            dt = time.time()
            print(f"\rloss eval", end="", flush=True)
            current_loss = estimate_loss(device)
            print(
                f"step {iter+1}: train loss {current_loss:.4f}, time {(dt-dt0)/iter_bench:.3f} sec/iter                       "
            )
            iter_bench = 1
            sdt = datetime.datetime.now(tz=local_timezone).strftime("%Y-%m-%d %H:%M:%S")
            print(f"Sample at {sdt}:", flush=True)
            for temperature in [0.75, 1.0, 1.1]:
                print(f"--------temperature: {temperature} ---------")
                prompt = inputs[gen_id%len(inputs)]
                # print(f"Prompt: {prompt}")
                # generate_sample(td=td, device=device, prompt=prompt, toks=params['sample_size'], temperature=temperature, top_k=10, with_beam=False)
                print(f"Prompt: {prompt}")
                generate_sample(td=td, device=device, prompt=prompt, toks=params['sample_size'], temperature=temperature, top_k=10, with_beam=True)
            print("-------------------------------------------")
            gen_id += 1
            dt0 = time.time()
    
        if params['lr_schedule'] is True:
            lr = lr_schedule(optimizer, iter, params['warmup'], params['lr_max'], params['decay'], params['lr_min'])
    
        xb, yb = get_torch_batch(td, params['batch_size'], device, "train")
        cur_loss, cur_norm = do_train_step(xb, yb, device=device)

        
        nt = time.time()
        if (nt-last_print)>1:
            rec = {
                'epoch': iter/num_batches,
                'batch': iter%params['sample_every_n_iterations'],
                'num_batches': params['sample_every_n_iterations'],
                'loss': cur_loss,
                'learning_rate': lr,
                'gradient_norm': cur_norm.item(),
            }
            status_string, record = train_utils.train_state(rec)
            print(status_string, end="\r")
            last_print=nt
        
        start_iter = iter
        iter_bench += 1
        if (iter+1)%params['save_every_n_iterations'] == 0:
            MJ.save_checkpoint(params, model, optimizer, iter, current_loss, file_path=model_file_path, log=log)

In [None]:
tu = start_tu_session()
try:
    train(train_utils = tu)
except KeyboardInterrupt:
    print(f"\nTraining interrupted.")
tu.train_session_end()

training, start at 2025-02-28 16:33:37...
step 1024: train loss 6.9965, time 0.145 sec/iter                       00051 grad_norm: 1.750 Sec/It: 0.143  
Sample at 2025-02-28 16:36:16:
--------temperature: 0.75 ---------
Prompt: What is the difference between good and evil? The difference 


--------temperature: 1.0 ---------
Prompt: What is the difference between good and evil? The difference 


--------temperature: 1.1 ---------
Prompt: What is the difference between good and evil? The difference 


-------------------------------------------
step 2048: train loss 6.4830, time 0.142 sec/iter                       00102 grad_norm: 2.428 Sec/It: 0.143  
Sample at 2025-02-28 16:39:04:
--------temperature: 0.75 ---------
Prompt: How did everything come into existence? The origin 


--------temperature: 1.0 ---------
Prompt: How did everything come into existence? The origin 


--------temperature: 1.1 ---------
Prompt: How did everything come into existence? The origin 


-------------------------------------------
step 3072: train loss 6.2291, time 0.142 sec/iter                       00153 grad_norm: 2.195 Sec/It: 0.143  
Sample at 2025-02-28 16:41:45:
--------temperature: 0.75 ---------
Prompt: What was at the beginning of time? Time itself 


--------temperature: 1.0 ---------
Prompt: What was at the beginning of time? Time itself 


--------temperature: 1.1 ---------
Prompt: What was at the beginning of time? Time itself 


-------------------------------------------
step 4096: train loss 5.9733, time 0.142 sec/iter                       00200 grad_norm: 4.342 Sec/It: 0.143  
Sample at 2025-02-28 16:44:27:
--------temperature: 0.75 ---------
Prompt: How are physics, quantum-mechanics and consciousness related? The relation between 


--------temperature: 1.0 ---------
Prompt: How are physics, quantum-mechanics and consciousness related? The relation between 


--------temperature: 1.1 ---------
Prompt: How are physics, quantum-mechanics and consciousness related? The relation between 


-------------------------------------------
step 5120: train loss 5.7695, time 0.143 sec/iter                       00198 grad_norm: 5.157 Sec/It: 0.143  
Sample at 2025-02-28 16:47:11:
--------temperature: 0.75 ---------
Prompt: How to attain complete self-awareness? Complete 


--------temperature: 1.0 ---------
Prompt: How to attain complete self-awareness? Complete 


--------temperature: 1.1 ---------
Prompt: How to attain complete self-awareness? Complete 


-------------------------------------------
step 6144: train loss 5.6477, time 0.142 sec/iter                       00196 grad_norm: 4.871 Sec/It: 0.143  
Sample at 2025-02-28 16:49:54:
--------temperature: 0.75 ---------
Prompt: What is the nature of reality? The nature 


--------temperature: 1.0 ---------
Prompt: What is the nature of reality? The nature 


--------temperature: 1.1 ---------
Prompt: What is the nature of reality? The nature 


-------------------------------------------
step 7168: train loss 5.5089, time 0.142 sec/iter                       00194 grad_norm: 4.024 Sec/It: 0.142  
Sample at 2025-02-28 16:52:36:
--------temperature: 0.75 ---------
Prompt: How be a good human being? A human 


--------temperature: 1.0 ---------
Prompt: How be a good human being? A human 


--------temperature: 1.1 ---------
Prompt: How be a good human being? A human 


-------------------------------------------
step 8192: train loss 5.4827, time 0.142 sec/iter                       00192 grad_norm: 4.934 Sec/It: 0.143  
Sample at 2025-02-28 16:55:29:
--------temperature: 0.75 ---------
Prompt: What is the difference between good and evil? The difference 


--------temperature: 1.0 ---------
Prompt: What is the difference between good and evil? The difference 


--------temperature: 1.1 ---------
Prompt: What is the difference between good and evil? The difference 


-------------------------------------------
step 9216: train loss 5.3657, time 0.143 sec/iter                       00191 grad_norm: 3.284 Sec/It: 0.143  
Sample at 2025-02-28 16:58:17:
--------temperature: 0.75 ---------
Prompt: How did everything come into existence? The origin 


--------temperature: 1.0 ---------
Prompt: How did everything come into existence? The origin 


--------temperature: 1.1 ---------
Prompt: How did everything come into existence? The origin 


-------------------------------------------
step 10240: train loss 5.2779, time 0.142 sec/iter                       0189 grad_norm: 3.103 Sec/It: 0.143  
Sample at 2025-02-28 17:01:05:
--------temperature: 0.75 ---------
Prompt: What was at the beginning of time? Time itself 


--------temperature: 1.0 ---------
Prompt: What was at the beginning of time? Time itself 


--------temperature: 1.1 ---------
Prompt: What was at the beginning of time? Time itself 


-------------------------------------------
step 11264: train loss 5.2029, time 0.142 sec/iter                       0187 grad_norm: 3.922 Sec/It: 0.142  
Sample at 2025-02-28 17:03:52:
--------temperature: 0.75 ---------
Prompt: How are physics, quantum-mechanics and consciousness related? The relation between 


--------temperature: 1.0 ---------
Prompt: How are physics, quantum-mechanics and consciousness related? The relation between 


--------temperature: 1.1 ---------
Prompt: How are physics, quantum-mechanics and consciousness related? The relation between 


-------------------------------------------
step 12288: train loss 5.1483, time 0.142 sec/iter                       0185 grad_norm: 4.009 Sec/It: 0.143  
Sample at 2025-02-28 17:06:34:
--------temperature: 0.75 ---------
Prompt: How to attain complete self-awareness? Complete 


--------temperature: 1.0 ---------
Prompt: How to attain complete self-awareness? Complete 


--------temperature: 1.1 ---------
Prompt: How to attain complete self-awareness? Complete 


-------------------------------------------
step 13312: train loss 5.1006, time 0.143 sec/iter                       0183 grad_norm: 3.168 Sec/It: 0.142  
Sample at 2025-02-28 17:09:22:
--------temperature: 0.75 ---------
Prompt: What is the nature of reality? The nature 


--------temperature: 1.0 ---------
Prompt: What is the nature of reality? The nature 


--------temperature: 1.1 ---------
Prompt: What is the nature of reality? The nature 


-------------------------------------------
Ep: 0.05 Bat: 759/1024       ⦊██████████████▉     ⦉ loss: 5.0482 lr: 0.000182 grad_norm: 5.996 Sec/It: 0.143  

In [None]:
# for t in [0.5, 1.5]:
#     print(f"------Temperature {t}--------")
#     generate_sample(td, device, prompt="How are consciousness and quantum mechanics related?", toks=150, temperature=t, top_k=16)

In [None]:
# Test-code below, unfinished.

In [None]:
texts = []
enc_texts = []
for i in range(500):
    e = td[i*50000][:256]
    tx = torch.tensor([e]).to(device)
    enc_texts.append(tx)
    texts.append(td.decode(e))


In [None]:
enc_texts[0].shape

In [None]:
emb_text = []
cont_text = []
for et in enc_texts:
    emb_text.append(model.embedding(et))
    cont_text.append(model.context(et))

In [None]:
emb_text[0].shape, cont_text[0].shape

In [None]:
emb_vec = []
cont_vec = []
for i in range(len(emb_text)):
    emb_vec.append(emb_text[i][0].sum(axis=0))
    cont_vec.append(cont_text[i][0].sum(axis=0))

In [None]:
def cos_vec(a, b):
    al = torch.sqrt(torch.dot(a,a))
    bl = torch.sqrt(torch.dot(b,b))
    an = a/al
    bn = b/bl
    return torch.dot(an,bn)

In [None]:
for i in range(0):
    best = -10.0
    ind = -1
    for j in range(len(emb_vec)):
        if i==j:
            continue
        cos_val = cos_vec(emb_vec[i], emb_vec[j])
        if cos_val > best:
            best = cos_val
            ind = j
    # print(f"{texts[i][:20]} ->{best}: {texts[ind][:20]}")
print()        
for i in range(200):
    best = 0
    ind = -1
    for j in range(len(emb_vec)):
        if i==j:
            continue
        cos_val = cos_vec(cont_vec[i], cont_vec[j])
        if cos_val > best:
            best = cos_val
            ind = j
    print(f"{texts[i]} \n\n->{best}:\n\n {texts[ind]}")
    print("---------------------------------------------")


In [None]:
td[200002]