# Step 0a - Install dependencies

In [2]:
!pip install pandas numpy
!pip install torch torchvision torchaudio
!pip install datasets sentencepiece



Collecting datasets
  Using cached datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting sentencepiece
  Using cached sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (7.7 kB)
Collecting pyarrow>=12.0.0 (from datasets)
  Using cached pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.0 kB)
Collecting pyarrow-hotfix (from datasets)
  Using cached pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.19.0 (from datasets)
  Using cached requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.62.1 (from datasets)
  Using cached tqdm-4.66.2-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Using cached multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.

# Step 0b - Import module dependencies

In [45]:
import os
import random
from datasets import load_dataset
import sentencepiece as spm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

# Step 0c - Constants

In [24]:
sentencepiece_output_dir = 'sentencepiece_models'
sentencepiece_corpus_filename = f"tiny_stories_texts.txt"
sentencepiece_model_prefix = os.path.join(sentencepiece_output_dir, 'tiny_stories_spm_sampled')

story_token_max_length = 20

use_small_dataset = True
small_data_set_size = 10

# Step 0c - Load datasets

Read the tiny stories data set:

In [32]:
# Load the Tiny Stories dataset
dataset = load_dataset("roneneldan/TinyStories")

# Split the dataset into training and validation sets
train_dataset = dataset['train']
valid_dataset = dataset['validation']

train_stories = train_dataset['text']
valid_stories = valid_dataset['text']

if use_small_dataset is True:
    print("Using small datasets")
    train_stories = train_stories[:small_data_set_size]
    valid_stories = valid_stories[:small_data_set_size]

print(f"Training stories set size: {len(train_stories)}")
print(f"Valisation stories set size: {len(valid_stories)}")


Repo card metadata block was not found. Setting CardData to empty.


Using small datasets
Training stories set size: 10
Valisation stories set size: 10


Gather all of the data set and export it to a text file for training of the sentence piece model:

In [15]:
# Specify the directory where you want to save the files
if not os.path.exists(sentencepiece_output_dir):
    os.makedirs(sentencepiece_output_dir)

# Save all texts to a single file in the specified directory, one story per line
sentencepiece_corpus_file_path = os.path.join(sentencepiece_output_dir, sentencepiece_corpus_filename)


# Combine texts from training and validation sets
all_texts = train_dataset['text'] + valid_dataset['text']

random.shuffle(all_texts)

# Sample a smaller subset of the dataset, e.g., 10% of the data
sample_size = int(0.1 * len(all_texts))
sampled_text = all_texts[:sample_size]

# Save all texts to a single file, one story per line
with open(sentencepiece_corpus_file_path, 'w', encoding='utf-8') as f:
    for story in sampled_text:
        f.write(story + '\n')

Next generate the sentence piece model:

In [19]:
spm.SentencePieceTrainer.train(input=sentencepiece_corpus_file_path, model_prefix=sentencepiece_model_prefix, vocab_size=8000, character_coverage=0.9995, model_type='unigram')

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: sentencepiece_models/tiny_stories_texts.txt
  input_format: 
  model_prefix: sentencepiece_models/tiny_stories_spm_sampled
  model_type: UNIGRAM
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piec

Next initialize the sentence piece model:

In [21]:
spm_model_path = f"{sentencepiece_model_prefix}.model"
sp = spm.SentencePieceProcessor(model_file=spm_model_path)

# Step X - Generate the input data and the labels

In [42]:
def prepare_data(stories, sp, max_length):
    inputs, labels = [], []
    bos_id, eos_id = sp.bos_id(), sp.eos_id()
    
    for story in stories:
        # Tokenize the story and truncate if necessary
        tokens = sp.encode(story, out_type=int)[:max_length - 2]

        # Prepend BOS and append EOS token IDs
        input_ids = [bos_id] + tokens + [eos_id]
        label_ids = [bos_id] + tokens + [eos_id]  # Adjusted to ensure labels also start with bos_id and end with eos_id

        # Ensure the final lists are of max_length
        # This might already be ensured by previous steps, but double-checking to align with the assertion requirements
        input_ids = (input_ids + [eos_id] * max_length)[:max_length]  # Padding with eos_id if necessary, though this should be rare given earlier truncation
        label_ids = (label_ids + [eos_id] * max_length)[:max_length]

        # Assertions to ensure each sequence meets the specified criteria
        assert len(input_ids) == max_length, f"Input sequence length does not match max_length. Length: {len(input_ids)}"
        assert len(label_ids) == max_length, f"Label sequence length does not match max_length. Length: {len(label_ids)}"
        assert input_ids[0] == bos_id, "Input sequence does not start with bos_id."
        assert label_ids[0] == bos_id, "Label sequence does not start with bos_id."
        assert input_ids[-1] == eos_id, "Input sequence does not end with eos_id."
        assert label_ids[-1] == eos_id, "Label sequence does not end with eos_id."

        inputs.append(input_ids)
        labels.append(label_ids)
    
    return inputs, labels


def assert_max_length(data, max_length):
    for entry in data:
        # Each entry should not exceed max_length tokens
        assert len(entry) <= max_length, f"Entry exceeds max_length of {max_length} tokens."


train_inputs, train_labels = prepare_data(train_stories, sp, story_token_max_length)
assert(len(train_inputs) == len(train_stories))
assert_max_length(train_inputs, story_token_max_length)
assert_max_length(train_labels, story_token_max_length)

valid_inputs, valid_labels = prepare_data(valid_stories, sp, story_token_max_length)
assert(len(valid_inputs) == len(valid_stories))
assert_max_length(valid_inputs, story_token_max_length)
assert_max_length(valid_labels, story_token_max_length)

[[1, 50, 26, 5, 8, 38, 59, 81, 24, 120, 8, 2001, 21, 13, 198, 3, 12, 168, 10, 2], [1, 56, 60, 8, 37, 5, 39, 9, 8, 38, 169, 81, 5341, 3, 5341, 82, 7, 69, 287, 2], [1, 50, 26, 5, 8, 38, 326, 81, 2004, 9, 917, 447, 6, 1211, 3, 14, 47, 8, 48, 2], [1, 56, 60, 8, 37, 5, 21, 8, 1135, 475, 28, 489, 5, 39, 9, 8, 38, 2938, 156, 2], [1, 56, 60, 8, 37, 5, 39, 9, 8, 38, 59, 81, 24, 3, 24, 129, 7, 651, 22, 2], [1, 56, 60, 8, 37, 5, 21, 8, 48, 643, 5, 39, 9, 8, 753, 2571, 3, 17, 753, 2], [1, 56, 60, 8, 37, 5, 21, 8, 216, 755, 5, 39, 9, 8, 2010, 38, 59, 81, 24, 2], [1, 56, 60, 8, 37, 5, 21, 8, 1109, 755, 5, 39, 249, 8, 38, 91, 81, 119, 3, 2], [1, 56, 60, 8, 37, 5, 39, 9, 8, 964, 38, 142, 81, 149, 3, 149, 82, 7, 286, 2], [1, 50, 26, 5, 8, 287, 887, 81, 119, 70, 34, 8, 630, 21, 25, 260, 169, 3, 14, 2]]


# Step X - Setup dataset

In [44]:
class TinyStoriesDataset(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)

# Assuming `train_inputs` and `train_labels` are your processed datasets
train_dataset = TinyStoriesDataset(train_inputs, train_labels)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# Step X - Transformer decoder

In [47]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(TransformerDecoderLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads = heads, dropout = dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion* embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, add_norm_x, src_mask):
        attention_output = self.attention(add_norm_x, add_norm_x, add_norm_x, attn_mask=src_mask)
        add_norm_x = self.dropout(self.norm1(attention_output + add_norm_x))
        forward = self.feed_forward(add_norm_x)
        out = self.norm2(forward + x)
        return self.dropout(out)
    

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(TransformerDecoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_size, heads, forward_expansion, dropout)
            for _ in range(num_layers)
        ])

        self.fully_connected_layer_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, src_mask)

        out = self.fully_connected_layer_out(x)

        return out
