<a href="https://colab.research.google.com/github/domschl/torch-transformer-poet/blob/main/torch_transformer_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Torch-Transformer-Poet

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [1]:
!pip install -U ml-indie-tools



In [2]:
import sys
if 'google.colab' in sys.modules:
    # from: https://github.com/pytorch/pytorch/issues/107960  (libcuda not found)
    !export LC_ALL="en_US.UTF-8"
    !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
    !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
    !ldconfig /usr/lib64-nvidia

In [3]:
import logging
import os
import copy
import json
import time
import datetime
import random
import math
import numpy as np
from zoneinfo import ZoneInfo

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

In [4]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

from ml_indie_tools.Calibre_Dataset import Calibre_Dataset
from ml_indie_tools.Folder_Dataset import Folder_Dataset

import ml_indie_tools.pytorch_meta_tools as MJ

In [5]:
logging.basicConfig(level=logging.INFO)
log = logging.Logger("Main")
log.setLevel(logging.INFO)

## Preliminary

A pytorch deep multi-head attention model for text generation following Andrej Karpathy's [video-lecture-ng](https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py)

This code can use either CPU, GPU, or Apple Silicon. Google Colab is supported too, select the corresponding Colab runtime (menu: **`Runtime / Change runtime type`**)

## 0. Environment

In [6]:
cached_batch_data = None   # Do regenerate time-consuming training data, if aleady cached.

ml_env = MLEnv(platform='pt', accelerator='fastest')
ml_env.describe()

'OS: Linux, Python: 3.12.3, Jupyter Notebook Pytorch: 2.4.0.dev20240501+cu121, GPU: NVIDIA GeForce RTX 4070 Ti (/  285W |       4MiB), CPU'

## 1. Project configuration

In [7]:
# project_name = 'women_writers'
# project_name='research'
project_name='neo_philosophers'
model_cpu = None
model_name=f'tr_{project_name}_v3_pt'

use_preprocessed_data = True                      # Use already tokenized data
use_existing_model_from_checkpoint = False         # Try to load checkpoint of training
use_torch_compile = False                           # Requires a modern graphics card with torch compile backend support
skip_additional_texts = True                       # Don't look for other data sources in `additional_texts.json`

if 'google.colab' in sys.modules:  # Google colab notebooks run on server that provide UTC time, we adapt logs to local time:
    local_timezone = ZoneInfo('Europe/Berlin')
else:
    local_timezone = None

# NOTICE: This will request access to Google Drive, if running on Google Colab. Google Drive is used to store snapshots
# training data. See project ml-indie-tools: https://github.com/domschl/ml-indie-tools
#
# Note: you need to allow popups in your browser for COLAB, otherwise you won't see the google-drive login box, and drive access will fail!

root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(f"Root path (all projects) : {root_path} (This will be '.' (current dir) for local projects, and a google drive path for Colab)")
print(f"Project path             : {project_path} (Changes to the file system happen only below this project path")
print(f"Model path (snapshots)   : {model_path} (Model weights and snapshots are stored here)")
print(f"Data path (training data): {data_path} (Training data will be downloaded here)")
print(f"Log dir (tensorboard)    : {log_path} (it doesn't work to put logs on gdrive due to caching, hence local dir)")

Root path (all projects) : . (This will be '.' (current dir) for local projects, and a google drive path for Colab)
Project path             : . (Changes to the file system happen only below this project path
Model path (snapshots)   : ./model/tr_neo_philosophers_v3_pt (Model weights and snapshots are stored here)
Data path (training data): ./data (Training data will be downloaded here)
Log dir (tensorboard)    : ./logs (it doesn't work to put logs on gdrive due to caching, hence local dir)


##  2.1 Text data from Project Gutenberg

`Text_Dataset` and `Gutenberg_Dataset` classes: libraries for training,
encoding, batch generation, and formatted source display. It read some
books from Project Gutenberg and supports creation of training batches.
The output functions support highlighting to allow to compare generated
texts with the actual sources to help to identify identical (memorized)
parts.

In [8]:
use_dark_mode=False # Set to false for white background. HTML-text-compare uses background-colorization to identify different sources. Those background colors are dependent on the theme type.

In [9]:
token_file = os.path.join(data_path,f"{project_name}_tokens.json")
if use_preprocessed_data is True:
    if os.path.exists(token_file):
        td = Text_Dataset()
        td.load_tokenizer(token_file)
    else:
        use_preprocessed_data = False

INFO:Datasets:Loading tokenizer from ./data/neo_philosophers_tokens.json
INFO:Datasets:Loading tokenizer done.


In [10]:
if use_preprocessed_data is False:
    cache_dir = os.path.join(data_path, 'gutenberg_cache')
    gd = Gutenberg_Dataset(cache_dir=cache_dir)

    if project_name == 'women_writers':  # sample searches
        search_spec= {
            "author": ["Emily Brontë", "Jane Austen", "Virginia Woolf"],
            "language": ["english"]
        }
        book_list=gd.search(search_spec)
    elif project_name == 'neo_philosophers':
        search_spec = {
            "author": ["Immanuel Kant", "Friedrich Nietzsche", "Wilhelm Hegel", "Arthur Schopenhauer"],
            "language": ["english", "german"]
        }
        book_list=gd.search(search_spec)
        search_spec = {
            "author": ["Plato", "Platon"],
            "title": ["Timaeus", "Critias", "Symposium"],
            "language": ["english", "german"]
        }
        book_list+=gd.search(search_spec)
        search_spec = {
            "title": ["Buddh", "Sutra"],
            "language": ["english", "german"]
        }
        book_list+=gd.search(search_spec)
    else:
        search_spec = {}
        book_list = []

    book_cnt = len(book_list)
    print(f"{book_cnt} matching books found with search {search_spec}.")

    if book_cnt > 0:
        if book_cnt<80:
            # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts,
            # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
            book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)
        else:
            logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

        for i in range(len(book_list)):
            if 'author' not in book_list[i]:
                book_list[i]['author']='unknown'
            print(f"{i}: {book_list[i]['title']} - {book_list[i]['author']}, {book_list[i]['ebook_id']}")

        if project_name == 'women_writers':
            select = ("Bennett", "1342", "5670", "1245", "161", "141", "121", "105", "Susan", "Wuthering", "Emma", "Voyage")  # List unique single-words from title or ebook_id to select a given book
            sub_book_list = [book_list[i] for i in range(len(book_list)) if not set([book_list[i]['ebook_id']]+book_list[i]['title'].split(' ')).isdisjoint(set(select))]
        else:
            sub_book_list = book_list

        print("Using:")
        for i in range(len(sub_book_list)):
            if 'author' not in sub_book_list[i]:
                sub_book_list[i]['author']='unknown'
            print(f"{i+1}: {sub_book_list[i]['title']} - {sub_book_list[i]['author']}")

        td = Text_Dataset(sub_book_list)
    else:
        td = Text_Dataset()()

## 2.2 Additional training material from folders or Calibre library

This looks for a file `additional_texts.json` in the `project_path` as shown above.

```json
{
  "local_texts": ["/some/directory/that/contains/texts"],
  "calibre": "/home/myuser/Calibre Library"
}
```

If the folder(s) defined in `local_texts` contain text files with default endings `.txt`, `.md`, `.org`, or `.py` (can be configured), they are added to the training data. Folders are searched recursively.

If the path defined in `calibre` contains a Calibre database, all text files (`.txt` only) within that library are added to the training data.

In [11]:
if use_preprocessed_data is False and skip_additional_texts is False:
    additional = os.path.join(project_path, "additional_texts.json")
    print(f"Looking for description of additional sources in {additional}")
    if os.path.exists(additional) is True:
        with open(additional, 'r') as f:
            add_desc = json.load(f)
            if 'local_texts' in add_desc:
                fd = Folder_Dataset()
                for text_path in add_desc['local_texts']:
                    print(f"Loading texts from {text_path}")
                    fd.load_index(text_path, use_aliases=False, max_file_size=100000)
                td.load_texts(fd.records[:10000])
            if 'calibre' in add_desc:
                cal_path = add_desc['calibre']
                if os.path.exists(cal_path):
                    print(f"Loading text from calibre at {cal_path}")
                    cd = Calibre_Dataset(cal_path)
                    cd.load_index(max_file_size=100000000)
                    td.load_texts(cd.records[:1000])

## 2.3 Tokenize data

In [12]:
if use_preprocessed_data is False:
    MAX_TOKENS = 10000  # This becomes vocab_size
    MAX_NGRAM_LEN = 4   # Max length of a token
    CHUNK_SIZE = 500000 # Split larger texts in chunks, if not None

    print("")
    print(f"Starting tokenizer with token length from 1..{MAX_NGRAM_LEN} with a max of {MAX_TOKENS} unique tokens,")
    print("this can take considerable time...")

    # Better tested NGRAM tokenizer:
    # td.init_tokenizer(tokenizer='ngram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS) 
    # or alternative 'BYTEGRAM' (more experimental, can encode arbitrary UTF-8)
    # td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.init_tokenizer(tokenizer='bytegram', max_ngrams=MAX_NGRAM_LEN, max_tokens=MAX_TOKENS, chunk_size=CHUNK_SIZE)
    td.save_tokenizer(token_file)

## 3. Model metadata

In [13]:
params = None
updatable_keys=['learning_rate', 'batch_size', 'current_epoch', 'current_loss',
                 'sample_every_n_iterations', 'sample_size', 'save_every_n_iterations', 'max_iterations']
attn_layers = 8
dims = 256
sequence_length = 256

params = { # Multi-head self-attention
        'meta_name_template': '{mhsa_layers}x{heads}x{units}x{vocab_size}',

        'mhsa_layers': attn_layers,
        'heads': 8,
        'vocab_size': td.get_unique_token_count(),
        'sequence_len': sequence_length,
        'embedding_size': dims,
        'test_iterations': 20,  # number of epocs for loss estimation

        'batch_size': 64,      # A100: 80, V100: 32, None: set in depedence of graphics card (s.b.)
        'learning_rate': 2e-4,   # None: Set in dependence of graphics hw

        'sample_every_n_iterations': 256,
        'sample_size': 128,
        'save_every_n_iterations': 4096,

        'max_iterations': 100000000  # maximum number of training iterations
    }

model_file_path = MJ.get_model_filename(model_path)
if use_existing_model_from_checkpoint is True:
    params = MJ.load_model_metadata_from_checkpoint(params, updatable_keys, model_file_path, device=device, log=log) # torch.device('cpu'))
if params == None or use_existing_model_from_checkpoint is False:
    use_existing_model_from_checkpoint = False
# print(params)

## 4. Batch handling

In [14]:
joint_training=0
td.init_getitem(sample_type='encoded', sample_length=params['sequence_len']+1+joint_training, content_stepping=1)
num_records = len(td)
print(f"{num_records} records")

15392802 records


In [15]:
def get_sample_sub_batch(sample_batch, batch_size, sub_index=0):
    joint_training=0
    for i in range(batch_size):
        Xi = sample_batch[sub_index:-1-joint_training+sub_index]
        yi = sample_batch[sub_index+1:]
        if i==0:
            # smpX=np.array(Xi, dtype=np.float32)
            smpX=np.array(Xi, dtype=np.int32)
            smpy=np.array(yi, dtype=np.int32)
        else:
            # smpX = np.vstack((smpX, np.array(Xi, dtype=np.float32)))
            smpX = np.vstack((smpX, np.array(Xi, dtype=np.int32)))
            smpy = np.vstack((smpy, np.array(yi, dtype=np.int32)))
    return np.array(smpX), np.array(smpy)

def get_sample_batch(td, batch_size):
    sample_batch = td.get_random_item()
    return get_sample_sub_batch(sample_batch, batch_size)

In [16]:
num_batches = num_records // params['batch_size']
print(f"num_batches = {num_batches}")

num_batches = 240512


In [17]:
x, y = get_sample_batch(td, 2)
x.shape, y.shape

((2, 256), (2, 256))

In [18]:
sample_data = None

def get_torch_subbatch(td, batch_size, device, split=None, sub_index=0):
    global sample_data
    if sub_index==0:
        sample_data = td.get_random_item()
    x, y = get_sample_sub_batch(sample_data, batch_size, sub_index)
    tx = torch.tensor(x, dtype=torch.long).to(device)
    tx.requires_grad = False
    ty = torch.tensor(y, dtype=torch.long).to(device)
    ty.requires_grad = False
    return tx, ty

def get_torch_batch(td, batch_size, device, split=None):
    x, y = get_sample_batch(td, batch_size)
    tx = torch.tensor(x, dtype=torch.long).to(device)
    tx.requires_grad = False
    ty = torch.tensor(y, dtype=torch.long).to(device)
    ty.requires_grad = False
    return tx, ty

def get_zero_state(batch_size, sequence_len, hidden_size, device):
    zstate = torch.zeros(batch_size, sequence_len, hidden_size, device=device)
    zstate.requires_grad = False
    return zstate

## 5. Loss and training helpers

In [19]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, device=None):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model, device=device)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
        
class MultiHeadSelfAttentionWithMemory(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_size,
        sequence_len,
        num_heads,
        num_layers,
        device=None,
    ):
        super().__init__()
        if device is None:
            raise ValueError(
                "Device is None at MultiHeadSelfAttentionWithMemory"
            )
        self.device = device
        self.sequence_len = sequence_len
        context_sub_layers = num_layers // 2
        self.context_sub_layers = context_sub_layers
        dims = embedding_size
        self.dims = dims
        # each token directly reads off the logits for the next token from a lookup table
        self.embedding = nn.Embedding(vocab_size, embedding_size, device=device)
        self.pos_encoder = PositionalEncoding(dims, dropout=0.0, max_len=10000, device=device)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dims, nhead=num_heads, dim_feedforward=dims*4, dropout=0.0, batch_first=True, device=device) # , batch_first=True
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # context_sub_layers)
        # encoder_layer2 = nn.TransformerEncoderLayer(d_model=dims, nhead=num_heads, dim_feedforward=dims*4, dropout=0.0, device=device)
        # self.transformer2 = nn.TransformerEncoder(encoder_layer2, num_layers=num_layers - context_sub_layers)
        self.out_proj = nn.Linear(dims, vocab_size, device=device)


    def forward(self, idx):
        B, D = idx.shape
        x = self.embedding(idx) * math.sqrt(D)
        x = self.pos_encoder(x) 
        x_mask = nn.Transformer.generate_square_subsequent_mask(D).to(device)
        x = self.transformer(x, x_mask)
        # x = self.transformer2(x, x_mask)
        logits = self.out_proj(x)
        return logits

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate new tokens given a context

        Note: for apple MPS, top_k is limited max 16 vor older torchs! ((01/2023) implementation limitation)
        See: https://github.com/pytorch/pytorch/issues/78915
        Solved in: https://github.com/pytorch/pytorch/pull/94639 (03/2023)

        :param idx: the context (B,T) tensor of indices
        :param max_new_tokens: the maximum number of tokens to generate
        :param temperature: the temperature to use for sampling
        :param top_k: the number of top tokens to consider
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last sequence_len tokens
            idx_cond = idx[:, -self.sequence_len :]
            # print(idx_cond.shape)
            # get the predictions
            logits = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply temperature
            if temperature != 1.0 and temperature > 0.0:
                logits = logits / temperature
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


In [20]:
print("creating model...")
try:
    # Colab + torch 2 -> lots of garbage.
    if model is not None:
        del model
except:
    pass


model = MultiHeadSelfAttentionWithMemory(vocab_size=params['vocab_size'], embedding_size=params['embedding_size'],
                                       sequence_len=params['sequence_len'],
                                       num_heads=params['heads'], num_layers=params['mhsa_layers'],
                                       device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'])

model = model.to(device)
if use_existing_model_from_checkpoint is True:
    params_load = MJ.load_checkpoint(params, model, optimizer, file_path=model_file_path, updatable_keys=updatable_keys, device=device, log=log) # torch.device("cpu"))
    if params_load is not None:
        params = params_load
model = model.to(device)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

if use_torch_compile is True:
    if device == 'cuda':
        print("Compiling...")
        model = torch.compile(model)
        print("Compile ok.")
        try:
            torch.set_float32_matmul_precision('high')
        except:
            print("Seems no tensor cores for that.")
    # elif str(device) == 'mps':
    #     print("Compiling...")
    #     model = torch.compile(model)
    #     print("Compile ok.")

if 'current_epoch' in params:
    ep = params['current_epoch']
else:
    ep=0
if 'current_loss' in params:
    ls = params['current_loss']
else:
    ls=0

if ep==0 and ls==0:
    start_iter = 0
else:
    start_iter = ep
    current_loss = ls

# print the number of parameters in the model
print(model)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")

creating model...
MultiHeadSelfAttentionWithMemory(
  (embedding): Embedding(10000, 256)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (out_proj): Linear(in_features=256, out_features=10000, bias=True)
)
11.44808 M parameters


In [21]:
@torch.no_grad()
def estimate_loss(device):
    # XXX: this does take data for train and val from SAME pool!
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(params['test_iterations'])
        for k in range(params['test_iterations']):
            # if k % (params['test_iterations']/10 + 1) == 0:
            #     print(".", end="", flush=True)
            X, Y = get_torch_batch(td, params['batch_size'], device, split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    print("\r", end="", flush=True)
    mloss = (out['train']+out['val'])/2.0
    return mloss

def generate_sample(td, device, prompt=' ', toks=100, state=None, temperature=1.0, top_k=None, pad=False):
    # generate from the model
    # context = torch.zeros((1, 1), dtype=torch.long, device=device)
    model.eval()
    if pad is True:
        while len(prompt)<params['sequence_len']:
            if len(prompt)==params['sequence_len']-1:
                prompt = '\n' + prompt
            else:
                prompt = ' ' + prompt
    context = torch.tensor([td.encode(prompt)]).to(device)
    answer = model.generate(context, max_new_tokens=toks, temperature=temperature, top_k=top_k)
    txt = td.decode(answer[0].tolist())
    # Identify memorisation of text by highlighting verbatim quotes from sources
    # that are longer than 10 chars. HTML colorcoded output for source identification:
    td.source_highlight(txt, min_quote_size=10, dark_mode=False, display_ref_anchor=False)
    return txt


In [22]:
# @torch.jit.script
# @torch.compile
criterion = nn.CrossEntropyLoss()
def do_train_step(xb, yb, device, state=None):
    model.train()
    logits = model(xb)
    # print(logits.shape)
    output_flat = logits.view(-1, params['vocab_size'])
    # output_flat = logits.view(-1, params['vocab_size'])
    # print(output_flat.shape)
    ybr = yb.view(-1)
    # print(ybr.shape)
    loss = criterion(output_flat, ybr)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()

    return loss.item()

In [23]:
dt0 = time.time()
sdt = datetime.datetime.now(tz=local_timezone).strftime("%Y-%m-%d %H:%M:%S")
print(f"training, start at {sdt}...")
gen_id = 0
iter_bench = 1
cur_loss_m = 0
cur_loss_m_avg = 25
# current_loss = estimate_loss(device)
inputs = ["What is the difference between good and evil? The difference ", "How did everything come into existence? The origin ", "What was at the beginning of time? Time itself ", "How are physics, quantum-mechanics and consciousness related? The relation between ", "How to attain complete self-awareness? Complete ", "What is the nature of reality? The nature ", "How be a good human being? A human "]
for iter in range(start_iter, params['max_iterations']):
    # every once in a while evaluate the loss on train and val sets
    if (iter + 1) % params['sample_every_n_iterations'] == 0 or iter == params['max_iterations'] - 1:
        dt = time.time()
        print(f"\rloss eval", end="", flush=True)
        current_loss = cur_loss_m # estimate_loss(device)
        print(
            f"step {iter+1}: train loss {current_loss:.4f}, time {(dt-dt0)/iter_bench:.3f} sec/iter"
        )
        iter_bench = 1
        sdt = datetime.datetime.now(tz=local_timezone).strftime("%Y-%m-%d %H:%M:%S")
        print(f"Sample at {sdt}:", flush=True)
        for temperature in [0.75]:
            print(f"--------temperature: {temperature} ---------")
            prompt = inputs[gen_id%len(inputs)]
            print(f"Prompt: {prompt}")
            generate_sample(td=td, device=device, prompt=prompt, toks=params['sample_size'], temperature=temperature, top_k=16)
        print("-------------------------------------------")
        gen_id += 1
        dt0 = time.time()

    xb, yb = get_torch_batch(td, params['batch_size'], device, "train")
    cur_loss = do_train_step(xb, yb, device=device)
    if cur_loss_m == 0:
        cur_loss_m = cur_loss
    else:
        cur_loss_m = (cur_loss + cur_loss_m * (cur_loss_m_avg-1))/cur_loss_m_avg
    print(f"\rIteration: {iter+1:5d}/{((iter+1)//params['sample_every_n_iterations']+1)*params['sample_every_n_iterations']}/{params['max_iterations']} loss: {cur_loss_m:.4f}", end="", flush=True)

    start_iter = iter
    iter_bench += 1
    if (iter+1)%params['save_every_n_iterations'] == 0:
        MJ.save_checkpoint(params, model, optimizer, iter, current_loss, file_path=model_file_path, log=log)


training, start at 2024-05-02 13:12:47...
loss eval:   255/256/100000000 loss: 7.0350step 256: train loss 7.0350, time 0.097 sec/iter
Sample at 2024-05-02 13:13:12:
--------temperature: 0.75 ---------
Prompt: What is the difference between good and evil? The difference 


-------------------------------------------
loss eval:   511/512/100000000 loss: 6.8775step 512: train loss 6.8775, time 0.097 sec/iter
Sample at 2024-05-02 13:13:41:
--------temperature: 0.75 ---------
Prompt: How did everything come into existence? The origin 


-------------------------------------------
loss eval:   767/768/100000000 loss: 6.7218step 768: train loss 6.7218, time 0.097 sec/iter
Sample at 2024-05-02 13:14:11:
--------temperature: 0.75 ---------
Prompt: What was at the beginning of time? Time itself 


-------------------------------------------
Iteration:   888/1024/100000000 loss: 6.6877

KeyboardInterrupt: 

In [None]:
# for t in [0.5, 1.5]:
#     print(f"------Temperature {t}--------")
#     generate_sample(td, device, prompt="How are consciousness and quantum mechanics related?", toks=150, temperature=t, top_k=16)

In [None]:
# Test-code below, unfinished.

In [None]:
texts = []
enc_texts = []
for i in range(500):
    e = td[i*50000][:256]
    tx = torch.tensor([e]).to(device)
    enc_texts.append(tx)
    texts.append(td.decode(e))


In [None]:
enc_texts[0].shape

In [None]:
emb_text = []
cont_text = []
for et in enc_texts:
    emb_text.append(model.embedding(et))
    cont_text.append(model.context(et))

In [None]:
emb_text[0].shape, cont_text[0].shape

In [None]:
emb_vec = []
cont_vec = []
for i in range(len(emb_text)):
    emb_vec.append(emb_text[i][0].sum(axis=0))
    cont_vec.append(cont_text[i][0].sum(axis=0))

In [None]:
def cos_vec(a, b):
    al = torch.sqrt(torch.dot(a,a))
    bl = torch.sqrt(torch.dot(b,b))
    an = a/al
    bn = b/bl
    return torch.dot(an,bn)

In [None]:
for i in range(0):
    best = -10.0
    ind = -1
    for j in range(len(emb_vec)):
        if i==j:
            continue
        cos_val = cos_vec(emb_vec[i], emb_vec[j])
        if cos_val > best:
            best = cos_val
            ind = j
    # print(f"{texts[i][:20]} ->{best}: {texts[ind][:20]}")
print()        
for i in range(200):
    best = 0
    ind = -1
    for j in range(len(emb_vec)):
        if i==j:
            continue
        cos_val = cos_vec(cont_vec[i], cont_vec[j])
        if cos_val > best:
            best = cos_val
            ind = j
    print(f"{texts[i]} \n\n->{best}:\n\n {texts[ind]}")
    print("---------------------------------------------")


In [None]:
td[200002]