# Step 0a - Install dependencies

In [2]:
!pip install pandas numpy
!pip install torch torchvision torchaudio
!pip install datasets sentencepiece



Collecting datasets
  Using cached datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting sentencepiece
  Using cached sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (7.7 kB)
Collecting pyarrow>=12.0.0 (from datasets)
  Using cached pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.0 kB)
Collecting pyarrow-hotfix (from datasets)
  Using cached pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.19.0 (from datasets)
  Using cached requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.62.1 (from datasets)
  Using cached tqdm-4.66.2-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Using cached multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.

# Step 0b - Import module dependencies

In [45]:
import os
import random
from datasets import load_dataset
import sentencepiece as spm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

# Step 0c - Configs

In [77]:
sentencepiece_output_dir = 'sentencepiece_models'
sentencepiece_corpus_filename = f"tiny_stories_texts.txt"
sentencepiece_model_prefix = os.path.join(sentencepiece_output_dir, 'tiny_stories_spm_sampled')

story_token_max_length = 20

use_small_dataset = True
small_data_set_size = 100

vocabulary_size = 8000

embedding_size = 256
num_decoder_layers = 6
num_heads = 8
forward_layer_expansion = 4
dropout = 0.1

learning_rate = 0.001

epochs = 100


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Running models on: {device}")

Running models on: cpu


# Step 0c - Load datasets

Read the tiny stories data set:

In [78]:
# Load the Tiny Stories dataset
dataset = load_dataset("roneneldan/TinyStories")

# Split the dataset into training and validation sets
train_dataset = dataset['train']
valid_dataset = dataset['validation']

train_stories = train_dataset['text']
valid_stories = valid_dataset['text']

if use_small_dataset is True:
    print("Using small datasets")
    train_stories = train_stories[:small_data_set_size]
    valid_stories = valid_stories[:small_data_set_size]

print(f"Training stories set size: {len(train_stories)}")
print(f"Valisation stories set size: {len(valid_stories)}")


Repo card metadata block was not found. Setting CardData to empty.


Using small datasets
Training stories set size: 100
Valisation stories set size: 100


Gather all of the data set and export it to a text file for training of the sentence piece model:

In [79]:
# Specify the directory where you want to save the files
if not os.path.exists(sentencepiece_output_dir):
    os.makedirs(sentencepiece_output_dir)

# Save all texts to a single file in the specified directory, one story per line
sentencepiece_corpus_file_path = os.path.join(sentencepiece_output_dir, sentencepiece_corpus_filename)


# Combine texts from training and validation sets
all_texts = train_dataset['text'] + valid_dataset['text']

random.shuffle(all_texts)

# Sample a smaller subset of the dataset, e.g., 10% of the data
sample_size = int(0.1 * len(all_texts))
sampled_text = all_texts[:sample_size]

# Save all texts to a single file, one story per line
with open(sentencepiece_corpus_file_path, 'w', encoding='utf-8') as f:
    for story in sampled_text:
        f.write(story + '\n')

Next generate the sentence piece model:

In [81]:
spm.SentencePieceTrainer.train(input=sentencepiece_corpus_file_path, model_prefix=sentencepiece_model_prefix, vocab_size=vocabulary_size, character_coverage=0.9995, model_type='unigram')

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: sentencepiece_models/tiny_stories_texts.txt
  input_format: 
  model_prefix: sentencepiece_models/tiny_stories_spm_sampled
  model_type: UNIGRAM
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piec

Next initialize the sentence piece model:

In [82]:
spm_model_path = f"{sentencepiece_model_prefix}.model"
sp = spm.SentencePieceProcessor(model_file=spm_model_path)

# Step X - Generate the input data and the labels

In [83]:
def prepare_data(stories, sp, max_length):
    inputs, labels = [], []
    bos_id, eos_id = sp.bos_id(), sp.eos_id()
    
    for story in stories:
        # Tokenize the story and truncate if necessary
        tokens = sp.encode(story, out_type=int)[:max_length - 2]

        # Prepend BOS and append EOS token IDs
        input_ids = [bos_id] + tokens + [eos_id]
        label_ids = [bos_id] + tokens + [eos_id]  # Adjusted to ensure labels also start with bos_id and end with eos_id

        # Ensure the final lists are of max_length
        # This might already be ensured by previous steps, but double-checking to align with the assertion requirements
        input_ids = (input_ids + [eos_id] * max_length)[:max_length]  # Padding with eos_id if necessary, though this should be rare given earlier truncation
        label_ids = (label_ids + [eos_id] * max_length)[:max_length]

        # Assertions to ensure each sequence meets the specified criteria
        assert len(input_ids) == max_length, f"Input sequence length does not match max_length. Length: {len(input_ids)}"
        assert len(label_ids) == max_length, f"Label sequence length does not match max_length. Length: {len(label_ids)}"
        assert input_ids[0] == bos_id, "Input sequence does not start with bos_id."
        assert label_ids[0] == bos_id, "Label sequence does not start with bos_id."
        assert input_ids[-1] == eos_id, "Input sequence does not end with eos_id."
        assert label_ids[-1] == eos_id, "Label sequence does not end with eos_id."

        inputs.append(input_ids)
        labels.append(label_ids)
    
    return inputs, labels


def assert_max_length(data, max_length):
    for entry in data:
        # Each entry should not exceed max_length tokens
        assert len(entry) <= max_length, f"Entry exceeds max_length of {max_length} tokens."


train_inputs, train_labels = prepare_data(train_stories, sp, story_token_max_length)
assert(len(train_inputs) == len(train_stories))
assert_max_length(train_inputs, story_token_max_length)
assert_max_length(train_labels, story_token_max_length)

valid_inputs, valid_labels = prepare_data(valid_stories, sp, story_token_max_length)
assert(len(valid_inputs) == len(valid_stories))
assert_max_length(valid_inputs, story_token_max_length)
assert_max_length(valid_labels, story_token_max_length)

# Step X - Setup dataset

In [84]:
class TinyStoriesDataset(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)


# Step X - Transformer decoder

In [85]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(TransformerDecoderLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads = heads, dropout = dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion* embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        attention_output, _ = self.attention(x, x, x, attn_mask=src_mask)
        x = self.dropout(self.norm1(attention_output + x))
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        return self.dropout(out)
    

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(TransformerDecoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_size, heads, forward_expansion, dropout)
            for _ in range(num_layers)
        ])

        self.fully_connected_layer_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, src_mask)

        out = self.fully_connected_layer_out(x)

        return out


# Step X - Train

In [86]:
model = TransformerDecoder(
    vocab_size = vocabulary_size,
    embed_size = embedding_size,
    num_layers = num_decoder_layers,
    heads = num_heads,
    device = device,
    forward_expansion = forward_layer_expansion,
    dropout = dropout,
    max_length = story_token_max_length
).to(device)

# Assuming `train_inputs` and `train_labels` are your processed datasets
train_dataset_processed = TinyStoriesDataset(train_inputs, train_labels)
train_loader = DataLoader(train_dataset_processed, batch_size=32, shuffle=True)

validation_dataset_processed = TinyStoriesDataset(valid_inputs, valid_labels)
validation_loader = DataLoader(validation_dataset_processed, batch_size=32, shuffle=False)


def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs, src_mask=None)
            loss = criterion(outputs.transpose(1,2), labels)
            val_loss += loss.item()
    return val_loss / len(loader)

def train(model, train_loader, valid_loader, optimizer, criterion, device):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs, src_mask = None)
            loss = criterion(outputs.transpose(1,2), labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        train_loss = epoch_loss / len(train_loader)
        val_loss = validate(model, valid_loader, criterion, device)

        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')  

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train(model, train_loader, validation_loader, optimizer, criterion, device)

Epoch 1, Train Loss: 7.2935, Validation Loss: 4.4145
Epoch 2, Train Loss: 4.4164, Validation Loss: 3.1123
Epoch 3, Train Loss: 3.2259, Validation Loss: 2.4775
Epoch 4, Train Loss: 2.5634, Validation Loss: 1.8930
Epoch 5, Train Loss: 1.9534, Validation Loss: 1.6019
Epoch 6, Train Loss: 1.5956, Validation Loss: 1.3947
Epoch 7, Train Loss: 1.1596, Validation Loss: 1.2515
Epoch 8, Train Loss: 0.9121, Validation Loss: 1.1474
Epoch 9, Train Loss: 0.7144, Validation Loss: 1.0663
Epoch 10, Train Loss: 0.6498, Validation Loss: 1.0226
Epoch 11, Train Loss: 0.3975, Validation Loss: 0.9799
Epoch 12, Train Loss: 0.3428, Validation Loss: 0.9517
Epoch 13, Train Loss: 0.2278, Validation Loss: 0.9381
Epoch 14, Train Loss: 0.1916, Validation Loss: 0.9047
Epoch 15, Train Loss: 0.1129, Validation Loss: 0.8962
Epoch 16, Train Loss: 0.0824, Validation Loss: 0.8926
Epoch 17, Train Loss: 0.0631, Validation Loss: 0.8841
Epoch 18, Train Loss: 0.0490, Validation Loss: 0.8789
Epoch 19, Train Loss: 0.0380, Validat

# Step X - inference methods

In [99]:
repetition_threshold = 3

def generate_text_simple(model, start_prompt, sp, device, max_length):
    model.eval()
    words = start_prompt.split()
    token_ids = sp.encode(start_prompt, out_type=int)

    if max(token_ids) >= vocabulary_size:
        raise ValueError(f"Token ID {max(token_ids)} exceeds vocab size of {vocabulary_size}")

    consecutive_repetitions = 0
    last_token_id = None

    for _ in range(max_length):
        input_ids = torch.tensor([token_ids], device=device)

        print(f"Input IDs: {input_ids}")
        print(f"Shape: {input_ids.shape}")

        with torch.no_grad():
            outputs = model(input_ids, src_mask=None)
            predictions = outputs[:, -1, :]
            predicted_id = torch.argmax(predictions, axis=-1).item()

        # Check for consecutive repetitions
        if predicted_id == last_token_id:
            consecutive_repetitions += 1
        else:
            consecutive_repetitions = 0  # Reset the counter if the current token is different

        last_token_id = predicted_id  # Update the last seen token ID

        # Exit if the repetition threshold is reached
        if consecutive_repetitions >= repetition_threshold:
            print(f"Stopping early due to repeated token ({predicted_id}) detected {repetition_threshold} times in a row.")
            break

        if predicted_id == sp.eos_id():
            break

        token_ids.append(predicted_id)
        generated_word = sp.DecodeIds([predicted_id])
        words.append(generated_word)

    generated_text = ' '.join(words)
    return generated_text

generated_text = generate_text_simple(model, "The ancient castle", sp, device, story_token_max_length)
print(generated_text)

Input IDs: tensor([[  17, 1534,  604]])
Shape: torch.Size([1, 3])
Input IDs: tensor([[  17, 1534,  604,    4]])
Shape: torch.Size([1, 4])
Input IDs: tensor([[  17, 1534,  604,    4,    4]])
Shape: torch.Size([1, 5])
Input IDs: tensor([[  17, 1534,  604,    4,    4,    4]])
Shape: torch.Size([1, 6])
Stopping early due to repeated token (4) detected 3 times in a row.
The ancient castle and and and
