# Why Snakes?

This is the training notebook for YS-120M-base-v0.1.

This code is MIT licensed. `Use it at your own risk`.

There's a lot of garbage code in here as I iterated relatively fast to figure out what worked, or made small improvements. The best use of this project is to learn from it, but that's hampered by the junk. You are better off looking to the main branch for updates--which aren't checked in at the moment I'm writing this--but those changes for v0.2 will make the code more legible, even if still imperfect.

My personal machine has an Intel 13th generation processor, which is among those impacted by the flaw that causes a random Blue Screen of Death. In an effort to push through, I brought training to Collab. It turns out, that collab likes to check if you are human at random intervals and then stop if you aren't watching at that moment. If you don't mind a fairly annoying experience, $50 to train a model on an L4 isn't outragous. It's more frustrating when you already own a better GPU and can't use it because you have a flawed CPU.

### You are going to ask:

* Why is it only 120 million parameters?
 - I needed it plus a chunk of training data to fit on an L4 (22.5 GB of gpu memory)
 - I inexplicably used longs for the token size rather than unsigned 16-bit ints.
 - After a lot of experimentation, I found that I needed a batch size of at least 64 to (mostly) train stably.
 - I didn't think of using a smaller sequence length with a bigger batch size to get closer to fully trained before switching to a longer sequence length and a small batch size until I had already spent a lot of time training.
 - I wanted to finish v0.1 before starting a new round of training with my learnings -- otherwise, I'd keep learning and you'd never see a v0.1

* Why didn't you use a BPE-based tokenizer?
 - BPE is super well suited to the transformer architecture where you have a limited context size.
 - With Mamba 2, I have an unlimited context size. What's more a lot of the tokens that are fractions of a word don't have any inherent meaning.
 - BPE compresses text in a meaningful way, and most importantly, compresses out spaces, which would use up valuable tokens.
 - With Mamba 2, I have an unlimited context and can give the spaces meaning -- think of Python, multiple spaces mean something.
 - I can also tread things like capitalization, italics, underline, etc. as a token -- a modifier on the coming word. And that also has meaning without changing the base word directly, making it easier for the model to realize this is the same word but with another token influencing it -- just like a normal sentence.

* Why didn't you use the HuggingFace library?
 - I appreciate what HuggingFace has done. They are among my favorite companies, but this project is about paving new roads.
 - I didn't want to battle learning the ins and outs of making a custom tokenizer and a custom model while learning to build the architecture I wanted.
 - Maybe some day I can contribute this to HuggingFace, but not until after I have the model I want.

* Why didn't you use the Mamba 2 reference project?
 - It only runs on Linux with cuda. I wanted my project to work for people with a GPU or a CPU.

* Does the world really need another small LLM?
 - Probably not, but I need to learn and if we as a community don't continue to work toward a functioning architure that isn't the Transformer architecture, only the largest of companies will be able to train models that are useful.
 - While I super appreciate the revolution that Transformers have given us, they've also created a moat that allows huge companies to hoard the knowledge, experience, and capability.
 - I want to know if I can do it--make a useful LLM.

* If you want to make a useful LLM, why release v0.1 which you specifically say isn't useful?
 - Because I want to give really determined people a chance to learn from what I did. Even though it is really imperfect and not very useful, it's still slightly useful.
 - I will release a v0.2 which I expect to be better and easier to learn from, but these things take time and we're in a technology cycle that moves so fast that sitting on the knowledge helps nobody.

* Why didn't you use more training data?
 - I decided to generate my own training data.
 - I like the idea of curating the data, and while I have only minimally curated it to this point, over time, I can continue to improve it.
 - Unfortunately, generating synthetic data takes time, and I wanted to release a v0.1 model as soon as I can. This is also why v0.2 will be trained on more data.
 - More data means slower training epochs, which means a longer time before I have a first version that I can evaluate.

### Baseline categorical cross entropy loss

Baseline categorical cross entropy loss let's us compare our loss against a model that has a perfectly equal chance of predicting any token.

Baseline categorical cross entropy loss for our vocab = log(33,438) = 4.52

We beat the baseline categorical cross entropy loss on our first training round, so at least the results aren't random.

### Perplexity

Perplexity helps us know that, for our vocabulary size, how well are we training. Categorical cross entropy is otherwise rather difficult to interpret. Having a good perpelexity isn't a great metric in terms of knowing how good your LLM is, but if you have a bad Perplexity, your model is bad.

* Starting perplexity (e^loss) = e^5.2 = 181.27
* Final perplexity (e^loss) = e^? = ? [will fill in when i have the final resutls]



In [1]:
import pickle
import string
import time
import os
import torch
import re
import random
import math

import numpy as np

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader, random_split, IterableDataset


In [2]:
%env YS_LLM_BASE_PATH=/content/drive/MyDrive/ys-llm/

env: YS_LLM_BASE_PATH=/content/drive/MyDrive/ys-llm/


In [3]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0


In [4]:
# from numba import cuda
# device = cuda.get_current_device()
# device.reset()

In [5]:
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(torch.cuda.current_device()))
print(torch.cuda.get_device_properties(0).total_memory)

2.3.1+cu121
True
12.1
True
1
0
NVIDIA L4
23802544128


In [6]:
# %env CUDA_LAUNCH_BLOCKING=1

In [7]:
torch.cuda.is_available()

True

In [8]:
# /content/drive/MyDrive/ys-llm/training_data
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:

DEBUG_TOKENIZER = False


class Tokenizer:
    def __init__(self):
        self.index_to_word = {}
        self.word_to_index = {}
        self.current_index = 0
        self.initialized = False
        self.special_tokens = ["<upper>", "<shout>", "<start>", "<end>", "<space>", "<newline>"]

    def _add_to_vocab(self, word):
        if word not in self.word_to_index:
            self.index_to_word[self.current_index] = word
            self.word_to_index[word] = self.current_index
            self.current_index += 1

    def _initialize_vocabulary(self):
        if not self.initialized:
            for token in self.special_tokens:
                self._add_to_vocab(token)
            self._add_to_vocab('_')
            # Note:
            # We have to keep the model small, and there's no intention
            # of being exclusionary, but I have to focus on the text that I know.
            # If you are taking and modifying this code, I highly recommend
            # streamlining this for the language you will train on and use.
            # I don't have the money to train a model that can handle all languages.

            # Initialize single-letter printable tokens for basic ASCII and common Western European characters
            ascii_range = list(range(32, 127))  # Basic printable ASCII
            extended_chars = [
                'é', 'è', 'ê', 'ë', 'à', 'á', 'â', 'ä', 'å', 'ç', 'ì', 'í', 'î', 'ï',
                'ñ', 'ò', 'ó', 'ô', 'ö', 'ù', 'ú', 'û', 'ü', 'ý', 'ÿ', 'ø', 'å', 'Æ',
                'æ', 'œ', 'ß', 'ñ', '¿', '¡', '€', '£'
            ]
            extended_chars += [char.upper() for char in extended_chars if char.islower()]

            # Add ASCII characters to vocabulary
            for i in ascii_range:
                char = chr(i)
                self._add_to_vocab(char)

            # Add common Western European accented characters and symbols
            for char in extended_chars:
                self._add_to_vocab(char)

            self.initialized = True

    def learn_new_vocab(self, document: str):
        if not self.initialized:
            self._initialize_vocabulary()
        words = self.to_words(document)
        for word in words:
            if word not in self.word_to_index:
                self._add_to_vocab(word)

    @staticmethod
    def _split_text(text):
        tokens = []

        # Define the regular expressions for the tokens, ordered by priority
        space_regex = r'\s+'  # One or more spaces
        tag_regex = r'<[a-z]+>'  # <lowercase letters>
        cap_word_regex = r'[A-Z][a-zà-ÿ]*'  # Capital followed by lowercase, including accented characters
        all_caps_regex = r'[A-Z][A-ZÀ-ß]+'  # All capital letters in a row, including accented uppercase characters
        lower_word_regex = r'[a-zà-ÿ]+'  # One or more lowercase letters, including accented characters
        single_cap_regex = r'[A-ZÀ-ß]'  # A single capital word (like 'A' or accented capital)
        number_regex = r'\d+'  # One or more numbers
        symbol_regex = r'[^\w\s]|_'  # Single symbol or punctuation mark

        # Create a combined regex to match all tokens
        combined_regex = f'({space_regex}|{tag_regex}|{cap_word_regex}|{all_caps_regex}|{lower_word_regex}|{single_cap_regex}|{number_regex}|{symbol_regex})'

        # Find all matches using the combined regex
        matches = re.findall(combined_regex, text)

        for match in matches:
            tokens.append(match)

        return tokens

    def to_words(self, text: str):
        tokens = self._split_text(text)

        result = []
        for token in tokens:
            if not token:
                continue
            if DEBUG_TOKENIZER:
                print(f"Processing token: {token}")  # Debugging statement
            if '\n' in token:
                result.append('<newline>')
            elif token.isspace():
                result.append('<space>')
            elif re.match(r'<[a-z]+>', token):
                result.append(token)  # tag
            elif re.match(r'[A-Z][a-zà-ÿ]*', token):
                result.append('<upper>')
                result.append(token.lower())
            elif re.match(r'[A-Z][A-ZÀ-ß]+', token):
                result.append('<shout>')
                result.append(token.lower())
            elif re.match(r'[a-zà-ÿ]+', token):
                result.append(token)
            elif re.match(r'[A-ZÀ-ß]', token):
                result.append(token.lower())
            else:
                result.append(token)

        if DEBUG_TOKENIZER:
            print(f"Final token list: {result}")  # Debugging statement
        return result

    def tokenize(self, text):
        if not self.initialized:
            self._initialize_vocabulary()
        tokens = []
        words = self.to_words(text)
        for word in words:
            if word in self.word_to_index:
                tokens.append(self.word_to_index[word])
            else:
                if DEBUG_TOKENIZER:
                    print(f"Word '{word}' not found in vocabulary. Using individual characters.")
                for letter in word:
                    if letter in self.word_to_index:
                        tokens.append(self.word_to_index[letter])
                    elif DEBUG_TOKENIZER:
                        print(f"Character '{letter}' wasn't in vocabulary. Skipping!")
        return tokens

    def detokenize(self, tokens):
        text = ''
        capitalize_next = False
        shout_next = False
        for token in tokens:
            word = self.index_to_word[token]
            if word == '<space>':
                text += ' '
            elif word == '<newline>':
                text += '\n'
            elif word == '<upper>':
                capitalize_next = True
            elif word == '<shout>':
                shout_next = True
            elif word == '_':
                text += '_'
            elif word in self.special_tokens:
                continue  # Skip rendering of standalone special tokens
            else:
                if capitalize_next:
                    word = word.capitalize()
                    capitalize_next = False
                if shout_next:
                    word = word.upper()
                    shout_next = False
                text += word
        return text

    def vocab_size(self):
        return self.current_index

    def save(self, filepath):
        with open(filepath, 'wb') as f:
            pickle.dump(
                (self.index_to_word, self.word_to_index, self.current_index, self.initialized), f)

    def load(self, filepath):
        with open(filepath, 'rb') as f:
            self.index_to_word, self.word_to_index, self.current_index, self.initialized = pickle.load(f)
        if DEBUG_TOKENIZER:
            print("Tokenizer vocabulary size: ", self.vocab_size())

    def print_vocabulary(self):
        print(self.word_to_index)
        print("Tokenizer vocabulary size: ", self.vocab_size())

    def get_end_token(self):
        if not self.initialized:
            self._initialize_vocabulary()
        return self.word_to_index['<end>']

    def get_start_token(self):
        if not self.initialized:
            self._initialize_vocabulary()
        return self.word_to_index['<start>']


In [10]:
ATTENTION_DEBUG = False


class Attention(nn.Module):
    def __init__(self, state_dim, input_dim, output_dim, block_size=32, dropout_rate=0.1):
        super(Attention, self).__init__()
        self.state_dim = state_dim
        self.block_size = block_size
        self.P = nn.Parameter(torch.eye(state_dim).repeat(state_dim, 1, 1) * 0.1 + 0.01)
        self.Q = nn.Parameter(torch.eye(state_dim, input_dim) * 0.1 + 0.01)
        self.R = nn.Parameter(torch.eye(output_dim, state_dim) * 0.1 + 0.01)
        self.S = nn.Parameter(torch.eye(output_dim, input_dim) * 0.1 + 0.01)
        self.layer_norm = nn.LayerNorm(state_dim)
        self.input_dropout = nn.Dropout(p=dropout_rate)
        # self.attn_dropout = nn.Dropout(p=dropout_rate)
        self.output_dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        device = x.device
        batch_size, sequence_length, input_dim = x.shape
        outputs = []
        state = torch.zeros(batch_size, self.state_dim, device=device)

        x = self.input_dropout(x)

        p_expanded = self.P.unsqueeze(0)
        q_expanded = self.Q.expand(batch_size, -1, -1)
        s_expanded = self.S.expand(batch_size, -1, -1)

        for start in range(0, sequence_length, self.block_size):
            end = min(start + self.block_size, sequence_length)
            x_block = x[:, start:end, :]

            for t in range(x_block.shape[1]):
                input_t = x_block[:, t, :].unsqueeze(-1)

                state_quad = torch.matmul(state.unsqueeze(-1), state.unsqueeze(-2))
                state = torch.einsum('bij,bijk->bk', state_quad, p_expanded) + \
                        torch.matmul(q_expanded, input_t).squeeze(-1)

                state = self.layer_norm(state)

                # state = self.attn_dropout(state)

                output_t = torch.matmul(self.R, state.unsqueeze(-1)).squeeze(-1) + \
                           torch.matmul(s_expanded, input_t).squeeze(-1)
                outputs.append(output_t.unsqueeze(1))

            if ATTENTION_DEBUG:
                if torch.isnan(state).any():
                    print(f"NaN detected at step {start}-{end} in state.")
                    raise ValueError("NaN detected in state after block processing")

                state_norm = torch.norm(state, p=2, dim=-1)
                if torch.any(state_norm > 1e5):
                    print(f"Exploding state detected at step {start}-{end} with norm: {state_norm.max().item()}")
                elif torch.any(state_norm < 1e-5):
                    print(f"Vanishing state detected at step {start}-{end} with norm: {state_norm.min().item()}")

            state = state.detach()

        final_output = torch.cat(outputs, dim=1)

        final_output = self.output_dropout(final_output)

        if ATTENTION_DEBUG:
            if torch.isnan(final_output).any():
                raise ValueError("NaN detected in final output")

            final_output_norm = torch.norm(final_output, p=2, dim=-1)
            if torch.any(final_output_norm > 1e5):
                print(f"Exploding final output detected with norm: {final_output_norm.max().item()}")
            elif torch.any(final_output_norm < 1e-5):
                print(f"Vanishing final output detected with norm: {final_output_norm.min().item()}")

        return final_output

In [11]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=448, state_dim=448, output_dim=448):
        super(LanguageModel, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.state_space_model = Attention(state_dim=state_dim, input_dim=embedding_dim, output_dim=output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)
        self.output_layer = nn.Linear(output_dim, vocab_size)

    def forward(self, input_tokens):
        embedded = self.embedding(input_tokens)
        if len(embedded.shape) == 2:
            embedded = embedded.unsqueeze(0)
        context_representation = self.state_space_model(embedded)
        context_representation = self.layer_norm(context_representation)
        output = self.output_layer(context_representation)
        return output

    def predict_next_token(self, input_tokens):
        output = self.forward(input_tokens)
        next_token_id = torch.argmax(output[:, -1, :], dim=-1).squeeze()
        return next_token_id.item()

    def predict_next_token_softmax(self, input_tokens, top_p=0.9):
        output = self.forward(input_tokens)

        logits = output[:, -1, :].squeeze()
        probs = torch.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        cutoff_index = torch.searchsorted(cumulative_probs, top_p)
        top_p_probs = sorted_probs[:cutoff_index + 1]
        top_p_indices = sorted_indices[:cutoff_index + 1]
        next_token_id = torch.multinomial(top_p_probs, 1)

        return top_p_indices[next_token_id].item()

In [12]:

class ModelInterface:
    def __init__(self, model_save_path="model.bin", tokenizer_save_path="tokenizer.pkl"):
        self.tokenizer = Tokenizer()
        self.tokenizer.load(tokenizer_save_path)
        self.end_token = self.tokenizer.get_end_token()
        self.model = LanguageModel(vocab_size=self.tokenizer.vocab_size())
        self.model.load_state_dict(torch.load(model_save_path), strict=False)
        self.model.eval()
        self.device_name = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {self.device_name}")
        self.device = torch.device(self.device_name)
        self.model.to(self.device)

    def count_parameters(self):
        total_params = sum(p.numel() for p in self.model.parameters())
        return total_params

    def vocab_size(self):
        return self.tokenizer.vocab_size()

    def complete(self, current_context: str, top_p: float = 0.9, max_tokens: int = 100000):
        tokens = self.tokenizer.tokenize(current_context)
        for i in range(max_tokens):
            input_tokens = torch.tensor(tokens).to(self.device)
            if top_p >= 1.0 or top_p <= 0.0:
                next_token = self.model.predict_next_token(input_tokens)
            else:
                next_token = self.model.predict_next_token_softmax(input_tokens, top_p)
            tokens.append(next_token)
            if next_token == self.end_token:
                break
        result = self.tokenizer.detokenize(tokens)
        return result

    def prompt(self, prompt: str, top_p: float = 0.9, max_tokens: int = 100000):
        new_prompt = prompt
        if '<start>' not in prompt:
            new_prompt = prompt + '\n<start>'

        result = self.complete(new_prompt, top_p, max_tokens)
        return result.replace('<end>', '').strip()


In [13]:
class TextDataset(IterableDataset):
    def __init__(self, files, tokenizer, context_length, shuffle_files=True):
        self.files = files
        self.tokenizer = tokenizer
        self.context_length = context_length
        self._length = None
        if shuffle_files:
            random.shuffle(self.files)

    def __iter__(self):
        for file in self.files:
            with open(file, 'r', encoding='utf-8') as f:
                text = f.read()
                tokens = self.tokenizer.tokenize(text)
                for i in range(0, len(tokens) - self.context_length, self.context_length):
                    x = tokens[i:i + self.context_length]
                    y = tokens[i + 1:i + 1 + self.context_length]
                    yield torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

    def __len__(self):
        if self._length is None:
            # Calculate the length by iterating through the dataset
            count = 0
            for file in self.files:
                with open(file, 'r', encoding='utf-8') as f:
                    text = f.read()
                    tokens = self.tokenizer.tokenize(text)
                    count += len(tokens) // self.context_length
            self._length = count
        return self._length


In [14]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = np.inf
        self.wait = 0
        self.stop_training = False

    def check(self, current_loss):
        if (self.best_loss - current_loss) > self.min_delta:
            self.best_loss = current_loss
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stop_training = True

In [15]:

def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, epochs=100, max_grad_norm=1.0,
                patience=5, accumulation_steps=4, no_validation=False):
    print("Training model...")
    device_name = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device_name}")
    device = torch.device(device_name)

    model.to(device)
    early_stopping = EarlyStopping(patience=patience)

    if device_name == "cuda":
        # Initialize GradScaler for mixed precision if using GPU
        scaler = torch.cuda.amp.GradScaler()
    else:
        scaler = None  # No scaler needed for CPU

    base_path = os.getenv("YS_LLM_BASE_PATH", "./")

    print("First epoch starting...")
    parameters_shown = False
    best_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        optimizer.zero_grad()  # Initialize the optimizer at the start of each epoch

        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            if device_name == "cuda":
                # Use autocast for mixed precision on GPU
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    batch_size, sequence_length, vocab_size = outputs.shape
                    outputs = outputs.view(batch_size * sequence_length, vocab_size)
                    targets = targets.view(batch_size * sequence_length)

                    loss = criterion(outputs, targets) / accumulation_steps
                # Scale the loss before backpropagation
                scaler.scale(loss).backward()
                # if not parameters_shown:
                #     total_params = sum(p.numel() for p in model.parameters())
                #     print(f"Total number of parameters: {total_params}")
                #     parameters_shown = True
            else:
                # Full precision for CPU
                outputs = model(inputs)
                batch_size, sequence_length, vocab_size = outputs.shape
                outputs = outputs.view(batch_size * sequence_length, vocab_size)
                targets = targets.view(batch_size * sequence_length)

                loss = criterion(outputs, targets) / accumulation_steps
                loss.backward()

            # Accumulate gradients and update model after a certain number of steps
            if (i + 1) % accumulation_steps == 0:
                if device_name == "cuda":
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                if device_name == "cuda":
                    scaler.update()
                optimizer.zero_grad()  # Reset gradients for the next accumulation cycle

            epoch_loss += loss.item() * accumulation_steps  # Multiply to undo the earlier division

        if not no_validation:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    if device_name == "cuda":
                        with torch.cuda.amp.autocast():  # Use autocast for validation as well on GPU
                            outputs = model(inputs)
                            outputs = outputs.view(-1, outputs.size(-1))
                            targets = targets.view(-1)
                            loss = criterion(outputs, targets)
                    else:
                        outputs = model(inputs)
                        outputs = outputs.view(-1, outputs.size(-1))
                        targets = targets.view(-1)
                        loss = criterion(outputs, targets)

                    val_loss += loss.item()

            val_loss /= len(val_loader)
            print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader)}, Val Loss: {val_loss}')
            scheduler.step()

            if math.isnan(val_loss) or math.isnan(epoch_loss) or math.isinf(val_loss) or math.isinf(epoch_loss):
                print("Stopping due to numerical instability.")
                return False

            if epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(model.state_dict(), f"{base_path}model_checkpoint.bin")
                print(f"Checkpoint saved at epoch {epoch + 1} to {base_path}model_checkpoint.bin")
                if not parameters_shown:
                    total_params = sum(p.numel() for p in model.parameters())
                    print(f"Total number of parameters: {total_params}")
                    parameters_shown = True

            early_stopping.check(val_loss)
            if early_stopping.stop_training:
                print(f"Early stopping triggered at epoch {epoch + 1}")
                break
        else:
            # Monitor training loss for stopping if no validation
            print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader)}')
            scheduler.step()

            if math.isnan(epoch_loss) or math.isinf(epoch_loss):
                print("Stopping due to numerical instability.")
                return False

            if epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(model.state_dict(), f"{base_path}model_checkpoint.bin")
                print(f"New best model saved at epoch {epoch + 1}")
                if not parameters_shown:
                    total_params = sum(p.numel() for p in model.parameters())
                    print(f"Total number of parameters: {total_params}")
                    parameters_shown = True

            early_stopping.check(epoch_loss)
            if early_stopping.stop_training:
                print(f"Early stopping triggered at epoch {epoch + 1}")
                break

    return True


def safe_tensor_conversion(data_list, dtype=torch.long):
    try:
        tensor = torch.tensor(data_list, dtype=dtype)
        return tensor
    except Exception as e:
        print(f"Error during tensor conversion: {e}")
        previous_len = None
        for i, data in enumerate(data_list):
            data_len = len(data)
            if previous_len is None:
                print(f"Data length at index {i}: {data_len}")
                previous_len = data_len
            if data_len != previous_len:
                print(f"Data length at index {i}: {data_len}")
                print(f"Data at index {i}: {data}")
            for token in data:
                if token is not None:
                    print(f"None found in data at index {i}: {data}")
        raise e


def prepare_training_data(texts, tokenizer, sequence_length=5):
    x_train_list, y_train_list = [], []
    for text in texts:
        tokens = tokenizer.tokenize(text)
        if len(tokens) < sequence_length + 1:
            continue  # Skip sequences that are too short
        for i in range(len(tokens) - sequence_length):
            intput_sequence = tokens[i:i + sequence_length]
            target_sequence = tokens[i + 1:i + sequence_length + 1]
            if len(intput_sequence) == sequence_length and len(target_sequence) == sequence_length:
                x_train_list.append(intput_sequence)  # Input sequence
                y_train_list.append(target_sequence)  # Target sequence
            else:
                print(
                    f"Skipping a sequence with incorrect length: x_seq={len(intput_sequence)}, y_seq={len(target_sequence)}")

    x_train = safe_tensor_conversion(x_train_list)
    y_train = safe_tensor_conversion(y_train_list)

    # print(f"x_train shape: {x_train.shape}")  # Expect (num_sequences, sequence_length)
    # print(f"y_train shape: {y_train.shape}")  # Expect (num_sequences, sequence_length)

    return x_train, y_train


def split_dataset(dataset, val_split=0.2):
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    return train_dataset, val_dataset


def do_train(training_sequence_length=5, batch_size=64, max_epochs=100, patience=5, model_save_path="model.bin",
             tokenizer_save_path="tokenizer.pkl", no_validation=False):
    tokenizer, train_loader, val_loader, number_of_samples = build_tokenizer_and_load_tokens(training_sequence_length,
                                                                                             batch_size,
                                                                                             tokenizer_save_path)

    vocab_size = tokenizer.vocab_size()
    print(f"Vocabulary size: {vocab_size}")

    model = LanguageModel(tokenizer.vocab_size())

    base_path = os.getenv("YS_LLM_BASE_PATH", "./")
    if os.path.exists(f"{base_path}model_checkpoint.bin"):
        model.load_state_dict(torch.load(f"{base_path}model_checkpoint.bin"))
        print(f"Resumed training from {base_path}model_checkpoint.bin")
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Number of parameters: {total_params}")

    total_training_steps = (number_of_samples // batch_size) * max_epochs
    if number_of_samples % batch_size != 0:
        # Add one step for each epoch to cover the last incomplete batch
        total_training_steps += max_epochs

    optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0005, total_steps=total_training_steps)
    criterion = nn.CrossEntropyLoss()

    model_trained = train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, epochs=max_epochs,
                                patience=patience, no_validation=no_validation)
    if model_trained:
        print("Saving model...")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Number of parameters: {total_params}")
        print(f"Vocabulary size: {vocab_size}")


def build_tokenizer_and_load_tokens(training_sequence_length, batch_size, tokenizer_save_path, no_validation=False):
    print("Loading and tokenizing training data...")
    training_file_names = get_training_file_names()
    # Split file names for training and validation
    random.shuffle(training_file_names)
    if no_validation:
        train_files = training_file_names
        val_files = []
    else:
        split_index = int(0.8 * len(training_file_names))
        train_files = training_file_names[:split_index]
        val_files = training_file_names[split_index:]

    # Initialize the tokenizer
    tokenizer = Tokenizer()
    if os.path.exists(tokenizer_save_path):
        tokenizer.load(tokenizer_save_path)
        print("Loaded existing tokenizer.")
    else:
        # Update tokenizer vocab by iterating over each document
        for file_path in training_file_names:
            with open(file_path, 'r', encoding='utf-8') as file:
                document = file.read()
                tokenizer.learn_new_vocab(document)  # Add document content to vocab
                del document  # Free memory as soon as possible

        # Save the tokenizer with the built vocabulary
        tokenizer.save(tokenizer_save_path)

    # DataLoader for batching (shuffle is done within the TextDataset)
    train_dataset = TextDataset(train_files, tokenizer, training_sequence_length)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

    if no_validation:
        val_loader = None
    else:
        val_dataset = TextDataset(val_files, tokenizer, training_sequence_length)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    number_of_samples = len(train_dataset)

    return tokenizer, train_loader, val_loader, number_of_samples


def get_training_file_names(directory="training_data"):
    base_path = os.getenv("YS_LLM_BASE_PATH", "./")
    file_names = []
    for filename in os.listdir(f"{base_path}{directory}"):
        if filename.endswith(".txt") or filename.endswith(".md"):
            filepath = os.path.join(f"{base_path}{directory}", filename)
            file_names.append(filepath)
    return file_names


def notebook_do_train(no_validation=False):
    base_path = os.getenv("YS_LLM_BASE_PATH", "./")
    model_path = f"{base_path}model.bin"
    tokenizer_path = f"{base_path}tokenizer.pkl"
    do_train(training_sequence_length=96, batch_size=20,
             max_epochs=400, patience=10,
             model_save_path=model_path,
             tokenizer_save_path=tokenizer_path,
             no_validation=no_validation)


In [19]:
def interactive_interface():
    print("This is a usage example. This LLM will completely make up things and be wrong. Don't follow its advice and "
          "don't expect good results. There is no warranty. Use at your own risk.")
    base_path = os.getenv("YS_LLM_BASE_PATH", "./")
    model_path = f"{base_path}model.bin"
    tokenizer_path = f"{base_path}tokenizer.pkl"
    print("Loading...")
    interface = ModelInterface(model_save_path=model_path, tokenizer_save_path=tokenizer_path)
    params = interface.count_parameters()
    print(f"Number of parameters: {params}")
    print(f"Vocabulary size: {interface.vocab_size()}")
    print("Type a prompt that the model will write a story about. Type 'exit' to exit.")
    while True:
        next_input = input("> ")
        if next_input == 'exit':
            break
        result = interface.complete(next_input.strip(), top_p=1.0, max_tokens=96)
        print(f"\n{result}")

In [17]:
# Uncomment to train
# notebook_do_train(no_validation=True)

# Training runs for v0.1: 120M Parameters/33,438 token Vocab

Starting with a sequence length of 72 was a mistake. I knew that it wasn't efficient, but I was trying to maximize memory usage. I should have stuck to a multiple of the attention block size and just increased the batch size. That would have been more stable and made future training easier. It was numerically unstable 2 of those first 3 runs with a sequence length of 72.

There were three cases of numeric instability through the rest of the training, but that's one of the challenges with working with Mamba 2. It feels like you are always walking the edge. Most of the time, a training run was interrupted because colab stopped after some random time. This was frustrating, especially since I paid for the usage and paid extra for epochs that weren't allowed to finish. I speculate they charged me for at least 12 hours of time where they killed the run partway through an epoch before a save. I had to get up in the middle of the night to restart it. Not an experience I recommend and I won't be using colab after my compute units are gone. I read and re-read the colab rules and don't see how this breaks them -- but I suspect they are treating my account as free tier, despite paying for compute units.

| length | batch | epochs | loss start | loss end |
|---|---|---|---|---|
| 72 | 64 | 9 | ~5.2 | ~3.1 |
| 72 | 64 | 21 | ~3.1 | ~2.4 |
| 72 | 64 | 14? | ~2.37 | ~2.0 |
| 32 | 256 | 11 | ~2.25 | ~2.0 |
| 32 | 256 | 17 | ~2.05 | ~1.7 |
| 32 | 256 | 9 | ~1.88 | ~1.6 |
| 32 | 256 | 35 | ~1.64 | ~1.09 |
| 32 | 256 | 9 | ~1.35 | ~1.04 |
| 32 | 256 | 34 | ~1.22 | ~0.75 |
| 64 | 80 | 4 | ~1.43 | ~1.07 |
| 64 | 80 | 24 | ~1.16 | ~0.6 |
| 96 | 20 | 3 | ~1.21 | ~1.02 |
| 96 | 20 | 7 | ~1.14 | ~0.75 |
| 96 | 20 | 7 | ~0.94 | ~0.64 |
| 96 | 20 | 5 | ~0.82 | ~0.61 |
| 96 | 20 | 1 | ~0.69 | - |
| 96 | 20 | 3 | ~0.69 | ~0.59 |
| 96 | 20 | 3 | ~0.65 | ~0.56 |

## Learnings for v0.2

For v0.2, I plan on starting off with a sequence length equal to the attention block size, and the biggest batch I can manage. It helps with numeric stability (not getting 0 or infinity as a weight.) Then after I getting closer to trained, switching to longer and longer sequence lengths (multiples of the attention block size) to ensure it has used multiple blocks during training.

In [21]:
# Uncomment to test the model
interactive_interface()

This is a usage example. This LLM will completely make up things and be wrong. Don't follow its advice and don't expect good results. There is no warranty. Use at your own risk.
Loading...
Using cuda
Number of parameters: 120513182
Vocabulary size: 33438
Type a prompt that the model will write a story about. Type 'exit' to exit.
> write a story

write a story about a 65-year-old geologist named Emily who embarks on a solo expedition into the American Southwest, only to discover a massive underground complex that holds secrets from her family's past, and must navigate a treacherous web of family loyalty and precious valuable 
> exit
