Dan's todo list for tomorrow:

* Setup ML Flow for testing
* Look into whether Argmax is not the right choice
* Save progress after each epoch
* Train on bigger datasets - UPDATE, this does not seem to make a significant difference.

# Step 0a - Install dependencies

In [2]:
!pip install pandas numpy
!pip install torch torchvision torchaudio
!pip install datasets sentencepiece



Collecting datasets
  Using cached datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting sentencepiece
  Using cached sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (7.7 kB)
Collecting pyarrow>=12.0.0 (from datasets)
  Using cached pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.0 kB)
Collecting pyarrow-hotfix (from datasets)
  Using cached pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.19.0 (from datasets)
  Using cached requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.62.1 (from datasets)
  Using cached tqdm-4.66.2-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Using cached multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.

# Step 0b - Import module dependencies

In [1]:
import os
import random
from datasets import load_dataset
import sentencepiece as spm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


# Step 0c - Configs

In [7]:
sentencepiece_output_dir = 'sentencepiece_models'
sentencepiece_corpus_filename = f"tiny_stories_texts.txt"
sentencepiece_model_prefix = os.path.join(sentencepiece_output_dir, 'tiny_stories_spm_sampled')

story_token_max_length = 20

use_small_dataset = True
small_data_set_size = 10000

vocabulary_size = 8000

embedding_size = 256
num_decoder_layers = 6
num_heads = 8
forward_layer_expansion = 4
dropout = 0.1

learning_rate = 0.001

epochs = 25


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Running models on: {device}")

inference_debug_log_enabled = False

Running models on: cpu


# Step 0c - Load datasets

Read the tiny stories data set:

In [8]:
# Load the Tiny Stories dataset
dataset = load_dataset("roneneldan/TinyStories")

# Split the dataset into training and validation sets
train_dataset = dataset['train']
valid_dataset = dataset['validation']

train_stories = train_dataset['text']
valid_stories = valid_dataset['text']

print(f"Training stories set size (Pre resize): {len(train_stories)}")
print(f"Validation stories set size (Pre resize): {len(valid_stories)}")

if use_small_dataset is True:
    print("Using small datasets")
    train_stories = train_stories[:small_data_set_size]
    valid_stories = valid_stories[:small_data_set_size]

print(f"Training stories set size: {len(train_stories)}")
print(f"Validation stories set size: {len(valid_stories)}")


Repo card metadata block was not found. Setting CardData to empty.


Training stories set size (Pre resize): 2119719
Validation stories set size (Pre resize): 21990
Using small datasets
Training stories set size: 10000
Validation stories set size: 10000


Gather all of the data set and export it to a text file for training of the sentence piece model:

In [9]:
# Specify the directory where you want to save the files
if not os.path.exists(sentencepiece_output_dir):
    os.makedirs(sentencepiece_output_dir)

# Save all texts to a single file in the specified directory, one story per line
sentencepiece_corpus_file_path = os.path.join(sentencepiece_output_dir, sentencepiece_corpus_filename)


# Combine texts from training and validation sets
all_texts = train_dataset['text'] + valid_dataset['text']

random.shuffle(all_texts)

# Sample a smaller subset of the dataset, e.g., 10% of the data
sample_size = int(0.1 * len(all_texts))
sampled_text = all_texts[:sample_size]

# Save all texts to a single file, one story per line
with open(sentencepiece_corpus_file_path, 'w', encoding='utf-8') as f:
    for story in sampled_text:
        f.write(story + '\n')

Next generate the sentence piece model:

In [10]:
spm.SentencePieceTrainer.train(input=sentencepiece_corpus_file_path, model_prefix=sentencepiece_model_prefix, vocab_size=vocabulary_size, character_coverage=0.9995, model_type='unigram')

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: sentencepiece_models/tiny_stories_texts.txt
  input_format: 
  model_prefix: sentencepiece_models/tiny_stories_spm_sampled
  model_type: UNIGRAM
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piec

Next initialize the sentence piece model:

In [11]:
spm_model_path = f"{sentencepiece_model_prefix}.model"
sp = spm.SentencePieceProcessor(model_file=spm_model_path)

# Step X - Generate the input data and the labels

In [12]:
def prepare_data(stories, sp, max_length):
    inputs, labels = [], []
    bos_id, eos_id = sp.bos_id(), sp.eos_id()
    
    for story in stories:
        # Tokenize the story and truncate if necessary
        tokens = sp.encode(story, out_type=int)[:max_length -1]

        # Prepend BOS and append EOS token IDs
        input_ids = [bos_id] + tokens
        label_ids = tokens + [eos_id]  # Adjusted to ensure labels also start with bos_id and end with eos_id

        # Ensure the final lists are of max_length
        # This might already be ensured by previous steps, but double-checking to align with the assertion requirements
        input_ids = (input_ids + [eos_id] * max_length)[:max_length]  # Padding with eos_id if necessary, though this should be rare given earlier truncation
        label_ids = (label_ids + [eos_id] * max_length)[:max_length]

        # Assertions to ensure each sequence meets the specified criteria
        assert len(input_ids) == max_length, f"Input sequence length does not match max_length. Length: {len(input_ids)}"
        assert len(label_ids) == max_length, f"Label sequence length does not match max_length. Length: {len(label_ids)}"
        assert input_ids[0] == bos_id, "Input sequence does not start with bos_id."
        assert label_ids[-1] == eos_id, "Label sequence does not end with eos_id."

        inputs.append(input_ids)
        labels.append(label_ids)
    
    return inputs, labels

def assert_max_length(data, max_length):
    for entry in data:
        # Each entry should not exceed max_length tokens
        assert len(entry) <= max_length, f"Entry exceeds max_length of {max_length} tokens."


train_inputs, train_labels = prepare_data(train_stories, sp, story_token_max_length)
assert(len(train_inputs) == len(train_stories))
assert_max_length(train_inputs, story_token_max_length)
assert_max_length(train_labels, story_token_max_length)

print(train_inputs[0])
print(train_labels[0])

valid_inputs, valid_labels = prepare_data(valid_stories, sp, story_token_max_length)
assert(len(valid_inputs) == len(valid_stories))
assert_max_length(valid_inputs, story_token_max_length)
assert_max_length(valid_labels, story_token_max_length)

[1, 50, 26, 5, 8, 38, 58, 79, 24, 123, 8, 1901, 21, 13, 199, 3, 12, 167, 10, 9]
[50, 26, 5, 8, 38, 58, 79, 24, 123, 8, 1901, 21, 13, 199, 3, 12, 167, 10, 9, 2]


# Step X - Setup dataset

In [13]:
class TinyStoriesDataset(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)


# Step X - Transformer decoder

In [14]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(TransformerDecoderLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads = heads, dropout = dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion* embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attention_output, _ = self.attention(x, x, x, attn_mask=src_mask)
        x = self.dropout(self.norm1(attention_output + x))
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        return self.dropout(out)
    

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(TransformerDecoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_size, heads, forward_expansion, dropout)
            for _ in range(num_layers)
        ])

        self.fully_connected_layer_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, src_mask)

        out = self.fully_connected_layer_out(x)

        return out


# Step X - Train

In [17]:
model = TransformerDecoder(
    vocab_size = vocabulary_size,
    embed_size = embedding_size,
    num_layers = num_decoder_layers,
    heads = num_heads,
    device = device,
    forward_expansion = forward_layer_expansion,
    dropout = dropout,
    max_length = story_token_max_length
).to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

# Assuming `train_inputs` and `train_labels` are your processed datasets
train_dataset_processed = TinyStoriesDataset(train_inputs, train_labels)
train_loader = DataLoader(train_dataset_processed, batch_size=32, shuffle=True)

validation_dataset_processed = TinyStoriesDataset(valid_inputs, valid_labels)
validation_loader = DataLoader(validation_dataset_processed, batch_size=32, shuffle=False)


def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs, src_mask=None)
            loss = criterion(outputs.transpose(1,2), labels)
            val_loss += loss.item()
    return val_loss / len(loader)

def train(model, train_loader, valid_loader, optimizer, criterion, device):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs, src_mask = None)
            loss = criterion(outputs.transpose(1,2), labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        train_loss = epoch_loss / len(train_loader)
        val_loss = validate(model, valid_loader, criterion, device)

        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
        generated_text = generate_text_simple(model, "The ancient house", sp, device, story_token_max_length)
        print(f"Text generated after Epoch {epoch+1}: {generated_text}")  

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train(model, train_loader, validation_loader, optimizer, criterion, device)

The model has 8,847,680 trainable parameters
Epoch 1, Train Loss: 2.9769, Validation Loss: 2.3731
Text generated after Epoch 1: The ancient house , there was a little girl named Lily . She loved to play with her and loved
Epoch 2, Train Loss: 2.2877, Validation Loss: 2.2454
Text generated after Epoch 2: The ancient house , there was a little girl named Lily . She loved little girl named Lily . One
Epoch 3, Train Loss: 2.1558, Validation Loss: 2.2004
Text generated after Epoch 3: The ancient house , there was a little girl named . She was a little and loved to play with
Epoch 4, Train Loss: 2.0751, Validation Loss: 2.1956
Text generated after Epoch 4: The ancient house , there was a little girl . She loved to play with his to play with his
Epoch 5, Train Loss: 2.0109, Validation Loss: 2.2001
Text generated after Epoch 5: The ancient house and Ben . She was a little girl named Lily . She loved to play with his
Epoch 6, Train Loss: 1.9559, Validation Loss: 2.1867
Text generated after Epoc

# Step X - inference methods

In [16]:
repetition_threshold = 3

def generate_text_simple(model, start_prompt, sp, device, max_length):
    model.eval()
    words = start_prompt.split()
    token_ids = sp.encode(start_prompt, out_type=int)

    if max(token_ids) >= vocabulary_size:
        raise ValueError(f"Token ID {max(token_ids)} exceeds vocab size of {vocabulary_size}")

    consecutive_repetitions = 0
    last_token_id = None

    for _ in range(max_length):
        input_ids = torch.tensor([token_ids], device=device)

        if inference_debug_log_enabled:
            print(f"Input IDs: {input_ids}")
            print(f"Shape: {input_ids.shape}")

        with torch.no_grad():
            outputs = model(input_ids, src_mask=None)
            predictions = outputs[:, -1, :]
            predicted_id = torch.argmax(predictions, axis=-1).item()

        # Check for consecutive repetitions
        if predicted_id == last_token_id:
            consecutive_repetitions += 1
        else:
            consecutive_repetitions = 0  # Reset the counter if the current token is different

        last_token_id = predicted_id  # Update the last seen token ID

        # Exit if the repetition threshold is reached
        if consecutive_repetitions >= repetition_threshold:
            print(f"Stopping early due to repeated token ({predicted_id}) detected {repetition_threshold} times in a row.")
            break

        if predicted_id == sp.eos_id():
            break

        token_ids.append(predicted_id)
        generated_word = sp.DecodeIds([predicted_id])
        words.append(generated_word)

    generated_text = ' '.join(words)
    return generated_text

generated_text = generate_text_simple(model, "One day, a little girl", sp, device, story_token_max_length)
print(generated_text)

One day, a little girl named Timmy . She was a time . She loved to play with her


# Debuggin notes

## Things to explore

1. [Doing inference with strings from the dataset](#Doing-inference-with-strings-from-the-dataset)
1. Do I understand my architecture?
1. [Is my model too simple? (Under fitting)](#Is-my-model-too-simple?-(Under-fitting))
1. [Is my model overfitting?](#Is-my-model-overfitting?)
1. [Learning rate optimizations]()
1. [Sequence Length Handling]()
1. [Repetition Penalty]()


## Investigations

### Doing inference with strings from the dataset

I tried to use the "One day, a little girl" string from the first entry in the dataset to see it would perform better with an input string it had already seen. But this didn't do much.

Here's the output after:

```
Epoch 1, Train Loss: 1.1776, Validation Loss: 0.3109
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 1: The ancient house house house house
Epoch 2, Train Loss: 0.1383, Validation Loss: 0.0950
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 2: The ancient house house house house
Epoch 3, Train Loss: 0.0280, Validation Loss: 0.0669
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 3: The ancient house house house house

...
```

### Is my model too simple? (Under fitting)

TODO: Update with investigation.

Notes from chat GPT: 

`Underfitting: If the model is too simple, it might not have learned the underlying patterns in the data sufficiently. Consider increasing the model complexity by adding more layers or increasing the embedding size.`

### Is my model overfitting?

#### 

TODO: Investigate regularization techniques if needed.

`Overfitting: If the model has memorized the training data rather than learning to generalize, it may perform poorly on slightly different inputs or validation data. Regularization techniques (e.g., dropout, weight decay) or more training data could help.`

#### Examination of training and validation lost.

Given the following training output:

```
Epoch 1, Train Loss: 1.1776, Validation Loss: 0.3109
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 1: The ancient house house house house
Epoch 2, Train Loss: 0.1383, Validation Loss: 0.0950
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 2: The ancient house house house house
Epoch 3, Train Loss: 0.0280, Validation Loss: 0.0669
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 3: The ancient house house house house
Epoch 4, Train Loss: 0.0032, Validation Loss: 0.0638
Stopping early due to repeated token (180) detected 3 times in a row.
Text generated after Epoch 4: The ancient house house house house
Epoch 5, Train Loss: 0.0008, Validation Loss: 0.0648
```

We can see that the training loss is going down, while the validation loss stagnates, but The problem of predicting the same name over and over from the first epoch already, overfitting is most likely not the issue in the first case.


### Learning rate optimizations

TODO: investigate

From chatGPT
```
The choice of learning_rate can significantly affect training. Too high a learning rate can cause the model to converge too quickly to a suboptimal solution, while too low a rate can slow down training or cause it to stall.
Solution: Consider using a learning rate scheduler to adjust the rate during training. torch.optim.lr_scheduler provides several options, such as StepLR or ReduceLROnPlateau, which can help improve training dynamics.
```

### Sequence Length Handling

From Chatgpt:

```
The fixed story_token_max_length determines how much context the model considers during training and inference. If this length is not optimally chosen, it could impact the model's ability to generate coherent text.
Solution: Experiment with different sequence lengths. Also, ensure that your padding and truncation strategies during preprocessing align with the model's expectations.
```

### Repetition Penalty

From chatgpt

```
Your method to stop generating text after a certain number of repeated tokens is a practical approach to handle repetitive output. However, this doesn't address the underlying cause of why the model prefers these repetitions.
Solution: Consider implementing more nuanced sampling methods during inference, such as top-k or top-p (nucleus) sampling, which can encourage diversity in the generated text.
```